style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -2,9 +2,9 @@
"""Database backup utility that uses project configuration."""
import os
import sys
import shutil
import sqlite3
import sys
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse
@@ -19,10 +19,10 @@ def get_database_path():
"""Extract database path from DATABASE_URL."""
db_url = settings.database_url
if db_url.startswith('sqlite:///'):
if db_url.startswith("sqlite:///"):
# Remove sqlite:/// prefix and handle relative paths
db_path = db_url.replace('sqlite:///', '')
if db_path.startswith('./'):
db_path = db_url.replace("sqlite:///", "")
if db_path.startswith("./"):
db_path = db_path[2:] # Remove ./ prefix
return db_path
else:
@@ -76,7 +76,7 @@ def backup_database():
print("[BACKUP] Starting database backup...")
print(f"[INFO] Database URL: {settings.database_url}")
if settings.database_url.startswith('sqlite'):
if settings.database_url.startswith("sqlite"):
return backup_sqlite_database()
else:
print("[INFO] For PostgreSQL databases, use pg_dump:")

View File

@@ -26,20 +26,19 @@ Usage:
"""
import sys
from pathlib import Path
from datetime import datetime, timezone
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy.orm import Session
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from models.database.content_page import ContentPage
# ============================================================================
# DEFAULT PAGE CONTENT
# ============================================================================
@@ -458,6 +457,7 @@ DEFAULT_PAGES = [
# SCRIPT FUNCTIONS
# ============================================================================
def create_default_pages(db: Session) -> None:
"""
Create default platform content pages.
@@ -475,13 +475,14 @@ def create_default_pages(db: Session) -> None:
# Check if page already exists (platform default with this slug)
existing = db.execute(
select(ContentPage).where(
ContentPage.vendor_id == None,
ContentPage.slug == page_data["slug"]
ContentPage.vendor_id == None, ContentPage.slug == page_data["slug"]
)
).scalar_one_or_none()
if existing:
print(f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists")
print(
f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists"
)
skipped_count += 1
continue
@@ -519,7 +520,9 @@ def create_default_pages(db: Session) -> None:
if created_count > 0:
print("✅ Default platform content pages created successfully!\n")
print("Next steps:")
print(" 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms")
print(
" 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms"
)
print(" 2. Vendors can override these pages through the vendor dashboard")
print(" 3. Edit platform defaults through the admin panel\n")
else:
@@ -530,6 +533,7 @@ def create_default_pages(db: Session) -> None:
# MAIN EXECUTION
# ============================================================================
def main():
"""Main execution function."""
print("\n🚀 Starting Default Content Pages Creation Script...\n")

View File

@@ -12,17 +12,18 @@ from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import sqlite3
from datetime import datetime, UTC
import os
import sqlite3
from datetime import UTC, datetime
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Get database URL
database_url = os.getenv('DATABASE_URL', 'wizamart.db')
db_path = database_url.replace('sqlite:///', '')
database_url = os.getenv("DATABASE_URL", "wizamart.db")
db_path = database_url.replace("sqlite:///", "")
print(f"📦 Creating inventory entries in {db_path}...")
@@ -30,12 +31,14 @@ conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get products without inventory
cursor.execute("""
cursor.execute(
"""
SELECT p.id, p.vendor_id, p.product_id
FROM products p
LEFT JOIN inventory i ON p.id = i.product_id
WHERE i.id IS NULL
""")
"""
)
products_without_inventory = cursor.fetchall()
if not products_without_inventory:
@@ -47,7 +50,8 @@ print(f"📦 Creating inventory for {len(products_without_inventory)} products..
# Create inventory entries
for product_id, vendor_id, sku in products_without_inventory:
cursor.execute("""
cursor.execute(
"""
INSERT INTO inventory (
vendor_id,
product_id,
@@ -57,15 +61,17 @@ for product_id, vendor_id, sku in products_without_inventory:
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
vendor_id,
product_id,
'Main Warehouse',
100, # Total quantity
0, # Reserved quantity
datetime.now(UTC),
datetime.now(UTC)
))
""",
(
vendor_id,
product_id,
"Main Warehouse",
100, # Total quantity
0, # Reserved quantity
datetime.now(UTC),
datetime.now(UTC),
),
)
conn.commit()
@@ -76,7 +82,9 @@ cursor.execute("SELECT COUNT(*) FROM inventory")
total_count = cursor.fetchone()[0]
print(f"\n📊 Total inventory entries: {total_count}")
cursor.execute("SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5")
cursor.execute(
"SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5"
)
print("\n📦 Sample inventory:")
for row in cursor.fetchall():
available = row[2] - row[3]

View File

@@ -12,18 +12,20 @@ from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from models.database.vendor import Vendor
from models.database.content_page import ContentPage
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from models.database.content_page import ContentPage
from models.database.vendor import Vendor
def create_landing_page(
vendor_subdomain: str,
template: str = "default",
title: str = None,
content: str = None
content: str = None,
):
"""
Create a landing page for a vendor.
@@ -38,9 +40,7 @@ def create_landing_page(
try:
# Find vendor
vendor = db.query(Vendor).filter(
Vendor.subdomain == vendor_subdomain
).first()
vendor = db.query(Vendor).filter(Vendor.subdomain == vendor_subdomain).first()
if not vendor:
print(f"❌ Vendor '{vendor_subdomain}' not found!")
@@ -49,10 +49,11 @@ def create_landing_page(
print(f"✅ Found vendor: {vendor.name} (ID: {vendor.id})")
# Check if landing page already exists
existing = db.query(ContentPage).filter(
ContentPage.vendor_id == vendor.id,
ContentPage.slug == "landing"
).first()
existing = (
db.query(ContentPage)
.filter(ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing")
.first()
)
if existing:
print(f"⚠️ Landing page already exists (ID: {existing.id})")
@@ -74,7 +75,8 @@ def create_landing_page(
vendor_id=vendor.id,
slug="landing",
title=title or f"Welcome to {vendor.name}",
content=content or f"""
content=content
or f"""
<h2>About {vendor.name}</h2>
<p>{vendor.description or 'Your trusted shopping destination for quality products.'}</p>
@@ -96,7 +98,7 @@ def create_landing_page(
published_at=datetime.now(timezone.utc),
show_in_footer=False,
show_in_header=False,
display_order=0
display_order=0,
)
db.add(landing_page)
@@ -141,10 +143,13 @@ def list_vendors():
print(f" Code: {vendor.vendor_code}")
# Check if has landing page
landing = db.query(ContentPage).filter(
ContentPage.vendor_id == vendor.id,
ContentPage.slug == "landing"
).first()
landing = (
db.query(ContentPage)
.filter(
ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing"
)
.first()
)
if landing:
print(f" Landing Page: ✅ ({landing.template})")
@@ -165,7 +170,7 @@ def show_templates():
("default", "Clean professional layout with 3-column quick links"),
("minimal", "Ultra-simple centered design with single CTA"),
("modern", "Full-screen hero with animations and features"),
("full", "Maximum features with split-screen hero and stats")
("full", "Maximum features with split-screen hero and stats"),
]
for name, desc in templates:
@@ -209,7 +214,7 @@ if __name__ == "__main__":
success = create_landing_page(
vendor_subdomain=vendor_subdomain,
template=template,
title=title if title else None
title=title if title else None,
)
if success:

View File

@@ -46,16 +46,22 @@ def create_platform_pages():
print("1. Creating Platform Homepage...")
# Check if already exists
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="platform_homepage").first()
existing = (
db.query(ContentPage)
.filter_by(vendor_id=None, slug="platform_homepage")
.first()
)
if existing:
print(f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})")
print(
f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})"
)
else:
try:
homepage = content_page_service.create_page(
db,
slug="platform_homepage",
title="Welcome to Our Multi-Vendor Marketplace",
content="""
db,
slug="platform_homepage",
title="Welcome to Our Multi-Vendor Marketplace",
content="""
<p class="lead">
Connect vendors with customers worldwide. Build your online store and reach millions of shoppers.
</p>
@@ -64,14 +70,14 @@ def create_platform_pages():
with minimal effort and maximum impact.
</p>
""",
template="modern", # Uses platform/homepage-modern.html
vendor_id=None, # Platform-level page
is_published=True,
show_in_header=False, # Homepage is not in menu (it's the root)
show_in_footer=False,
display_order=0,
template="modern", # Uses platform/homepage-modern.html
vendor_id=None, # Platform-level page
is_published=True,
show_in_header=False, # Homepage is not in menu (it's the root)
show_in_footer=False,
display_order=0,
meta_description="Leading multi-vendor marketplace platform. Connect with thousands of vendors and discover millions of products.",
meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform"
meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform",
)
print(f" ✅ Created: {homepage.title} (/{homepage.slug})")
except Exception as e:
@@ -88,10 +94,10 @@ def create_platform_pages():
else:
try:
about = content_page_service.create_page(
db,
slug="about",
title="About Us",
content="""
db,
slug="about",
title="About Us",
content="""
<h2>Our Mission</h2>
<p>
We're on a mission to democratize e-commerce by providing powerful,
@@ -121,13 +127,13 @@ def create_platform_pages():
<li><strong>Excellence:</strong> We strive for the highest quality in everything we do</li>
</ul>
""",
vendor_id=None,
is_published=True,
show_in_header=True, # Show in header navigation
show_in_footer=True, # Show in footer
display_order=1,
vendor_id=None,
is_published=True,
show_in_header=True, # Show in header navigation
show_in_footer=True, # Show in footer
display_order=1,
meta_description="Learn about our mission to democratize e-commerce and empower entrepreneurs worldwide.",
meta_keywords="about us, mission, vision, values, company"
meta_keywords="about us, mission, vision, values, company",
)
print(f" ✅ Created: {about.title} (/{about.slug})")
except Exception as e:
@@ -144,10 +150,10 @@ def create_platform_pages():
else:
try:
faq = content_page_service.create_page(
db,
slug="faq",
title="Frequently Asked Questions",
content="""
db,
slug="faq",
title="Frequently Asked Questions",
content="""
<h2>Getting Started</h2>
<h3>How do I create a vendor account?</h3>
@@ -204,13 +210,13 @@ def create_platform_pages():
and marketing tools.
</p>
""",
vendor_id=None,
is_published=True,
show_in_header=True, # Show in header navigation
show_in_footer=True,
display_order=2,
vendor_id=None,
is_published=True,
show_in_header=True, # Show in header navigation
show_in_footer=True,
display_order=2,
meta_description="Find answers to common questions about our marketplace platform.",
meta_keywords="faq, frequently asked questions, help, support"
meta_keywords="faq, frequently asked questions, help, support",
)
print(f" ✅ Created: {faq.title} (/{faq.slug})")
except Exception as e:
@@ -221,16 +227,18 @@ def create_platform_pages():
# ========================================================================
print("4. Creating Contact Us page...")
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first()
existing = (
db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first()
)
if existing:
print(f" ⚠️ Skipped: Contact Us - already exists (ID: {existing.id})")
else:
try:
contact = content_page_service.create_page(
db,
slug="contact",
title="Contact Us",
content="""
db,
slug="contact",
title="Contact Us",
content="""
<h2>Get in Touch</h2>
<p>
We'd love to hear from you! Whether you have questions about our platform,
@@ -271,13 +279,13 @@ def create_platform_pages():
Email: <a href="mailto:press@marketplace.com">press@marketplace.com</a>
</p>
""",
vendor_id=None,
is_published=True,
show_in_header=True, # Show in header navigation
show_in_footer=True,
display_order=3,
vendor_id=None,
is_published=True,
show_in_header=True, # Show in header navigation
show_in_footer=True,
display_order=3,
meta_description="Get in touch with our team. We're here to help you succeed.",
meta_keywords="contact, support, email, phone, address"
meta_keywords="contact, support, email, phone, address",
)
print(f" ✅ Created: {contact.title} (/{contact.slug})")
except Exception as e:
@@ -290,14 +298,16 @@ def create_platform_pages():
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="terms").first()
if existing:
print(f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})")
print(
f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})"
)
else:
try:
terms = content_page_service.create_page(
db,
slug="terms",
title="Terms of Service",
content="""
db,
slug="terms",
title="Terms of Service",
content="""
<p><em>Last updated: January 1, 2024</em></p>
<h2>1. Acceptance of Terms</h2>
@@ -361,13 +371,13 @@ def create_platform_pages():
<a href="mailto:legal@marketplace.com">legal@marketplace.com</a>.
</p>
""",
vendor_id=None,
is_published=True,
show_in_header=False, # Too legal for header
show_in_footer=True, # Show in footer
display_order=10,
vendor_id=None,
is_published=True,
show_in_header=False, # Too legal for header
show_in_footer=True, # Show in footer
display_order=10,
meta_description="Read our terms of service and platform usage policies.",
meta_keywords="terms of service, terms, legal, policy, agreement"
meta_keywords="terms of service, terms, legal, policy, agreement",
)
print(f" ✅ Created: {terms.title} (/{terms.slug})")
except Exception as e:
@@ -378,16 +388,18 @@ def create_platform_pages():
# ========================================================================
print("6. Creating Privacy Policy page...")
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first()
existing = (
db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first()
)
if existing:
print(f" ⚠️ Skipped: Privacy Policy - already exists (ID: {existing.id})")
else:
try:
privacy = content_page_service.create_page(
db,
slug="privacy",
title="Privacy Policy",
content="""
db,
slug="privacy",
title="Privacy Policy",
content="""
<p><em>Last updated: January 1, 2024</em></p>
<h2>1. Information We Collect</h2>
@@ -453,13 +465,13 @@ def create_platform_pages():
<a href="mailto:privacy@marketplace.com">privacy@marketplace.com</a>.
</p>
""",
vendor_id=None,
is_published=True,
show_in_header=False, # Too legal for header
show_in_footer=True, # Show in footer
display_order=11,
vendor_id=None,
is_published=True,
show_in_header=False, # Too legal for header
show_in_footer=True, # Show in footer
display_order=11,
meta_description="Learn how we collect, use, and protect your personal information.",
meta_keywords="privacy policy, privacy, data protection, gdpr, personal information"
meta_keywords="privacy policy, privacy, data protection, gdpr, personal information",
)
print(f" ✅ Created: {privacy.title} (/{privacy.slug})")
except Exception as e:

