style: apply black and isort formatting across entire codebase
- Standardize quote style (single to double quotes) - Reorder and group imports alphabetically - Fix line breaks and indentation for consistency - Apply PEP 8 formatting standards Also updated Makefile to exclude both venv and .venv from code quality checks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -2,9 +2,9 @@
|
||||
"""Database backup utility that uses project configuration."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@@ -19,10 +19,10 @@ def get_database_path():
|
||||
"""Extract database path from DATABASE_URL."""
|
||||
db_url = settings.database_url
|
||||
|
||||
if db_url.startswith('sqlite:///'):
|
||||
if db_url.startswith("sqlite:///"):
|
||||
# Remove sqlite:/// prefix and handle relative paths
|
||||
db_path = db_url.replace('sqlite:///', '')
|
||||
if db_path.startswith('./'):
|
||||
db_path = db_url.replace("sqlite:///", "")
|
||||
if db_path.startswith("./"):
|
||||
db_path = db_path[2:] # Remove ./ prefix
|
||||
return db_path
|
||||
else:
|
||||
@@ -76,7 +76,7 @@ def backup_database():
|
||||
print("[BACKUP] Starting database backup...")
|
||||
print(f"[INFO] Database URL: {settings.database_url}")
|
||||
|
||||
if settings.database_url.startswith('sqlite'):
|
||||
if settings.database_url.startswith("sqlite"):
|
||||
return backup_sqlite_database()
|
||||
else:
|
||||
print("[INFO] For PostgreSQL databases, use pg_dump:")
|
||||
|
||||
@@ -26,20 +26,19 @@ Usage:
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from models.database.content_page import ContentPage
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DEFAULT PAGE CONTENT
|
||||
# ============================================================================
|
||||
@@ -458,6 +457,7 @@ DEFAULT_PAGES = [
|
||||
# SCRIPT FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_default_pages(db: Session) -> None:
|
||||
"""
|
||||
Create default platform content pages.
|
||||
@@ -475,13 +475,14 @@ def create_default_pages(db: Session) -> None:
|
||||
# Check if page already exists (platform default with this slug)
|
||||
existing = db.execute(
|
||||
select(ContentPage).where(
|
||||
ContentPage.vendor_id == None,
|
||||
ContentPage.slug == page_data["slug"]
|
||||
ContentPage.vendor_id == None, ContentPage.slug == page_data["slug"]
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
print(f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists")
|
||||
print(
|
||||
f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
@@ -519,7 +520,9 @@ def create_default_pages(db: Session) -> None:
|
||||
if created_count > 0:
|
||||
print("✅ Default platform content pages created successfully!\n")
|
||||
print("Next steps:")
|
||||
print(" 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms")
|
||||
print(
|
||||
" 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms"
|
||||
)
|
||||
print(" 2. Vendors can override these pages through the vendor dashboard")
|
||||
print(" 3. Edit platform defaults through the admin panel\n")
|
||||
else:
|
||||
@@ -530,6 +533,7 @@ def create_default_pages(db: Session) -> None:
|
||||
# MAIN EXECUTION
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main():
|
||||
"""Main execution function."""
|
||||
print("\n🚀 Starting Default Content Pages Creation Script...\n")
|
||||
|
||||
@@ -12,17 +12,18 @@ from pathlib import Path
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime, UTC
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Get database URL
|
||||
database_url = os.getenv('DATABASE_URL', 'wizamart.db')
|
||||
db_path = database_url.replace('sqlite:///', '')
|
||||
database_url = os.getenv("DATABASE_URL", "wizamart.db")
|
||||
db_path = database_url.replace("sqlite:///", "")
|
||||
|
||||
print(f"📦 Creating inventory entries in {db_path}...")
|
||||
|
||||
@@ -30,12 +31,14 @@ conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get products without inventory
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT p.id, p.vendor_id, p.product_id
|
||||
FROM products p
|
||||
LEFT JOIN inventory i ON p.id = i.product_id
|
||||
WHERE i.id IS NULL
|
||||
""")
|
||||
"""
|
||||
)
|
||||
products_without_inventory = cursor.fetchall()
|
||||
|
||||
if not products_without_inventory:
|
||||
@@ -47,7 +50,8 @@ print(f"📦 Creating inventory for {len(products_without_inventory)} products..
|
||||
|
||||
# Create inventory entries
|
||||
for product_id, vendor_id, sku in products_without_inventory:
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO inventory (
|
||||
vendor_id,
|
||||
product_id,
|
||||
@@ -57,15 +61,17 @@ for product_id, vendor_id, sku in products_without_inventory:
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
vendor_id,
|
||||
product_id,
|
||||
'Main Warehouse',
|
||||
100, # Total quantity
|
||||
0, # Reserved quantity
|
||||
datetime.now(UTC),
|
||||
datetime.now(UTC)
|
||||
))
|
||||
""",
|
||||
(
|
||||
vendor_id,
|
||||
product_id,
|
||||
"Main Warehouse",
|
||||
100, # Total quantity
|
||||
0, # Reserved quantity
|
||||
datetime.now(UTC),
|
||||
datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -76,7 +82,9 @@ cursor.execute("SELECT COUNT(*) FROM inventory")
|
||||
total_count = cursor.fetchone()[0]
|
||||
print(f"\n📊 Total inventory entries: {total_count}")
|
||||
|
||||
cursor.execute("SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5")
|
||||
cursor.execute(
|
||||
"SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5"
|
||||
)
|
||||
print("\n📦 Sample inventory:")
|
||||
for row in cursor.fetchall():
|
||||
available = row[2] - row[3]
|
||||
|
||||
@@ -12,18 +12,20 @@ from pathlib import Path
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.database import SessionLocal
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.content_page import ContentPage
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from models.database.content_page import ContentPage
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
|
||||
def create_landing_page(
|
||||
vendor_subdomain: str,
|
||||
template: str = "default",
|
||||
title: str = None,
|
||||
content: str = None
|
||||
content: str = None,
|
||||
):
|
||||
"""
|
||||
Create a landing page for a vendor.
|
||||
@@ -38,9 +40,7 @@ def create_landing_page(
|
||||
|
||||
try:
|
||||
# Find vendor
|
||||
vendor = db.query(Vendor).filter(
|
||||
Vendor.subdomain == vendor_subdomain
|
||||
).first()
|
||||
vendor = db.query(Vendor).filter(Vendor.subdomain == vendor_subdomain).first()
|
||||
|
||||
if not vendor:
|
||||
print(f"❌ Vendor '{vendor_subdomain}' not found!")
|
||||
@@ -49,10 +49,11 @@ def create_landing_page(
|
||||
print(f"✅ Found vendor: {vendor.name} (ID: {vendor.id})")
|
||||
|
||||
# Check if landing page already exists
|
||||
existing = db.query(ContentPage).filter(
|
||||
ContentPage.vendor_id == vendor.id,
|
||||
ContentPage.slug == "landing"
|
||||
).first()
|
||||
existing = (
|
||||
db.query(ContentPage)
|
||||
.filter(ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing")
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
print(f"⚠️ Landing page already exists (ID: {existing.id})")
|
||||
@@ -74,7 +75,8 @@ def create_landing_page(
|
||||
vendor_id=vendor.id,
|
||||
slug="landing",
|
||||
title=title or f"Welcome to {vendor.name}",
|
||||
content=content or f"""
|
||||
content=content
|
||||
or f"""
|
||||
<h2>About {vendor.name}</h2>
|
||||
<p>{vendor.description or 'Your trusted shopping destination for quality products.'}</p>
|
||||
|
||||
@@ -96,7 +98,7 @@ def create_landing_page(
|
||||
published_at=datetime.now(timezone.utc),
|
||||
show_in_footer=False,
|
||||
show_in_header=False,
|
||||
display_order=0
|
||||
display_order=0,
|
||||
)
|
||||
|
||||
db.add(landing_page)
|
||||
@@ -141,10 +143,13 @@ def list_vendors():
|
||||
print(f" Code: {vendor.vendor_code}")
|
||||
|
||||
# Check if has landing page
|
||||
landing = db.query(ContentPage).filter(
|
||||
ContentPage.vendor_id == vendor.id,
|
||||
ContentPage.slug == "landing"
|
||||
).first()
|
||||
landing = (
|
||||
db.query(ContentPage)
|
||||
.filter(
|
||||
ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing"
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if landing:
|
||||
print(f" Landing Page: ✅ ({landing.template})")
|
||||
@@ -165,7 +170,7 @@ def show_templates():
|
||||
("default", "Clean professional layout with 3-column quick links"),
|
||||
("minimal", "Ultra-simple centered design with single CTA"),
|
||||
("modern", "Full-screen hero with animations and features"),
|
||||
("full", "Maximum features with split-screen hero and stats")
|
||||
("full", "Maximum features with split-screen hero and stats"),
|
||||
]
|
||||
|
||||
for name, desc in templates:
|
||||
@@ -209,7 +214,7 @@ if __name__ == "__main__":
|
||||
success = create_landing_page(
|
||||
vendor_subdomain=vendor_subdomain,
|
||||
template=template,
|
||||
title=title if title else None
|
||||
title=title if title else None,
|
||||
)
|
||||
|
||||
if success:
|
||||
|
||||
@@ -46,16 +46,22 @@ def create_platform_pages():
|
||||
print("1. Creating Platform Homepage...")
|
||||
|
||||
# Check if already exists
|
||||
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="platform_homepage").first()
|
||||
existing = (
|
||||
db.query(ContentPage)
|
||||
.filter_by(vendor_id=None, slug="platform_homepage")
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
print(f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})")
|
||||
print(
|
||||
f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
homepage = content_page_service.create_page(
|
||||
db,
|
||||
slug="platform_homepage",
|
||||
title="Welcome to Our Multi-Vendor Marketplace",
|
||||
content="""
|
||||
db,
|
||||
slug="platform_homepage",
|
||||
title="Welcome to Our Multi-Vendor Marketplace",
|
||||
content="""
|
||||
<p class="lead">
|
||||
Connect vendors with customers worldwide. Build your online store and reach millions of shoppers.
|
||||
</p>
|
||||
@@ -64,14 +70,14 @@ def create_platform_pages():
|
||||
with minimal effort and maximum impact.
|
||||
</p>
|
||||
""",
|
||||
template="modern", # Uses platform/homepage-modern.html
|
||||
vendor_id=None, # Platform-level page
|
||||
is_published=True,
|
||||
show_in_header=False, # Homepage is not in menu (it's the root)
|
||||
show_in_footer=False,
|
||||
display_order=0,
|
||||
template="modern", # Uses platform/homepage-modern.html
|
||||
vendor_id=None, # Platform-level page
|
||||
is_published=True,
|
||||
show_in_header=False, # Homepage is not in menu (it's the root)
|
||||
show_in_footer=False,
|
||||
display_order=0,
|
||||
meta_description="Leading multi-vendor marketplace platform. Connect with thousands of vendors and discover millions of products.",
|
||||
meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform"
|
||||
meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform",
|
||||
)
|
||||
print(f" ✅ Created: {homepage.title} (/{homepage.slug})")
|
||||
except Exception as e:
|
||||
@@ -88,10 +94,10 @@ def create_platform_pages():
|
||||
else:
|
||||
try:
|
||||
about = content_page_service.create_page(
|
||||
db,
|
||||
slug="about",
|
||||
title="About Us",
|
||||
content="""
|
||||
db,
|
||||
slug="about",
|
||||
title="About Us",
|
||||
content="""
|
||||
<h2>Our Mission</h2>
|
||||
<p>
|
||||
We're on a mission to democratize e-commerce by providing powerful,
|
||||
@@ -121,13 +127,13 @@ def create_platform_pages():
|
||||
<li><strong>Excellence:</strong> We strive for the highest quality in everything we do</li>
|
||||
</ul>
|
||||
""",
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=True, # Show in header navigation
|
||||
show_in_footer=True, # Show in footer
|
||||
display_order=1,
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=True, # Show in header navigation
|
||||
show_in_footer=True, # Show in footer
|
||||
display_order=1,
|
||||
meta_description="Learn about our mission to democratize e-commerce and empower entrepreneurs worldwide.",
|
||||
meta_keywords="about us, mission, vision, values, company"
|
||||
meta_keywords="about us, mission, vision, values, company",
|
||||
)
|
||||
print(f" ✅ Created: {about.title} (/{about.slug})")
|
||||
except Exception as e:
|
||||
@@ -144,10 +150,10 @@ def create_platform_pages():
|
||||
else:
|
||||
try:
|
||||
faq = content_page_service.create_page(
|
||||
db,
|
||||
slug="faq",
|
||||
title="Frequently Asked Questions",
|
||||
content="""
|
||||
db,
|
||||
slug="faq",
|
||||
title="Frequently Asked Questions",
|
||||
content="""
|
||||
<h2>Getting Started</h2>
|
||||
|
||||
<h3>How do I create a vendor account?</h3>
|
||||
@@ -204,13 +210,13 @@ def create_platform_pages():
|
||||
and marketing tools.
|
||||
</p>
|
||||
""",
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=True, # Show in header navigation
|
||||
show_in_footer=True,
|
||||
display_order=2,
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=True, # Show in header navigation
|
||||
show_in_footer=True,
|
||||
display_order=2,
|
||||
meta_description="Find answers to common questions about our marketplace platform.",
|
||||
meta_keywords="faq, frequently asked questions, help, support"
|
||||
meta_keywords="faq, frequently asked questions, help, support",
|
||||
)
|
||||
print(f" ✅ Created: {faq.title} (/{faq.slug})")
|
||||
except Exception as e:
|
||||
@@ -221,16 +227,18 @@ def create_platform_pages():
|
||||
# ========================================================================
|
||||
print("4. Creating Contact Us page...")
|
||||
|
||||
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first()
|
||||
existing = (
|
||||
db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first()
|
||||
)
|
||||
if existing:
|
||||
print(f" ⚠️ Skipped: Contact Us - already exists (ID: {existing.id})")
|
||||
else:
|
||||
try:
|
||||
contact = content_page_service.create_page(
|
||||
db,
|
||||
slug="contact",
|
||||
title="Contact Us",
|
||||
content="""
|
||||
db,
|
||||
slug="contact",
|
||||
title="Contact Us",
|
||||
content="""
|
||||
<h2>Get in Touch</h2>
|
||||
<p>
|
||||
We'd love to hear from you! Whether you have questions about our platform,
|
||||
@@ -271,13 +279,13 @@ def create_platform_pages():
|
||||
Email: <a href="mailto:press@marketplace.com">press@marketplace.com</a>
|
||||
</p>
|
||||
""",
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=True, # Show in header navigation
|
||||
show_in_footer=True,
|
||||
display_order=3,
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=True, # Show in header navigation
|
||||
show_in_footer=True,
|
||||
display_order=3,
|
||||
meta_description="Get in touch with our team. We're here to help you succeed.",
|
||||
meta_keywords="contact, support, email, phone, address"
|
||||
meta_keywords="contact, support, email, phone, address",
|
||||
)
|
||||
print(f" ✅ Created: {contact.title} (/{contact.slug})")
|
||||
except Exception as e:
|
||||
@@ -290,14 +298,16 @@ def create_platform_pages():
|
||||
|
||||
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="terms").first()
|
||||
if existing:
|
||||
print(f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})")
|
||||
print(
|
||||
f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
terms = content_page_service.create_page(
|
||||
db,
|
||||
slug="terms",
|
||||
title="Terms of Service",
|
||||
content="""
|
||||
db,
|
||||
slug="terms",
|
||||
title="Terms of Service",
|
||||
content="""
|
||||
<p><em>Last updated: January 1, 2024</em></p>
|
||||
|
||||
<h2>1. Acceptance of Terms</h2>
|
||||
@@ -361,13 +371,13 @@ def create_platform_pages():
|
||||
<a href="mailto:legal@marketplace.com">legal@marketplace.com</a>.
|
||||
</p>
|
||||
""",
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=False, # Too legal for header
|
||||
show_in_footer=True, # Show in footer
|
||||
display_order=10,
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=False, # Too legal for header
|
||||
show_in_footer=True, # Show in footer
|
||||
display_order=10,
|
||||
meta_description="Read our terms of service and platform usage policies.",
|
||||
meta_keywords="terms of service, terms, legal, policy, agreement"
|
||||
meta_keywords="terms of service, terms, legal, policy, agreement",
|
||||
)
|
||||
print(f" ✅ Created: {terms.title} (/{terms.slug})")
|
||||
except Exception as e:
|
||||
@@ -378,16 +388,18 @@ def create_platform_pages():
|
||||
# ========================================================================
|
||||
print("6. Creating Privacy Policy page...")
|
||||
|
||||
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first()
|
||||
existing = (
|
||||
db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first()
|
||||
)
|
||||
if existing:
|
||||
print(f" ⚠️ Skipped: Privacy Policy - already exists (ID: {existing.id})")
|
||||
else:
|
||||
try:
|
||||
privacy = content_page_service.create_page(
|
||||
db,
|
||||
slug="privacy",
|
||||
title="Privacy Policy",
|
||||
content="""
|
||||
db,
|
||||
slug="privacy",
|
||||
title="Privacy Policy",
|
||||
content="""
|
||||
<p><em>Last updated: January 1, 2024</em></p>
|
||||
|
||||
<h2>1. Information We Collect</h2>
|
||||
@@ -453,13 +465,13 @@ def create_platform_pages():
|
||||
<a href="mailto:privacy@marketplace.com">privacy@marketplace.com</a>.
|
||||
</p>
|
||||
""",
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=False, # Too legal for header
|
||||
show_in_footer=True, # Show in footer
|
||||
display_order=11,
|
||||
vendor_id=None,
|
||||
is_published=True,
|
||||
show_in_header=False, # Too legal for header
|
||||
show_in_footer=True, # Show in footer
|
||||
display_order=11,
|
||||
meta_description="Learn how we collect, use, and protect your personal information.",
|
||||
meta_keywords="privacy policy, privacy, data protection, gdpr, personal information"
|
||||
meta_keywords="privacy policy, privacy, data protection, gdpr, personal information",
|
||||
)
|
||||
print(f" ✅ Created: {privacy.title} (/{privacy.slug})")
|
||||
except Exception as e:
|
||||
|
||||
@@ -17,29 +17,30 @@ This script is idempotent - safe to run multiple times.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import (print_environment_info, settings,
|
||||
validate_production_settings)
|
||||
from app.core.database import SessionLocal
|
||||
from app.core.config import settings, print_environment_info, validate_production_settings
|
||||
from app.core.environment import is_production, get_environment
|
||||
from models.database.user import User
|
||||
from models.database.admin import AdminSetting
|
||||
from middleware.auth import AuthManager
|
||||
from app.core.environment import get_environment, is_production
|
||||
from app.core.permissions import PermissionGroups
|
||||
|
||||
from middleware.auth import AuthManager
|
||||
from models.database.admin import AdminSetting
|
||||
from models.database.user import User
|
||||
|
||||
# =============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def print_header(text: str):
|
||||
"""Print formatted header."""
|
||||
print("\n" + "=" * 70)
|
||||
@@ -71,6 +72,7 @@ def print_error(text: str):
|
||||
# INITIALIZATION FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_admin_user(db: Session, auth_manager: AuthManager) -> User:
|
||||
"""Create or get the platform admin user."""
|
||||
|
||||
@@ -206,9 +208,7 @@ def create_admin_settings(db: Session) -> int:
|
||||
for setting_data in default_settings:
|
||||
# Check if setting already exists
|
||||
existing = db.execute(
|
||||
select(AdminSetting).where(
|
||||
AdminSetting.key == setting_data["key"]
|
||||
)
|
||||
select(AdminSetting).where(AdminSetting.key == setting_data["key"])
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not existing:
|
||||
@@ -281,6 +281,7 @@ def verify_rbac_schema(db: Session) -> bool:
|
||||
# MAIN INITIALIZATION
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def initialize_production(db: Session, auth_manager: AuthManager):
|
||||
"""Initialize production database with essential data."""
|
||||
|
||||
@@ -362,6 +363,7 @@ def print_summary(db: Session):
|
||||
# MAIN ENTRY POINT
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
|
||||
@@ -405,6 +407,7 @@ def main():
|
||||
print_header("❌ INITIALIZATION FAILED")
|
||||
print(f"\nError: {e}\n")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ Usage: python route_diagnostics.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, Dict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def check_route_order():
|
||||
@@ -29,23 +29,23 @@ def check_route_order():
|
||||
html_routes = []
|
||||
|
||||
for route in routes:
|
||||
if hasattr(route, 'path'):
|
||||
if hasattr(route, "path"):
|
||||
path = route.path
|
||||
methods = getattr(route, 'methods', set())
|
||||
methods = getattr(route, "methods", set())
|
||||
|
||||
# Determine if JSON or HTML based on common patterns
|
||||
if 'login' in path or 'dashboard' in path or 'products' in path:
|
||||
if "login" in path or "dashboard" in path or "products" in path:
|
||||
# Check response class if available
|
||||
if hasattr(route, 'response_class'):
|
||||
if hasattr(route, "response_class"):
|
||||
response_class = str(route.response_class)
|
||||
if 'HTML' in response_class:
|
||||
if "HTML" in response_class:
|
||||
html_routes.append((path, methods))
|
||||
else:
|
||||
json_routes.append((path, methods))
|
||||
else:
|
||||
# Check endpoint name/tags
|
||||
endpoint = getattr(route, 'endpoint', None)
|
||||
if endpoint and 'page' in endpoint.__name__.lower():
|
||||
endpoint = getattr(route, "endpoint", None)
|
||||
if endpoint and "page" in endpoint.__name__.lower():
|
||||
html_routes.append((path, methods))
|
||||
else:
|
||||
json_routes.append((path, methods))
|
||||
@@ -57,8 +57,10 @@ def check_route_order():
|
||||
# Check for specific vendor info route
|
||||
vendor_info_found = False
|
||||
for route in routes:
|
||||
if hasattr(route, 'path'):
|
||||
if '/{vendor_code}' == route.path and 'GET' in getattr(route, 'methods', set()):
|
||||
if hasattr(route, "path"):
|
||||
if "/{vendor_code}" == route.path and "GET" in getattr(
|
||||
route, "methods", set()
|
||||
):
|
||||
vendor_info_found = True
|
||||
print("✅ Found vendor info endpoint: GET /{vendor_code}")
|
||||
break
|
||||
@@ -100,14 +102,14 @@ def test_vendor_endpoint():
|
||||
print(f" Content-Type: {response.headers.get('content-type', 'N/A')}")
|
||||
|
||||
# Check if response is JSON
|
||||
content_type = response.headers.get('content-type', '')
|
||||
if 'application/json' in content_type:
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
print("✅ Response is JSON")
|
||||
data = response.json()
|
||||
print(f" Vendor: {data.get('name', 'N/A')}")
|
||||
print(f" Code: {data.get('vendor_code', 'N/A')}")
|
||||
return True
|
||||
elif 'text/html' in content_type:
|
||||
elif "text/html" in content_type:
|
||||
print("❌ ERROR: Response is HTML, not JSON!")
|
||||
print(" This confirms the route ordering issue")
|
||||
print(" The HTML page route is catching the API request")
|
||||
|
||||
@@ -26,39 +26,39 @@ This script is idempotent when run normally.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, delete
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.core.config import settings
|
||||
from app.core.environment import is_production, get_environment
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor, VendorUser, Role
|
||||
from models.database.vendor_domain import VendorDomain
|
||||
from models.database.vendor_theme import VendorTheme
|
||||
from models.database.customer import Customer, CustomerAddress
|
||||
from models.database.product import Product
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.order import Order, OrderItem
|
||||
from models.database.admin import PlatformAlert
|
||||
from middleware.auth import AuthManager
|
||||
|
||||
# =============================================================================
|
||||
# MODE DETECTION (from environment variable set by Makefile)
|
||||
# =============================================================================
|
||||
import os
|
||||
|
||||
SEED_MODE = os.getenv('SEED_MODE', 'normal') # normal, minimal, reset
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal
|
||||
from app.core.environment import get_environment, is_production
|
||||
from middleware.auth import AuthManager
|
||||
from models.database.admin import PlatformAlert
|
||||
from models.database.customer import Customer, CustomerAddress
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.order import Order, OrderItem
|
||||
from models.database.product import Product
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Role, Vendor, VendorUser
|
||||
from models.database.vendor_domain import VendorDomain
|
||||
from models.database.vendor_theme import VendorTheme
|
||||
|
||||
SEED_MODE = os.getenv("SEED_MODE", "normal") # normal, minimal, reset
|
||||
|
||||
# =============================================================================
|
||||
# DEMO DATA CONFIGURATION
|
||||
@@ -147,6 +147,7 @@ THEME_PRESETS = {
|
||||
# HELPER FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def print_header(text: str):
|
||||
"""Print formatted header."""
|
||||
print("\n" + "=" * 70)
|
||||
@@ -178,6 +179,7 @@ def print_error(text: str):
|
||||
# SAFETY CHECKS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def check_environment():
|
||||
"""Prevent running demo seed in production."""
|
||||
|
||||
@@ -196,9 +198,7 @@ def check_environment():
|
||||
def check_admin_exists(db: Session) -> bool:
|
||||
"""Check if admin user exists."""
|
||||
|
||||
admin = db.execute(
|
||||
select(User).where(User.role == "admin")
|
||||
).scalar_one_or_none()
|
||||
admin = db.execute(select(User).where(User.role == "admin")).scalar_one_or_none()
|
||||
|
||||
if not admin:
|
||||
print_error("No admin user found!")
|
||||
@@ -214,6 +214,7 @@ def check_admin_exists(db: Session) -> bool:
|
||||
# DATA DELETION (for reset mode)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def reset_all_data(db: Session):
|
||||
"""Delete ALL data from database (except admin user)."""
|
||||
|
||||
@@ -258,13 +259,14 @@ def reset_all_data(db: Session):
|
||||
# SEEDING FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
|
||||
"""Create demo vendors with users."""
|
||||
|
||||
vendors = []
|
||||
|
||||
# Determine how many vendors to create based on mode
|
||||
vendor_count = 1 if SEED_MODE == 'minimal' else settings.seed_demo_vendors
|
||||
vendor_count = 1 if SEED_MODE == "minimal" else settings.seed_demo_vendors
|
||||
vendors_to_create = DEMO_VENDORS[:vendor_count]
|
||||
users_to_create = DEMO_VENDOR_USERS[:vendor_count]
|
||||
|
||||
@@ -320,7 +322,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
|
||||
db.add(vendor_user_link)
|
||||
|
||||
# Create vendor theme
|
||||
theme_colors = THEME_PRESETS.get(vendor_data["theme_preset"], THEME_PRESETS["modern"])
|
||||
theme_colors = THEME_PRESETS.get(
|
||||
vendor_data["theme_preset"], THEME_PRESETS["modern"]
|
||||
)
|
||||
theme = VendorTheme(
|
||||
vendor_id=vendor.id,
|
||||
theme_name=vendor_data["theme_preset"],
|
||||
@@ -355,7 +359,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
|
||||
return vendors
|
||||
|
||||
|
||||
def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager, count: int) -> List[Customer]:
|
||||
def create_demo_customers(
|
||||
db: Session, vendor: Vendor, auth_manager: AuthManager, count: int
|
||||
) -> List[Customer]:
|
||||
"""Create demo customers for a vendor."""
|
||||
|
||||
customers = []
|
||||
@@ -367,10 +373,11 @@ def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager
|
||||
customer_number = f"CUST-{vendor.vendor_code}-{i:04d}"
|
||||
|
||||
# Check if customer already exists
|
||||
existing_customer = db.query(Customer).filter(
|
||||
Customer.vendor_id == vendor.id,
|
||||
Customer.email == email
|
||||
).first()
|
||||
existing_customer = (
|
||||
db.query(Customer)
|
||||
.filter(Customer.vendor_id == vendor.id, Customer.email == email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_customer:
|
||||
customers.append(existing_customer)
|
||||
@@ -412,19 +419,22 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc
|
||||
product_id = f"{vendor.vendor_code}-PROD-{i:03d}"
|
||||
|
||||
# Check if this product already exists
|
||||
existing_product = db.query(Product).filter(
|
||||
Product.vendor_id == vendor.id,
|
||||
Product.product_id == product_id
|
||||
).first()
|
||||
existing_product = (
|
||||
db.query(Product)
|
||||
.filter(Product.vendor_id == vendor.id, Product.product_id == product_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_product:
|
||||
products.append(existing_product)
|
||||
continue # Skip creation, product already exists
|
||||
|
||||
# Check if marketplace product already exists
|
||||
existing_mp = db.query(MarketplaceProduct).filter(
|
||||
MarketplaceProduct.marketplace_product_id == marketplace_product_id
|
||||
).first()
|
||||
existing_mp = (
|
||||
db.query(MarketplaceProduct)
|
||||
.filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_mp:
|
||||
marketplace_product = existing_mp
|
||||
@@ -485,6 +495,7 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc
|
||||
# MAIN SEEDING
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def seed_demo_data(db: Session, auth_manager: AuthManager):
|
||||
"""Seed demo data for development."""
|
||||
|
||||
@@ -501,7 +512,7 @@ def seed_demo_data(db: Session, auth_manager: AuthManager):
|
||||
sys.exit(1)
|
||||
|
||||
# Step 3: Reset data if in reset mode
|
||||
if SEED_MODE == 'reset':
|
||||
if SEED_MODE == "reset":
|
||||
print_step(3, "Resetting data...")
|
||||
reset_all_data(db)
|
||||
|
||||
@@ -513,20 +524,13 @@ def seed_demo_data(db: Session, auth_manager: AuthManager):
|
||||
print_step(5, "Creating demo customers...")
|
||||
for vendor in vendors:
|
||||
create_demo_customers(
|
||||
db,
|
||||
vendor,
|
||||
auth_manager,
|
||||
count=settings.seed_customers_per_vendor
|
||||
db, vendor, auth_manager, count=settings.seed_customers_per_vendor
|
||||
)
|
||||
|
||||
# Step 6: Create products
|
||||
print_step(6, "Creating demo products...")
|
||||
for vendor in vendors:
|
||||
create_demo_products(
|
||||
db,
|
||||
vendor,
|
||||
count=settings.seed_products_per_vendor
|
||||
)
|
||||
create_demo_products(db, vendor, count=settings.seed_products_per_vendor)
|
||||
|
||||
# Commit all changes
|
||||
db.commit()
|
||||
@@ -558,17 +562,18 @@ def print_summary(db: Session):
|
||||
print(f" Subdomain: {vendor.subdomain}.{settings.platform_domain}")
|
||||
|
||||
# Query custom domains separately
|
||||
custom_domain = db.query(VendorDomain).filter(
|
||||
VendorDomain.vendor_id == vendor.id,
|
||||
VendorDomain.is_active == True
|
||||
).first()
|
||||
custom_domain = (
|
||||
db.query(VendorDomain)
|
||||
.filter(VendorDomain.vendor_id == vendor.id, VendorDomain.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_domain:
|
||||
# Try different possible field names (model field might vary)
|
||||
domain_value = (
|
||||
getattr(custom_domain, 'domain', None) or
|
||||
getattr(custom_domain, 'domain_name', None) or
|
||||
getattr(custom_domain, 'name', None)
|
||||
getattr(custom_domain, "domain", None)
|
||||
or getattr(custom_domain, "domain_name", None)
|
||||
or getattr(custom_domain, "name", None)
|
||||
)
|
||||
if domain_value:
|
||||
print(f" Custom: {domain_value}")
|
||||
@@ -584,8 +589,12 @@ def print_summary(db: Session):
|
||||
print(f" Email: {vendor_data['email']}")
|
||||
print(f" Password: {vendor_data['password']}")
|
||||
if vendor:
|
||||
print(f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login")
|
||||
print(f" or http://{vendor.subdomain}.localhost:8000/vendor/login")
|
||||
print(
|
||||
f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login"
|
||||
)
|
||||
print(
|
||||
f" or http://{vendor.subdomain}.localhost:8000/vendor/login"
|
||||
)
|
||||
print()
|
||||
|
||||
print(f"\n🛒 Demo Customer Credentials:")
|
||||
@@ -600,7 +609,9 @@ def print_summary(db: Session):
|
||||
print("─" * 70)
|
||||
for vendor in vendors:
|
||||
print(f" {vendor.name}:")
|
||||
print(f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/")
|
||||
print(
|
||||
f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/"
|
||||
)
|
||||
print(f" Subdomain: http://{vendor.subdomain}.localhost:8000/")
|
||||
print()
|
||||
|
||||
@@ -621,6 +632,7 @@ def print_summary(db: Session):
|
||||
# MAIN ENTRY POINT
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
|
||||
@@ -647,6 +659,7 @@ def main():
|
||||
print_header("❌ SEEDING FAILED")
|
||||
print(f"\nError: {e}\n")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -11,11 +11,11 @@ Usage:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def count_files(directory: str, pattern: str) -> int:
|
||||
@@ -26,10 +26,14 @@ def count_files(directory: str, pattern: str) -> int:
|
||||
count = 0
|
||||
for root, dirs, files in os.walk(directory):
|
||||
# Skip __pycache__ and other cache directories
|
||||
dirs[:] = [d for d in dirs if d not in ['__pycache__', '.pytest_cache', '.git', 'node_modules']]
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if d not in ["__pycache__", ".pytest_cache", ".git", "node_modules"]
|
||||
]
|
||||
|
||||
for file in files:
|
||||
if pattern == '*' or file.endswith(pattern):
|
||||
if pattern == "*" or file.endswith(pattern):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -41,26 +45,26 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st
|
||||
|
||||
# Try to use system tree command first
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
# Windows tree command
|
||||
result = subprocess.run(
|
||||
['tree', '/F', '/A', directory],
|
||||
["tree", "/F", "/A", directory],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
errors='replace'
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
)
|
||||
return result.stdout
|
||||
else:
|
||||
# Linux/Mac tree command with exclusions
|
||||
exclude_args = []
|
||||
if exclude_patterns:
|
||||
exclude_args = ['-I', '|'.join(exclude_patterns)]
|
||||
exclude_args = ["-I", "|".join(exclude_patterns)]
|
||||
|
||||
result = subprocess.run(
|
||||
['tree', '-F', '-a'] + exclude_args + [directory],
|
||||
["tree", "-F", "-a"] + exclude_args + [directory],
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
@@ -71,10 +75,19 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st
|
||||
return generate_manual_tree(directory, exclude_patterns)
|
||||
|
||||
|
||||
def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, prefix: str = "") -> str:
|
||||
def generate_manual_tree(
|
||||
directory: str, exclude_patterns: List[str] = None, prefix: str = ""
|
||||
) -> str:
|
||||
"""Generate tree structure manually when tree command is not available."""
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = ['__pycache__', '.pytest_cache', '.git', 'node_modules', '*.pyc', '*.pyo']
|
||||
exclude_patterns = [
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
".git",
|
||||
"node_modules",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
]
|
||||
|
||||
output = []
|
||||
path = Path(directory)
|
||||
@@ -86,7 +99,7 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre
|
||||
# Skip excluded patterns
|
||||
skip = False
|
||||
for pattern in exclude_patterns:
|
||||
if pattern.startswith('*'):
|
||||
if pattern.startswith("*"):
|
||||
# File extension pattern
|
||||
if item.name.endswith(pattern[1:]):
|
||||
skip = True
|
||||
@@ -106,7 +119,9 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre
|
||||
if item.is_dir():
|
||||
output.append(f"{prefix}{current_prefix}{item.name}/")
|
||||
extension = " " if is_last else "│ "
|
||||
subtree = generate_manual_tree(str(item), exclude_patterns, prefix + extension)
|
||||
subtree = generate_manual_tree(
|
||||
str(item), exclude_patterns, prefix + extension
|
||||
)
|
||||
if subtree:
|
||||
output.append(subtree)
|
||||
else:
|
||||
@@ -127,20 +142,36 @@ def generate_frontend_structure() -> str:
|
||||
|
||||
# Templates section
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ JINJA2 TEMPLATES ║")
|
||||
output.append("║ Location: app/templates ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ JINJA2 TEMPLATES ║"
|
||||
)
|
||||
output.append(
|
||||
"║ Location: app/templates ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
output.append(get_tree_structure("app/templates"))
|
||||
|
||||
# Static assets section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ STATIC ASSETS ║")
|
||||
output.append("║ Location: static ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ STATIC ASSETS ║"
|
||||
)
|
||||
output.append(
|
||||
"║ Location: static ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
output.append(get_tree_structure("static"))
|
||||
|
||||
@@ -148,11 +179,21 @@ def generate_frontend_structure() -> str:
|
||||
if os.path.exists("docs"):
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ DOCUMENTATION ║")
|
||||
output.append("║ Location: docs ║")
|
||||
output.append("║ (also listed in tools structure) ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ DOCUMENTATION ║"
|
||||
)
|
||||
output.append(
|
||||
"║ Location: docs ║"
|
||||
)
|
||||
output.append(
|
||||
"║ (also listed in tools structure) ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
output.append("Note: Documentation is also included in tools structure")
|
||||
output.append(" for infrastructure/DevOps context.")
|
||||
@@ -160,9 +201,15 @@ def generate_frontend_structure() -> str:
|
||||
# Statistics section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ STATISTICS ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ STATISTICS ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
output.append("Templates:")
|
||||
@@ -174,8 +221,8 @@ def generate_frontend_structure() -> str:
|
||||
output.append(f" - JavaScript files: {count_files('static', '.js')}")
|
||||
output.append(f" - CSS files: {count_files('static', '.css')}")
|
||||
output.append(" - Image files:")
|
||||
for ext in ['png', 'jpg', 'jpeg', 'gif', 'svg', 'webp', 'ico']:
|
||||
count = count_files('static', f'.{ext}')
|
||||
for ext in ["png", "jpg", "jpeg", "gif", "svg", "webp", "ico"]:
|
||||
count = count_files("static", f".{ext}")
|
||||
if count > 0:
|
||||
output.append(f" - .{ext}: {count}")
|
||||
|
||||
@@ -200,41 +247,58 @@ def generate_backend_structure() -> str:
|
||||
output.append("=" * 78)
|
||||
output.append("")
|
||||
|
||||
exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info', 'templates']
|
||||
exclude = [
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
".pytest_cache",
|
||||
"*.egg-info",
|
||||
"templates",
|
||||
]
|
||||
|
||||
# Backend directories to include
|
||||
backend_dirs = [
|
||||
('app', 'Application Code'),
|
||||
('middleware', 'Middleware Components'),
|
||||
('models', 'Database Models'),
|
||||
('storage', 'File Storage'),
|
||||
('tasks', 'Background Tasks'),
|
||||
('logs', 'Application Logs'),
|
||||
("app", "Application Code"),
|
||||
("middleware", "Middleware Components"),
|
||||
("models", "Database Models"),
|
||||
("storage", "File Storage"),
|
||||
("tasks", "Background Tasks"),
|
||||
("logs", "Application Logs"),
|
||||
]
|
||||
|
||||
for directory, title in backend_dirs:
|
||||
if os.path.exists(directory):
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(f"║ {title.upper().center(62)} ║")
|
||||
output.append(f"║ Location: {directory + '/'.ljust(51)} ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
output.append(get_tree_structure(directory, exclude))
|
||||
|
||||
# Statistics section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ STATISTICS ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ STATISTICS ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
output.append("Python Files by Directory:")
|
||||
total_py_files = 0
|
||||
for directory, title in backend_dirs:
|
||||
if os.path.exists(directory):
|
||||
count = count_files(directory, '.py')
|
||||
count = count_files(directory, ".py")
|
||||
total_py_files += count
|
||||
output.append(f" - {directory}/: {count} files")
|
||||
|
||||
@@ -242,11 +306,11 @@ def generate_backend_structure() -> str:
|
||||
|
||||
output.append("")
|
||||
output.append("Application Components (if in app/):")
|
||||
components = ['routes', 'services', 'schemas', 'exceptions', 'utils']
|
||||
components = ["routes", "services", "schemas", "exceptions", "utils"]
|
||||
for component in components:
|
||||
component_path = f"app/{component}"
|
||||
if os.path.exists(component_path):
|
||||
count = count_files(component_path, '.py')
|
||||
count = count_files(component_path, ".py")
|
||||
output.append(f" - app/{component}: {count} files")
|
||||
|
||||
output.append("")
|
||||
@@ -264,48 +328,58 @@ def generate_tools_structure() -> str:
|
||||
output.append("=" * 78)
|
||||
output.append("")
|
||||
|
||||
exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info']
|
||||
exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"]
|
||||
|
||||
# Tools directories to include
|
||||
tools_dirs = [
|
||||
('alembic', 'Database Migrations'),
|
||||
('scripts', 'Utility Scripts'),
|
||||
('docker', 'Docker Configuration'),
|
||||
('docs', 'Documentation'),
|
||||
("alembic", "Database Migrations"),
|
||||
("scripts", "Utility Scripts"),
|
||||
("docker", "Docker Configuration"),
|
||||
("docs", "Documentation"),
|
||||
]
|
||||
|
||||
for directory, title in tools_dirs:
|
||||
if os.path.exists(directory):
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(f"║ {title.upper().center(62)} ║")
|
||||
output.append(f"║ Location: {directory + '/'.ljust(51)} ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
output.append(get_tree_structure(directory, exclude))
|
||||
|
||||
# Configuration files section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ CONFIGURATION FILES ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ CONFIGURATION FILES ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
output.append("Root configuration files:")
|
||||
config_files = [
|
||||
('Makefile', 'Build automation'),
|
||||
('requirements.txt', 'Python dependencies'),
|
||||
('pyproject.toml', 'Python project config'),
|
||||
('setup.py', 'Python setup script'),
|
||||
('setup.cfg', 'Setup configuration'),
|
||||
('alembic.ini', 'Alembic migrations config'),
|
||||
('mkdocs.yml', 'MkDocs documentation config'),
|
||||
('Dockerfile', 'Docker image definition'),
|
||||
('docker-compose.yml', 'Docker services'),
|
||||
('.dockerignore', 'Docker ignore patterns'),
|
||||
('.gitignore', 'Git ignore patterns'),
|
||||
('.env.example', 'Environment variables template'),
|
||||
("Makefile", "Build automation"),
|
||||
("requirements.txt", "Python dependencies"),
|
||||
("pyproject.toml", "Python project config"),
|
||||
("setup.py", "Python setup script"),
|
||||
("setup.cfg", "Setup configuration"),
|
||||
("alembic.ini", "Alembic migrations config"),
|
||||
("mkdocs.yml", "MkDocs documentation config"),
|
||||
("Dockerfile", "Docker image definition"),
|
||||
("docker-compose.yml", "Docker services"),
|
||||
(".dockerignore", "Docker ignore patterns"),
|
||||
(".gitignore", "Git ignore patterns"),
|
||||
(".env.example", "Environment variables template"),
|
||||
]
|
||||
|
||||
for file, description in config_files:
|
||||
@@ -315,9 +389,15 @@ def generate_tools_structure() -> str:
|
||||
# Statistics section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ STATISTICS ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ STATISTICS ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
output.append("Database Migrations:")
|
||||
@@ -327,7 +407,9 @@ def generate_tools_structure() -> str:
|
||||
if migration_count > 0:
|
||||
# Get first and last migration
|
||||
try:
|
||||
migrations = sorted([f for f in os.listdir("alembic/versions") if f.endswith('.py')])
|
||||
migrations = sorted(
|
||||
[f for f in os.listdir("alembic/versions") if f.endswith(".py")]
|
||||
)
|
||||
if migrations:
|
||||
output.append(f" - First: {migrations[0][:40]}...")
|
||||
if len(migrations) > 1:
|
||||
@@ -341,9 +423,9 @@ def generate_tools_structure() -> str:
|
||||
output.append("Scripts:")
|
||||
if os.path.exists("scripts"):
|
||||
script_types = {
|
||||
'.py': 'Python scripts',
|
||||
'.sh': 'Shell scripts',
|
||||
'.bat': 'Batch scripts',
|
||||
".py": "Python scripts",
|
||||
".sh": "Shell scripts",
|
||||
".bat": "Batch scripts",
|
||||
}
|
||||
for ext, desc in script_types.items():
|
||||
count = count_files("scripts", ext)
|
||||
@@ -356,8 +438,8 @@ def generate_tools_structure() -> str:
|
||||
output.append("Documentation:")
|
||||
if os.path.exists("docs"):
|
||||
doc_types = {
|
||||
'.md': 'Markdown files',
|
||||
'.rst': 'reStructuredText files',
|
||||
".md": "Markdown files",
|
||||
".rst": "reStructuredText files",
|
||||
}
|
||||
for ext, desc in doc_types.items():
|
||||
count = count_files("docs", ext)
|
||||
@@ -368,7 +450,7 @@ def generate_tools_structure() -> str:
|
||||
|
||||
output.append("")
|
||||
output.append("Docker:")
|
||||
docker_files = ['Dockerfile', 'docker-compose.yml', '.dockerignore']
|
||||
docker_files = ["Dockerfile", "docker-compose.yml", ".dockerignore"]
|
||||
docker_exists = any(os.path.exists(f) for f in docker_files)
|
||||
if docker_exists:
|
||||
output.append(" ✓ Docker configuration present")
|
||||
@@ -394,25 +476,44 @@ def generate_test_structure() -> str:
|
||||
|
||||
# Test files section
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ TEST FILES ║")
|
||||
output.append("║ Location: tests/ ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ TEST FILES ║"
|
||||
)
|
||||
output.append(
|
||||
"║ Location: tests/ ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info']
|
||||
exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"]
|
||||
output.append(get_tree_structure("tests", exclude))
|
||||
|
||||
# Configuration section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ TEST CONFIGURATION ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ TEST CONFIGURATION ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
output.append("Test configuration files:")
|
||||
test_config_files = ['pytest.ini', 'conftest.py', 'tests/conftest.py', '.coveragerc']
|
||||
test_config_files = [
|
||||
"pytest.ini",
|
||||
"conftest.py",
|
||||
"tests/conftest.py",
|
||||
".coveragerc",
|
||||
]
|
||||
for file in test_config_files:
|
||||
if os.path.exists(file):
|
||||
output.append(f" ✓ {file}")
|
||||
@@ -420,16 +521,22 @@ def generate_test_structure() -> str:
|
||||
# Statistics section
|
||||
output.append("")
|
||||
output.append("")
|
||||
output.append("╔══════════════════════════════════════════════════════════════════╗")
|
||||
output.append("║ STATISTICS ║")
|
||||
output.append("╚══════════════════════════════════════════════════════════════════╝")
|
||||
output.append(
|
||||
"╔══════════════════════════════════════════════════════════════════╗"
|
||||
)
|
||||
output.append(
|
||||
"║ STATISTICS ║"
|
||||
)
|
||||
output.append(
|
||||
"╚══════════════════════════════════════════════════════════════════╝"
|
||||
)
|
||||
output.append("")
|
||||
|
||||
# Count test files
|
||||
test_file_count = 0
|
||||
if os.path.exists("tests"):
|
||||
for root, dirs, files in os.walk("tests"):
|
||||
dirs[:] = [d for d in dirs if d != '__pycache__']
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
for file in files:
|
||||
if file.startswith("test_") and file.endswith(".py"):
|
||||
test_file_count += 1
|
||||
@@ -439,13 +546,13 @@ def generate_test_structure() -> str:
|
||||
|
||||
output.append("")
|
||||
output.append("By Category:")
|
||||
categories = ['unit', 'integration', 'system', 'e2e', 'performance']
|
||||
categories = ["unit", "integration", "system", "e2e", "performance"]
|
||||
for category in categories:
|
||||
category_path = f"tests/{category}"
|
||||
if os.path.exists(category_path):
|
||||
count = 0
|
||||
for root, dirs, files in os.walk(category_path):
|
||||
dirs[:] = [d for d in dirs if d != '__pycache__']
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
for file in files:
|
||||
if file.startswith("test_") and file.endswith(".py"):
|
||||
count += 1
|
||||
@@ -455,12 +562,12 @@ def generate_test_structure() -> str:
|
||||
test_function_count = 0
|
||||
if os.path.exists("tests"):
|
||||
for root, dirs, files in os.walk("tests"):
|
||||
dirs[:] = [d for d in dirs if d != '__pycache__']
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
for file in files:
|
||||
if file.startswith("test_") and file.endswith(".py"):
|
||||
filepath = os.path.join(root, file)
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("def test_"):
|
||||
test_function_count += 1
|
||||
@@ -494,19 +601,19 @@ def main():
|
||||
structure_type = sys.argv[1].lower()
|
||||
|
||||
generators = {
|
||||
'frontend': ('frontend-structure.txt', generate_frontend_structure),
|
||||
'backend': ('backend-structure.txt', generate_backend_structure),
|
||||
'tests': ('test-structure.txt', generate_test_structure),
|
||||
'tools': ('tools-structure.txt', generate_tools_structure),
|
||||
"frontend": ("frontend-structure.txt", generate_frontend_structure),
|
||||
"backend": ("backend-structure.txt", generate_backend_structure),
|
||||
"tests": ("test-structure.txt", generate_test_structure),
|
||||
"tools": ("tools-structure.txt", generate_tools_structure),
|
||||
}
|
||||
|
||||
if structure_type == 'all':
|
||||
if structure_type == "all":
|
||||
for name, (filename, generator) in generators.items():
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Generating {name} structure...")
|
||||
print('=' * 60)
|
||||
print("=" * 60)
|
||||
content = generator()
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
print(f"✅ {name.capitalize()} structure saved to {filename}")
|
||||
print(f"\n{content}\n")
|
||||
@@ -514,7 +621,7 @@ def main():
|
||||
filename, generator = generators[structure_type]
|
||||
print(f"Generating {structure_type} structure...")
|
||||
content = generator()
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
print(f"\n✅ Structure saved to {filename}\n")
|
||||
print(content)
|
||||
|
||||
@@ -20,22 +20,26 @@ Requirements:
|
||||
* Customer: username=customer, password=customer123, vendor_id=1
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
class Color:
|
||||
"""Terminal colors for pretty output"""
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
MAGENTA = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
BOLD = '\033[1m'
|
||||
END = '\033[0m'
|
||||
|
||||
GREEN = "\033[92m"
|
||||
RED = "\033[91m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
BOLD = "\033[1m"
|
||||
END = "\033[0m"
|
||||
|
||||
|
||||
def print_section(name: str):
|
||||
"""Print section header"""
|
||||
@@ -43,66 +47,70 @@ def print_section(name: str):
|
||||
print(f" {name}")
|
||||
print(f"{'═' * 60}{Color.END}")
|
||||
|
||||
|
||||
def print_test(name: str):
|
||||
"""Print test name"""
|
||||
print(f"\n{Color.BOLD}{Color.BLUE}🧪 Test: {name}{Color.END}")
|
||||
|
||||
|
||||
def print_success(message: str):
|
||||
"""Print success message"""
|
||||
print(f"{Color.GREEN}✅ {message}{Color.END}")
|
||||
|
||||
|
||||
def print_error(message: str):
|
||||
"""Print error message"""
|
||||
print(f"{Color.RED}❌ {message}{Color.END}")
|
||||
|
||||
|
||||
def print_info(message: str):
|
||||
"""Print info message"""
|
||||
print(f"{Color.YELLOW}ℹ️ {message}{Color.END}")
|
||||
|
||||
|
||||
def print_warning(message: str):
|
||||
"""Print warning message"""
|
||||
print(f"{Color.MAGENTA}⚠️ {message}{Color.END}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ADMIN AUTHENTICATION TESTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_admin_login() -> Optional[Dict]:
|
||||
"""Test admin login and cookie configuration"""
|
||||
print_test("Admin Login")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/v1/admin/auth/login",
|
||||
json={"username": "admin", "password": "admin123"}
|
||||
json={"username": "admin", "password": "admin123"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
cookies = response.cookies
|
||||
|
||||
|
||||
if "access_token" in data:
|
||||
print_success("Admin login successful")
|
||||
print_success(f"Received access token: {data['access_token'][:20]}...")
|
||||
else:
|
||||
print_error("No access token in response")
|
||||
return None
|
||||
|
||||
|
||||
if "admin_token" in cookies:
|
||||
print_success("admin_token cookie set")
|
||||
print_info("Cookie path should be /admin (verify in browser)")
|
||||
else:
|
||||
print_error("admin_token cookie NOT set")
|
||||
|
||||
return {
|
||||
"token": data["access_token"],
|
||||
"user": data.get("user", {})
|
||||
}
|
||||
|
||||
return {"token": data["access_token"], "user": data.get("user", {})}
|
||||
else:
|
||||
print_error(f"Login failed: {response.status_code}")
|
||||
print_error(f"Response: {response.text}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception during admin login: {str(e)}")
|
||||
return None
|
||||
@@ -111,13 +119,13 @@ def test_admin_login() -> Optional[Dict]:
|
||||
def test_admin_cannot_access_vendor_api(admin_token: str):
|
||||
"""Test that admin token cannot access vendor API"""
|
||||
print_test("Admin Token on Vendor API (Should Block)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products",
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code in [401, 403]:
|
||||
data = response.json()
|
||||
print_success("Admin correctly blocked from vendor API")
|
||||
@@ -129,7 +137,7 @@ def test_admin_cannot_access_vendor_api(admin_token: str):
|
||||
else:
|
||||
print_warning(f"Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -138,13 +146,13 @@ def test_admin_cannot_access_vendor_api(admin_token: str):
|
||||
def test_admin_cannot_access_customer_api(admin_token: str):
|
||||
"""Test that admin token cannot access customer account pages"""
|
||||
print_test("Admin Token on Customer API (Should Block)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/shop/account/dashboard",
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
|
||||
# Customer pages may return HTML or JSON error
|
||||
if response.status_code in [401, 403]:
|
||||
print_success("Admin correctly blocked from customer pages")
|
||||
@@ -155,7 +163,7 @@ def test_admin_cannot_access_customer_api(admin_token: str):
|
||||
else:
|
||||
print_warning(f"Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -165,46 +173,47 @@ def test_admin_cannot_access_customer_api(admin_token: str):
|
||||
# VENDOR AUTHENTICATION TESTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_vendor_login() -> Optional[Dict]:
|
||||
"""Test vendor login and cookie configuration"""
|
||||
print_test("Vendor Login")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/v1/vendor/auth/login",
|
||||
json={"username": "vendor", "password": "vendor123"}
|
||||
json={"username": "vendor", "password": "vendor123"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
cookies = response.cookies
|
||||
|
||||
|
||||
if "access_token" in data:
|
||||
print_success("Vendor login successful")
|
||||
print_success(f"Received access token: {data['access_token'][:20]}...")
|
||||
else:
|
||||
print_error("No access token in response")
|
||||
return None
|
||||
|
||||
|
||||
if "vendor_token" in cookies:
|
||||
print_success("vendor_token cookie set")
|
||||
print_info("Cookie path should be /vendor (verify in browser)")
|
||||
else:
|
||||
print_error("vendor_token cookie NOT set")
|
||||
|
||||
|
||||
if "vendor" in data:
|
||||
print_success(f"Vendor: {data['vendor'].get('vendor_code', 'N/A')}")
|
||||
|
||||
|
||||
return {
|
||||
"token": data["access_token"],
|
||||
"user": data.get("user", {}),
|
||||
"vendor": data.get("vendor", {})
|
||||
"vendor": data.get("vendor", {}),
|
||||
}
|
||||
else:
|
||||
print_error(f"Login failed: {response.status_code}")
|
||||
print_error(f"Response: {response.text}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception during vendor login: {str(e)}")
|
||||
return None
|
||||
@@ -213,13 +222,13 @@ def test_vendor_login() -> Optional[Dict]:
|
||||
def test_vendor_cannot_access_admin_api(vendor_token: str):
|
||||
"""Test that vendor token cannot access admin API"""
|
||||
print_test("Vendor Token on Admin API (Should Block)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/api/v1/admin/vendors",
|
||||
headers={"Authorization": f"Bearer {vendor_token}"}
|
||||
headers={"Authorization": f"Bearer {vendor_token}"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code in [401, 403]:
|
||||
data = response.json()
|
||||
print_success("Vendor correctly blocked from admin API")
|
||||
@@ -231,7 +240,7 @@ def test_vendor_cannot_access_admin_api(vendor_token: str):
|
||||
else:
|
||||
print_warning(f"Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -240,13 +249,13 @@ def test_vendor_cannot_access_admin_api(vendor_token: str):
|
||||
def test_vendor_cannot_access_customer_api(vendor_token: str):
|
||||
"""Test that vendor token cannot access customer account pages"""
|
||||
print_test("Vendor Token on Customer API (Should Block)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/shop/account/dashboard",
|
||||
headers={"Authorization": f"Bearer {vendor_token}"}
|
||||
headers={"Authorization": f"Bearer {vendor_token}"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code in [401, 403]:
|
||||
print_success("Vendor correctly blocked from customer pages")
|
||||
return True
|
||||
@@ -256,7 +265,7 @@ def test_vendor_cannot_access_customer_api(vendor_token: str):
|
||||
else:
|
||||
print_warning(f"Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -266,42 +275,40 @@ def test_vendor_cannot_access_customer_api(vendor_token: str):
|
||||
# CUSTOMER AUTHENTICATION TESTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_customer_login() -> Optional[Dict]:
|
||||
"""Test customer login and cookie configuration"""
|
||||
print_test("Customer Login")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/v1/public/vendors/1/customers/login",
|
||||
json={"username": "customer", "password": "customer123"}
|
||||
json={"username": "customer", "password": "customer123"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
cookies = response.cookies
|
||||
|
||||
|
||||
if "access_token" in data:
|
||||
print_success("Customer login successful")
|
||||
print_success(f"Received access token: {data['access_token'][:20]}...")
|
||||
else:
|
||||
print_error("No access token in response")
|
||||
return None
|
||||
|
||||
|
||||
if "customer_token" in cookies:
|
||||
print_success("customer_token cookie set")
|
||||
print_info("Cookie path should be /shop (verify in browser)")
|
||||
else:
|
||||
print_error("customer_token cookie NOT set")
|
||||
|
||||
return {
|
||||
"token": data["access_token"],
|
||||
"user": data.get("user", {})
|
||||
}
|
||||
|
||||
return {"token": data["access_token"], "user": data.get("user", {})}
|
||||
else:
|
||||
print_error(f"Login failed: {response.status_code}")
|
||||
print_error(f"Response: {response.text}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception during customer login: {str(e)}")
|
||||
return None
|
||||
@@ -310,13 +317,13 @@ def test_customer_login() -> Optional[Dict]:
|
||||
def test_customer_cannot_access_admin_api(customer_token: str):
|
||||
"""Test that customer token cannot access admin API"""
|
||||
print_test("Customer Token on Admin API (Should Block)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/api/v1/admin/vendors",
|
||||
headers={"Authorization": f"Bearer {customer_token}"}
|
||||
headers={"Authorization": f"Bearer {customer_token}"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code in [401, 403]:
|
||||
data = response.json()
|
||||
print_success("Customer correctly blocked from admin API")
|
||||
@@ -328,7 +335,7 @@ def test_customer_cannot_access_admin_api(customer_token: str):
|
||||
else:
|
||||
print_warning(f"Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -337,13 +344,13 @@ def test_customer_cannot_access_admin_api(customer_token: str):
|
||||
def test_customer_cannot_access_vendor_api(customer_token: str):
|
||||
"""Test that customer token cannot access vendor API"""
|
||||
print_test("Customer Token on Vendor API (Should Block)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products",
|
||||
headers={"Authorization": f"Bearer {customer_token}"}
|
||||
headers={"Authorization": f"Bearer {customer_token}"},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code in [401, 403]:
|
||||
data = response.json()
|
||||
print_success("Customer correctly blocked from vendor API")
|
||||
@@ -355,7 +362,7 @@ def test_customer_cannot_access_vendor_api(customer_token: str):
|
||||
else:
|
||||
print_warning(f"Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -364,17 +371,17 @@ def test_customer_cannot_access_vendor_api(customer_token: str):
|
||||
def test_public_shop_access():
|
||||
"""Test that public shop pages are accessible without authentication"""
|
||||
print_test("Public Shop Access (No Auth Required)")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/shop/products")
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
print_success("Public shop pages accessible without auth")
|
||||
return True
|
||||
else:
|
||||
print_error(f"Failed to access public shop: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -383,10 +390,10 @@ def test_public_shop_access():
|
||||
def test_health_check():
|
||||
"""Test health check endpoint"""
|
||||
print_test("Health Check")
|
||||
|
||||
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/health")
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print_success("Health check passed")
|
||||
@@ -395,7 +402,7 @@ def test_health_check():
|
||||
else:
|
||||
print_error(f"Health check failed: {response.status_code}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Exception: {str(e)}")
|
||||
return False
|
||||
@@ -405,19 +412,16 @@ def test_health_check():
|
||||
# MAIN TEST RUNNER
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print(f"\n{Color.BOLD}{Color.CYAN}{'═' * 60}")
|
||||
print(f" 🔒 COMPLETE AUTHENTICATION SYSTEM TEST SUITE")
|
||||
print(f"{'═' * 60}{Color.END}")
|
||||
print(f"Testing server at: {BASE_URL}")
|
||||
|
||||
results = {
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"total": 0
|
||||
}
|
||||
|
||||
|
||||
results = {"passed": 0, "failed": 0, "total": 0}
|
||||
|
||||
# Health check first
|
||||
print_section("🏥 System Health")
|
||||
results["total"] += 1
|
||||
@@ -427,12 +431,12 @@ def main():
|
||||
results["failed"] += 1
|
||||
print_error("Server not responding. Is it running?")
|
||||
return
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# ADMIN TESTS
|
||||
# ========================================================================
|
||||
print_section("👤 Admin Authentication Tests")
|
||||
|
||||
|
||||
# Admin login
|
||||
results["total"] += 1
|
||||
admin_auth = test_admin_login()
|
||||
@@ -441,7 +445,7 @@ def main():
|
||||
else:
|
||||
results["failed"] += 1
|
||||
admin_auth = None
|
||||
|
||||
|
||||
# Admin cross-context tests
|
||||
if admin_auth:
|
||||
results["total"] += 1
|
||||
@@ -449,18 +453,18 @@ def main():
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
results["total"] += 1
|
||||
if test_admin_cannot_access_customer_api(admin_auth["token"]):
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# VENDOR TESTS
|
||||
# ========================================================================
|
||||
print_section("🏪 Vendor Authentication Tests")
|
||||
|
||||
|
||||
# Vendor login
|
||||
results["total"] += 1
|
||||
vendor_auth = test_vendor_login()
|
||||
@@ -469,7 +473,7 @@ def main():
|
||||
else:
|
||||
results["failed"] += 1
|
||||
vendor_auth = None
|
||||
|
||||
|
||||
# Vendor cross-context tests
|
||||
if vendor_auth:
|
||||
results["total"] += 1
|
||||
@@ -477,25 +481,25 @@ def main():
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
results["total"] += 1
|
||||
if test_vendor_cannot_access_customer_api(vendor_auth["token"]):
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# CUSTOMER TESTS
|
||||
# ========================================================================
|
||||
print_section("🛒 Customer Authentication Tests")
|
||||
|
||||
|
||||
# Public shop access
|
||||
results["total"] += 1
|
||||
if test_public_shop_access():
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
# Customer login
|
||||
results["total"] += 1
|
||||
customer_auth = test_customer_login()
|
||||
@@ -504,7 +508,7 @@ def main():
|
||||
else:
|
||||
results["failed"] += 1
|
||||
customer_auth = None
|
||||
|
||||
|
||||
# Customer cross-context tests
|
||||
if customer_auth:
|
||||
results["total"] += 1
|
||||
@@ -512,29 +516,31 @@ def main():
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
results["total"] += 1
|
||||
if test_customer_cannot_access_vendor_api(customer_auth["token"]):
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# RESULTS
|
||||
# ========================================================================
|
||||
print_section("📊 Test Results")
|
||||
print(f"\n{Color.BOLD}Total Tests: {results['total']}{Color.END}")
|
||||
print_success(f"Passed: {results['passed']}")
|
||||
if results['failed'] > 0:
|
||||
if results["failed"] > 0:
|
||||
print_error(f"Failed: {results['failed']}")
|
||||
|
||||
if results['failed'] == 0:
|
||||
|
||||
if results["failed"] == 0:
|
||||
print(f"\n{Color.GREEN}{Color.BOLD}🎉 ALL TESTS PASSED!{Color.END}")
|
||||
print(f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}")
|
||||
print(
|
||||
f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}"
|
||||
)
|
||||
else:
|
||||
print(f"\n{Color.RED}{Color.BOLD}⚠️ SOME TESTS FAILED{Color.END}")
|
||||
print(f"{Color.RED}Please review the output above.{Color.END}")
|
||||
|
||||
|
||||
# Browser tests reminder
|
||||
print_section("🌐 Manual Browser Tests")
|
||||
print("Please also verify in browser:")
|
||||
|
||||
@@ -9,10 +9,11 @@ Tests:
|
||||
- Transfer ownership
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from pprint import pprint
|
||||
|
||||
import requests
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
ADMIN_TOKEN = None # Will be set after login
|
||||
|
||||
@@ -25,8 +26,8 @@ def login_admin():
|
||||
f"{BASE_URL}/api/v1/admin/auth/login", # ✅ Changed: added /admin/
|
||||
json={
|
||||
"username": "admin", # Replace with your admin username
|
||||
"password": "admin123" # Replace with your admin password
|
||||
}
|
||||
"password": "admin123", # Replace with your admin password
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -44,7 +45,7 @@ def get_headers():
|
||||
"""Get authorization headers."""
|
||||
return {
|
||||
"Authorization": f"Bearer {ADMIN_TOKEN}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
@@ -60,13 +61,11 @@ def test_create_vendor_with_both_emails():
|
||||
"subdomain": "testdual",
|
||||
"owner_email": "owner@testdual.com",
|
||||
"contact_email": "contact@testdual.com",
|
||||
"description": "Test vendor with separate emails"
|
||||
"description": "Test vendor with separate emails",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/v1/admin/vendors",
|
||||
headers=get_headers(),
|
||||
json=vendor_data
|
||||
f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -79,7 +78,7 @@ def test_create_vendor_with_both_emails():
|
||||
print(f" Username: {data['owner_username']}")
|
||||
print(f" Password: {data['temporary_password']}")
|
||||
print(f"\n🔗 Login URL: {data['login_url']}")
|
||||
return data['id']
|
||||
return data["id"]
|
||||
else:
|
||||
print(f"❌ Failed: {response.status_code}")
|
||||
print(response.text)
|
||||
@@ -97,13 +96,11 @@ def test_create_vendor_single_email():
|
||||
"name": "Test Single Email Vendor",
|
||||
"subdomain": "testsingle",
|
||||
"owner_email": "owner@testsingle.com",
|
||||
"description": "Test vendor with single email"
|
||||
"description": "Test vendor with single email",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/v1/admin/vendors",
|
||||
headers=get_headers(),
|
||||
json=vendor_data
|
||||
f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -113,12 +110,12 @@ def test_create_vendor_single_email():
|
||||
print(f" Owner Email: {data['owner_email']}")
|
||||
print(f" Contact Email: {data['contact_email']}")
|
||||
|
||||
if data['owner_email'] == data['contact_email']:
|
||||
if data["owner_email"] == data["contact_email"]:
|
||||
print("✅ Contact email correctly defaulted to owner email")
|
||||
else:
|
||||
print("❌ Contact email should have defaulted to owner email")
|
||||
|
||||
return data['id']
|
||||
return data["id"]
|
||||
else:
|
||||
print(f"❌ Failed: {response.status_code}")
|
||||
print(response.text)
|
||||
@@ -133,13 +130,13 @@ def test_update_vendor_contact_email(vendor_id):
|
||||
|
||||
update_data = {
|
||||
"contact_email": "newcontact@business.com",
|
||||
"name": "Updated Vendor Name"
|
||||
"name": "Updated Vendor Name",
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
|
||||
headers=get_headers(),
|
||||
json=update_data
|
||||
json=update_data,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -163,8 +160,7 @@ def test_get_vendor_details(vendor_id):
|
||||
print("=" * 60)
|
||||
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
|
||||
headers=get_headers()
|
||||
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers()
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -196,8 +192,7 @@ def test_transfer_ownership(vendor_id, new_owner_user_id):
|
||||
|
||||
# First, get current owner info
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
|
||||
headers=get_headers()
|
||||
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers()
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -211,13 +206,13 @@ def test_transfer_ownership(vendor_id, new_owner_user_id):
|
||||
transfer_data = {
|
||||
"new_owner_user_id": new_owner_user_id,
|
||||
"confirm_transfer": True,
|
||||
"transfer_reason": "Testing ownership transfer feature"
|
||||
"transfer_reason": "Testing ownership transfer feature",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}/transfer-ownership",
|
||||
headers=get_headers(),
|
||||
json=transfer_data
|
||||
json=transfer_data,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
@@ -22,15 +22,17 @@ import argparse
|
||||
import ast
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class Severity(Enum):
|
||||
"""Validation severity levels"""
|
||||
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
INFO = "info"
|
||||
@@ -39,6 +41,7 @@ class Severity(Enum):
|
||||
@dataclass
|
||||
class Violation:
|
||||
"""Represents an architectural rule violation"""
|
||||
|
||||
rule_id: str
|
||||
rule_name: str
|
||||
severity: Severity
|
||||
@@ -52,6 +55,7 @@ class Violation:
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Results of architecture validation"""
|
||||
|
||||
violations: List[Violation] = field(default_factory=list)
|
||||
files_checked: int = 0
|
||||
rules_applied: int = 0
|
||||
@@ -82,7 +86,7 @@ class ArchitectureValidator:
|
||||
print(f"❌ Configuration file not found: {self.config_path}")
|
||||
sys.exit(1)
|
||||
|
||||
with open(self.config_path, 'r') as f:
|
||||
with open(self.config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
print(f"📋 Loaded architecture rules: {config.get('project', 'unknown')}")
|
||||
@@ -126,7 +130,7 @@ class ArchitectureValidator:
|
||||
continue
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# API-001: Check for Pydantic model usage
|
||||
self._check_pydantic_usage(file_path, content, lines)
|
||||
@@ -147,8 +151,8 @@ class ArchitectureValidator:
|
||||
return
|
||||
|
||||
# Check for response_model in route decorators
|
||||
route_pattern = r'@router\.(get|post|put|delete|patch)'
|
||||
dict_return_pattern = r'return\s+\{.*\}'
|
||||
route_pattern = r"@router\.(get|post|put|delete|patch)"
|
||||
dict_return_pattern = r"return\s+\{.*\}"
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# Check for dict returns in endpoints
|
||||
@@ -166,47 +170,51 @@ class ArchitectureValidator:
|
||||
if re.search(dict_return_pattern, func_line):
|
||||
self._add_violation(
|
||||
rule_id="API-001",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.ERROR,
|
||||
file_path=file_path,
|
||||
line_number=j + 1,
|
||||
message="Endpoint returns raw dict instead of Pydantic model",
|
||||
context=func_line.strip(),
|
||||
suggestion="Define a Pydantic response model and use response_model parameter"
|
||||
suggestion="Define a Pydantic response model and use response_model parameter",
|
||||
)
|
||||
|
||||
def _check_no_business_logic_in_endpoints(self, file_path: Path, content: str, lines: List[str]):
|
||||
def _check_no_business_logic_in_endpoints(
|
||||
self, file_path: Path, content: str, lines: List[str]
|
||||
):
|
||||
"""API-002: Ensure no business logic in endpoints"""
|
||||
rule = self._get_rule("API-002")
|
||||
if not rule:
|
||||
return
|
||||
|
||||
anti_patterns = [
|
||||
(r'db\.add\(', "Database operations should be in service layer"),
|
||||
(r'db\.commit\(\)', "Database commits should be in service layer"),
|
||||
(r'db\.query\(', "Database queries should be in service layer"),
|
||||
(r'db\.execute\(', "Database operations should be in service layer"),
|
||||
(r"db\.add\(", "Database operations should be in service layer"),
|
||||
(r"db\.commit\(\)", "Database commits should be in service layer"),
|
||||
(r"db\.query\(", "Database queries should be in service layer"),
|
||||
(r"db\.execute\(", "Database operations should be in service layer"),
|
||||
]
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# Skip service method calls (allowed)
|
||||
if '_service.' in line or 'service.' in line:
|
||||
if "_service." in line or "service." in line:
|
||||
continue
|
||||
|
||||
for pattern, message in anti_patterns:
|
||||
if re.search(pattern, line):
|
||||
self._add_violation(
|
||||
rule_id="API-002",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.ERROR,
|
||||
file_path=file_path,
|
||||
line_number=i,
|
||||
message=message,
|
||||
context=line.strip(),
|
||||
suggestion="Move database operations to service layer"
|
||||
suggestion="Move database operations to service layer",
|
||||
)
|
||||
|
||||
def _check_endpoint_exception_handling(self, file_path: Path, content: str, lines: List[str]):
|
||||
def _check_endpoint_exception_handling(
|
||||
self, file_path: Path, content: str, lines: List[str]
|
||||
):
|
||||
"""API-003: Check proper exception handling in endpoints"""
|
||||
rule = self._get_rule("API-003")
|
||||
if not rule:
|
||||
@@ -222,40 +230,41 @@ class ArchitectureValidator:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
# Check if it's a route handler
|
||||
has_router_decorator = any(
|
||||
isinstance(d, ast.Call) and
|
||||
isinstance(d.func, ast.Attribute) and
|
||||
getattr(d.func.value, 'id', None) == 'router'
|
||||
isinstance(d, ast.Call)
|
||||
and isinstance(d.func, ast.Attribute)
|
||||
and getattr(d.func.value, "id", None) == "router"
|
||||
for d in node.decorator_list
|
||||
)
|
||||
|
||||
if has_router_decorator:
|
||||
# Check if function body has try/except
|
||||
has_try_except = any(
|
||||
isinstance(child, ast.Try)
|
||||
for child in ast.walk(node)
|
||||
isinstance(child, ast.Try) for child in ast.walk(node)
|
||||
)
|
||||
|
||||
# Check if function calls service methods
|
||||
has_service_call = any(
|
||||
isinstance(child, ast.Call) and
|
||||
isinstance(child.func, ast.Attribute) and
|
||||
'service' in getattr(child.func.value, 'id', '').lower()
|
||||
isinstance(child, ast.Call)
|
||||
and isinstance(child.func, ast.Attribute)
|
||||
and "service" in getattr(child.func.value, "id", "").lower()
|
||||
for child in ast.walk(node)
|
||||
)
|
||||
|
||||
if has_service_call and not has_try_except:
|
||||
self._add_violation(
|
||||
rule_id="API-003",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.WARNING,
|
||||
file_path=file_path,
|
||||
line_number=node.lineno,
|
||||
message=f"Endpoint '{node.name}' calls service but lacks exception handling",
|
||||
context=f"def {node.name}(...)",
|
||||
suggestion="Wrap service calls in try/except and convert to HTTPException"
|
||||
suggestion="Wrap service calls in try/except and convert to HTTPException",
|
||||
)
|
||||
|
||||
def _check_endpoint_authentication(self, file_path: Path, content: str, lines: List[str]):
|
||||
def _check_endpoint_authentication(
|
||||
self, file_path: Path, content: str, lines: List[str]
|
||||
):
|
||||
"""API-004: Check authentication on endpoints"""
|
||||
rule = self._get_rule("API-004")
|
||||
if not rule:
|
||||
@@ -264,24 +273,28 @@ class ArchitectureValidator:
|
||||
# This is a warning-level check
|
||||
# Look for endpoints without Depends(get_current_*)
|
||||
for i, line in enumerate(lines, 1):
|
||||
if '@router.' in line and ('post' in line or 'put' in line or 'delete' in line):
|
||||
if "@router." in line and (
|
||||
"post" in line or "put" in line or "delete" in line
|
||||
):
|
||||
# Check next 5 lines for auth
|
||||
has_auth = False
|
||||
for j in range(i, min(i + 5, len(lines))):
|
||||
if 'Depends(get_current_' in lines[j]:
|
||||
if "Depends(get_current_" in lines[j]:
|
||||
has_auth = True
|
||||
break
|
||||
|
||||
if not has_auth and 'include_in_schema=False' not in ' '.join(lines[i:i+5]):
|
||||
if not has_auth and "include_in_schema=False" not in " ".join(
|
||||
lines[i : i + 5]
|
||||
):
|
||||
self._add_violation(
|
||||
rule_id="API-004",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.WARNING,
|
||||
file_path=file_path,
|
||||
line_number=i,
|
||||
message="Endpoint may be missing authentication",
|
||||
context=line.strip(),
|
||||
suggestion="Add Depends(get_current_user) or similar if endpoint should be protected"
|
||||
suggestion="Add Depends(get_current_user) or similar if endpoint should be protected",
|
||||
)
|
||||
|
||||
def _validate_service_layer(self, target_path: Path):
|
||||
@@ -296,7 +309,7 @@ class ArchitectureValidator:
|
||||
continue
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# SVC-001: No HTTPException in services
|
||||
self._check_no_http_exception_in_services(file_path, content, lines)
|
||||
@@ -307,38 +320,45 @@ class ArchitectureValidator:
|
||||
# SVC-003: DB session as parameter
|
||||
self._check_db_session_parameter(file_path, content, lines)
|
||||
|
||||
def _check_no_http_exception_in_services(self, file_path: Path, content: str, lines: List[str]):
|
||||
def _check_no_http_exception_in_services(
|
||||
self, file_path: Path, content: str, lines: List[str]
|
||||
):
|
||||
"""SVC-001: Services must not raise HTTPException"""
|
||||
rule = self._get_rule("SVC-001")
|
||||
if not rule:
|
||||
return
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if 'raise HTTPException' in line:
|
||||
if "raise HTTPException" in line:
|
||||
self._add_violation(
|
||||
rule_id="SVC-001",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.ERROR,
|
||||
file_path=file_path,
|
||||
line_number=i,
|
||||
message="Service raises HTTPException - use domain exceptions instead",
|
||||
context=line.strip(),
|
||||
suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that"
|
||||
suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that",
|
||||
)
|
||||
|
||||
if 'from fastapi import HTTPException' in line or 'from fastapi.exceptions import HTTPException' in line:
|
||||
if (
|
||||
"from fastapi import HTTPException" in line
|
||||
or "from fastapi.exceptions import HTTPException" in line
|
||||
):
|
||||
self._add_violation(
|
||||
rule_id="SVC-001",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.ERROR,
|
||||
file_path=file_path,
|
||||
line_number=i,
|
||||
message="Service imports HTTPException - services should not know about HTTP",
|
||||
context=line.strip(),
|
||||
suggestion="Remove HTTPException import and use domain exceptions"
|
||||
suggestion="Remove HTTPException import and use domain exceptions",
|
||||
)
|
||||
|
||||
def _check_service_exceptions(self, file_path: Path, content: str, lines: List[str]):
|
||||
def _check_service_exceptions(
|
||||
self, file_path: Path, content: str, lines: List[str]
|
||||
):
|
||||
"""SVC-002: Check for proper exception handling"""
|
||||
rule = self._get_rule("SVC-002")
|
||||
if not rule:
|
||||
@@ -346,19 +366,21 @@ class ArchitectureValidator:
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# Check for generic Exception raises
|
||||
if re.match(r'\s*raise Exception\(', line):
|
||||
if re.match(r"\s*raise Exception\(", line):
|
||||
self._add_violation(
|
||||
rule_id="SVC-002",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.WARNING,
|
||||
file_path=file_path,
|
||||
line_number=i,
|
||||
message="Service raises generic Exception - use specific domain exception",
|
||||
context=line.strip(),
|
||||
suggestion="Create custom exception class for this error case"
|
||||
suggestion="Create custom exception class for this error case",
|
||||
)
|
||||
|
||||
def _check_db_session_parameter(self, file_path: Path, content: str, lines: List[str]):
|
||||
def _check_db_session_parameter(
|
||||
self, file_path: Path, content: str, lines: List[str]
|
||||
):
|
||||
"""SVC-003: Service methods should accept db session as parameter"""
|
||||
rule = self._get_rule("SVC-003")
|
||||
if not rule:
|
||||
@@ -366,16 +388,16 @@ class ArchitectureValidator:
|
||||
|
||||
# Check for SessionLocal() creation in service files
|
||||
for i, line in enumerate(lines, 1):
|
||||
if 'SessionLocal()' in line and 'class' not in line:
|
||||
if "SessionLocal()" in line and "class" not in line:
|
||||
self._add_violation(
|
||||
rule_id="SVC-003",
|
||||
rule_name=rule['name'],
|
||||
rule_name=rule["name"],
|
||||
severity=Severity.ERROR,
|
||||
file_path=file_path,
|
||||
line_number=i,
|
||||
message="Service creates database session internally",
|
||||
context=line.strip(),
|
||||
suggestion="Accept db: Session as method parameter instead"
|
||||
suggestion="Accept db: Session as method parameter instead",
|
||||
)
|
||||
|
||||
def _validate_models(self, target_path: Path):
|
||||
@@ -391,11 +413,11 @@ class ArchitectureValidator:
|
||||
continue
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# Check for mixing SQLAlchemy and Pydantic
|
||||
for i, line in enumerate(lines, 1):
|
||||
if re.search(r'class.*\(Base.*,.*BaseModel.*\)', line):
|
||||
if re.search(r"class.*\(Base.*,.*BaseModel.*\)", line):
|
||||
self._add_violation(
|
||||
rule_id="MDL-002",
|
||||
rule_name="Separate SQLAlchemy and Pydantic models",
|
||||
@@ -404,7 +426,7 @@ class ArchitectureValidator:
|
||||
line_number=i,
|
||||
message="Model mixes SQLAlchemy Base and Pydantic BaseModel",
|
||||
context=line.strip(),
|
||||
suggestion="Keep SQLAlchemy models and Pydantic models separate"
|
||||
suggestion="Keep SQLAlchemy models and Pydantic models separate",
|
||||
)
|
||||
|
||||
def _validate_exceptions(self, target_path: Path):
|
||||
@@ -418,11 +440,11 @@ class ArchitectureValidator:
|
||||
continue
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# EXC-002: Check for bare except
|
||||
for i, line in enumerate(lines, 1):
|
||||
if re.match(r'\s*except\s*:', line):
|
||||
if re.match(r"\s*except\s*:", line):
|
||||
self._add_violation(
|
||||
rule_id="EXC-002",
|
||||
rule_name="No bare except clauses",
|
||||
@@ -431,7 +453,7 @@ class ArchitectureValidator:
|
||||
line_number=i,
|
||||
message="Bare except clause catches all exceptions including system exits",
|
||||
context=line.strip(),
|
||||
suggestion="Specify exception type: except ValueError: or except Exception:"
|
||||
suggestion="Specify exception type: except ValueError: or except Exception:",
|
||||
)
|
||||
|
||||
def _validate_javascript(self, target_path: Path):
|
||||
@@ -443,11 +465,16 @@ class ArchitectureValidator:
|
||||
|
||||
for file_path in js_files:
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# JS-001: Check for window.apiClient
|
||||
for i, line in enumerate(lines, 1):
|
||||
if 'window.apiClient' in line and '//' not in line[:line.find('window.apiClient')] if 'window.apiClient' in line else True:
|
||||
if (
|
||||
"window.apiClient" in line
|
||||
and "//" not in line[: line.find("window.apiClient")]
|
||||
if "window.apiClient" in line
|
||||
else True
|
||||
):
|
||||
self._add_violation(
|
||||
rule_id="JS-001",
|
||||
rule_name="Use apiClient directly",
|
||||
@@ -456,14 +483,14 @@ class ArchitectureValidator:
|
||||
line_number=i,
|
||||
message="Use apiClient directly instead of window.apiClient",
|
||||
context=line.strip(),
|
||||
suggestion="Replace window.apiClient with apiClient"
|
||||
suggestion="Replace window.apiClient with apiClient",
|
||||
)
|
||||
|
||||
# JS-002: Check for console usage
|
||||
for i, line in enumerate(lines, 1):
|
||||
if re.search(r'console\.(log|warn|error)', line):
|
||||
if re.search(r"console\.(log|warn|error)", line):
|
||||
# Skip if it's a comment or bootstrap message
|
||||
if '//' in line or '✅' in line or 'eslint-disable' in line:
|
||||
if "//" in line or "✅" in line or "eslint-disable" in line:
|
||||
continue
|
||||
|
||||
self._add_violation(
|
||||
@@ -474,7 +501,7 @@ class ArchitectureValidator:
|
||||
line_number=i,
|
||||
message="Use centralized logger instead of console",
|
||||
context=line.strip()[:80],
|
||||
suggestion="Use window.LogConfig.createLogger('moduleName')"
|
||||
suggestion="Use window.LogConfig.createLogger('moduleName')",
|
||||
)
|
||||
|
||||
def _validate_templates(self, target_path: Path):
|
||||
@@ -486,14 +513,16 @@ class ArchitectureValidator:
|
||||
|
||||
for file_path in template_files:
|
||||
# Skip base template and partials
|
||||
if 'base.html' in file_path.name or 'partials' in str(file_path):
|
||||
if "base.html" in file_path.name or "partials" in str(file_path):
|
||||
continue
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# TPL-001: Check for extends
|
||||
has_extends = any('{% extends' in line and 'admin/base.html' in line for line in lines)
|
||||
has_extends = any(
|
||||
"{% extends" in line and "admin/base.html" in line for line in lines
|
||||
)
|
||||
|
||||
if not has_extends:
|
||||
self._add_violation(
|
||||
@@ -504,23 +533,29 @@ class ArchitectureValidator:
|
||||
line_number=1,
|
||||
message="Admin template does not extend admin/base.html",
|
||||
context=file_path.name,
|
||||
suggestion="Add {% extends 'admin/base.html' %} at the top"
|
||||
suggestion="Add {% extends 'admin/base.html' %} at the top",
|
||||
)
|
||||
|
||||
def _get_rule(self, rule_id: str) -> Dict[str, Any]:
|
||||
"""Get rule configuration by ID"""
|
||||
# Look in different rule categories
|
||||
for category in ['api_endpoint_rules', 'service_layer_rules', 'model_rules',
|
||||
'exception_rules', 'javascript_rules', 'template_rules']:
|
||||
for category in [
|
||||
"api_endpoint_rules",
|
||||
"service_layer_rules",
|
||||
"model_rules",
|
||||
"exception_rules",
|
||||
"javascript_rules",
|
||||
"template_rules",
|
||||
]:
|
||||
rules = self.config.get(category, [])
|
||||
for rule in rules:
|
||||
if rule.get('id') == rule_id:
|
||||
if rule.get("id") == rule_id:
|
||||
return rule
|
||||
return None
|
||||
|
||||
def _should_ignore_file(self, file_path: Path) -> bool:
|
||||
"""Check if file should be ignored"""
|
||||
ignore_patterns = self.config.get('ignore', {}).get('files', [])
|
||||
ignore_patterns = self.config.get("ignore", {}).get("files", [])
|
||||
|
||||
for pattern in ignore_patterns:
|
||||
if file_path.match(pattern):
|
||||
@@ -528,9 +563,17 @@ class ArchitectureValidator:
|
||||
|
||||
return False
|
||||
|
||||
def _add_violation(self, rule_id: str, rule_name: str, severity: Severity,
|
||||
file_path: Path, line_number: int, message: str,
|
||||
context: str = "", suggestion: str = ""):
|
||||
def _add_violation(
|
||||
self,
|
||||
rule_id: str,
|
||||
rule_name: str,
|
||||
severity: Severity,
|
||||
file_path: Path,
|
||||
line_number: int,
|
||||
message: str,
|
||||
context: str = "",
|
||||
suggestion: str = "",
|
||||
):
|
||||
"""Add a violation to results"""
|
||||
violation = Violation(
|
||||
rule_id=rule_id,
|
||||
@@ -540,7 +583,7 @@ class ArchitectureValidator:
|
||||
line_number=line_number,
|
||||
message=message,
|
||||
context=context,
|
||||
suggestion=suggestion
|
||||
suggestion=suggestion,
|
||||
)
|
||||
self.result.violations.append(violation)
|
||||
|
||||
@@ -590,24 +633,34 @@ class ArchitectureValidator:
|
||||
|
||||
violations_json = []
|
||||
for v in self.result.violations:
|
||||
rel_path = str(v.file_path.relative_to(self.project_root)) if self.project_root in v.file_path.parents else str(v.file_path)
|
||||
violations_json.append({
|
||||
'rule_id': v.rule_id,
|
||||
'rule_name': v.rule_name,
|
||||
'severity': v.severity.value,
|
||||
'file_path': rel_path,
|
||||
'line_number': v.line_number,
|
||||
'message': v.message,
|
||||
'context': v.context or '',
|
||||
'suggestion': v.suggestion or ''
|
||||
})
|
||||
rel_path = (
|
||||
str(v.file_path.relative_to(self.project_root))
|
||||
if self.project_root in v.file_path.parents
|
||||
else str(v.file_path)
|
||||
)
|
||||
violations_json.append(
|
||||
{
|
||||
"rule_id": v.rule_id,
|
||||
"rule_name": v.rule_name,
|
||||
"severity": v.severity.value,
|
||||
"file_path": rel_path,
|
||||
"line_number": v.line_number,
|
||||
"message": v.message,
|
||||
"context": v.context or "",
|
||||
"suggestion": v.suggestion or "",
|
||||
}
|
||||
)
|
||||
|
||||
output = {
|
||||
'files_checked': self.result.files_checked,
|
||||
'total_violations': len(self.result.violations),
|
||||
'errors': len([v for v in self.result.violations if v.severity == Severity.ERROR]),
|
||||
'warnings': len([v for v in self.result.violations if v.severity == Severity.WARNING]),
|
||||
'violations': violations_json
|
||||
"files_checked": self.result.files_checked,
|
||||
"total_violations": len(self.result.violations),
|
||||
"errors": len(
|
||||
[v for v in self.result.violations if v.severity == Severity.ERROR]
|
||||
),
|
||||
"warnings": len(
|
||||
[v for v in self.result.violations if v.severity == Severity.WARNING]
|
||||
),
|
||||
"violations": violations_json,
|
||||
}
|
||||
|
||||
print(json.dumps(output, indent=2))
|
||||
@@ -616,7 +669,11 @@ class ArchitectureValidator:
|
||||
|
||||
def _print_violation(self, v: Violation):
|
||||
"""Print a single violation"""
|
||||
rel_path = v.file_path.relative_to(self.project_root) if self.project_root in v.file_path.parents else v.file_path
|
||||
rel_path = (
|
||||
v.file_path.relative_to(self.project_root)
|
||||
if self.project_root in v.file_path.parents
|
||||
else v.file_path
|
||||
)
|
||||
|
||||
print(f"\n [{v.rule_id}] {v.rule_name}")
|
||||
print(f" File: {rel_path}:{v.line_number}")
|
||||
@@ -634,40 +691,40 @@ def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validate architecture patterns in codebase",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__
|
||||
epilog=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'path',
|
||||
nargs='?',
|
||||
"path",
|
||||
nargs="?",
|
||||
type=Path,
|
||||
default=Path.cwd(),
|
||||
help="Path to validate (default: current directory)"
|
||||
help="Path to validate (default: current directory)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--config',
|
||||
"-c",
|
||||
"--config",
|
||||
type=Path,
|
||||
default=Path.cwd() / '.architecture-rules.yaml',
|
||||
help="Path to architecture rules config (default: .architecture-rules.yaml)"
|
||||
default=Path.cwd() / ".architecture-rules.yaml",
|
||||
help="Path to architecture rules config (default: .architecture-rules.yaml)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--verbose',
|
||||
action='store_true',
|
||||
help="Show detailed output including context"
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Show detailed output including context",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--errors-only',
|
||||
action='store_true',
|
||||
help="Only show errors, suppress warnings"
|
||||
"--errors-only", action="store_true", help="Only show errors, suppress warnings"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--json',
|
||||
action='store_true',
|
||||
help="Output results as JSON (for programmatic use)"
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Output results as JSON (for programmatic use)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -687,5 +744,5 @@ def main():
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
@@ -20,10 +22,10 @@ def get_database_info():
|
||||
parsed = urlparse(db_url)
|
||||
|
||||
db_type = parsed.scheme
|
||||
if db_type == 'sqlite':
|
||||
if db_type == "sqlite":
|
||||
# Extract path from sqlite:///./path or sqlite:///path
|
||||
db_path = db_url.replace('sqlite:///', '')
|
||||
if db_path.startswith('./'):
|
||||
db_path = db_url.replace("sqlite:///", "")
|
||||
if db_path.startswith("./"):
|
||||
db_path = db_path[2:]
|
||||
return db_type, db_path
|
||||
else:
|
||||
@@ -40,7 +42,7 @@ def verify_database_setup():
|
||||
db_type, db_path = get_database_info()
|
||||
print(f"[INFO] Database type: {db_type}")
|
||||
|
||||
if db_type == 'sqlite':
|
||||
if db_type == "sqlite":
|
||||
if not os.path.exists(db_path):
|
||||
print(f"[ERROR] Database file not found: {db_path}")
|
||||
return False
|
||||
@@ -54,10 +56,14 @@ def verify_database_setup():
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Get table list (works for both SQLite and PostgreSQL)
|
||||
if db_type == 'sqlite':
|
||||
result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
|
||||
if db_type == "sqlite":
|
||||
result = conn.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
)
|
||||
else:
|
||||
result = conn.execute(text("SELECT tablename FROM pg_tables WHERE schemaname='public'"))
|
||||
result = conn.execute(
|
||||
text("SELECT tablename FROM pg_tables WHERE schemaname='public'")
|
||||
)
|
||||
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
|
||||
@@ -65,12 +71,17 @@ def verify_database_setup():
|
||||
|
||||
# Expected tables from your models
|
||||
expected_tables = [
|
||||
'users', 'products', 'inventory', 'vendors', 'products',
|
||||
'marketplace_import_jobs', 'alembic_version'
|
||||
"users",
|
||||
"products",
|
||||
"inventory",
|
||||
"vendors",
|
||||
"products",
|
||||
"marketplace_import_jobs",
|
||||
"alembic_version",
|
||||
]
|
||||
|
||||
for table in sorted(tables):
|
||||
if table == 'alembic_version':
|
||||
if table == "alembic_version":
|
||||
print(f" * {table} (migration tracking)")
|
||||
elif table in expected_tables:
|
||||
print(f" * {table}")
|
||||
@@ -78,12 +89,12 @@ def verify_database_setup():
|
||||
print(f" ? {table} (unexpected)")
|
||||
|
||||
# Check for missing expected tables
|
||||
missing_tables = set(expected_tables) - set(tables) - {'alembic_version'}
|
||||
missing_tables = set(expected_tables) - set(tables) - {"alembic_version"}
|
||||
if missing_tables:
|
||||
print(f"[WARNING] Missing expected tables: {missing_tables}")
|
||||
|
||||
# Check if alembic_version table exists
|
||||
if 'alembic_version' in tables:
|
||||
if "alembic_version" in tables:
|
||||
result = conn.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.fetchone()
|
||||
if version:
|
||||
@@ -126,16 +137,19 @@ def verify_model_structure():
|
||||
# Test database models
|
||||
try:
|
||||
from models.database.base import Base
|
||||
|
||||
print(f"[OK] Database Base imported")
|
||||
print(f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}")
|
||||
print(
|
||||
f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}"
|
||||
)
|
||||
|
||||
# Import specific models
|
||||
from models.database.user import User
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.product import Product
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.product import Product
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
print("[OK] All database models imported successfully")
|
||||
|
||||
@@ -146,13 +160,23 @@ def verify_model_structure():
|
||||
# Test API models
|
||||
try:
|
||||
import models.schema
|
||||
|
||||
print("[OK] API models package imported")
|
||||
|
||||
# Test specific API model imports
|
||||
api_modules = ['base', 'auth', 'product', 'inventory', 'vendor ', 'marketplace', 'admin', 'stats']
|
||||
api_modules = [
|
||||
"base",
|
||||
"auth",
|
||||
"product",
|
||||
"inventory",
|
||||
"vendor ",
|
||||
"marketplace",
|
||||
"admin",
|
||||
"stats",
|
||||
]
|
||||
for module in api_modules:
|
||||
try:
|
||||
__import__(f'models.api.{module}')
|
||||
__import__(f"models.api.{module}")
|
||||
print(f" * models.api.{module}")
|
||||
except ImportError:
|
||||
print(f" ? models.api.{module} (not found, optional)")
|
||||
@@ -176,7 +200,7 @@ def check_project_structure():
|
||||
"models/database/inventory.py",
|
||||
"app/core/config.py",
|
||||
"alembic/env.py",
|
||||
"alembic.ini"
|
||||
"alembic.ini",
|
||||
]
|
||||
|
||||
for path in critical_paths:
|
||||
@@ -189,7 +213,7 @@ def check_project_structure():
|
||||
init_files = [
|
||||
"models/__init__.py",
|
||||
"models/database/__init__.py",
|
||||
"models/api/__init__.py"
|
||||
"models/api/__init__.py",
|
||||
]
|
||||
|
||||
print(f"\n[INIT] Checking __init__.py files...")
|
||||
@@ -221,7 +245,9 @@ if __name__ == "__main__":
|
||||
print("Next steps:")
|
||||
print(" 1. Run 'make dev' to start your FastAPI server")
|
||||
print(" 2. Visit http://localhost:8000/docs for interactive API docs")
|
||||
print(" 3. Use your API endpoints for authentication, products, inventory, etc.")
|
||||
print(
|
||||
" 3. Use your API endpoints for authentication, products, inventory, etc."
|
||||
)
|
||||
sys.exit(0)
|
||||
else:
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user