View File

@@ -17,29 +17,30 @@ This script is idempotent - safe to run multiple times.
"""
import sys
from pathlib import Path
from datetime import datetime, timezone
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy.orm import Session
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.config import (print_environment_info, settings,
validate_production_settings)
from app.core.database import SessionLocal
from app.core.config import settings, print_environment_info, validate_production_settings
from app.core.environment import is_production, get_environment
from models.database.user import User
from models.database.admin import AdminSetting
from middleware.auth import AuthManager
from app.core.environment import get_environment, is_production
from app.core.permissions import PermissionGroups
from middleware.auth import AuthManager
from models.database.admin import AdminSetting
from models.database.user import User
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def print_header(text: str):
"""Print formatted header."""
print("\n" + "=" * 70)
@@ -71,6 +72,7 @@ def print_error(text: str):
# INITIALIZATION FUNCTIONS
# =============================================================================
def create_admin_user(db: Session, auth_manager: AuthManager) -> User:
"""Create or get the platform admin user."""
@@ -206,9 +208,7 @@ def create_admin_settings(db: Session) -> int:
for setting_data in default_settings:
# Check if setting already exists
existing = db.execute(
select(AdminSetting).where(
AdminSetting.key == setting_data["key"]
)
select(AdminSetting).where(AdminSetting.key == setting_data["key"])
).scalar_one_or_none()
if not existing:
@@ -281,6 +281,7 @@ def verify_rbac_schema(db: Session) -> bool:
# MAIN INITIALIZATION
# =============================================================================
def initialize_production(db: Session, auth_manager: AuthManager):
"""Initialize production database with essential data."""
@@ -362,6 +363,7 @@ def print_summary(db: Session):
# MAIN ENTRY POINT
# =============================================================================
def main():
"""Main entry point."""
@@ -405,6 +407,7 @@ def main():
print_header("❌ INITIALIZATION FAILED")
print(f"\nError: {e}\n")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -7,7 +7,7 @@ Usage: python route_diagnostics.py
"""
import sys
from typing import List, Dict
from typing import Dict, List
def check_route_order():
@@ -29,23 +29,23 @@ def check_route_order():
html_routes = []
for route in routes:
if hasattr(route, 'path'):
if hasattr(route, "path"):
path = route.path
methods = getattr(route, 'methods', set())
methods = getattr(route, "methods", set())
# Determine if JSON or HTML based on common patterns
if 'login' in path or 'dashboard' in path or 'products' in path:
if "login" in path or "dashboard" in path or "products" in path:
# Check response class if available
if hasattr(route, 'response_class'):
if hasattr(route, "response_class"):
response_class = str(route.response_class)
if 'HTML' in response_class:
if "HTML" in response_class:
html_routes.append((path, methods))
else:
json_routes.append((path, methods))
else:
# Check endpoint name/tags
endpoint = getattr(route, 'endpoint', None)
if endpoint and 'page' in endpoint.__name__.lower():
endpoint = getattr(route, "endpoint", None)
if endpoint and "page" in endpoint.__name__.lower():
html_routes.append((path, methods))
else:
json_routes.append((path, methods))
@@ -57,8 +57,10 @@ def check_route_order():
# Check for specific vendor info route
vendor_info_found = False
for route in routes:
if hasattr(route, 'path'):
if '/{vendor_code}' == route.path and 'GET' in getattr(route, 'methods', set()):
if hasattr(route, "path"):
if "/{vendor_code}" == route.path and "GET" in getattr(
route, "methods", set()
):
vendor_info_found = True
print("✅ Found vendor info endpoint: GET /{vendor_code}")
break
@@ -100,14 +102,14 @@ def test_vendor_endpoint():
print(f" Content-Type: {response.headers.get('content-type', 'N/A')}")
# Check if response is JSON
content_type = response.headers.get('content-type', '')
if 'application/json' in content_type:
content_type = response.headers.get("content-type", "")
if "application/json" in content_type:
print("✅ Response is JSON")
data = response.json()
print(f" Vendor: {data.get('name', 'N/A')}")
print(f" Code: {data.get('vendor_code', 'N/A')}")
return True
elif 'text/html' in content_type:
elif "text/html" in content_type:
print("❌ ERROR: Response is HTML, not JSON!")
print(" This confirms the route ordering issue")
print(" The HTML page route is catching the API request")

View File

@@ -26,39 +26,39 @@ This script is idempotent when run normally.
"""
import sys
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from pathlib import Path
from typing import Dict, List
from datetime import datetime, timezone, timedelta
from decimal import Decimal
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy.orm import Session
from sqlalchemy import select, delete
from app.core.database import SessionLocal
from app.core.config import settings
from app.core.environment import is_production, get_environment
from models.database.user import User
from models.database.vendor import Vendor, VendorUser, Role
from models.database.vendor_domain import VendorDomain
from models.database.vendor_theme import VendorTheme
from models.database.customer import Customer, CustomerAddress
from models.database.product import Product
from models.database.marketplace_product import MarketplaceProduct
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.order import Order, OrderItem
from models.database.admin import PlatformAlert
from middleware.auth import AuthManager
# =============================================================================
# MODE DETECTION (from environment variable set by Makefile)
# =============================================================================
import os
SEED_MODE = os.getenv('SEED_MODE', 'normal') # normal, minimal, reset
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import SessionLocal
from app.core.environment import get_environment, is_production
from middleware.auth import AuthManager
from models.database.admin import PlatformAlert
from models.database.customer import Customer, CustomerAddress
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order, OrderItem
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
from models.database.vendor_domain import VendorDomain
from models.database.vendor_theme import VendorTheme
SEED_MODE = os.getenv("SEED_MODE", "normal") # normal, minimal, reset
# =============================================================================
# DEMO DATA CONFIGURATION
@@ -147,6 +147,7 @@ THEME_PRESETS = {
# HELPER FUNCTIONS
# =============================================================================
def print_header(text: str):
"""Print formatted header."""
print("\n" + "=" * 70)
@@ -178,6 +179,7 @@ def print_error(text: str):
# SAFETY CHECKS
# =============================================================================
def check_environment():
"""Prevent running demo seed in production."""
@@ -196,9 +198,7 @@ def check_environment():
def check_admin_exists(db: Session) -> bool:
"""Check if admin user exists."""
admin = db.execute(
select(User).where(User.role == "admin")
).scalar_one_or_none()
admin = db.execute(select(User).where(User.role == "admin")).scalar_one_or_none()
if not admin:
print_error("No admin user found!")
@@ -214,6 +214,7 @@ def check_admin_exists(db: Session) -> bool:
# DATA DELETION (for reset mode)
# =============================================================================
def reset_all_data(db: Session):
"""Delete ALL data from database (except admin user)."""
@@ -258,13 +259,14 @@ def reset_all_data(db: Session):
# SEEDING FUNCTIONS
# =============================================================================
def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
"""Create demo vendors with users."""
vendors = []
# Determine how many vendors to create based on mode
vendor_count = 1 if SEED_MODE == 'minimal' else settings.seed_demo_vendors
vendor_count = 1 if SEED_MODE == "minimal" else settings.seed_demo_vendors
vendors_to_create = DEMO_VENDORS[:vendor_count]
users_to_create = DEMO_VENDOR_USERS[:vendor_count]
@@ -320,7 +322,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
db.add(vendor_user_link)
# Create vendor theme
theme_colors = THEME_PRESETS.get(vendor_data["theme_preset"], THEME_PRESETS["modern"])
theme_colors = THEME_PRESETS.get(
vendor_data["theme_preset"], THEME_PRESETS["modern"]
)
theme = VendorTheme(
vendor_id=vendor.id,
theme_name=vendor_data["theme_preset"],
@@ -355,7 +359,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
return vendors
def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager, count: int) -> List[Customer]:
def create_demo_customers(
db: Session, vendor: Vendor, auth_manager: AuthManager, count: int
) -> List[Customer]:
"""Create demo customers for a vendor."""
customers = []
@@ -367,10 +373,11 @@ def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager
customer_number = f"CUST-{vendor.vendor_code}-{i:04d}"
# Check if customer already exists
existing_customer = db.query(Customer).filter(
Customer.vendor_id == vendor.id,
Customer.email == email
).first()
existing_customer = (
db.query(Customer)
.filter(Customer.vendor_id == vendor.id, Customer.email == email)
.first()
)
if existing_customer:
customers.append(existing_customer)
@@ -412,19 +419,22 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc
product_id = f"{vendor.vendor_code}-PROD-{i:03d}"
# Check if this product already exists
existing_product = db.query(Product).filter(
Product.vendor_id == vendor.id,
Product.product_id == product_id
).first()
existing_product = (
db.query(Product)
.filter(Product.vendor_id == vendor.id, Product.product_id == product_id)
.first()
)
if existing_product:
products.append(existing_product)
continue # Skip creation, product already exists
# Check if marketplace product already exists
existing_mp = db.query(MarketplaceProduct).filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
).first()
existing_mp = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
.first()
)
if existing_mp:
marketplace_product = existing_mp
@@ -485,6 +495,7 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc
# MAIN SEEDING
# =============================================================================
def seed_demo_data(db: Session, auth_manager: AuthManager):
"""Seed demo data for development."""
@@ -501,7 +512,7 @@ def seed_demo_data(db: Session, auth_manager: AuthManager):
sys.exit(1)
# Step 3: Reset data if in reset mode
if SEED_MODE == 'reset':
if SEED_MODE == "reset":
print_step(3, "Resetting data...")
reset_all_data(db)
@@ -513,20 +524,13 @@ def seed_demo_data(db: Session, auth_manager: AuthManager):
print_step(5, "Creating demo customers...")
for vendor in vendors:
create_demo_customers(
db,
vendor,
auth_manager,
count=settings.seed_customers_per_vendor
db, vendor, auth_manager, count=settings.seed_customers_per_vendor
)
# Step 6: Create products
print_step(6, "Creating demo products...")
for vendor in vendors:
create_demo_products(
db,
vendor,
count=settings.seed_products_per_vendor
)
create_demo_products(db, vendor, count=settings.seed_products_per_vendor)
# Commit all changes
db.commit()
@@ -558,17 +562,18 @@ def print_summary(db: Session):
print(f" Subdomain: {vendor.subdomain}.{settings.platform_domain}")
# Query custom domains separately
custom_domain = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor.id,
VendorDomain.is_active == True
).first()
custom_domain = (
db.query(VendorDomain)
.filter(VendorDomain.vendor_id == vendor.id, VendorDomain.is_active == True)
.first()
)
if custom_domain:
# Try different possible field names (model field might vary)
domain_value = (
getattr(custom_domain, 'domain', None) or
getattr(custom_domain, 'domain_name', None) or
getattr(custom_domain, 'name', None)
getattr(custom_domain, "domain", None)
or getattr(custom_domain, "domain_name", None)
or getattr(custom_domain, "name", None)
)
if domain_value:
print(f" Custom: {domain_value}")
@@ -584,8 +589,12 @@ def print_summary(db: Session):
print(f" Email: {vendor_data['email']}")
print(f" Password: {vendor_data['password']}")
if vendor:
print(f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login")
print(f" or http://{vendor.subdomain}.localhost:8000/vendor/login")
print(
f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login"
)
print(
f" or http://{vendor.subdomain}.localhost:8000/vendor/login"
)
print()
print(f"\n🛒 Demo Customer Credentials:")
@@ -600,7 +609,9 @@ def print_summary(db: Session):
print("" * 70)
for vendor in vendors:
print(f" {vendor.name}:")
print(f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/")
print(
f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/"
)
print(f" Subdomain: http://{vendor.subdomain}.localhost:8000/")
print()
@@ -621,6 +632,7 @@ def print_summary(db: Session):
# MAIN ENTRY POINT
# =============================================================================
def main():
"""Main entry point."""
@@ -647,6 +659,7 @@ def main():
print_header("❌ SEEDING FAILED")
print(f"\nError: {e}\n")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -11,11 +11,11 @@ Usage:
"""
import os
import sys
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Dict
from typing import Dict, List
def count_files(directory: str, pattern: str) -> int:
@@ -26,10 +26,14 @@ def count_files(directory: str, pattern: str) -> int:
count = 0
for root, dirs, files in os.walk(directory):
# Skip __pycache__ and other cache directories
dirs[:] = [d for d in dirs if d not in ['__pycache__', '.pytest_cache', '.git', 'node_modules']]
dirs[:] = [
d
for d in dirs
if d not in ["__pycache__", ".pytest_cache", ".git", "node_modules"]
]
for file in files:
if pattern == '*' or file.endswith(pattern):
if pattern == "*" or file.endswith(pattern):
count += 1
return count
@@ -41,26 +45,26 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st
# Try to use system tree command first
try:
if sys.platform == 'win32':
if sys.platform == "win32":
# Windows tree command
result = subprocess.run(
['tree', '/F', '/A', directory],
["tree", "/F", "/A", directory],
capture_output=True,
text=True,
encoding='utf-8',
errors='replace'
encoding="utf-8",
errors="replace",
)
return result.stdout
else:
# Linux/Mac tree command with exclusions
exclude_args = []
if exclude_patterns:
exclude_args = ['-I', '|'.join(exclude_patterns)]
exclude_args = ["-I", "|".join(exclude_patterns)]
result = subprocess.run(
['tree', '-F', '-a'] + exclude_args + [directory],
["tree", "-F", "-a"] + exclude_args + [directory],
capture_output=True,
text=True
text=True,
)
if result.returncode == 0:
return result.stdout
@@ -71,10 +75,19 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st
return generate_manual_tree(directory, exclude_patterns)
def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, prefix: str = "") -> str:
def generate_manual_tree(
directory: str, exclude_patterns: List[str] = None, prefix: str = ""
) -> str:
"""Generate tree structure manually when tree command is not available."""
if exclude_patterns is None:
exclude_patterns = ['__pycache__', '.pytest_cache', '.git', 'node_modules', '*.pyc', '*.pyo']
exclude_patterns = [
"__pycache__",
".pytest_cache",
".git",
"node_modules",
"*.pyc",
"*.pyo",
]
output = []
path = Path(directory)
@@ -86,7 +99,7 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre
# Skip excluded patterns
skip = False
for pattern in exclude_patterns:
if pattern.startswith('*'):
if pattern.startswith("*"):
# File extension pattern
if item.name.endswith(pattern[1:]):
skip = True
@@ -106,7 +119,9 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre
if item.is_dir():
output.append(f"{prefix}{current_prefix}{item.name}/")
extension = " " if is_last else ""
subtree = generate_manual_tree(str(item), exclude_patterns, prefix + extension)
subtree = generate_manual_tree(
str(item), exclude_patterns, prefix + extension
)
if subtree:
output.append(subtree)
else:
@@ -127,20 +142,36 @@ def generate_frontend_structure() -> str:
# Templates section
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ JINJA2 TEMPLATES ║")
output.append("║ Location: app/templates ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ JINJA2 TEMPLATES ║"
)
output.append(
"║ Location: app/templates ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append(get_tree_structure("app/templates"))
# Static assets section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ STATIC ASSETS ║")
output.append("║ Location: static ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ STATIC ASSETS ║"
)
output.append(
"║ Location: static ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append(get_tree_structure("static"))
@@ -148,11 +179,21 @@ def generate_frontend_structure() -> str:
if os.path.exists("docs"):
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ DOCUMENTATION ║")
output.append("║ Location: docs ║")
output.append("║ (also listed in tools structure) ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ DOCUMENTATION ║"
)
output.append(
"║ Location: docs ║"
)
output.append(
"║ (also listed in tools structure) ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append("Note: Documentation is also included in tools structure")
output.append(" for infrastructure/DevOps context.")
@@ -160,9 +201,15 @@ def generate_frontend_structure() -> str:
# Statistics section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ STATISTICS ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ STATISTICS ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append("Templates:")
@@ -174,8 +221,8 @@ def generate_frontend_structure() -> str:
output.append(f" - JavaScript files: {count_files('static', '.js')}")
output.append(f" - CSS files: {count_files('static', '.css')}")
output.append(" - Image files:")
for ext in ['png', 'jpg', 'jpeg', 'gif', 'svg', 'webp', 'ico']:
count = count_files('static', f'.{ext}')
for ext in ["png", "jpg", "jpeg", "gif", "svg", "webp", "ico"]:
count = count_files("static", f".{ext}")
if count > 0:
output.append(f" - .{ext}: {count}")
@@ -200,41 +247,58 @@ def generate_backend_structure() -> str:
output.append("=" * 78)
output.append("")
exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info', 'templates']
exclude = [
"__pycache__",
"*.pyc",
"*.pyo",
".pytest_cache",
"*.egg-info",
"templates",
]
# Backend directories to include
backend_dirs = [
('app', 'Application Code'),
('middleware', 'Middleware Components'),
('models', 'Database Models'),
('storage', 'File Storage'),
('tasks', 'Background Tasks'),
('logs', 'Application Logs'),
("app", "Application Code"),
("middleware", "Middleware Components"),
("models", "Database Models"),
("storage", "File Storage"),
("tasks", "Background Tasks"),
("logs", "Application Logs"),
]
for directory, title in backend_dirs:
if os.path.exists(directory):
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(f"{title.upper().center(62)}")
output.append(f"║ Location: {directory + '/'.ljust(51)}")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append(get_tree_structure(directory, exclude))
# Statistics section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ STATISTICS ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ STATISTICS ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append("Python Files by Directory:")
total_py_files = 0
for directory, title in backend_dirs:
if os.path.exists(directory):
count = count_files(directory, '.py')
count = count_files(directory, ".py")
total_py_files += count
output.append(f" - {directory}/: {count} files")
@@ -242,11 +306,11 @@ def generate_backend_structure() -> str:
output.append("")
output.append("Application Components (if in app/):")
components = ['routes', 'services', 'schemas', 'exceptions', 'utils']
components = ["routes", "services", "schemas", "exceptions", "utils"]
for component in components:
component_path = f"app/{component}"
if os.path.exists(component_path):
count = count_files(component_path, '.py')
count = count_files(component_path, ".py")
output.append(f" - app/{component}: {count} files")
output.append("")
@@ -264,48 +328,58 @@ def generate_tools_structure() -> str:
output.append("=" * 78)
output.append("")
exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info']
exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"]
# Tools directories to include
tools_dirs = [
('alembic', 'Database Migrations'),
('scripts', 'Utility Scripts'),
('docker', 'Docker Configuration'),
('docs', 'Documentation'),
("alembic", "Database Migrations"),
("scripts", "Utility Scripts"),
("docker", "Docker Configuration"),
("docs", "Documentation"),
]
for directory, title in tools_dirs:
if os.path.exists(directory):
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(f"{title.upper().center(62)}")
output.append(f"║ Location: {directory + '/'.ljust(51)}")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append(get_tree_structure(directory, exclude))
# Configuration files section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ CONFIGURATION FILES ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ CONFIGURATION FILES ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append("Root configuration files:")
config_files = [
('Makefile', 'Build automation'),
('requirements.txt', 'Python dependencies'),
('pyproject.toml', 'Python project config'),
('setup.py', 'Python setup script'),
('setup.cfg', 'Setup configuration'),
('alembic.ini', 'Alembic migrations config'),
('mkdocs.yml', 'MkDocs documentation config'),
('Dockerfile', 'Docker image definition'),
('docker-compose.yml', 'Docker services'),
('.dockerignore', 'Docker ignore patterns'),
('.gitignore', 'Git ignore patterns'),
('.env.example', 'Environment variables template'),
("Makefile", "Build automation"),
("requirements.txt", "Python dependencies"),
("pyproject.toml", "Python project config"),
("setup.py", "Python setup script"),
("setup.cfg", "Setup configuration"),
("alembic.ini", "Alembic migrations config"),
("mkdocs.yml", "MkDocs documentation config"),
("Dockerfile", "Docker image definition"),
("docker-compose.yml", "Docker services"),
(".dockerignore", "Docker ignore patterns"),
(".gitignore", "Git ignore patterns"),
(".env.example", "Environment variables template"),
]
for file, description in config_files:
@@ -315,9 +389,15 @@ def generate_tools_structure() -> str:
# Statistics section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ STATISTICS ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ STATISTICS ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append("Database Migrations:")
@@ -327,7 +407,9 @@ def generate_tools_structure() -> str:
if migration_count > 0:
# Get first and last migration
try:
migrations = sorted([f for f in os.listdir("alembic/versions") if f.endswith('.py')])
migrations = sorted(
[f for f in os.listdir("alembic/versions") if f.endswith(".py")]
)
if migrations:
output.append(f" - First: {migrations[0][:40]}...")
if len(migrations) > 1:
@@ -341,9 +423,9 @@ def generate_tools_structure() -> str:
output.append("Scripts:")
if os.path.exists("scripts"):
script_types = {
'.py': 'Python scripts',
'.sh': 'Shell scripts',
'.bat': 'Batch scripts',
".py": "Python scripts",
".sh": "Shell scripts",
".bat": "Batch scripts",
}
for ext, desc in script_types.items():
count = count_files("scripts", ext)
@@ -356,8 +438,8 @@ def generate_tools_structure() -> str:
output.append("Documentation:")
if os.path.exists("docs"):
doc_types = {
'.md': 'Markdown files',
'.rst': 'reStructuredText files',
".md": "Markdown files",
".rst": "reStructuredText files",
}
for ext, desc in doc_types.items():
count = count_files("docs", ext)
@@ -368,7 +450,7 @@ def generate_tools_structure() -> str:
output.append("")
output.append("Docker:")
docker_files = ['Dockerfile', 'docker-compose.yml', '.dockerignore']
docker_files = ["Dockerfile", "docker-compose.yml", ".dockerignore"]
docker_exists = any(os.path.exists(f) for f in docker_files)
if docker_exists:
output.append(" ✓ Docker configuration present")
@@ -394,25 +476,44 @@ def generate_test_structure() -> str:
# Test files section
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ TEST FILES ║")
output.append("║ Location: tests/ ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ TEST FILES ║"
)
output.append(
"║ Location: tests/ ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info']
exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"]
output.append(get_tree_structure("tests", exclude))
# Configuration section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ TEST CONFIGURATION ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ TEST CONFIGURATION ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
output.append("Test configuration files:")
test_config_files = ['pytest.ini', 'conftest.py', 'tests/conftest.py', '.coveragerc']
test_config_files = [
"pytest.ini",
"conftest.py",
"tests/conftest.py",
".coveragerc",
]
for file in test_config_files:
if os.path.exists(file):
output.append(f"{file}")
@@ -420,16 +521,22 @@ def generate_test_structure() -> str:
# Statistics section
output.append("")
output.append("")
output.append("╔══════════════════════════════════════════════════════════════════╗")
output.append("║ STATISTICS ║")
output.append("╚══════════════════════════════════════════════════════════════════╝")
output.append(
"╔══════════════════════════════════════════════════════════════════╗"
)
output.append(
"║ STATISTICS ║"
)
output.append(
"╚══════════════════════════════════════════════════════════════════╝"
)
output.append("")
# Count test files
test_file_count = 0
if os.path.exists("tests"):
for root, dirs, files in os.walk("tests"):
dirs[:] = [d for d in dirs if d != '__pycache__']
dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
if file.startswith("test_") and file.endswith(".py"):
test_file_count += 1
@@ -439,13 +546,13 @@ def generate_test_structure() -> str:
output.append("")
output.append("By Category:")
categories = ['unit', 'integration', 'system', 'e2e', 'performance']
categories = ["unit", "integration", "system", "e2e", "performance"]
for category in categories:
category_path = f"tests/{category}"
if os.path.exists(category_path):
count = 0
for root, dirs, files in os.walk(category_path):
dirs[:] = [d for d in dirs if d != '__pycache__']
dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
if file.startswith("test_") and file.endswith(".py"):
count += 1
@@ -455,12 +562,12 @@ def generate_test_structure() -> str:
test_function_count = 0
if os.path.exists("tests"):
for root, dirs, files in os.walk("tests"):
dirs[:] = [d for d in dirs if d != '__pycache__']
dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
if file.startswith("test_") and file.endswith(".py"):
filepath = os.path.join(root, file)
try:
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
if line.strip().startswith("def test_"):
test_function_count += 1
@@ -494,19 +601,19 @@ def main():
structure_type = sys.argv[1].lower()
generators = {
'frontend': ('frontend-structure.txt', generate_frontend_structure),
'backend': ('backend-structure.txt', generate_backend_structure),
'tests': ('test-structure.txt', generate_test_structure),
'tools': ('tools-structure.txt', generate_tools_structure),
"frontend": ("frontend-structure.txt", generate_frontend_structure),
"backend": ("backend-structure.txt", generate_backend_structure),
"tests": ("test-structure.txt", generate_test_structure),
"tools": ("tools-structure.txt", generate_tools_structure),
}
if structure_type == 'all':
if structure_type == "all":
for name, (filename, generator) in generators.items():
print(f"\n{'=' * 60}")
print(f"Generating {name} structure...")
print('=' * 60)
print("=" * 60)
content = generator()
with open(filename, 'w', encoding='utf-8') as f:
with open(filename, "w", encoding="utf-8") as f:
f.write(content)
print(f"{name.capitalize()} structure saved to {filename}")
print(f"\n{content}\n")
@@ -514,7 +621,7 @@ def main():
filename, generator = generators[structure_type]
print(f"Generating {structure_type} structure...")
content = generator()
with open(filename, 'w', encoding='utf-8') as f:
with open(filename, "w", encoding="utf-8") as f:
f.write(content)
print(f"\n✅ Structure saved to {filename}\n")
print(content)

View File

@@ -20,22 +20,26 @@ Requirements:
* Customer: username=customer, password=customer123, vendor_id=1
"""
import requests
import json
from typing import Dict, Optional
import requests
BASE_URL = "http://localhost:8000"
class Color:
"""Terminal colors for pretty output"""
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
MAGENTA = '\033[95m'
CYAN = '\033[96m'
BOLD = '\033[1m'
END = '\033[0m'
GREEN = "\033[92m"
RED = "\033[91m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
BOLD = "\033[1m"
END = "\033[0m"
def print_section(name: str):
"""Print section header"""
@@ -43,66 +47,70 @@ def print_section(name: str):
print(f" {name}")
print(f"{'' * 60}{Color.END}")
def print_test(name: str):
"""Print test name"""
print(f"\n{Color.BOLD}{Color.BLUE}🧪 Test: {name}{Color.END}")
def print_success(message: str):
"""Print success message"""
print(f"{Color.GREEN}{message}{Color.END}")
def print_error(message: str):
"""Print error message"""
print(f"{Color.RED}{message}{Color.END}")
def print_info(message: str):
"""Print info message"""
print(f"{Color.YELLOW} {message}{Color.END}")
def print_warning(message: str):
"""Print warning message"""
print(f"{Color.MAGENTA}⚠️ {message}{Color.END}")
# ============================================================================
# ADMIN AUTHENTICATION TESTS
# ============================================================================
def test_admin_login() -> Optional[Dict]:
"""Test admin login and cookie configuration"""
print_test("Admin Login")
try:
response = requests.post(
f"{BASE_URL}/api/v1/admin/auth/login",
json={"username": "admin", "password": "admin123"}
json={"username": "admin", "password": "admin123"},
)
if response.status_code == 200:
data = response.json()
cookies = response.cookies
if "access_token" in data:
print_success("Admin login successful")
print_success(f"Received access token: {data['access_token'][:20]}...")
else:
print_error("No access token in response")
return None
if "admin_token" in cookies:
print_success("admin_token cookie set")
print_info("Cookie path should be /admin (verify in browser)")
else:
print_error("admin_token cookie NOT set")
return {
"token": data["access_token"],
"user": data.get("user", {})
}
return {"token": data["access_token"], "user": data.get("user", {})}
else:
print_error(f"Login failed: {response.status_code}")
print_error(f"Response: {response.text}")
return None
except Exception as e:
print_error(f"Exception during admin login: {str(e)}")
return None
@@ -111,13 +119,13 @@ def test_admin_login() -> Optional[Dict]:
def test_admin_cannot_access_vendor_api(admin_token: str):
"""Test that admin token cannot access vendor API"""
print_test("Admin Token on Vendor API (Should Block)")
try:
response = requests.get(
f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)
if response.status_code in [401, 403]:
data = response.json()
print_success("Admin correctly blocked from vendor API")
@@ -129,7 +137,7 @@ def test_admin_cannot_access_vendor_api(admin_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -138,13 +146,13 @@ def test_admin_cannot_access_vendor_api(admin_token: str):
def test_admin_cannot_access_customer_api(admin_token: str):
"""Test that admin token cannot access customer account pages"""
print_test("Admin Token on Customer API (Should Block)")
try:
response = requests.get(
f"{BASE_URL}/shop/account/dashboard",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)
# Customer pages may return HTML or JSON error
if response.status_code in [401, 403]:
print_success("Admin correctly blocked from customer pages")
@@ -155,7 +163,7 @@ def test_admin_cannot_access_customer_api(admin_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -165,46 +173,47 @@ def test_admin_cannot_access_customer_api(admin_token: str):
# VENDOR AUTHENTICATION TESTS
# ============================================================================
def test_vendor_login() -> Optional[Dict]:
"""Test vendor login and cookie configuration"""
print_test("Vendor Login")
try:
response = requests.post(
f"{BASE_URL}/api/v1/vendor/auth/login",
json={"username": "vendor", "password": "vendor123"}
json={"username": "vendor", "password": "vendor123"},
)
if response.status_code == 200:
data = response.json()
cookies = response.cookies
if "access_token" in data:
print_success("Vendor login successful")
print_success(f"Received access token: {data['access_token'][:20]}...")
else:
print_error("No access token in response")
return None
if "vendor_token" in cookies:
print_success("vendor_token cookie set")
print_info("Cookie path should be /vendor (verify in browser)")
else:
print_error("vendor_token cookie NOT set")
if "vendor" in data:
print_success(f"Vendor: {data['vendor'].get('vendor_code', 'N/A')}")
return {
"token": data["access_token"],
"user": data.get("user", {}),
"vendor": data.get("vendor", {})
"vendor": data.get("vendor", {}),
}
else:
print_error(f"Login failed: {response.status_code}")
print_error(f"Response: {response.text}")
return None
except Exception as e:
print_error(f"Exception during vendor login: {str(e)}")
return None
@@ -213,13 +222,13 @@ def test_vendor_login() -> Optional[Dict]:
def test_vendor_cannot_access_admin_api(vendor_token: str):
"""Test that vendor token cannot access admin API"""
print_test("Vendor Token on Admin API (Should Block)")
try:
response = requests.get(
f"{BASE_URL}/api/v1/admin/vendors",
headers={"Authorization": f"Bearer {vendor_token}"}
headers={"Authorization": f"Bearer {vendor_token}"},
)
if response.status_code in [401, 403]:
data = response.json()
print_success("Vendor correctly blocked from admin API")
@@ -231,7 +240,7 @@ def test_vendor_cannot_access_admin_api(vendor_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -240,13 +249,13 @@ def test_vendor_cannot_access_admin_api(vendor_token: str):
def test_vendor_cannot_access_customer_api(vendor_token: str):
"""Test that vendor token cannot access customer account pages"""
print_test("Vendor Token on Customer API (Should Block)")
try:
response = requests.get(
f"{BASE_URL}/shop/account/dashboard",
headers={"Authorization": f"Bearer {vendor_token}"}
headers={"Authorization": f"Bearer {vendor_token}"},
)
if response.status_code in [401, 403]:
print_success("Vendor correctly blocked from customer pages")
return True
@@ -256,7 +265,7 @@ def test_vendor_cannot_access_customer_api(vendor_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -266,42 +275,40 @@ def test_vendor_cannot_access_customer_api(vendor_token: str):
# CUSTOMER AUTHENTICATION TESTS
# ============================================================================
def test_customer_login() -> Optional[Dict]:
"""Test customer login and cookie configuration"""
print_test("Customer Login")
try:
response = requests.post(
f"{BASE_URL}/api/v1/public/vendors/1/customers/login",
json={"username": "customer", "password": "customer123"}
json={"username": "customer", "password": "customer123"},
)
if response.status_code == 200:
data = response.json()
cookies = response.cookies
if "access_token" in data:
print_success("Customer login successful")
print_success(f"Received access token: {data['access_token'][:20]}...")
else:
print_error("No access token in response")
return None
if "customer_token" in cookies:
print_success("customer_token cookie set")
print_info("Cookie path should be /shop (verify in browser)")
else:
print_error("customer_token cookie NOT set")
return {
"token": data["access_token"],
"user": data.get("user", {})
}
return {"token": data["access_token"], "user": data.get("user", {})}
else:
print_error(f"Login failed: {response.status_code}")
print_error(f"Response: {response.text}")
return None
except Exception as e:
print_error(f"Exception during customer login: {str(e)}")
return None
@@ -310,13 +317,13 @@ def test_customer_login() -> Optional[Dict]:
def test_customer_cannot_access_admin_api(customer_token: str):
"""Test that customer token cannot access admin API"""
print_test("Customer Token on Admin API (Should Block)")
try:
response = requests.get(
f"{BASE_URL}/api/v1/admin/vendors",
headers={"Authorization": f"Bearer {customer_token}"}
headers={"Authorization": f"Bearer {customer_token}"},
)
if response.status_code in [401, 403]:
data = response.json()
print_success("Customer correctly blocked from admin API")
@@ -328,7 +335,7 @@ def test_customer_cannot_access_admin_api(customer_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -337,13 +344,13 @@ def test_customer_cannot_access_admin_api(customer_token: str):
def test_customer_cannot_access_vendor_api(customer_token: str):
"""Test that customer token cannot access vendor API"""
print_test("Customer Token on Vendor API (Should Block)")
try:
response = requests.get(
f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products",
headers={"Authorization": f"Bearer {customer_token}"}
headers={"Authorization": f"Bearer {customer_token}"},
)
if response.status_code in [401, 403]:
data = response.json()
print_success("Customer correctly blocked from vendor API")
@@ -355,7 +362,7 @@ def test_customer_cannot_access_vendor_api(customer_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -364,17 +371,17 @@ def test_customer_cannot_access_vendor_api(customer_token: str):
def test_public_shop_access():
"""Test that public shop pages are accessible without authentication"""
print_test("Public Shop Access (No Auth Required)")
try:
response = requests.get(f"{BASE_URL}/shop/products")
if response.status_code == 200:
print_success("Public shop pages accessible without auth")
return True
else:
print_error(f"Failed to access public shop: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -383,10 +390,10 @@ def test_public_shop_access():
def test_health_check():
"""Test health check endpoint"""
print_test("Health Check")
try:
response = requests.get(f"{BASE_URL}/health")
if response.status_code == 200:
data = response.json()
print_success("Health check passed")
@@ -395,7 +402,7 @@ def test_health_check():
else:
print_error(f"Health check failed: {response.status_code}")
return False
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -405,19 +412,16 @@ def test_health_check():
# MAIN TEST RUNNER
# ============================================================================
def main():
"""Run all tests"""
print(f"\n{Color.BOLD}{Color.CYAN}{'' * 60}")
print(f" 🔒 COMPLETE AUTHENTICATION SYSTEM TEST SUITE")
print(f"{'' * 60}{Color.END}")
print(f"Testing server at: {BASE_URL}")
results = {
"passed": 0,
"failed": 0,
"total": 0
}
results = {"passed": 0, "failed": 0, "total": 0}
# Health check first
print_section("🏥 System Health")
results["total"] += 1
@@ -427,12 +431,12 @@ def main():
results["failed"] += 1
print_error("Server not responding. Is it running?")
return
# ========================================================================
# ADMIN TESTS
# ========================================================================
print_section("👤 Admin Authentication Tests")
# Admin login
results["total"] += 1
admin_auth = test_admin_login()
@@ -441,7 +445,7 @@ def main():
else:
results["failed"] += 1
admin_auth = None
# Admin cross-context tests
if admin_auth:
results["total"] += 1
@@ -449,18 +453,18 @@ def main():
results["passed"] += 1
else:
results["failed"] += 1
results["total"] += 1
if test_admin_cannot_access_customer_api(admin_auth["token"]):
results["passed"] += 1
else:
results["failed"] += 1
# ========================================================================
# VENDOR TESTS
# ========================================================================
print_section("🏪 Vendor Authentication Tests")
# Vendor login
results["total"] += 1
vendor_auth = test_vendor_login()
@@ -469,7 +473,7 @@ def main():
else:
results["failed"] += 1
vendor_auth = None
# Vendor cross-context tests
if vendor_auth:
results["total"] += 1
@@ -477,25 +481,25 @@ def main():
results["passed"] += 1
else:
results["failed"] += 1
results["total"] += 1
if test_vendor_cannot_access_customer_api(vendor_auth["token"]):
results["passed"] += 1
else:
results["failed"] += 1
# ========================================================================
# CUSTOMER TESTS
# ========================================================================
print_section("🛒 Customer Authentication Tests")
# Public shop access
results["total"] += 1
if test_public_shop_access():
results["passed"] += 1
else:
results["failed"] += 1
# Customer login
results["total"] += 1
customer_auth = test_customer_login()
@@ -504,7 +508,7 @@ def main():
else:
results["failed"] += 1
customer_auth = None
# Customer cross-context tests
if customer_auth:
results["total"] += 1
@@ -512,29 +516,31 @@ def main():
results["passed"] += 1
else:
results["failed"] += 1
results["total"] += 1
if test_customer_cannot_access_vendor_api(customer_auth["token"]):
results["passed"] += 1
else:
results["failed"] += 1
# ========================================================================
# RESULTS
# ========================================================================
print_section("📊 Test Results")
print(f"\n{Color.BOLD}Total Tests: {results['total']}{Color.END}")
print_success(f"Passed: {results['passed']}")
if results['failed'] > 0:
if results["failed"] > 0:
print_error(f"Failed: {results['failed']}")
if results['failed'] == 0:
if results["failed"] == 0:
print(f"\n{Color.GREEN}{Color.BOLD}🎉 ALL TESTS PASSED!{Color.END}")
print(f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}")
print(
f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}"
)
else:
print(f"\n{Color.RED}{Color.BOLD}⚠️ SOME TESTS FAILED{Color.END}")
print(f"{Color.RED}Please review the output above.{Color.END}")
# Browser tests reminder
print_section("🌐 Manual Browser Tests")
print("Please also verify in browser:")

View File

@@ -9,10 +9,11 @@ Tests:
- Transfer ownership
"""
import requests
import json
from pprint import pprint
import requests
BASE_URL = "http://localhost:8000"
ADMIN_TOKEN = None # Will be set after login
@@ -25,8 +26,8 @@ def login_admin():
f"{BASE_URL}/api/v1/admin/auth/login", # ✅ Changed: added /admin/
json={
"username": "admin", # Replace with your admin username
"password": "admin123" # Replace with your admin password
}
"password": "admin123", # Replace with your admin password
},
)
if response.status_code == 200:
@@ -44,7 +45,7 @@ def get_headers():
"""Get authorization headers."""
return {
"Authorization": f"Bearer {ADMIN_TOKEN}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
@@ -60,13 +61,11 @@ def test_create_vendor_with_both_emails():
"subdomain": "testdual",
"owner_email": "owner@testdual.com",
"contact_email": "contact@testdual.com",
"description": "Test vendor with separate emails"
"description": "Test vendor with separate emails",
}
response = requests.post(
f"{BASE_URL}/api/v1/admin/vendors",
headers=get_headers(),
json=vendor_data
f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data
)
if response.status_code == 200:
@@ -79,7 +78,7 @@ def test_create_vendor_with_both_emails():
print(f" Username: {data['owner_username']}")
print(f" Password: {data['temporary_password']}")
print(f"\n🔗 Login URL: {data['login_url']}")
return data['id']
return data["id"]
else:
print(f"❌ Failed: {response.status_code}")
print(response.text)
@@ -97,13 +96,11 @@ def test_create_vendor_single_email():
"name": "Test Single Email Vendor",
"subdomain": "testsingle",
"owner_email": "owner@testsingle.com",
"description": "Test vendor with single email"
"description": "Test vendor with single email",
}
response = requests.post(
f"{BASE_URL}/api/v1/admin/vendors",
headers=get_headers(),
json=vendor_data
f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data
)
if response.status_code == 200:
@@ -113,12 +110,12 @@ def test_create_vendor_single_email():
print(f" Owner Email: {data['owner_email']}")
print(f" Contact Email: {data['contact_email']}")
if data['owner_email'] == data['contact_email']:
if data["owner_email"] == data["contact_email"]:
print("✅ Contact email correctly defaulted to owner email")
else:
print("❌ Contact email should have defaulted to owner email")
return data['id']
return data["id"]
else:
print(f"❌ Failed: {response.status_code}")
print(response.text)
@@ -133,13 +130,13 @@ def test_update_vendor_contact_email(vendor_id):
update_data = {
"contact_email": "newcontact@business.com",
"name": "Updated Vendor Name"
"name": "Updated Vendor Name",
}
response = requests.put(
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
headers=get_headers(),
json=update_data
json=update_data,
)
if response.status_code == 200:
@@ -163,8 +160,7 @@ def test_get_vendor_details(vendor_id):
print("=" * 60)
response = requests.get(
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
headers=get_headers()
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers()
)
if response.status_code == 200:
@@ -196,8 +192,7 @@ def test_transfer_ownership(vendor_id, new_owner_user_id):
# First, get current owner info
response = requests.get(
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
headers=get_headers()
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers()
)
if response.status_code == 200:
@@ -211,13 +206,13 @@ def test_transfer_ownership(vendor_id, new_owner_user_id):
transfer_data = {
"new_owner_user_id": new_owner_user_id,
"confirm_transfer": True,
"transfer_reason": "Testing ownership transfer feature"
"transfer_reason": "Testing ownership transfer feature",
}
response = requests.post(
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}/transfer-ownership",
headers=get_headers(),
json=transfer_data
json=transfer_data,
)
if response.status_code == 200:

View File

@@ -22,15 +22,17 @@ import argparse
import ast
import re
import sys
from pathlib import Path
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple
import yaml
class Severity(Enum):
"""Validation severity levels"""
ERROR = "error"
WARNING = "warning"
INFO = "info"
@@ -39,6 +41,7 @@ class Severity(Enum):
@dataclass
class Violation:
"""Represents an architectural rule violation"""
rule_id: str
rule_name: str
severity: Severity
@@ -52,6 +55,7 @@ class Violation:
@dataclass
class ValidationResult:
"""Results of architecture validation"""
violations: List[Violation] = field(default_factory=list)
files_checked: int = 0
rules_applied: int = 0
@@ -82,7 +86,7 @@ class ArchitectureValidator:
print(f"❌ Configuration file not found: {self.config_path}")
sys.exit(1)
with open(self.config_path, 'r') as f:
with open(self.config_path, "r") as f:
config = yaml.safe_load(f)
print(f"📋 Loaded architecture rules: {config.get('project', 'unknown')}")
@@ -126,7 +130,7 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
lines = content.split('\n')
lines = content.split("\n")
# API-001: Check for Pydantic model usage
self._check_pydantic_usage(file_path, content, lines)
@@ -147,8 +151,8 @@ class ArchitectureValidator:
return
# Check for response_model in route decorators
route_pattern = r'@router\.(get|post|put|delete|patch)'
dict_return_pattern = r'return\s+\{.*\}'
route_pattern = r"@router\.(get|post|put|delete|patch)"
dict_return_pattern = r"return\s+\{.*\}"
for i, line in enumerate(lines, 1):
# Check for dict returns in endpoints
@@ -166,47 +170,51 @@ class ArchitectureValidator:
if re.search(dict_return_pattern, func_line):
self._add_violation(
rule_id="API-001",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=j + 1,
message="Endpoint returns raw dict instead of Pydantic model",
context=func_line.strip(),
suggestion="Define a Pydantic response model and use response_model parameter"
suggestion="Define a Pydantic response model and use response_model parameter",
)
def _check_no_business_logic_in_endpoints(self, file_path: Path, content: str, lines: List[str]):
def _check_no_business_logic_in_endpoints(
self, file_path: Path, content: str, lines: List[str]
):
"""API-002: Ensure no business logic in endpoints"""
rule = self._get_rule("API-002")
if not rule:
return
anti_patterns = [
(r'db\.add\(', "Database operations should be in service layer"),
(r'db\.commit\(\)', "Database commits should be in service layer"),
(r'db\.query\(', "Database queries should be in service layer"),
(r'db\.execute\(', "Database operations should be in service layer"),
(r"db\.add\(", "Database operations should be in service layer"),
(r"db\.commit\(\)", "Database commits should be in service layer"),
(r"db\.query\(", "Database queries should be in service layer"),
(r"db\.execute\(", "Database operations should be in service layer"),
]
for i, line in enumerate(lines, 1):
# Skip service method calls (allowed)
if '_service.' in line or 'service.' in line:
if "_service." in line or "service." in line:
continue
for pattern, message in anti_patterns:
if re.search(pattern, line):
self._add_violation(
rule_id="API-002",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message=message,
context=line.strip(),
suggestion="Move database operations to service layer"
suggestion="Move database operations to service layer",
)
def _check_endpoint_exception_handling(self, file_path: Path, content: str, lines: List[str]):
def _check_endpoint_exception_handling(
self, file_path: Path, content: str, lines: List[str]
):
"""API-003: Check proper exception handling in endpoints"""
rule = self._get_rule("API-003")
if not rule:
@@ -222,40 +230,41 @@ class ArchitectureValidator:
if isinstance(node, ast.FunctionDef):
# Check if it's a route handler
has_router_decorator = any(
isinstance(d, ast.Call) and
isinstance(d.func, ast.Attribute) and
getattr(d.func.value, 'id', None) == 'router'
isinstance(d, ast.Call)
and isinstance(d.func, ast.Attribute)
and getattr(d.func.value, "id", None) == "router"
for d in node.decorator_list
)
if has_router_decorator:
# Check if function body has try/except
has_try_except = any(
isinstance(child, ast.Try)
for child in ast.walk(node)
isinstance(child, ast.Try) for child in ast.walk(node)
)
# Check if function calls service methods
has_service_call = any(
isinstance(child, ast.Call) and
isinstance(child.func, ast.Attribute) and
'service' in getattr(child.func.value, 'id', '').lower()
isinstance(child, ast.Call)
and isinstance(child.func, ast.Attribute)
and "service" in getattr(child.func.value, "id", "").lower()
for child in ast.walk(node)
)
if has_service_call and not has_try_except:
self._add_violation(
rule_id="API-003",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.WARNING,
file_path=file_path,
line_number=node.lineno,
message=f"Endpoint '{node.name}' calls service but lacks exception handling",
context=f"def {node.name}(...)",
suggestion="Wrap service calls in try/except and convert to HTTPException"
suggestion="Wrap service calls in try/except and convert to HTTPException",
)
def _check_endpoint_authentication(self, file_path: Path, content: str, lines: List[str]):
def _check_endpoint_authentication(
self, file_path: Path, content: str, lines: List[str]
):
"""API-004: Check authentication on endpoints"""
rule = self._get_rule("API-004")
if not rule:
@@ -264,24 +273,28 @@ class ArchitectureValidator:
# This is a warning-level check
# Look for endpoints without Depends(get_current_*)
for i, line in enumerate(lines, 1):
if '@router.' in line and ('post' in line or 'put' in line or 'delete' in line):
if "@router." in line and (
"post" in line or "put" in line or "delete" in line
):
# Check next 5 lines for auth
has_auth = False
for j in range(i, min(i + 5, len(lines))):
if 'Depends(get_current_' in lines[j]:
if "Depends(get_current_" in lines[j]:
has_auth = True
break
if not has_auth and 'include_in_schema=False' not in ' '.join(lines[i:i+5]):
if not has_auth and "include_in_schema=False" not in " ".join(
lines[i : i + 5]
):
self._add_violation(
rule_id="API-004",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.WARNING,
file_path=file_path,
line_number=i,
message="Endpoint may be missing authentication",
context=line.strip(),
suggestion="Add Depends(get_current_user) or similar if endpoint should be protected"
suggestion="Add Depends(get_current_user) or similar if endpoint should be protected",
)
def _validate_service_layer(self, target_path: Path):
@@ -296,7 +309,7 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
lines = content.split('\n')
lines = content.split("\n")
# SVC-001: No HTTPException in services
self._check_no_http_exception_in_services(file_path, content, lines)
@@ -307,38 +320,45 @@ class ArchitectureValidator:
# SVC-003: DB session as parameter
self._check_db_session_parameter(file_path, content, lines)
def _check_no_http_exception_in_services(self, file_path: Path, content: str, lines: List[str]):
def _check_no_http_exception_in_services(
self, file_path: Path, content: str, lines: List[str]
):
"""SVC-001: Services must not raise HTTPException"""
rule = self._get_rule("SVC-001")
if not rule:
return
for i, line in enumerate(lines, 1):
if 'raise HTTPException' in line:
if "raise HTTPException" in line:
self._add_violation(
rule_id="SVC-001",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message="Service raises HTTPException - use domain exceptions instead",
context=line.strip(),
suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that"
suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that",
)
if 'from fastapi import HTTPException' in line or 'from fastapi.exceptions import HTTPException' in line:
if (
"from fastapi import HTTPException" in line
or "from fastapi.exceptions import HTTPException" in line
):
self._add_violation(
rule_id="SVC-001",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message="Service imports HTTPException - services should not know about HTTP",
context=line.strip(),
suggestion="Remove HTTPException import and use domain exceptions"
suggestion="Remove HTTPException import and use domain exceptions",
)
def _check_service_exceptions(self, file_path: Path, content: str, lines: List[str]):
def _check_service_exceptions(
self, file_path: Path, content: str, lines: List[str]
):
"""SVC-002: Check for proper exception handling"""
rule = self._get_rule("SVC-002")
if not rule:
@@ -346,19 +366,21 @@ class ArchitectureValidator:
for i, line in enumerate(lines, 1):
# Check for generic Exception raises
if re.match(r'\s*raise Exception\(', line):
if re.match(r"\s*raise Exception\(", line):
self._add_violation(
rule_id="SVC-002",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.WARNING,
file_path=file_path,
line_number=i,
message="Service raises generic Exception - use specific domain exception",
context=line.strip(),
suggestion="Create custom exception class for this error case"
suggestion="Create custom exception class for this error case",
)
def _check_db_session_parameter(self, file_path: Path, content: str, lines: List[str]):
def _check_db_session_parameter(
self, file_path: Path, content: str, lines: List[str]
):
"""SVC-003: Service methods should accept db session as parameter"""
rule = self._get_rule("SVC-003")
if not rule:
@@ -366,16 +388,16 @@ class ArchitectureValidator:
# Check for SessionLocal() creation in service files
for i, line in enumerate(lines, 1):
if 'SessionLocal()' in line and 'class' not in line:
if "SessionLocal()" in line and "class" not in line:
self._add_violation(
rule_id="SVC-003",
rule_name=rule['name'],
rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message="Service creates database session internally",
context=line.strip(),
suggestion="Accept db: Session as method parameter instead"
suggestion="Accept db: Session as method parameter instead",
)
def _validate_models(self, target_path: Path):
@@ -391,11 +413,11 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
lines = content.split('\n')
lines = content.split("\n")
# Check for mixing SQLAlchemy and Pydantic
for i, line in enumerate(lines, 1):
if re.search(r'class.*\(Base.*,.*BaseModel.*\)', line):
if re.search(r"class.*\(Base.*,.*BaseModel.*\)", line):
self._add_violation(
rule_id="MDL-002",
rule_name="Separate SQLAlchemy and Pydantic models",
@@ -404,7 +426,7 @@ class ArchitectureValidator:
line_number=i,
message="Model mixes SQLAlchemy Base and Pydantic BaseModel",
context=line.strip(),
suggestion="Keep SQLAlchemy models and Pydantic models separate"
suggestion="Keep SQLAlchemy models and Pydantic models separate",
)
def _validate_exceptions(self, target_path: Path):
@@ -418,11 +440,11 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
lines = content.split('\n')
lines = content.split("\n")
# EXC-002: Check for bare except
for i, line in enumerate(lines, 1):
if re.match(r'\s*except\s*:', line):
if re.match(r"\s*except\s*:", line):
self._add_violation(
rule_id="EXC-002",
rule_name="No bare except clauses",
@@ -431,7 +453,7 @@ class ArchitectureValidator:
line_number=i,
message="Bare except clause catches all exceptions including system exits",
context=line.strip(),
suggestion="Specify exception type: except ValueError: or except Exception:"
suggestion="Specify exception type: except ValueError: or except Exception:",
)
def _validate_javascript(self, target_path: Path):
@@ -443,11 +465,16 @@ class ArchitectureValidator:
for file_path in js_files:
content = file_path.read_text()
lines = content.split('\n')
lines = content.split("\n")
# JS-001: Check for window.apiClient
for i, line in enumerate(lines, 1):
if 'window.apiClient' in line and '//' not in line[:line.find('window.apiClient')] if 'window.apiClient' in line else True:
if (
"window.apiClient" in line
and "//" not in line[: line.find("window.apiClient")]
if "window.apiClient" in line
else True
):
self._add_violation(
rule_id="JS-001",
rule_name="Use apiClient directly",
@@ -456,14 +483,14 @@ class ArchitectureValidator:
line_number=i,
message="Use apiClient directly instead of window.apiClient",
context=line.strip(),
suggestion="Replace window.apiClient with apiClient"
suggestion="Replace window.apiClient with apiClient",
)
# JS-002: Check for console usage
for i, line in enumerate(lines, 1):
if re.search(r'console\.(log|warn|error)', line):
if re.search(r"console\.(log|warn|error)", line):
# Skip if it's a comment or bootstrap message
if '//' in line or '' in line or 'eslint-disable' in line:
if "//" in line or "" in line or "eslint-disable" in line:
continue
self._add_violation(
@@ -474,7 +501,7 @@ class ArchitectureValidator:
line_number=i,
message="Use centralized logger instead of console",
context=line.strip()[:80],
suggestion="Use window.LogConfig.createLogger('moduleName')"
suggestion="Use window.LogConfig.createLogger('moduleName')",
)
def _validate_templates(self, target_path: Path):
@@ -486,14 +513,16 @@ class ArchitectureValidator:
for file_path in template_files:
# Skip base template and partials
if 'base.html' in file_path.name or 'partials' in str(file_path):
if "base.html" in file_path.name or "partials" in str(file_path):
continue
content = file_path.read_text()
lines = content.split('\n')
lines = content.split("\n")
# TPL-001: Check for extends
has_extends = any('{% extends' in line and 'admin/base.html' in line for line in lines)
has_extends = any(
"{% extends" in line and "admin/base.html" in line for line in lines
)
if not has_extends:
self._add_violation(
@@ -504,23 +533,29 @@ class ArchitectureValidator:
line_number=1,
message="Admin template does not extend admin/base.html",
context=file_path.name,
suggestion="Add {% extends 'admin/base.html' %} at the top"
suggestion="Add {% extends 'admin/base.html' %} at the top",
)
def _get_rule(self, rule_id: str) -> Dict[str, Any]:
"""Get rule configuration by ID"""
# Look in different rule categories
for category in ['api_endpoint_rules', 'service_layer_rules', 'model_rules',
'exception_rules', 'javascript_rules', 'template_rules']:
for category in [
"api_endpoint_rules",
"service_layer_rules",
"model_rules",
"exception_rules",
"javascript_rules",
"template_rules",
]:
rules = self.config.get(category, [])
for rule in rules:
if rule.get('id') == rule_id:
if rule.get("id") == rule_id:
return rule
return None
def _should_ignore_file(self, file_path: Path) -> bool:
"""Check if file should be ignored"""
ignore_patterns = self.config.get('ignore', {}).get('files', [])
ignore_patterns = self.config.get("ignore", {}).get("files", [])
for pattern in ignore_patterns:
if file_path.match(pattern):
@@ -528,9 +563,17 @@ class ArchitectureValidator:
return False
def _add_violation(self, rule_id: str, rule_name: str, severity: Severity,
file_path: Path, line_number: int, message: str,
context: str = "", suggestion: str = ""):
def _add_violation(
self,
rule_id: str,
rule_name: str,
severity: Severity,
file_path: Path,
line_number: int,
message: str,
context: str = "",
suggestion: str = "",
):
"""Add a violation to results"""
violation = Violation(
rule_id=rule_id,
@@ -540,7 +583,7 @@ class ArchitectureValidator:
line_number=line_number,
message=message,
context=context,
suggestion=suggestion
suggestion=suggestion,
)
self.result.violations.append(violation)
@@ -590,24 +633,34 @@ class ArchitectureValidator:
violations_json = []
for v in self.result.violations:
rel_path = str(v.file_path.relative_to(self.project_root)) if self.project_root in v.file_path.parents else str(v.file_path)
violations_json.append({
'rule_id': v.rule_id,
'rule_name': v.rule_name,
'severity': v.severity.value,
'file_path': rel_path,
'line_number': v.line_number,
'message': v.message,
'context': v.context or '',
'suggestion': v.suggestion or ''
})
rel_path = (
str(v.file_path.relative_to(self.project_root))
if self.project_root in v.file_path.parents
else str(v.file_path)
)
violations_json.append(
{
"rule_id": v.rule_id,
"rule_name": v.rule_name,
"severity": v.severity.value,
"file_path": rel_path,
"line_number": v.line_number,
"message": v.message,
"context": v.context or "",
"suggestion": v.suggestion or "",
}
)
output = {
'files_checked': self.result.files_checked,
'total_violations': len(self.result.violations),
'errors': len([v for v in self.result.violations if v.severity == Severity.ERROR]),
'warnings': len([v for v in self.result.violations if v.severity == Severity.WARNING]),
'violations': violations_json
"files_checked": self.result.files_checked,
"total_violations": len(self.result.violations),
"errors": len(
[v for v in self.result.violations if v.severity == Severity.ERROR]
),
"warnings": len(
[v for v in self.result.violations if v.severity == Severity.WARNING]
),
"violations": violations_json,
}
print(json.dumps(output, indent=2))
@@ -616,7 +669,11 @@ class ArchitectureValidator:
def _print_violation(self, v: Violation):
"""Print a single violation"""
rel_path = v.file_path.relative_to(self.project_root) if self.project_root in v.file_path.parents else v.file_path
rel_path = (
v.file_path.relative_to(self.project_root)
if self.project_root in v.file_path.parents
else v.file_path
)
print(f"\n [{v.rule_id}] {v.rule_name}")
print(f" File: {rel_path}:{v.line_number}")
@@ -634,40 +691,40 @@ def main():
parser = argparse.ArgumentParser(
description="Validate architecture patterns in codebase",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
epilog=__doc__,
)
parser.add_argument(
'path',
nargs='?',
"path",
nargs="?",
type=Path,
default=Path.cwd(),
help="Path to validate (default: current directory)"
help="Path to validate (default: current directory)",
)
parser.add_argument(
'-c', '--config',
"-c",
"--config",
type=Path,
default=Path.cwd() / '.architecture-rules.yaml',
help="Path to architecture rules config (default: .architecture-rules.yaml)"
default=Path.cwd() / ".architecture-rules.yaml",
help="Path to architecture rules config (default: .architecture-rules.yaml)",
)
parser.add_argument(
'-v', '--verbose',
action='store_true',
help="Show detailed output including context"
"-v",
"--verbose",
action="store_true",
help="Show detailed output including context",
)
parser.add_argument(
'--errors-only',
action='store_true',
help="Only show errors, suppress warnings"
"--errors-only", action="store_true", help="Only show errors, suppress warnings"
)
parser.add_argument(
'--json',
action='store_true',
help="Output results as JSON (for programmatic use)"
"--json",
action="store_true",
help="Output results as JSON (for programmatic use)",
)
args = parser.parse_args()
@@ -687,5 +744,5 @@ def main():
sys.exit(exit_code)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -3,10 +3,12 @@
import os
import sys
from urllib.parse import urlparse
from sqlalchemy import create_engine, text
from alembic import command
from alembic.config import Config
from urllib.parse import urlparse
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
@@ -20,10 +22,10 @@ def get_database_info():
parsed = urlparse(db_url)
db_type = parsed.scheme
if db_type == 'sqlite':
if db_type == "sqlite":
# Extract path from sqlite:///./path or sqlite:///path
db_path = db_url.replace('sqlite:///', '')
if db_path.startswith('./'):
db_path = db_url.replace("sqlite:///", "")
if db_path.startswith("./"):
db_path = db_path[2:]
return db_type, db_path
else:
@@ -40,7 +42,7 @@ def verify_database_setup():
db_type, db_path = get_database_info()
print(f"[INFO] Database type: {db_type}")
if db_type == 'sqlite':
if db_type == "sqlite":
if not os.path.exists(db_path):
print(f"[ERROR] Database file not found: {db_path}")
return False
@@ -54,10 +56,14 @@ def verify_database_setup():
with engine.connect() as conn:
# Get table list (works for both SQLite and PostgreSQL)
if db_type == 'sqlite':
result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
if db_type == "sqlite":
result = conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table'")
)
else:
result = conn.execute(text("SELECT tablename FROM pg_tables WHERE schemaname='public'"))
result = conn.execute(
text("SELECT tablename FROM pg_tables WHERE schemaname='public'")
)
tables = [row[0] for row in result.fetchall()]
@@ -65,12 +71,17 @@ def verify_database_setup():
# Expected tables from your models
expected_tables = [
'users', 'products', 'inventory', 'vendors', 'products',
'marketplace_import_jobs', 'alembic_version'
"users",
"products",
"inventory",
"vendors",
"products",
"marketplace_import_jobs",
"alembic_version",
]
for table in sorted(tables):
if table == 'alembic_version':
if table == "alembic_version":
print(f" * {table} (migration tracking)")
elif table in expected_tables:
print(f" * {table}")
@@ -78,12 +89,12 @@ def verify_database_setup():
print(f" ? {table} (unexpected)")
# Check for missing expected tables
missing_tables = set(expected_tables) - set(tables) - {'alembic_version'}
missing_tables = set(expected_tables) - set(tables) - {"alembic_version"}
if missing_tables:
print(f"[WARNING] Missing expected tables: {missing_tables}")
# Check if alembic_version table exists
if 'alembic_version' in tables:
if "alembic_version" in tables:
result = conn.execute(text("SELECT version_num FROM alembic_version"))
version = result.fetchone()
if version:
@@ -126,16 +137,19 @@ def verify_model_structure():
# Test database models
try:
from models.database.base import Base
print(f"[OK] Database Base imported")
print(f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}")
print(
f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}"
)
# Import specific models
from models.database.user import User
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from models.database.vendor import Vendor
from models.database.product import Product
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Vendor
print("[OK] All database models imported successfully")
@@ -146,13 +160,23 @@ def verify_model_structure():
# Test API models
try:
import models.schema
print("[OK] API models package imported")
# Test specific API model imports
api_modules = ['base', 'auth', 'product', 'inventory', 'vendor ', 'marketplace', 'admin', 'stats']
api_modules = [
"base",
"auth",
"product",
"inventory",
"vendor ",
"marketplace",
"admin",
"stats",
]
for module in api_modules:
try:
__import__(f'models.api.{module}')
__import__(f"models.api.{module}")
print(f" * models.api.{module}")
except ImportError:
print(f" ? models.api.{module} (not found, optional)")
@@ -176,7 +200,7 @@ def check_project_structure():
"models/database/inventory.py",
"app/core/config.py",
"alembic/env.py",
"alembic.ini"
"alembic.ini",
]
for path in critical_paths:
@@ -189,7 +213,7 @@ def check_project_structure():
init_files = [
"models/__init__.py",
"models/database/__init__.py",
"models/api/__init__.py"
"models/api/__init__.py",
]
print(f"\n[INIT] Checking __init__.py files...")
@@ -221,7 +245,9 @@ if __name__ == "__main__":
print("Next steps:")
print(" 1. Run 'make dev' to start your FastAPI server")
print(" 2. Visit http://localhost:8000/docs for interactive API docs")
print(" 3. Use your API endpoints for authentication, products, inventory, etc.")
print(
" 3. Use your API endpoints for authentication, products, inventory, etc."
)
sys.exit(0)
else:
print()