From 21c13ca39b67c28d70db123382717043812e881d Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Fri, 28 Nov 2025 19:30:17 +0100 Subject: [PATCH] style: apply black and isort formatting across entire codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Standardize quote style (single to double quotes) - Reorder and group imports alphabetically - Fix line breaks and indentation for consistency - Apply PEP 8 formatting standards Also updated Makefile to exclude both venv and .venv from code quality checks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- Makefile | 12 +- alembic/env.py | 18 +- ...51b2e50581_initial_migration_all_tables.py | 1388 ++++++++++------- ...07_ensure_content_pages_table_with_all_.py | 87 +- ...dd_architecture_quality_tracking_tables.py | 348 +++-- .../a2064e1dfcd4_add_cart_items_table.py | 63 +- ...dd_template_field_to_content_pages_for_.py | 16 +- .../fa7d4d10e358_add_rbac_enhancements.py | 95 +- ...20ce8b4_add_content_pages_table_for_cms.py | 31 +- app/api/deps.py | 193 ++- app/api/main.py | 22 +- app/api/v1/__init__.py | 4 +- app/api/v1/admin/__init__.py | 26 +- app/api/v1/admin/audit.py | 68 +- app/api/v1/admin/auth.py | 15 +- app/api/v1/admin/code_quality.py | 161 +- app/api/v1/admin/content_pages.py | 65 +- app/api/v1/admin/dashboard.py | 3 +- app/api/v1/admin/marketplace.py | 2 +- app/api/v1/admin/monitoring.py | 2 +- app/api/v1/admin/notifications.py | 36 +- app/api/v1/admin/settings.py | 93 +- app/api/v1/admin/users.py | 2 +- app/api/v1/admin/vendor_domains.py | 93 +- app/api/v1/admin/vendor_themes.py | 46 +- app/api/v1/admin/vendors.py | 116 +- app/api/v1/shared/health.py | 2 +- app/api/v1/shared/uploads.py | 2 +- app/api/v1/shop/__init__.py | 6 +- app/api/v1/shop/auth.py | 114 +- app/api/v1/shop/cart.py | 81 +- app/api/v1/shop/content_pages.py | 26 +- app/api/v1/shop/orders.py | 72 +- app/api/v1/shop/products.py | 44 +- app/api/v1/vendor/__init__.py | 28 +- app/api/v1/vendor/analytics.py | 11 +- app/api/v1/vendor/auth.py | 77 +- app/api/v1/vendor/content_pages.py | 82 +- app/api/v1/vendor/customers.py | 79 +- app/api/v1/vendor/dashboard.py | 28 +- app/api/v1/vendor/info.py | 23 +- app/api/v1/vendor/inventory.py | 104 +- app/api/v1/vendor/marketplace.py | 48 +- app/api/v1/vendor/media.py | 104 +- app/api/v1/vendor/notifications.py | 119 +- app/api/v1/vendor/orders.py | 59 +- app/api/v1/vendor/payments.py | 86 +- app/api/v1/vendor/products.py | 124 +- app/api/v1/vendor/profile.py | 19 +- app/api/v1/vendor/settings.py | 17 +- app/api/v1/vendor/team.py | 193 +-- app/core/config.py | 39 +- app/core/environment.py | 25 +- app/core/lifespan.py | 9 +- app/core/permissions.py | 13 +- app/core/theme_presets.py | 112 +- app/exceptions/__init__.py | 285 ++-- app/exceptions/admin.py | 71 +- app/exceptions/auth.py | 16 +- app/exceptions/backup.py | 2 +- app/exceptions/base.py | 9 +- app/exceptions/cart.py | 24 +- app/exceptions/customer.py | 37 +- app/exceptions/error_renderer.py | 23 +- app/exceptions/handler.py | 109 +- app/exceptions/inventory.py | 34 +- app/exceptions/marketplace.py | 2 +- app/exceptions/marketplace_import_job.py | 51 +- app/exceptions/marketplace_product.py | 28 +- app/exceptions/media.py | 2 +- app/exceptions/monitoring.py | 2 +- app/exceptions/notification.py | 2 +- app/exceptions/order.py | 26 +- app/exceptions/payment.py | 2 +- app/exceptions/product.py | 9 +- app/exceptions/search.py | 2 +- app/exceptions/team.py | 61 +- app/exceptions/vendor.py | 33 +- app/exceptions/vendor_domain.py | 46 +- app/exceptions/vendor_theme.py | 39 +- app/models/architecture_scan.py | 92 +- app/routes/admin_pages.py | 261 ++-- app/routes/shop_pages.py | 313 ++-- app/routes/vendor_pages.py | 169 +- app/services/admin_audit_service.py | 54 +- app/services/admin_service.py | 246 +-- app/services/admin_settings_service.py | 114 +- app/services/audit_service.py | 2 +- app/services/auth_service.py | 22 +- app/services/backup_service.py | 2 +- app/services/cache_service.py | 2 +- app/services/cart_service.py | 266 ++-- app/services/code_quality_service.py | 287 ++-- app/services/configuration_service.py | 2 +- app/services/content_page_service.py | 57 +- app/services/customer_service.py | 190 ++- app/services/inventory_service.py | 139 +- .../marketplace_import_job_service.py | 25 +- app/services/marketplace_product_service.py | 195 ++- app/services/media_service.py | 2 +- app/services/monitoring_service.py | 2 +- app/services/notification_service.py | 2 +- app/services/order_service.py | 172 +- app/services/payment_service.py | 2 +- app/services/product_service.py | 71 +- app/services/search_service.py | 2 +- app/services/stats_service.py | 232 +-- app/services/team_service.py | 81 +- app/services/vendor_domain_service.py | 168 +- app/services/vendor_service.py | 152 +- app/services/vendor_team_service.py | 208 +-- app/services/vendor_theme_service.py | 161 +- app/tasks/background_tasks.py | 6 +- app/utils/csv_processor.py | 11 +- app/utils/data_processing.py | 4 +- app/utils/database.py | 1 + main.py | 142 +- middleware/auth.py | 35 +- middleware/context.py | 35 +- middleware/decorators.py | 7 +- middleware/theme_context.py | 39 +- middleware/vendor_context.py | 159 +- models/__init__.py | 17 +- models/database/__init__.py | 17 +- models/database/admin.py | 40 +- models/database/audit.py | 2 +- models/database/backup.py | 2 +- models/database/base.py | 5 +- models/database/cart.py | 5 +- models/database/configuration.py | 2 +- models/database/content_page.py | 45 +- models/database/customer.py | 14 +- models/database/inventory.py | 8 +- models/database/marketplace.py | 2 +- models/database/marketplace_import_job.py | 3 +- models/database/marketplace_product.py | 4 +- models/database/media.py | 2 +- models/database/monitoring.py | 2 +- models/database/notification.py | 2 +- models/database/order.py | 26 +- models/database/payment.py | 2 +- models/database/product.py | 12 +- models/database/search.py | 2 +- models/database/task.py | 2 +- models/database/user.py | 17 +- models/database/vendor.py | 118 +- models/database/vendor_domain.py | 23 +- models/database/vendor_theme.py | 41 +- models/schema/__init__.py | 9 +- models/schema/admin.py | 75 +- models/schema/auth.py | 8 +- models/schema/cart.py | 24 +- models/schema/customer.py | 27 +- models/schema/inventory.py | 10 +- models/schema/marketplace.py | 2 +- models/schema/marketplace_import_job.py | 9 +- models/schema/marketplace_product.py | 12 +- models/schema/media.py | 2 +- models/schema/monitoring.py | 2 +- models/schema/notification.py | 2 +- models/schema/order.py | 20 +- models/schema/payment.py | 2 +- models/schema/product.py | 13 +- models/schema/search.py | 2 +- models/schema/stats.py | 11 +- models/schema/team.py | 90 +- models/schema/vendor.py | 62 +- models/schema/vendor_domain.py | 25 +- models/schema/vendor_theme.py | 40 +- scripts/backup_database.py | 10 +- scripts/create_default_content_pages.py | 18 +- scripts/create_inventory.py | 42 +- scripts/create_landing_page.py | 45 +- scripts/create_platform_pages.py | 144 +- scripts/init_production.py | 25 +- scripts/route_diagnostics.py | 28 +- scripts/seed_demo.py | 129 +- scripts/show_structure.py | 317 ++-- scripts/test_auth_complete.py | 198 +-- scripts/test_vendor_management.py | 41 +- scripts/validate_architecture.py | 271 ++-- scripts/verify_setup.py | 72 +- storage/backends.py | 2 +- storage/utils.py | 2 +- tasks/analytics_tasks.py | 2 +- tasks/backup_tasks.py | 2 +- tasks/cleanup_tasks.py | 2 +- tasks/email_tasks.py | 2 +- tasks/marketplace_import.py | 8 +- tasks/media_processing.py | 2 +- tasks/search_indexing.py | 2 +- tasks/task_manager.py | 2 +- tests/conftest.py | 4 +- tests/fixtures/auth_fixtures.py | 2 + tests/fixtures/customer_fixtures.py | 1 + .../marketplace_import_job_fixtures.py | 1 + .../fixtures/marketplace_product_fixtures.py | 8 +- tests/fixtures/testing_fixtures.py | 5 +- tests/fixtures/vendor_fixtures.py | 6 +- .../api/v1/test_admin_endpoints.py | 25 +- .../integration/api/v1/test_auth_endpoints.py | 14 +- tests/integration/api/v1/test_filtering.py | 97 +- .../api/v1/test_inventory_endpoints.py | 85 +- .../test_marketplace_import_job_endpoints.py | 125 +- .../api/v1/test_marketplace_product_export.py | 101 +- .../v1/test_marketplace_products_endpoints.py | 88 +- tests/integration/api/v1/test_pagination.py | 102 +- .../api/v1/test_stats_endpoints.py | 5 +- .../api/v1/test_vendor_endpoints.py | 83 +- .../api/v1/vendor/test_authentication.py | 98 +- .../api/v1/vendor/test_dashboard.py | 30 +- tests/integration/middleware/conftest.py | 22 +- .../middleware/test_context_detection_flow.py | 267 +++- .../middleware/test_middleware_stack.py | 166 +- .../middleware/test_theme_loading_flow.py | 218 ++- .../middleware/test_vendor_context_flow.py | 243 ++- .../security/test_input_validation.py | 19 +- .../integration/workflows/test_integration.py | 16 +- tests/performance/test_api_performance.py | 11 +- tests/system/test_error_handling.py | 107 +- tests/unit/middleware/test_auth.py | 100 +- tests/unit/middleware/test_context.py | 48 +- tests/unit/middleware/test_decorators.py | 30 +- tests/unit/middleware/test_logging.py | 39 +- tests/unit/middleware/test_rate_limiter.py | 25 +- tests/unit/middleware/test_theme_context.py | 43 +- tests/unit/middleware/test_vendor_context.py | 207 ++- tests/unit/models/test_database_models.py | 27 +- tests/unit/services/test_admin_service.py | 46 +- tests/unit/services/test_auth_service.py | 22 +- tests/unit/services/test_inventory_service.py | 133 +- .../unit/services/test_marketplace_service.py | 85 +- tests/unit/services/test_product_service.py | 98 +- tests/unit/services/test_stats_service.py | 34 +- tests/unit/services/test_vendor_service.py | 75 +- tests/unit/utils/test_csv_processor.py | 8 +- 236 files changed, 8450 insertions(+), 6545 deletions(-) diff --git a/Makefile b/Makefile index 75b36c8f..d201da6a 100644 --- a/Makefile +++ b/Makefile @@ -176,19 +176,19 @@ test-inventory: format: @echo "Running black..." - $(PYTHON) -m black . --exclude venv + $(PYTHON) -m black . --exclude '/(\.)?venv/' @echo "Running isort..." - $(PYTHON) -m isort . --skip venv + $(PYTHON) -m isort . --skip venv --skip .venv lint: @echo "Running linting..." - $(PYTHON) -m ruff check . --exclude venv - $(PYTHON) -m mypy . --ignore-missing-imports --exclude venv + $(PYTHON) -m ruff check . --exclude venv --exclude .venv + $(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*' lint-flake8: @echo "Running linting..." - $(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,__pycache__,.git - $(PYTHON) -m mypy . --ignore-missing-imports --exclude venv + $(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,.venv,__pycache__,.git + $(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*' check: format lint diff --git a/alembic/env.py b/alembic/env.py index ae52fcdf..8605ecc3 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -15,6 +15,7 @@ import sys from logging.config import fileConfig from sqlalchemy import engine_from_config, pool + from alembic import context # Add your project directory to the Python path @@ -39,13 +40,9 @@ print("=" * 70) # ADMIN MODELS # ---------------------------------------------------------------------------- try: - from models.database.admin import ( - AdminAuditLog, - AdminNotification, - AdminSetting, - PlatformAlert, - AdminSession - ) + from models.database.admin import (AdminAuditLog, AdminNotification, + AdminSession, AdminSetting, + PlatformAlert) print(" ✓ Admin models imported (5 models)") print(" - AdminAuditLog") @@ -70,7 +67,7 @@ except ImportError as e: # VENDOR MODELS # ---------------------------------------------------------------------------- try: - from models.database.vendor import Vendor, VendorUser, Role + from models.database.vendor import Role, Vendor, VendorUser print(" ✓ Vendor models imported (3 models)") print(" - Vendor") @@ -248,10 +245,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, - target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/alembic/versions/4951b2e50581_initial_migration_all_tables.py b/alembic/versions/4951b2e50581_initial_migration_all_tables.py index 36b6fe27..bdf730a1 100644 --- a/alembic/versions/4951b2e50581_initial_migration_all_tables.py +++ b/alembic/versions/4951b2e50581_initial_migration_all_tables.py @@ -1,18 +1,19 @@ """Initial migration - all tables Revision ID: 4951b2e50581 -Revises: +Revises: Create Date: 2025-10-27 22:28:33.137564 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '4951b2e50581' +revision: str = "4951b2e50581" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -20,549 +21,888 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('marketplace_products', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('marketplace_product_id', sa.String(), nullable=False), - sa.Column('title', sa.String(), nullable=False), - sa.Column('description', sa.String(), nullable=True), - sa.Column('link', sa.String(), nullable=True), - sa.Column('image_link', sa.String(), nullable=True), - sa.Column('availability', sa.String(), nullable=True), - sa.Column('price', sa.String(), nullable=True), - sa.Column('brand', sa.String(), nullable=True), - sa.Column('gtin', sa.String(), nullable=True), - sa.Column('mpn', sa.String(), nullable=True), - sa.Column('condition', sa.String(), nullable=True), - sa.Column('adult', sa.String(), nullable=True), - sa.Column('multipack', sa.Integer(), nullable=True), - sa.Column('is_bundle', sa.String(), nullable=True), - sa.Column('age_group', sa.String(), nullable=True), - sa.Column('color', sa.String(), nullable=True), - sa.Column('gender', sa.String(), nullable=True), - sa.Column('material', sa.String(), nullable=True), - sa.Column('pattern', sa.String(), nullable=True), - sa.Column('size', sa.String(), nullable=True), - sa.Column('size_type', sa.String(), nullable=True), - sa.Column('size_system', sa.String(), nullable=True), - sa.Column('item_group_id', sa.String(), nullable=True), - sa.Column('google_product_category', sa.String(), nullable=True), - sa.Column('product_type', sa.String(), nullable=True), - sa.Column('custom_label_0', sa.String(), nullable=True), - sa.Column('custom_label_1', sa.String(), nullable=True), - sa.Column('custom_label_2', sa.String(), nullable=True), - sa.Column('custom_label_3', sa.String(), nullable=True), - sa.Column('custom_label_4', sa.String(), nullable=True), - sa.Column('additional_image_link', sa.String(), nullable=True), - sa.Column('sale_price', sa.String(), nullable=True), - sa.Column('unit_pricing_measure', sa.String(), nullable=True), - sa.Column('unit_pricing_base_measure', sa.String(), nullable=True), - sa.Column('identifier_exists', sa.String(), nullable=True), - sa.Column('shipping', sa.String(), nullable=True), - sa.Column('currency', sa.String(), nullable=True), - sa.Column('marketplace', sa.String(), nullable=True), - sa.Column('vendor_name', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "marketplace_products", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("marketplace_product_id", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("link", sa.String(), nullable=True), + sa.Column("image_link", sa.String(), nullable=True), + sa.Column("availability", sa.String(), nullable=True), + sa.Column("price", sa.String(), nullable=True), + sa.Column("brand", sa.String(), nullable=True), + sa.Column("gtin", sa.String(), nullable=True), + sa.Column("mpn", sa.String(), nullable=True), + sa.Column("condition", sa.String(), nullable=True), + sa.Column("adult", sa.String(), nullable=True), + sa.Column("multipack", sa.Integer(), nullable=True), + sa.Column("is_bundle", sa.String(), nullable=True), + sa.Column("age_group", sa.String(), nullable=True), + sa.Column("color", sa.String(), nullable=True), + sa.Column("gender", sa.String(), nullable=True), + sa.Column("material", sa.String(), nullable=True), + sa.Column("pattern", sa.String(), nullable=True), + sa.Column("size", sa.String(), nullable=True), + sa.Column("size_type", sa.String(), nullable=True), + sa.Column("size_system", sa.String(), nullable=True), + sa.Column("item_group_id", sa.String(), nullable=True), + sa.Column("google_product_category", sa.String(), nullable=True), + sa.Column("product_type", sa.String(), nullable=True), + sa.Column("custom_label_0", sa.String(), nullable=True), + sa.Column("custom_label_1", sa.String(), nullable=True), + sa.Column("custom_label_2", sa.String(), nullable=True), + sa.Column("custom_label_3", sa.String(), nullable=True), + sa.Column("custom_label_4", sa.String(), nullable=True), + sa.Column("additional_image_link", sa.String(), nullable=True), + sa.Column("sale_price", sa.String(), nullable=True), + sa.Column("unit_pricing_measure", sa.String(), nullable=True), + sa.Column("unit_pricing_base_measure", sa.String(), nullable=True), + sa.Column("identifier_exists", sa.String(), nullable=True), + sa.Column("shipping", sa.String(), nullable=True), + sa.Column("currency", sa.String(), nullable=True), + sa.Column("marketplace", sa.String(), nullable=True), + sa.Column("vendor_name", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('idx_marketplace_brand', 'marketplace_products', ['marketplace', 'brand'], unique=False) - op.create_index('idx_marketplace_vendor', 'marketplace_products', ['marketplace', 'vendor_name'], unique=False) - op.create_index(op.f('ix_marketplace_products_availability'), 'marketplace_products', ['availability'], unique=False) - op.create_index(op.f('ix_marketplace_products_brand'), 'marketplace_products', ['brand'], unique=False) - op.create_index(op.f('ix_marketplace_products_google_product_category'), 'marketplace_products', ['google_product_category'], unique=False) - op.create_index(op.f('ix_marketplace_products_gtin'), 'marketplace_products', ['gtin'], unique=False) - op.create_index(op.f('ix_marketplace_products_id'), 'marketplace_products', ['id'], unique=False) - op.create_index(op.f('ix_marketplace_products_marketplace'), 'marketplace_products', ['marketplace'], unique=False) - op.create_index(op.f('ix_marketplace_products_marketplace_product_id'), 'marketplace_products', ['marketplace_product_id'], unique=True) - op.create_index(op.f('ix_marketplace_products_vendor_name'), 'marketplace_products', ['vendor_name'], unique=False) - op.create_table('users', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('email', sa.String(), nullable=False), - sa.Column('username', sa.String(), nullable=False), - sa.Column('first_name', sa.String(), nullable=True), - sa.Column('last_name', sa.String(), nullable=True), - sa.Column('hashed_password', sa.String(), nullable=False), - sa.Column('role', sa.String(), nullable=False), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('last_login', sa.DateTime(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_index( + "idx_marketplace_brand", + "marketplace_products", + ["marketplace", "brand"], + unique=False, ) - op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) - op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) - op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) - op.create_table('admin_audit_logs', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('admin_user_id', sa.Integer(), nullable=False), - sa.Column('action', sa.String(length=100), nullable=False), - sa.Column('target_type', sa.String(length=50), nullable=False), - sa.Column('target_id', sa.String(length=100), nullable=False), - sa.Column('details', sa.JSON(), nullable=True), - sa.Column('ip_address', sa.String(length=45), nullable=True), - sa.Column('user_agent', sa.Text(), nullable=True), - sa.Column('request_id', sa.String(length=100), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['admin_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + "idx_marketplace_vendor", + "marketplace_products", + ["marketplace", "vendor_name"], + unique=False, ) - op.create_index(op.f('ix_admin_audit_logs_action'), 'admin_audit_logs', ['action'], unique=False) - op.create_index(op.f('ix_admin_audit_logs_admin_user_id'), 'admin_audit_logs', ['admin_user_id'], unique=False) - op.create_index(op.f('ix_admin_audit_logs_id'), 'admin_audit_logs', ['id'], unique=False) - op.create_index(op.f('ix_admin_audit_logs_target_id'), 'admin_audit_logs', ['target_id'], unique=False) - op.create_index(op.f('ix_admin_audit_logs_target_type'), 'admin_audit_logs', ['target_type'], unique=False) - op.create_table('admin_notifications', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('type', sa.String(length=50), nullable=False), - sa.Column('priority', sa.String(length=20), nullable=True), - sa.Column('title', sa.String(length=200), nullable=False), - sa.Column('message', sa.Text(), nullable=False), - sa.Column('is_read', sa.Boolean(), nullable=True), - sa.Column('read_at', sa.DateTime(), nullable=True), - sa.Column('read_by_user_id', sa.Integer(), nullable=True), - sa.Column('action_required', sa.Boolean(), nullable=True), - sa.Column('action_url', sa.String(length=500), nullable=True), - sa.Column('notification_metadata', sa.JSON(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['read_by_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_availability"), + "marketplace_products", + ["availability"], + unique=False, ) - op.create_index(op.f('ix_admin_notifications_action_required'), 'admin_notifications', ['action_required'], unique=False) - op.create_index(op.f('ix_admin_notifications_id'), 'admin_notifications', ['id'], unique=False) - op.create_index(op.f('ix_admin_notifications_is_read'), 'admin_notifications', ['is_read'], unique=False) - op.create_index(op.f('ix_admin_notifications_priority'), 'admin_notifications', ['priority'], unique=False) - op.create_index(op.f('ix_admin_notifications_type'), 'admin_notifications', ['type'], unique=False) - op.create_table('admin_sessions', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('admin_user_id', sa.Integer(), nullable=False), - sa.Column('session_token', sa.String(length=255), nullable=False), - sa.Column('ip_address', sa.String(length=45), nullable=False), - sa.Column('user_agent', sa.Text(), nullable=True), - sa.Column('login_at', sa.DateTime(), nullable=False), - sa.Column('last_activity_at', sa.DateTime(), nullable=False), - sa.Column('logout_at', sa.DateTime(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('logout_reason', sa.String(length=50), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['admin_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_brand"), + "marketplace_products", + ["brand"], + unique=False, ) - op.create_index(op.f('ix_admin_sessions_admin_user_id'), 'admin_sessions', ['admin_user_id'], unique=False) - op.create_index(op.f('ix_admin_sessions_id'), 'admin_sessions', ['id'], unique=False) - op.create_index(op.f('ix_admin_sessions_is_active'), 'admin_sessions', ['is_active'], unique=False) - op.create_index(op.f('ix_admin_sessions_login_at'), 'admin_sessions', ['login_at'], unique=False) - op.create_index(op.f('ix_admin_sessions_session_token'), 'admin_sessions', ['session_token'], unique=True) - op.create_table('admin_settings', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('key', sa.String(length=100), nullable=False), - sa.Column('value', sa.Text(), nullable=False), - sa.Column('value_type', sa.String(length=20), nullable=True), - sa.Column('category', sa.String(length=50), nullable=True), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('is_encrypted', sa.Boolean(), nullable=True), - sa.Column('is_public', sa.Boolean(), nullable=True), - sa.Column('last_modified_by_user_id', sa.Integer(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['last_modified_by_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_google_product_category"), + "marketplace_products", + ["google_product_category"], + unique=False, ) - op.create_index(op.f('ix_admin_settings_category'), 'admin_settings', ['category'], unique=False) - op.create_index(op.f('ix_admin_settings_id'), 'admin_settings', ['id'], unique=False) - op.create_index(op.f('ix_admin_settings_key'), 'admin_settings', ['key'], unique=True) - op.create_table('platform_alerts', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('alert_type', sa.String(length=50), nullable=False), - sa.Column('severity', sa.String(length=20), nullable=False), - sa.Column('title', sa.String(length=200), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('affected_vendors', sa.JSON(), nullable=True), - sa.Column('affected_systems', sa.JSON(), nullable=True), - sa.Column('is_resolved', sa.Boolean(), nullable=True), - sa.Column('resolved_at', sa.DateTime(), nullable=True), - sa.Column('resolved_by_user_id', sa.Integer(), nullable=True), - sa.Column('resolution_notes', sa.Text(), nullable=True), - sa.Column('auto_generated', sa.Boolean(), nullable=True), - sa.Column('occurrence_count', sa.Integer(), nullable=True), - sa.Column('first_occurred_at', sa.DateTime(), nullable=False), - sa.Column('last_occurred_at', sa.DateTime(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['resolved_by_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_gtin"), + "marketplace_products", + ["gtin"], + unique=False, ) - op.create_index(op.f('ix_platform_alerts_alert_type'), 'platform_alerts', ['alert_type'], unique=False) - op.create_index(op.f('ix_platform_alerts_id'), 'platform_alerts', ['id'], unique=False) - op.create_index(op.f('ix_platform_alerts_is_resolved'), 'platform_alerts', ['is_resolved'], unique=False) - op.create_index(op.f('ix_platform_alerts_severity'), 'platform_alerts', ['severity'], unique=False) - op.create_table('vendors', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_code', sa.String(), nullable=False), - sa.Column('subdomain', sa.String(length=100), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('owner_user_id', sa.Integer(), nullable=False), - sa.Column('contact_email', sa.String(), nullable=True), - sa.Column('contact_phone', sa.String(), nullable=True), - sa.Column('website', sa.String(), nullable=True), - sa.Column('letzshop_csv_url_fr', sa.String(), nullable=True), - sa.Column('letzshop_csv_url_en', sa.String(), nullable=True), - sa.Column('letzshop_csv_url_de', sa.String(), nullable=True), - sa.Column('business_address', sa.Text(), nullable=True), - sa.Column('tax_number', sa.String(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('is_verified', sa.Boolean(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_id"), "marketplace_products", ["id"], unique=False ) - op.create_index(op.f('ix_vendors_id'), 'vendors', ['id'], unique=False) - op.create_index(op.f('ix_vendors_subdomain'), 'vendors', ['subdomain'], unique=True) - op.create_index(op.f('ix_vendors_vendor_code'), 'vendors', ['vendor_code'], unique=True) - op.create_table('customers', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('email', sa.String(length=255), nullable=False), - sa.Column('hashed_password', sa.String(length=255), nullable=False), - sa.Column('first_name', sa.String(length=100), nullable=True), - sa.Column('last_name', sa.String(length=100), nullable=True), - sa.Column('phone', sa.String(length=50), nullable=True), - sa.Column('customer_number', sa.String(length=100), nullable=False), - sa.Column('preferences', sa.JSON(), nullable=True), - sa.Column('marketing_consent', sa.Boolean(), nullable=True), - sa.Column('last_order_date', sa.DateTime(), nullable=True), - sa.Column('total_orders', sa.Integer(), nullable=True), - sa.Column('total_spent', sa.Numeric(precision=10, scale=2), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_marketplace"), + "marketplace_products", + ["marketplace"], + unique=False, ) - op.create_index(op.f('ix_customers_customer_number'), 'customers', ['customer_number'], unique=False) - op.create_index(op.f('ix_customers_email'), 'customers', ['email'], unique=False) - op.create_index(op.f('ix_customers_id'), 'customers', ['id'], unique=False) - op.create_table('marketplace_import_jobs', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('marketplace', sa.String(), nullable=False), - sa.Column('source_url', sa.String(), nullable=False), - sa.Column('status', sa.String(), nullable=False), - sa.Column('imported_count', sa.Integer(), nullable=True), - sa.Column('updated_count', sa.Integer(), nullable=True), - sa.Column('error_count', sa.Integer(), nullable=True), - sa.Column('total_processed', sa.Integer(), nullable=True), - sa.Column('error_message', sa.Text(), nullable=True), - sa.Column('started_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_marketplace_products_marketplace_product_id"), + "marketplace_products", + ["marketplace_product_id"], + unique=True, ) - op.create_index('idx_import_user_marketplace', 'marketplace_import_jobs', ['user_id', 'marketplace'], unique=False) - op.create_index('idx_import_vendor_created', 'marketplace_import_jobs', ['vendor_id', 'created_at'], unique=False) - op.create_index('idx_import_vendor_status', 'marketplace_import_jobs', ['vendor_id', 'status'], unique=False) - op.create_index(op.f('ix_marketplace_import_jobs_id'), 'marketplace_import_jobs', ['id'], unique=False) - op.create_index(op.f('ix_marketplace_import_jobs_marketplace'), 'marketplace_import_jobs', ['marketplace'], unique=False) - op.create_index(op.f('ix_marketplace_import_jobs_vendor_id'), 'marketplace_import_jobs', ['vendor_id'], unique=False) - op.create_table('products', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('marketplace_product_id', sa.Integer(), nullable=False), - sa.Column('product_id', sa.String(), nullable=True), - sa.Column('price', sa.Float(), nullable=True), - sa.Column('sale_price', sa.Float(), nullable=True), - sa.Column('currency', sa.String(), nullable=True), - sa.Column('availability', sa.String(), nullable=True), - sa.Column('condition', sa.String(), nullable=True), - sa.Column('is_featured', sa.Boolean(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('display_order', sa.Integer(), nullable=True), - sa.Column('min_quantity', sa.Integer(), nullable=True), - sa.Column('max_quantity', sa.Integer(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['marketplace_product_id'], ['marketplace_products.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('vendor_id', 'marketplace_product_id', name='uq_product') + op.create_index( + op.f("ix_marketplace_products_vendor_name"), + "marketplace_products", + ["vendor_name"], + unique=False, ) - op.create_index('idx_product_active', 'products', ['vendor_id', 'is_active'], unique=False) - op.create_index('idx_product_featured', 'products', ['vendor_id', 'is_featured'], unique=False) - op.create_index(op.f('ix_products_id'), 'products', ['id'], unique=False) - op.create_table('roles', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=100), nullable=False), - sa.Column('permissions', sa.JSON(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("username", sa.String(), nullable=False), + sa.Column("first_name", sa.String(), nullable=True), + sa.Column("last_name", sa.String(), nullable=True), + sa.Column("hashed_password", sa.String(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("last_login", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False) - op.create_table('vendor_domains', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('domain', sa.String(length=255), nullable=False), - sa.Column('is_primary', sa.Boolean(), nullable=False), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('ssl_status', sa.String(length=50), nullable=True), - sa.Column('ssl_verified_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('verification_token', sa.String(length=100), nullable=True), - sa.Column('is_verified', sa.Boolean(), nullable=False), - sa.Column('verified_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('vendor_id', 'domain', name='uq_vendor_domain'), - sa.UniqueConstraint('verification_token') + op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) + op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False) + op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True) + op.create_table( + "admin_audit_logs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("admin_user_id", sa.Integer(), nullable=False), + sa.Column("action", sa.String(length=100), nullable=False), + sa.Column("target_type", sa.String(length=50), nullable=False), + sa.Column("target_id", sa.String(length=100), nullable=False), + sa.Column("details", sa.JSON(), nullable=True), + sa.Column("ip_address", sa.String(length=45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_id", sa.String(length=100), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["admin_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('idx_domain_active', 'vendor_domains', ['domain', 'is_active'], unique=False) - op.create_index('idx_vendor_primary', 'vendor_domains', ['vendor_id', 'is_primary'], unique=False) - op.create_index(op.f('ix_vendor_domains_domain'), 'vendor_domains', ['domain'], unique=True) - op.create_index(op.f('ix_vendor_domains_id'), 'vendor_domains', ['id'], unique=False) - op.create_table('vendor_themes', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('theme_name', sa.String(length=100), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('colors', sa.JSON(), nullable=True), - sa.Column('font_family_heading', sa.String(length=100), nullable=True), - sa.Column('font_family_body', sa.String(length=100), nullable=True), - sa.Column('logo_url', sa.String(length=500), nullable=True), - sa.Column('logo_dark_url', sa.String(length=500), nullable=True), - sa.Column('favicon_url', sa.String(length=500), nullable=True), - sa.Column('banner_url', sa.String(length=500), nullable=True), - sa.Column('layout_style', sa.String(length=50), nullable=True), - sa.Column('header_style', sa.String(length=50), nullable=True), - sa.Column('product_card_style', sa.String(length=50), nullable=True), - sa.Column('custom_css', sa.Text(), nullable=True), - sa.Column('social_links', sa.JSON(), nullable=True), - sa.Column('meta_title_template', sa.String(length=200), nullable=True), - sa.Column('meta_description', sa.Text(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('vendor_id') + op.create_index( + op.f("ix_admin_audit_logs_action"), "admin_audit_logs", ["action"], unique=False ) - op.create_index(op.f('ix_vendor_themes_id'), 'vendor_themes', ['id'], unique=False) - op.create_table('customer_addresses', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('customer_id', sa.Integer(), nullable=False), - sa.Column('address_type', sa.String(length=50), nullable=False), - sa.Column('first_name', sa.String(length=100), nullable=False), - sa.Column('last_name', sa.String(length=100), nullable=False), - sa.Column('company', sa.String(length=200), nullable=True), - sa.Column('address_line_1', sa.String(length=255), nullable=False), - sa.Column('address_line_2', sa.String(length=255), nullable=True), - sa.Column('city', sa.String(length=100), nullable=False), - sa.Column('postal_code', sa.String(length=20), nullable=False), - sa.Column('country', sa.String(length=100), nullable=False), - sa.Column('is_default', sa.Boolean(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['customer_id'], ['customers.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_admin_audit_logs_admin_user_id"), + "admin_audit_logs", + ["admin_user_id"], + unique=False, ) - op.create_index(op.f('ix_customer_addresses_id'), 'customer_addresses', ['id'], unique=False) - op.create_table('inventory', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('product_id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('location', sa.String(), nullable=False), - sa.Column('quantity', sa.Integer(), nullable=False), - sa.Column('reserved_quantity', sa.Integer(), nullable=True), - sa.Column('gtin', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['product_id'], ['products.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('product_id', 'location', name='uq_inventory_product_location') + op.create_index( + op.f("ix_admin_audit_logs_id"), "admin_audit_logs", ["id"], unique=False ) - op.create_index('idx_inventory_product_location', 'inventory', ['product_id', 'location'], unique=False) - op.create_index('idx_inventory_vendor_product', 'inventory', ['vendor_id', 'product_id'], unique=False) - op.create_index(op.f('ix_inventory_gtin'), 'inventory', ['gtin'], unique=False) - op.create_index(op.f('ix_inventory_id'), 'inventory', ['id'], unique=False) - op.create_index(op.f('ix_inventory_location'), 'inventory', ['location'], unique=False) - op.create_index(op.f('ix_inventory_product_id'), 'inventory', ['product_id'], unique=False) - op.create_index(op.f('ix_inventory_vendor_id'), 'inventory', ['vendor_id'], unique=False) - op.create_table('vendor_users', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('role_id', sa.Integer(), nullable=False), - sa.Column('invited_by', sa.Integer(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['invited_by'], ['users.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_admin_audit_logs_target_id"), + "admin_audit_logs", + ["target_id"], + unique=False, ) - op.create_index(op.f('ix_vendor_users_id'), 'vendor_users', ['id'], unique=False) - op.create_table('orders', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('customer_id', sa.Integer(), nullable=False), - sa.Column('order_number', sa.String(), nullable=False), - sa.Column('status', sa.String(), nullable=False), - sa.Column('subtotal', sa.Float(), nullable=False), - sa.Column('tax_amount', sa.Float(), nullable=True), - sa.Column('shipping_amount', sa.Float(), nullable=True), - sa.Column('discount_amount', sa.Float(), nullable=True), - sa.Column('total_amount', sa.Float(), nullable=False), - sa.Column('currency', sa.String(), nullable=True), - sa.Column('shipping_address_id', sa.Integer(), nullable=False), - sa.Column('billing_address_id', sa.Integer(), nullable=False), - sa.Column('shipping_method', sa.String(), nullable=True), - sa.Column('tracking_number', sa.String(), nullable=True), - sa.Column('customer_notes', sa.Text(), nullable=True), - sa.Column('internal_notes', sa.Text(), nullable=True), - sa.Column('paid_at', sa.DateTime(), nullable=True), - sa.Column('shipped_at', sa.DateTime(), nullable=True), - sa.Column('delivered_at', sa.DateTime(), nullable=True), - sa.Column('cancelled_at', sa.DateTime(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['billing_address_id'], ['customer_addresses.id'], ), - sa.ForeignKeyConstraint(['customer_id'], ['customers.id'], ), - sa.ForeignKeyConstraint(['shipping_address_id'], ['customer_addresses.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_admin_audit_logs_target_type"), + "admin_audit_logs", + ["target_type"], + unique=False, ) - op.create_index(op.f('ix_orders_customer_id'), 'orders', ['customer_id'], unique=False) - op.create_index(op.f('ix_orders_id'), 'orders', ['id'], unique=False) - op.create_index(op.f('ix_orders_order_number'), 'orders', ['order_number'], unique=True) - op.create_index(op.f('ix_orders_status'), 'orders', ['status'], unique=False) - op.create_index(op.f('ix_orders_vendor_id'), 'orders', ['vendor_id'], unique=False) - op.create_table('order_items', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('order_id', sa.Integer(), nullable=False), - sa.Column('product_id', sa.Integer(), nullable=False), - sa.Column('product_name', sa.String(), nullable=False), - sa.Column('product_sku', sa.String(), nullable=True), - sa.Column('quantity', sa.Integer(), nullable=False), - sa.Column('unit_price', sa.Float(), nullable=False), - sa.Column('total_price', sa.Float(), nullable=False), - sa.Column('inventory_reserved', sa.Boolean(), nullable=True), - sa.Column('inventory_fulfilled', sa.Boolean(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['order_id'], ['orders.id'], ), - sa.ForeignKeyConstraint(['product_id'], ['products.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "admin_notifications", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("type", sa.String(length=50), nullable=False), + sa.Column("priority", sa.String(length=20), nullable=True), + sa.Column("title", sa.String(length=200), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("is_read", sa.Boolean(), nullable=True), + sa.Column("read_at", sa.DateTime(), nullable=True), + sa.Column("read_by_user_id", sa.Integer(), nullable=True), + sa.Column("action_required", sa.Boolean(), nullable=True), + sa.Column("action_url", sa.String(length=500), nullable=True), + sa.Column("notification_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["read_by_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_admin_notifications_action_required"), + "admin_notifications", + ["action_required"], + unique=False, + ) + op.create_index( + op.f("ix_admin_notifications_id"), "admin_notifications", ["id"], unique=False + ) + op.create_index( + op.f("ix_admin_notifications_is_read"), + "admin_notifications", + ["is_read"], + unique=False, + ) + op.create_index( + op.f("ix_admin_notifications_priority"), + "admin_notifications", + ["priority"], + unique=False, + ) + op.create_index( + op.f("ix_admin_notifications_type"), + "admin_notifications", + ["type"], + unique=False, + ) + op.create_table( + "admin_sessions", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("admin_user_id", sa.Integer(), nullable=False), + sa.Column("session_token", sa.String(length=255), nullable=False), + sa.Column("ip_address", sa.String(length=45), nullable=False), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("login_at", sa.DateTime(), nullable=False), + sa.Column("last_activity_at", sa.DateTime(), nullable=False), + sa.Column("logout_at", sa.DateTime(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("logout_reason", sa.String(length=50), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["admin_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_admin_sessions_admin_user_id"), + "admin_sessions", + ["admin_user_id"], + unique=False, + ) + op.create_index( + op.f("ix_admin_sessions_id"), "admin_sessions", ["id"], unique=False + ) + op.create_index( + op.f("ix_admin_sessions_is_active"), + "admin_sessions", + ["is_active"], + unique=False, + ) + op.create_index( + op.f("ix_admin_sessions_login_at"), "admin_sessions", ["login_at"], unique=False + ) + op.create_index( + op.f("ix_admin_sessions_session_token"), + "admin_sessions", + ["session_token"], + unique=True, + ) + op.create_table( + "admin_settings", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("key", sa.String(length=100), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("value_type", sa.String(length=20), nullable=True), + sa.Column("category", sa.String(length=50), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("is_encrypted", sa.Boolean(), nullable=True), + sa.Column("is_public", sa.Boolean(), nullable=True), + sa.Column("last_modified_by_user_id", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["last_modified_by_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_admin_settings_category"), "admin_settings", ["category"], unique=False + ) + op.create_index( + op.f("ix_admin_settings_id"), "admin_settings", ["id"], unique=False + ) + op.create_index( + op.f("ix_admin_settings_key"), "admin_settings", ["key"], unique=True + ) + op.create_table( + "platform_alerts", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("alert_type", sa.String(length=50), nullable=False), + sa.Column("severity", sa.String(length=20), nullable=False), + sa.Column("title", sa.String(length=200), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("affected_vendors", sa.JSON(), nullable=True), + sa.Column("affected_systems", sa.JSON(), nullable=True), + sa.Column("is_resolved", sa.Boolean(), nullable=True), + sa.Column("resolved_at", sa.DateTime(), nullable=True), + sa.Column("resolved_by_user_id", sa.Integer(), nullable=True), + sa.Column("resolution_notes", sa.Text(), nullable=True), + sa.Column("auto_generated", sa.Boolean(), nullable=True), + sa.Column("occurrence_count", sa.Integer(), nullable=True), + sa.Column("first_occurred_at", sa.DateTime(), nullable=False), + sa.Column("last_occurred_at", sa.DateTime(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["resolved_by_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_platform_alerts_alert_type"), + "platform_alerts", + ["alert_type"], + unique=False, + ) + op.create_index( + op.f("ix_platform_alerts_id"), "platform_alerts", ["id"], unique=False + ) + op.create_index( + op.f("ix_platform_alerts_is_resolved"), + "platform_alerts", + ["is_resolved"], + unique=False, + ) + op.create_index( + op.f("ix_platform_alerts_severity"), + "platform_alerts", + ["severity"], + unique=False, + ) + op.create_table( + "vendors", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_code", sa.String(), nullable=False), + sa.Column("subdomain", sa.String(length=100), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("owner_user_id", sa.Integer(), nullable=False), + sa.Column("contact_email", sa.String(), nullable=True), + sa.Column("contact_phone", sa.String(), nullable=True), + sa.Column("website", sa.String(), nullable=True), + sa.Column("letzshop_csv_url_fr", sa.String(), nullable=True), + sa.Column("letzshop_csv_url_en", sa.String(), nullable=True), + sa.Column("letzshop_csv_url_de", sa.String(), nullable=True), + sa.Column("business_address", sa.Text(), nullable=True), + sa.Column("tax_number", sa.String(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("is_verified", sa.Boolean(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["owner_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_vendors_id"), "vendors", ["id"], unique=False) + op.create_index(op.f("ix_vendors_subdomain"), "vendors", ["subdomain"], unique=True) + op.create_index( + op.f("ix_vendors_vendor_code"), "vendors", ["vendor_code"], unique=True + ) + op.create_table( + "customers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(length=255), nullable=False), + sa.Column("hashed_password", sa.String(length=255), nullable=False), + sa.Column("first_name", sa.String(length=100), nullable=True), + sa.Column("last_name", sa.String(length=100), nullable=True), + sa.Column("phone", sa.String(length=50), nullable=True), + sa.Column("customer_number", sa.String(length=100), nullable=False), + sa.Column("preferences", sa.JSON(), nullable=True), + sa.Column("marketing_consent", sa.Boolean(), nullable=True), + sa.Column("last_order_date", sa.DateTime(), nullable=True), + sa.Column("total_orders", sa.Integer(), nullable=True), + sa.Column("total_spent", sa.Numeric(precision=10, scale=2), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_customers_customer_number"), + "customers", + ["customer_number"], + unique=False, + ) + op.create_index(op.f("ix_customers_email"), "customers", ["email"], unique=False) + op.create_index(op.f("ix_customers_id"), "customers", ["id"], unique=False) + op.create_table( + "marketplace_import_jobs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("marketplace", sa.String(), nullable=False), + sa.Column("source_url", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column("imported_count", sa.Integer(), nullable=True), + sa.Column("updated_count", sa.Integer(), nullable=True), + sa.Column("error_count", sa.Integer(), nullable=True), + sa.Column("total_processed", sa.Integer(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "idx_import_user_marketplace", + "marketplace_import_jobs", + ["user_id", "marketplace"], + unique=False, + ) + op.create_index( + "idx_import_vendor_created", + "marketplace_import_jobs", + ["vendor_id", "created_at"], + unique=False, + ) + op.create_index( + "idx_import_vendor_status", + "marketplace_import_jobs", + ["vendor_id", "status"], + unique=False, + ) + op.create_index( + op.f("ix_marketplace_import_jobs_id"), + "marketplace_import_jobs", + ["id"], + unique=False, + ) + op.create_index( + op.f("ix_marketplace_import_jobs_marketplace"), + "marketplace_import_jobs", + ["marketplace"], + unique=False, + ) + op.create_index( + op.f("ix_marketplace_import_jobs_vendor_id"), + "marketplace_import_jobs", + ["vendor_id"], + unique=False, + ) + op.create_table( + "products", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("marketplace_product_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.String(), nullable=True), + sa.Column("price", sa.Float(), nullable=True), + sa.Column("sale_price", sa.Float(), nullable=True), + sa.Column("currency", sa.String(), nullable=True), + sa.Column("availability", sa.String(), nullable=True), + sa.Column("condition", sa.String(), nullable=True), + sa.Column("is_featured", sa.Boolean(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("display_order", sa.Integer(), nullable=True), + sa.Column("min_quantity", sa.Integer(), nullable=True), + sa.Column("max_quantity", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["marketplace_product_id"], + ["marketplace_products.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("vendor_id", "marketplace_product_id", name="uq_product"), + ) + op.create_index( + "idx_product_active", "products", ["vendor_id", "is_active"], unique=False + ) + op.create_index( + "idx_product_featured", "products", ["vendor_id", "is_featured"], unique=False + ) + op.create_index(op.f("ix_products_id"), "products", ["id"], unique=False) + op.create_table( + "roles", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("permissions", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_roles_id"), "roles", ["id"], unique=False) + op.create_table( + "vendor_domains", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("domain", sa.String(length=255), nullable=False), + sa.Column("is_primary", sa.Boolean(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("ssl_status", sa.String(length=50), nullable=True), + sa.Column("ssl_verified_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("verification_token", sa.String(length=100), nullable=True), + sa.Column("is_verified", sa.Boolean(), nullable=False), + sa.Column("verified_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("vendor_id", "domain", name="uq_vendor_domain"), + sa.UniqueConstraint("verification_token"), + ) + op.create_index( + "idx_domain_active", "vendor_domains", ["domain", "is_active"], unique=False + ) + op.create_index( + "idx_vendor_primary", + "vendor_domains", + ["vendor_id", "is_primary"], + unique=False, + ) + op.create_index( + op.f("ix_vendor_domains_domain"), "vendor_domains", ["domain"], unique=True + ) + op.create_index( + op.f("ix_vendor_domains_id"), "vendor_domains", ["id"], unique=False + ) + op.create_table( + "vendor_themes", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("theme_name", sa.String(length=100), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("colors", sa.JSON(), nullable=True), + sa.Column("font_family_heading", sa.String(length=100), nullable=True), + sa.Column("font_family_body", sa.String(length=100), nullable=True), + sa.Column("logo_url", sa.String(length=500), nullable=True), + sa.Column("logo_dark_url", sa.String(length=500), nullable=True), + sa.Column("favicon_url", sa.String(length=500), nullable=True), + sa.Column("banner_url", sa.String(length=500), nullable=True), + sa.Column("layout_style", sa.String(length=50), nullable=True), + sa.Column("header_style", sa.String(length=50), nullable=True), + sa.Column("product_card_style", sa.String(length=50), nullable=True), + sa.Column("custom_css", sa.Text(), nullable=True), + sa.Column("social_links", sa.JSON(), nullable=True), + sa.Column("meta_title_template", sa.String(length=200), nullable=True), + sa.Column("meta_description", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("vendor_id"), + ) + op.create_index(op.f("ix_vendor_themes_id"), "vendor_themes", ["id"], unique=False) + op.create_table( + "customer_addresses", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("customer_id", sa.Integer(), nullable=False), + sa.Column("address_type", sa.String(length=50), nullable=False), + sa.Column("first_name", sa.String(length=100), nullable=False), + sa.Column("last_name", sa.String(length=100), nullable=False), + sa.Column("company", sa.String(length=200), nullable=True), + sa.Column("address_line_1", sa.String(length=255), nullable=False), + sa.Column("address_line_2", sa.String(length=255), nullable=True), + sa.Column("city", sa.String(length=100), nullable=False), + sa.Column("postal_code", sa.String(length=20), nullable=False), + sa.Column("country", sa.String(length=100), nullable=False), + sa.Column("is_default", sa.Boolean(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["customer_id"], + ["customers.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_customer_addresses_id"), "customer_addresses", ["id"], unique=False + ) + op.create_table( + "inventory", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("location", sa.String(), nullable=False), + sa.Column("quantity", sa.Integer(), nullable=False), + sa.Column("reserved_quantity", sa.Integer(), nullable=True), + sa.Column("gtin", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["product_id"], + ["products.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "product_id", "location", name="uq_inventory_product_location" + ), + ) + op.create_index( + "idx_inventory_product_location", + "inventory", + ["product_id", "location"], + unique=False, + ) + op.create_index( + "idx_inventory_vendor_product", + "inventory", + ["vendor_id", "product_id"], + unique=False, + ) + op.create_index(op.f("ix_inventory_gtin"), "inventory", ["gtin"], unique=False) + op.create_index(op.f("ix_inventory_id"), "inventory", ["id"], unique=False) + op.create_index( + op.f("ix_inventory_location"), "inventory", ["location"], unique=False + ) + op.create_index( + op.f("ix_inventory_product_id"), "inventory", ["product_id"], unique=False + ) + op.create_index( + op.f("ix_inventory_vendor_id"), "inventory", ["vendor_id"], unique=False + ) + op.create_table( + "vendor_users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.Column("invited_by", sa.Integer(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["invited_by"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["role_id"], + ["roles.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_vendor_users_id"), "vendor_users", ["id"], unique=False) + op.create_table( + "orders", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("customer_id", sa.Integer(), nullable=False), + sa.Column("order_number", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column("subtotal", sa.Float(), nullable=False), + sa.Column("tax_amount", sa.Float(), nullable=True), + sa.Column("shipping_amount", sa.Float(), nullable=True), + sa.Column("discount_amount", sa.Float(), nullable=True), + sa.Column("total_amount", sa.Float(), nullable=False), + sa.Column("currency", sa.String(), nullable=True), + sa.Column("shipping_address_id", sa.Integer(), nullable=False), + sa.Column("billing_address_id", sa.Integer(), nullable=False), + sa.Column("shipping_method", sa.String(), nullable=True), + sa.Column("tracking_number", sa.String(), nullable=True), + sa.Column("customer_notes", sa.Text(), nullable=True), + sa.Column("internal_notes", sa.Text(), nullable=True), + sa.Column("paid_at", sa.DateTime(), nullable=True), + sa.Column("shipped_at", sa.DateTime(), nullable=True), + sa.Column("delivered_at", sa.DateTime(), nullable=True), + sa.Column("cancelled_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["billing_address_id"], + ["customer_addresses.id"], + ), + sa.ForeignKeyConstraint( + ["customer_id"], + ["customers.id"], + ), + sa.ForeignKeyConstraint( + ["shipping_address_id"], + ["customer_addresses.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_orders_customer_id"), "orders", ["customer_id"], unique=False + ) + op.create_index(op.f("ix_orders_id"), "orders", ["id"], unique=False) + op.create_index( + op.f("ix_orders_order_number"), "orders", ["order_number"], unique=True + ) + op.create_index(op.f("ix_orders_status"), "orders", ["status"], unique=False) + op.create_index(op.f("ix_orders_vendor_id"), "orders", ["vendor_id"], unique=False) + op.create_table( + "order_items", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("order_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("product_name", sa.String(), nullable=False), + sa.Column("product_sku", sa.String(), nullable=True), + sa.Column("quantity", sa.Integer(), nullable=False), + sa.Column("unit_price", sa.Float(), nullable=False), + sa.Column("total_price", sa.Float(), nullable=False), + sa.Column("inventory_reserved", sa.Boolean(), nullable=True), + sa.Column("inventory_fulfilled", sa.Boolean(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["order_id"], + ["orders.id"], + ), + sa.ForeignKeyConstraint( + ["product_id"], + ["products.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_order_items_id"), "order_items", ["id"], unique=False) + op.create_index( + op.f("ix_order_items_order_id"), "order_items", ["order_id"], unique=False ) - op.create_index(op.f('ix_order_items_id'), 'order_items', ['id'], unique=False) - op.create_index(op.f('ix_order_items_order_id'), 'order_items', ['order_id'], unique=False) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_order_items_order_id'), table_name='order_items') - op.drop_index(op.f('ix_order_items_id'), table_name='order_items') - op.drop_table('order_items') - op.drop_index(op.f('ix_orders_vendor_id'), table_name='orders') - op.drop_index(op.f('ix_orders_status'), table_name='orders') - op.drop_index(op.f('ix_orders_order_number'), table_name='orders') - op.drop_index(op.f('ix_orders_id'), table_name='orders') - op.drop_index(op.f('ix_orders_customer_id'), table_name='orders') - op.drop_table('orders') - op.drop_index(op.f('ix_vendor_users_id'), table_name='vendor_users') - op.drop_table('vendor_users') - op.drop_index(op.f('ix_inventory_vendor_id'), table_name='inventory') - op.drop_index(op.f('ix_inventory_product_id'), table_name='inventory') - op.drop_index(op.f('ix_inventory_location'), table_name='inventory') - op.drop_index(op.f('ix_inventory_id'), table_name='inventory') - op.drop_index(op.f('ix_inventory_gtin'), table_name='inventory') - op.drop_index('idx_inventory_vendor_product', table_name='inventory') - op.drop_index('idx_inventory_product_location', table_name='inventory') - op.drop_table('inventory') - op.drop_index(op.f('ix_customer_addresses_id'), table_name='customer_addresses') - op.drop_table('customer_addresses') - op.drop_index(op.f('ix_vendor_themes_id'), table_name='vendor_themes') - op.drop_table('vendor_themes') - op.drop_index(op.f('ix_vendor_domains_id'), table_name='vendor_domains') - op.drop_index(op.f('ix_vendor_domains_domain'), table_name='vendor_domains') - op.drop_index('idx_vendor_primary', table_name='vendor_domains') - op.drop_index('idx_domain_active', table_name='vendor_domains') - op.drop_table('vendor_domains') - op.drop_index(op.f('ix_roles_id'), table_name='roles') - op.drop_table('roles') - op.drop_index(op.f('ix_products_id'), table_name='products') - op.drop_index('idx_product_featured', table_name='products') - op.drop_index('idx_product_active', table_name='products') - op.drop_table('products') - op.drop_index(op.f('ix_marketplace_import_jobs_vendor_id'), table_name='marketplace_import_jobs') - op.drop_index(op.f('ix_marketplace_import_jobs_marketplace'), table_name='marketplace_import_jobs') - op.drop_index(op.f('ix_marketplace_import_jobs_id'), table_name='marketplace_import_jobs') - op.drop_index('idx_import_vendor_status', table_name='marketplace_import_jobs') - op.drop_index('idx_import_vendor_created', table_name='marketplace_import_jobs') - op.drop_index('idx_import_user_marketplace', table_name='marketplace_import_jobs') - op.drop_table('marketplace_import_jobs') - op.drop_index(op.f('ix_customers_id'), table_name='customers') - op.drop_index(op.f('ix_customers_email'), table_name='customers') - op.drop_index(op.f('ix_customers_customer_number'), table_name='customers') - op.drop_table('customers') - op.drop_index(op.f('ix_vendors_vendor_code'), table_name='vendors') - op.drop_index(op.f('ix_vendors_subdomain'), table_name='vendors') - op.drop_index(op.f('ix_vendors_id'), table_name='vendors') - op.drop_table('vendors') - op.drop_index(op.f('ix_platform_alerts_severity'), table_name='platform_alerts') - op.drop_index(op.f('ix_platform_alerts_is_resolved'), table_name='platform_alerts') - op.drop_index(op.f('ix_platform_alerts_id'), table_name='platform_alerts') - op.drop_index(op.f('ix_platform_alerts_alert_type'), table_name='platform_alerts') - op.drop_table('platform_alerts') - op.drop_index(op.f('ix_admin_settings_key'), table_name='admin_settings') - op.drop_index(op.f('ix_admin_settings_id'), table_name='admin_settings') - op.drop_index(op.f('ix_admin_settings_category'), table_name='admin_settings') - op.drop_table('admin_settings') - op.drop_index(op.f('ix_admin_sessions_session_token'), table_name='admin_sessions') - op.drop_index(op.f('ix_admin_sessions_login_at'), table_name='admin_sessions') - op.drop_index(op.f('ix_admin_sessions_is_active'), table_name='admin_sessions') - op.drop_index(op.f('ix_admin_sessions_id'), table_name='admin_sessions') - op.drop_index(op.f('ix_admin_sessions_admin_user_id'), table_name='admin_sessions') - op.drop_table('admin_sessions') - op.drop_index(op.f('ix_admin_notifications_type'), table_name='admin_notifications') - op.drop_index(op.f('ix_admin_notifications_priority'), table_name='admin_notifications') - op.drop_index(op.f('ix_admin_notifications_is_read'), table_name='admin_notifications') - op.drop_index(op.f('ix_admin_notifications_id'), table_name='admin_notifications') - op.drop_index(op.f('ix_admin_notifications_action_required'), table_name='admin_notifications') - op.drop_table('admin_notifications') - op.drop_index(op.f('ix_admin_audit_logs_target_type'), table_name='admin_audit_logs') - op.drop_index(op.f('ix_admin_audit_logs_target_id'), table_name='admin_audit_logs') - op.drop_index(op.f('ix_admin_audit_logs_id'), table_name='admin_audit_logs') - op.drop_index(op.f('ix_admin_audit_logs_admin_user_id'), table_name='admin_audit_logs') - op.drop_index(op.f('ix_admin_audit_logs_action'), table_name='admin_audit_logs') - op.drop_table('admin_audit_logs') - op.drop_index(op.f('ix_users_username'), table_name='users') - op.drop_index(op.f('ix_users_id'), table_name='users') - op.drop_index(op.f('ix_users_email'), table_name='users') - op.drop_table('users') - op.drop_index(op.f('ix_marketplace_products_vendor_name'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_marketplace_product_id'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_marketplace'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_id'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_gtin'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_google_product_category'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_brand'), table_name='marketplace_products') - op.drop_index(op.f('ix_marketplace_products_availability'), table_name='marketplace_products') - op.drop_index('idx_marketplace_vendor', table_name='marketplace_products') - op.drop_index('idx_marketplace_brand', table_name='marketplace_products') - op.drop_table('marketplace_products') + op.drop_index(op.f("ix_order_items_order_id"), table_name="order_items") + op.drop_index(op.f("ix_order_items_id"), table_name="order_items") + op.drop_table("order_items") + op.drop_index(op.f("ix_orders_vendor_id"), table_name="orders") + op.drop_index(op.f("ix_orders_status"), table_name="orders") + op.drop_index(op.f("ix_orders_order_number"), table_name="orders") + op.drop_index(op.f("ix_orders_id"), table_name="orders") + op.drop_index(op.f("ix_orders_customer_id"), table_name="orders") + op.drop_table("orders") + op.drop_index(op.f("ix_vendor_users_id"), table_name="vendor_users") + op.drop_table("vendor_users") + op.drop_index(op.f("ix_inventory_vendor_id"), table_name="inventory") + op.drop_index(op.f("ix_inventory_product_id"), table_name="inventory") + op.drop_index(op.f("ix_inventory_location"), table_name="inventory") + op.drop_index(op.f("ix_inventory_id"), table_name="inventory") + op.drop_index(op.f("ix_inventory_gtin"), table_name="inventory") + op.drop_index("idx_inventory_vendor_product", table_name="inventory") + op.drop_index("idx_inventory_product_location", table_name="inventory") + op.drop_table("inventory") + op.drop_index(op.f("ix_customer_addresses_id"), table_name="customer_addresses") + op.drop_table("customer_addresses") + op.drop_index(op.f("ix_vendor_themes_id"), table_name="vendor_themes") + op.drop_table("vendor_themes") + op.drop_index(op.f("ix_vendor_domains_id"), table_name="vendor_domains") + op.drop_index(op.f("ix_vendor_domains_domain"), table_name="vendor_domains") + op.drop_index("idx_vendor_primary", table_name="vendor_domains") + op.drop_index("idx_domain_active", table_name="vendor_domains") + op.drop_table("vendor_domains") + op.drop_index(op.f("ix_roles_id"), table_name="roles") + op.drop_table("roles") + op.drop_index(op.f("ix_products_id"), table_name="products") + op.drop_index("idx_product_featured", table_name="products") + op.drop_index("idx_product_active", table_name="products") + op.drop_table("products") + op.drop_index( + op.f("ix_marketplace_import_jobs_vendor_id"), + table_name="marketplace_import_jobs", + ) + op.drop_index( + op.f("ix_marketplace_import_jobs_marketplace"), + table_name="marketplace_import_jobs", + ) + op.drop_index( + op.f("ix_marketplace_import_jobs_id"), table_name="marketplace_import_jobs" + ) + op.drop_index("idx_import_vendor_status", table_name="marketplace_import_jobs") + op.drop_index("idx_import_vendor_created", table_name="marketplace_import_jobs") + op.drop_index("idx_import_user_marketplace", table_name="marketplace_import_jobs") + op.drop_table("marketplace_import_jobs") + op.drop_index(op.f("ix_customers_id"), table_name="customers") + op.drop_index(op.f("ix_customers_email"), table_name="customers") + op.drop_index(op.f("ix_customers_customer_number"), table_name="customers") + op.drop_table("customers") + op.drop_index(op.f("ix_vendors_vendor_code"), table_name="vendors") + op.drop_index(op.f("ix_vendors_subdomain"), table_name="vendors") + op.drop_index(op.f("ix_vendors_id"), table_name="vendors") + op.drop_table("vendors") + op.drop_index(op.f("ix_platform_alerts_severity"), table_name="platform_alerts") + op.drop_index(op.f("ix_platform_alerts_is_resolved"), table_name="platform_alerts") + op.drop_index(op.f("ix_platform_alerts_id"), table_name="platform_alerts") + op.drop_index(op.f("ix_platform_alerts_alert_type"), table_name="platform_alerts") + op.drop_table("platform_alerts") + op.drop_index(op.f("ix_admin_settings_key"), table_name="admin_settings") + op.drop_index(op.f("ix_admin_settings_id"), table_name="admin_settings") + op.drop_index(op.f("ix_admin_settings_category"), table_name="admin_settings") + op.drop_table("admin_settings") + op.drop_index(op.f("ix_admin_sessions_session_token"), table_name="admin_sessions") + op.drop_index(op.f("ix_admin_sessions_login_at"), table_name="admin_sessions") + op.drop_index(op.f("ix_admin_sessions_is_active"), table_name="admin_sessions") + op.drop_index(op.f("ix_admin_sessions_id"), table_name="admin_sessions") + op.drop_index(op.f("ix_admin_sessions_admin_user_id"), table_name="admin_sessions") + op.drop_table("admin_sessions") + op.drop_index(op.f("ix_admin_notifications_type"), table_name="admin_notifications") + op.drop_index( + op.f("ix_admin_notifications_priority"), table_name="admin_notifications" + ) + op.drop_index( + op.f("ix_admin_notifications_is_read"), table_name="admin_notifications" + ) + op.drop_index(op.f("ix_admin_notifications_id"), table_name="admin_notifications") + op.drop_index( + op.f("ix_admin_notifications_action_required"), table_name="admin_notifications" + ) + op.drop_table("admin_notifications") + op.drop_index( + op.f("ix_admin_audit_logs_target_type"), table_name="admin_audit_logs" + ) + op.drop_index(op.f("ix_admin_audit_logs_target_id"), table_name="admin_audit_logs") + op.drop_index(op.f("ix_admin_audit_logs_id"), table_name="admin_audit_logs") + op.drop_index( + op.f("ix_admin_audit_logs_admin_user_id"), table_name="admin_audit_logs" + ) + op.drop_index(op.f("ix_admin_audit_logs_action"), table_name="admin_audit_logs") + op.drop_table("admin_audit_logs") + op.drop_index(op.f("ix_users_username"), table_name="users") + op.drop_index(op.f("ix_users_id"), table_name="users") + op.drop_index(op.f("ix_users_email"), table_name="users") + op.drop_table("users") + op.drop_index( + op.f("ix_marketplace_products_vendor_name"), table_name="marketplace_products" + ) + op.drop_index( + op.f("ix_marketplace_products_marketplace_product_id"), + table_name="marketplace_products", + ) + op.drop_index( + op.f("ix_marketplace_products_marketplace"), table_name="marketplace_products" + ) + op.drop_index(op.f("ix_marketplace_products_id"), table_name="marketplace_products") + op.drop_index( + op.f("ix_marketplace_products_gtin"), table_name="marketplace_products" + ) + op.drop_index( + op.f("ix_marketplace_products_google_product_category"), + table_name="marketplace_products", + ) + op.drop_index( + op.f("ix_marketplace_products_brand"), table_name="marketplace_products" + ) + op.drop_index( + op.f("ix_marketplace_products_availability"), table_name="marketplace_products" + ) + op.drop_index("idx_marketplace_vendor", table_name="marketplace_products") + op.drop_index("idx_marketplace_brand", table_name="marketplace_products") + op.drop_table("marketplace_products") # ### end Alembic commands ### diff --git a/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py b/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py index 23e274db..24a6d43d 100644 --- a/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py +++ b/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py @@ -5,59 +5,72 @@ Revises: fef1d20ce8b4 Create Date: 2025-11-22 15:16:13.213613 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '72aa309d4007' -down_revision: Union[str, None] = 'fef1d20ce8b4' +revision: str = "72aa309d4007" +down_revision: Union[str, None] = "fef1d20ce8b4" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('content_pages', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=True), - sa.Column('slug', sa.String(length=100), nullable=False), - sa.Column('title', sa.String(length=200), nullable=False), - sa.Column('content', sa.Text(), nullable=False), - sa.Column('content_format', sa.String(length=20), nullable=True), - sa.Column('meta_description', sa.String(length=300), nullable=True), - sa.Column('meta_keywords', sa.String(length=300), nullable=True), - sa.Column('is_published', sa.Boolean(), nullable=False), - sa.Column('published_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('display_order', sa.Integer(), nullable=True), - sa.Column('show_in_footer', sa.Boolean(), nullable=True), - sa.Column('show_in_header', sa.Boolean(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('created_by', sa.Integer(), nullable=True), - sa.Column('updated_by', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='SET NULL'), - sa.ForeignKeyConstraint(['updated_by'], ['users.id'], ondelete='SET NULL'), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('vendor_id', 'slug', name='uq_vendor_slug') + op.create_table( + "content_pages", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=True), + sa.Column("slug", sa.String(length=100), nullable=False), + sa.Column("title", sa.String(length=200), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("content_format", sa.String(length=20), nullable=True), + sa.Column("meta_description", sa.String(length=300), nullable=True), + sa.Column("meta_keywords", sa.String(length=300), nullable=True), + sa.Column("is_published", sa.Boolean(), nullable=False), + sa.Column("published_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("display_order", sa.Integer(), nullable=True), + sa.Column("show_in_footer", sa.Boolean(), nullable=True), + sa.Column("show_in_header", sa.Boolean(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_by", sa.Integer(), nullable=True), + sa.Column("updated_by", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["created_by"], ["users.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint(["updated_by"], ["users.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("vendor_id", "slug", name="uq_vendor_slug"), + ) + op.create_index( + "idx_slug_published", "content_pages", ["slug", "is_published"], unique=False + ) + op.create_index( + "idx_vendor_published", + "content_pages", + ["vendor_id", "is_published"], + unique=False, + ) + op.create_index(op.f("ix_content_pages_id"), "content_pages", ["id"], unique=False) + op.create_index( + op.f("ix_content_pages_slug"), "content_pages", ["slug"], unique=False + ) + op.create_index( + op.f("ix_content_pages_vendor_id"), "content_pages", ["vendor_id"], unique=False ) - op.create_index('idx_slug_published', 'content_pages', ['slug', 'is_published'], unique=False) - op.create_index('idx_vendor_published', 'content_pages', ['vendor_id', 'is_published'], unique=False) - op.create_index(op.f('ix_content_pages_id'), 'content_pages', ['id'], unique=False) - op.create_index(op.f('ix_content_pages_slug'), 'content_pages', ['slug'], unique=False) - op.create_index(op.f('ix_content_pages_vendor_id'), 'content_pages', ['vendor_id'], unique=False) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_content_pages_vendor_id'), table_name='content_pages') - op.drop_index(op.f('ix_content_pages_slug'), table_name='content_pages') - op.drop_index(op.f('ix_content_pages_id'), table_name='content_pages') - op.drop_index('idx_vendor_published', table_name='content_pages') - op.drop_index('idx_slug_published', table_name='content_pages') - op.drop_table('content_pages') + op.drop_index(op.f("ix_content_pages_vendor_id"), table_name="content_pages") + op.drop_index(op.f("ix_content_pages_slug"), table_name="content_pages") + op.drop_index(op.f("ix_content_pages_id"), table_name="content_pages") + op.drop_index("idx_vendor_published", table_name="content_pages") + op.drop_index("idx_slug_published", table_name="content_pages") + op.drop_table("content_pages") # ### end Alembic commands ### diff --git a/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py b/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py index 6fc55203..d8c85c8c 100644 --- a/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py +++ b/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py @@ -5,15 +5,17 @@ Revises: a2064e1dfcd4 Create Date: 2025-11-28 09:21:16.545203 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op + # revision identifiers, used by Alembic. -revision: str = '7a7ce92593d5' -down_revision: Union[str, None] = 'a2064e1dfcd4' +revision: str = "7a7ce92593d5" +down_revision: Union[str, None] = "a2064e1dfcd4" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,127 +23,269 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Create architecture_scans table op.create_table( - 'architecture_scans', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), - sa.Column('total_files', sa.Integer(), nullable=True), - sa.Column('total_violations', sa.Integer(), nullable=True), - sa.Column('errors', sa.Integer(), nullable=True), - sa.Column('warnings', sa.Integer(), nullable=True), - sa.Column('duration_seconds', sa.Float(), nullable=True), - sa.Column('triggered_by', sa.String(length=100), nullable=True), - sa.Column('git_commit_hash', sa.String(length=40), nullable=True), - sa.PrimaryKeyConstraint('id') + "architecture_scans", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "timestamp", + sa.DateTime(timezone=True), + server_default=sa.text("(datetime('now'))"), + nullable=False, + ), + sa.Column("total_files", sa.Integer(), nullable=True), + sa.Column("total_violations", sa.Integer(), nullable=True), + sa.Column("errors", sa.Integer(), nullable=True), + sa.Column("warnings", sa.Integer(), nullable=True), + sa.Column("duration_seconds", sa.Float(), nullable=True), + sa.Column("triggered_by", sa.String(length=100), nullable=True), + sa.Column("git_commit_hash", sa.String(length=40), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_architecture_scans_id"), "architecture_scans", ["id"], unique=False + ) + op.create_index( + op.f("ix_architecture_scans_timestamp"), + "architecture_scans", + ["timestamp"], + unique=False, ) - op.create_index(op.f('ix_architecture_scans_id'), 'architecture_scans', ['id'], unique=False) - op.create_index(op.f('ix_architecture_scans_timestamp'), 'architecture_scans', ['timestamp'], unique=False) # Create architecture_rules table op.create_table( - 'architecture_rules', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('rule_id', sa.String(length=20), nullable=False), - sa.Column('category', sa.String(length=50), nullable=False), - sa.Column('name', sa.String(length=200), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('severity', sa.String(length=10), nullable=False), - sa.Column('enabled', sa.Boolean(), nullable=False, server_default='1'), - sa.Column('custom_config', sa.JSON(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('rule_id') + "architecture_rules", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("rule_id", sa.String(length=20), nullable=False), + sa.Column("category", sa.String(length=50), nullable=False), + sa.Column("name", sa.String(length=200), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("severity", sa.String(length=10), nullable=False), + sa.Column("enabled", sa.Boolean(), nullable=False, server_default="1"), + sa.Column("custom_config", sa.JSON(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("(datetime('now'))"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("(datetime('now'))"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("rule_id"), + ) + op.create_index( + op.f("ix_architecture_rules_id"), "architecture_rules", ["id"], unique=False + ) + op.create_index( + op.f("ix_architecture_rules_rule_id"), + "architecture_rules", + ["rule_id"], + unique=True, ) - op.create_index(op.f('ix_architecture_rules_id'), 'architecture_rules', ['id'], unique=False) - op.create_index(op.f('ix_architecture_rules_rule_id'), 'architecture_rules', ['rule_id'], unique=True) # Create architecture_violations table op.create_table( - 'architecture_violations', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('scan_id', sa.Integer(), nullable=False), - sa.Column('rule_id', sa.String(length=20), nullable=False), - sa.Column('rule_name', sa.String(length=200), nullable=False), - sa.Column('severity', sa.String(length=10), nullable=False), - sa.Column('file_path', sa.String(length=500), nullable=False), - sa.Column('line_number', sa.Integer(), nullable=False), - sa.Column('message', sa.Text(), nullable=False), - sa.Column('context', sa.Text(), nullable=True), - sa.Column('suggestion', sa.Text(), nullable=True), - sa.Column('status', sa.String(length=20), server_default='open', nullable=True), - sa.Column('assigned_to', sa.Integer(), nullable=True), - sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('resolved_by', sa.Integer(), nullable=True), - sa.Column('resolution_note', sa.Text(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), - sa.ForeignKeyConstraint(['assigned_to'], ['users.id'], ), - sa.ForeignKeyConstraint(['resolved_by'], ['users.id'], ), - sa.ForeignKeyConstraint(['scan_id'], ['architecture_scans.id'], ), - sa.PrimaryKeyConstraint('id') + "architecture_violations", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("scan_id", sa.Integer(), nullable=False), + sa.Column("rule_id", sa.String(length=20), nullable=False), + sa.Column("rule_name", sa.String(length=200), nullable=False), + sa.Column("severity", sa.String(length=10), nullable=False), + sa.Column("file_path", sa.String(length=500), nullable=False), + sa.Column("line_number", sa.Integer(), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("context", sa.Text(), nullable=True), + sa.Column("suggestion", sa.Text(), nullable=True), + sa.Column("status", sa.String(length=20), server_default="open", nullable=True), + sa.Column("assigned_to", sa.Integer(), nullable=True), + sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("resolved_by", sa.Integer(), nullable=True), + sa.Column("resolution_note", sa.Text(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("(datetime('now'))"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["assigned_to"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["resolved_by"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["scan_id"], + ["architecture_scans.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_architecture_violations_file_path"), + "architecture_violations", + ["file_path"], + unique=False, + ) + op.create_index( + op.f("ix_architecture_violations_id"), + "architecture_violations", + ["id"], + unique=False, + ) + op.create_index( + op.f("ix_architecture_violations_rule_id"), + "architecture_violations", + ["rule_id"], + unique=False, + ) + op.create_index( + op.f("ix_architecture_violations_scan_id"), + "architecture_violations", + ["scan_id"], + unique=False, + ) + op.create_index( + op.f("ix_architecture_violations_severity"), + "architecture_violations", + ["severity"], + unique=False, + ) + op.create_index( + op.f("ix_architecture_violations_status"), + "architecture_violations", + ["status"], + unique=False, ) - op.create_index(op.f('ix_architecture_violations_file_path'), 'architecture_violations', ['file_path'], unique=False) - op.create_index(op.f('ix_architecture_violations_id'), 'architecture_violations', ['id'], unique=False) - op.create_index(op.f('ix_architecture_violations_rule_id'), 'architecture_violations', ['rule_id'], unique=False) - op.create_index(op.f('ix_architecture_violations_scan_id'), 'architecture_violations', ['scan_id'], unique=False) - op.create_index(op.f('ix_architecture_violations_severity'), 'architecture_violations', ['severity'], unique=False) - op.create_index(op.f('ix_architecture_violations_status'), 'architecture_violations', ['status'], unique=False) # Create violation_assignments table op.create_table( - 'violation_assignments', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('violation_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('assigned_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), - sa.Column('assigned_by', sa.Integer(), nullable=True), - sa.Column('due_date', sa.DateTime(timezone=True), nullable=True), - sa.Column('priority', sa.String(length=10), server_default='medium', nullable=True), - sa.ForeignKeyConstraint(['assigned_by'], ['users.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ), - sa.PrimaryKeyConstraint('id') + "violation_assignments", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("violation_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "assigned_at", + sa.DateTime(timezone=True), + server_default=sa.text("(datetime('now'))"), + nullable=False, + ), + sa.Column("assigned_by", sa.Integer(), nullable=True), + sa.Column("due_date", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "priority", sa.String(length=10), server_default="medium", nullable=True + ), + sa.ForeignKeyConstraint( + ["assigned_by"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["violation_id"], + ["architecture_violations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_violation_assignments_id"), + "violation_assignments", + ["id"], + unique=False, + ) + op.create_index( + op.f("ix_violation_assignments_violation_id"), + "violation_assignments", + ["violation_id"], + unique=False, ) - op.create_index(op.f('ix_violation_assignments_id'), 'violation_assignments', ['id'], unique=False) - op.create_index(op.f('ix_violation_assignments_violation_id'), 'violation_assignments', ['violation_id'], unique=False) # Create violation_comments table op.create_table( - 'violation_comments', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('violation_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('comment', sa.Text(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ), - sa.PrimaryKeyConstraint('id') + "violation_comments", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("violation_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("comment", sa.Text(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("(datetime('now'))"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["violation_id"], + ["architecture_violations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_violation_comments_id"), "violation_comments", ["id"], unique=False + ) + op.create_index( + op.f("ix_violation_comments_violation_id"), + "violation_comments", + ["violation_id"], + unique=False, ) - op.create_index(op.f('ix_violation_comments_id'), 'violation_comments', ['id'], unique=False) - op.create_index(op.f('ix_violation_comments_violation_id'), 'violation_comments', ['violation_id'], unique=False) def downgrade() -> None: # Drop tables in reverse order (to respect foreign key constraints) - op.drop_index(op.f('ix_violation_comments_violation_id'), table_name='violation_comments') - op.drop_index(op.f('ix_violation_comments_id'), table_name='violation_comments') - op.drop_table('violation_comments') + op.drop_index( + op.f("ix_violation_comments_violation_id"), table_name="violation_comments" + ) + op.drop_index(op.f("ix_violation_comments_id"), table_name="violation_comments") + op.drop_table("violation_comments") - op.drop_index(op.f('ix_violation_assignments_violation_id'), table_name='violation_assignments') - op.drop_index(op.f('ix_violation_assignments_id'), table_name='violation_assignments') - op.drop_table('violation_assignments') + op.drop_index( + op.f("ix_violation_assignments_violation_id"), + table_name="violation_assignments", + ) + op.drop_index( + op.f("ix_violation_assignments_id"), table_name="violation_assignments" + ) + op.drop_table("violation_assignments") - op.drop_index(op.f('ix_architecture_violations_status'), table_name='architecture_violations') - op.drop_index(op.f('ix_architecture_violations_severity'), table_name='architecture_violations') - op.drop_index(op.f('ix_architecture_violations_scan_id'), table_name='architecture_violations') - op.drop_index(op.f('ix_architecture_violations_rule_id'), table_name='architecture_violations') - op.drop_index(op.f('ix_architecture_violations_id'), table_name='architecture_violations') - op.drop_index(op.f('ix_architecture_violations_file_path'), table_name='architecture_violations') - op.drop_table('architecture_violations') + op.drop_index( + op.f("ix_architecture_violations_status"), table_name="architecture_violations" + ) + op.drop_index( + op.f("ix_architecture_violations_severity"), + table_name="architecture_violations", + ) + op.drop_index( + op.f("ix_architecture_violations_scan_id"), table_name="architecture_violations" + ) + op.drop_index( + op.f("ix_architecture_violations_rule_id"), table_name="architecture_violations" + ) + op.drop_index( + op.f("ix_architecture_violations_id"), table_name="architecture_violations" + ) + op.drop_index( + op.f("ix_architecture_violations_file_path"), + table_name="architecture_violations", + ) + op.drop_table("architecture_violations") - op.drop_index(op.f('ix_architecture_rules_rule_id'), table_name='architecture_rules') - op.drop_index(op.f('ix_architecture_rules_id'), table_name='architecture_rules') - op.drop_table('architecture_rules') + op.drop_index( + op.f("ix_architecture_rules_rule_id"), table_name="architecture_rules" + ) + op.drop_index(op.f("ix_architecture_rules_id"), table_name="architecture_rules") + op.drop_table("architecture_rules") - op.drop_index(op.f('ix_architecture_scans_timestamp'), table_name='architecture_scans') - op.drop_index(op.f('ix_architecture_scans_id'), table_name='architecture_scans') - op.drop_table('architecture_scans') + op.drop_index( + op.f("ix_architecture_scans_timestamp"), table_name="architecture_scans" + ) + op.drop_index(op.f("ix_architecture_scans_id"), table_name="architecture_scans") + op.drop_table("architecture_scans") diff --git a/alembic/versions/a2064e1dfcd4_add_cart_items_table.py b/alembic/versions/a2064e1dfcd4_add_cart_items_table.py index 7773c638..3eca3ff3 100644 --- a/alembic/versions/a2064e1dfcd4_add_cart_items_table.py +++ b/alembic/versions/a2064e1dfcd4_add_cart_items_table.py @@ -5,15 +5,16 @@ Revises: f68d8da5315a Create Date: 2025-11-23 19:52:40.509538 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'a2064e1dfcd4' -down_revision: Union[str, None] = 'f68d8da5315a' +revision: str = "a2064e1dfcd4" +down_revision: Union[str, None] = "f68d8da5315a" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,34 +22,46 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Create cart_items table op.create_table( - 'cart_items', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('vendor_id', sa.Integer(), nullable=False), - sa.Column('product_id', sa.Integer(), nullable=False), - sa.Column('session_id', sa.String(length=255), nullable=False), - sa.Column('quantity', sa.Integer(), nullable=False), - sa.Column('price_at_add', sa.Float(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['product_id'], ['products.id'], ), - sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('vendor_id', 'session_id', 'product_id', name='uq_cart_item') + "cart_items", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("session_id", sa.String(length=255), nullable=False), + sa.Column("quantity", sa.Integer(), nullable=False), + sa.Column("price_at_add", sa.Float(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["product_id"], + ["products.id"], + ), + sa.ForeignKeyConstraint( + ["vendor_id"], + ["vendors.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "vendor_id", "session_id", "product_id", name="uq_cart_item" + ), ) # Create indexes - op.create_index('idx_cart_session', 'cart_items', ['vendor_id', 'session_id'], unique=False) - op.create_index('idx_cart_created', 'cart_items', ['created_at'], unique=False) - op.create_index(op.f('ix_cart_items_id'), 'cart_items', ['id'], unique=False) - op.create_index(op.f('ix_cart_items_session_id'), 'cart_items', ['session_id'], unique=False) + op.create_index( + "idx_cart_session", "cart_items", ["vendor_id", "session_id"], unique=False + ) + op.create_index("idx_cart_created", "cart_items", ["created_at"], unique=False) + op.create_index(op.f("ix_cart_items_id"), "cart_items", ["id"], unique=False) + op.create_index( + op.f("ix_cart_items_session_id"), "cart_items", ["session_id"], unique=False + ) def downgrade() -> None: # Drop indexes - op.drop_index(op.f('ix_cart_items_session_id'), table_name='cart_items') - op.drop_index(op.f('ix_cart_items_id'), table_name='cart_items') - op.drop_index('idx_cart_created', table_name='cart_items') - op.drop_index('idx_cart_session', table_name='cart_items') + op.drop_index(op.f("ix_cart_items_session_id"), table_name="cart_items") + op.drop_index(op.f("ix_cart_items_id"), table_name="cart_items") + op.drop_index("idx_cart_created", table_name="cart_items") + op.drop_index("idx_cart_session", table_name="cart_items") # Drop table - op.drop_table('cart_items') + op.drop_table("cart_items") diff --git a/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py b/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py index 2b40dda2..b095835d 100644 --- a/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py +++ b/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py @@ -5,24 +5,30 @@ Revises: 72aa309d4007 Create Date: 2025-11-22 23:51:40.694983 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'f68d8da5315a' -down_revision: Union[str, None] = '72aa309d4007' +revision: str = "f68d8da5315a" +down_revision: Union[str, None] = "72aa309d4007" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add template column to content_pages table - op.add_column('content_pages', sa.Column('template', sa.String(length=50), nullable=False, server_default='default')) + op.add_column( + "content_pages", + sa.Column( + "template", sa.String(length=50), nullable=False, server_default="default" + ), + ) def downgrade() -> None: # Remove template column from content_pages table - op.drop_column('content_pages', 'template') + op.drop_column("content_pages", "template") diff --git a/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py b/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py index 2fcca7a8..6c0decce 100644 --- a/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py +++ b/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py @@ -6,15 +6,16 @@ Create Date: 2025-11-13 16:51:25.010057 SQLite-compatible version """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'fa7d4d10e358' -down_revision: Union[str, None] = '4951b2e50581' +revision: str = "fa7d4d10e358" +down_revision: Union[str, None] = "4951b2e50581" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -28,9 +29,14 @@ def upgrade(): # ======================================================================== # User table changes # ======================================================================== - with op.batch_alter_table('users', schema=None) as batch_op: + with op.batch_alter_table("users", schema=None) as batch_op: batch_op.add_column( - sa.Column('is_email_verified', sa.Boolean(), nullable=False, server_default='false') + sa.Column( + "is_email_verified", + sa.Boolean(), + nullable=False, + server_default="false", + ) ) # Set existing active users as verified @@ -39,68 +45,65 @@ def upgrade(): # ======================================================================== # VendorUser table changes (requires table recreation for SQLite) # ======================================================================== - with op.batch_alter_table('vendor_users', schema=None) as batch_op: + with op.batch_alter_table("vendor_users", schema=None) as batch_op: # Add new columns batch_op.add_column( - sa.Column('user_type', sa.String(length=20), nullable=False, server_default='member') + sa.Column( + "user_type", + sa.String(length=20), + nullable=False, + server_default="member", + ) ) batch_op.add_column( - sa.Column('invitation_token', sa.String(length=100), nullable=True) + sa.Column("invitation_token", sa.String(length=100), nullable=True) ) batch_op.add_column( - sa.Column('invitation_sent_at', sa.DateTime(), nullable=True) + sa.Column("invitation_sent_at", sa.DateTime(), nullable=True) ) batch_op.add_column( - sa.Column('invitation_accepted_at', sa.DateTime(), nullable=True) + sa.Column("invitation_accepted_at", sa.DateTime(), nullable=True) ) # Create index on invitation_token - batch_op.create_index( - 'idx_vendor_users_invitation_token', - ['invitation_token'] - ) + batch_op.create_index("idx_vendor_users_invitation_token", ["invitation_token"]) # Modify role_id to be nullable (this recreates the table in SQLite) - batch_op.alter_column( - 'role_id', - existing_type=sa.Integer(), - nullable=True - ) + batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=True) # Change is_active default (this recreates the table in SQLite) batch_op.alter_column( - 'is_active', - existing_type=sa.Boolean(), - server_default='false' + "is_active", existing_type=sa.Boolean(), server_default="false" ) # Set owners correctly (after table modifications) # SQLite-compatible UPDATE with subquery - op.execute(""" + op.execute( + """ UPDATE vendor_users SET user_type = 'owner' WHERE (vendor_id, user_id) IN ( SELECT id, owner_user_id FROM vendors ) - """) + """ + ) # Set existing owners as active - op.execute(""" + op.execute( + """ UPDATE vendor_users SET is_active = TRUE WHERE user_type = 'owner' - """) + """ + ) # ======================================================================== # Role table changes # ======================================================================== - with op.batch_alter_table('roles', schema=None) as batch_op: + with op.batch_alter_table("roles", schema=None) as batch_op: # Create index on vendor_id and name - batch_op.create_index( - 'idx_roles_vendor_name', - ['vendor_id', 'name'] - ) + batch_op.create_index("idx_roles_vendor_name", ["vendor_id", "name"]) # Note: JSONB conversion only for PostgreSQL # SQLite stores JSON as TEXT by default, no conversion needed @@ -115,37 +118,31 @@ def downgrade(): # ======================================================================== # Role table changes # ======================================================================== - with op.batch_alter_table('roles', schema=None) as batch_op: - batch_op.drop_index('idx_roles_vendor_name') + with op.batch_alter_table("roles", schema=None) as batch_op: + batch_op.drop_index("idx_roles_vendor_name") # ======================================================================== # VendorUser table changes # ======================================================================== - with op.batch_alter_table('vendor_users', schema=None) as batch_op: + with op.batch_alter_table("vendor_users", schema=None) as batch_op: # Revert is_active default batch_op.alter_column( - 'is_active', - existing_type=sa.Boolean(), - server_default='true' + "is_active", existing_type=sa.Boolean(), server_default="true" ) # Revert role_id to NOT NULL # Note: This might fail if there are NULL values - batch_op.alter_column( - 'role_id', - existing_type=sa.Integer(), - nullable=False - ) + batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=False) # Drop indexes and columns - batch_op.drop_index('idx_vendor_users_invitation_token') - batch_op.drop_column('invitation_accepted_at') - batch_op.drop_column('invitation_sent_at') - batch_op.drop_column('invitation_token') - batch_op.drop_column('user_type') + batch_op.drop_index("idx_vendor_users_invitation_token") + batch_op.drop_column("invitation_accepted_at") + batch_op.drop_column("invitation_sent_at") + batch_op.drop_column("invitation_token") + batch_op.drop_column("user_type") # ======================================================================== # User table changes # ======================================================================== - with op.batch_alter_table('users', schema=None) as batch_op: - batch_op.drop_column('is_email_verified') \ No newline at end of file + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.drop_column("is_email_verified") diff --git a/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py b/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py index 0c895879..cc24f81a 100644 --- a/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py +++ b/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py @@ -5,30 +5,43 @@ Revises: fa7d4d10e358 Create Date: 2025-11-22 13:41:18.069674 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'fef1d20ce8b4' -down_revision: Union[str, None] = 'fa7d4d10e358' +revision: str = "fef1d20ce8b4" +down_revision: Union[str, None] = "fa7d4d10e358" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('idx_roles_vendor_name', table_name='roles') - op.drop_index('idx_vendor_users_invitation_token', table_name='vendor_users') - op.create_index(op.f('ix_vendor_users_invitation_token'), 'vendor_users', ['invitation_token'], unique=False) + op.drop_index("idx_roles_vendor_name", table_name="roles") + op.drop_index("idx_vendor_users_invitation_token", table_name="vendor_users") + op.create_index( + op.f("ix_vendor_users_invitation_token"), + "vendor_users", + ["invitation_token"], + unique=False, + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_vendor_users_invitation_token'), table_name='vendor_users') - op.create_index('idx_vendor_users_invitation_token', 'vendor_users', ['invitation_token'], unique=False) - op.create_index('idx_roles_vendor_name', 'roles', ['vendor_id', 'name'], unique=False) + op.drop_index(op.f("ix_vendor_users_invitation_token"), table_name="vendor_users") + op.create_index( + "idx_vendor_users_invitation_token", + "vendor_users", + ["invitation_token"], + unique=False, + ) + op.create_index( + "idx_roles_vendor_name", "roles", ["vendor_id", "name"], unique=False + ) # ### end Alembic commands ### diff --git a/app/api/deps.py b/app/api/deps.py index d219a60a..f2761a9a 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -34,22 +34,20 @@ The cookie path restrictions prevent cross-context cookie leakage: import logging from typing import Optional -from fastapi import Depends, Request, Cookie +from fastapi import Cookie, Depends, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from app.core.database import get_db +from app.exceptions import (AdminRequiredException, + InsufficientPermissionsException, + InvalidTokenException, + UnauthorizedVendorAccessException, + VendorNotFoundException) from middleware.auth import AuthManager from middleware.rate_limiter import RateLimiter -from models.database.vendor import Vendor from models.database.user import User -from app.exceptions import ( - AdminRequiredException, - InvalidTokenException, - InsufficientPermissionsException, - VendorNotFoundException, - UnauthorizedVendorAccessException -) +from models.database.vendor import Vendor # Initialize dependencies security = HTTPBearer(auto_error=False) # auto_error=False prevents automatic 403 @@ -62,11 +60,12 @@ logger = logging.getLogger(__name__) # HELPER FUNCTIONS # ============================================================================ + def _get_token_from_request( - credentials: Optional[HTTPAuthorizationCredentials], - cookie_value: Optional[str], - cookie_name: str, - request_path: str + credentials: Optional[HTTPAuthorizationCredentials], + cookie_value: Optional[str], + cookie_name: str, + request_path: str, ) -> tuple[Optional[str], Optional[str]]: """ Extract token from Authorization header or cookie. @@ -108,10 +107,7 @@ def _validate_user_token(token: str, db: Session) -> User: Raises: InvalidTokenException: If token is invalid """ - mock_credentials = HTTPAuthorizationCredentials( - scheme="Bearer", - credentials=token - ) + mock_credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) return auth_manager.get_current_user(db, mock_credentials) @@ -119,11 +115,12 @@ def _validate_user_token(token: str, db: Session) -> User: # ADMIN AUTHENTICATION # ============================================================================ + def get_current_admin_from_cookie_or_header( - request: Request, - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - admin_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db), + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + admin_token: Optional[str] = Cookie(None), + db: Session = Depends(get_db), ) -> User: """ Get current admin user from admin_token cookie or Authorization header. @@ -148,10 +145,7 @@ def get_current_admin_from_cookie_or_header( AdminRequiredException: If user is not admin """ token, source = _get_token_from_request( - credentials, - admin_token, - "admin_token", - str(request.url.path) + credentials, admin_token, "admin_token", str(request.url.path) ) if not token: @@ -172,8 +166,8 @@ def get_current_admin_from_cookie_or_header( def get_current_admin_api( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db), + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), ) -> User: """ Get current admin user from Authorization header ONLY. @@ -208,11 +202,12 @@ def get_current_admin_api( # VENDOR AUTHENTICATION # ============================================================================ + def get_current_vendor_from_cookie_or_header( - request: Request, - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - vendor_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db), + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + vendor_token: Optional[str] = Cookie(None), + db: Session = Depends(get_db), ) -> User: """ Get current vendor user from vendor_token cookie or Authorization header. @@ -237,10 +232,7 @@ def get_current_vendor_from_cookie_or_header( InsufficientPermissionsException: If user is not vendor or is admin """ token, source = _get_token_from_request( - credentials, - vendor_token, - "vendor_token", - str(request.url.path) + credentials, vendor_token, "vendor_token", str(request.url.path) ) if not token: @@ -270,8 +262,8 @@ def get_current_vendor_from_cookie_or_header( def get_current_vendor_api( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db), + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), ) -> User: """ Get current vendor user from Authorization header ONLY. @@ -310,11 +302,12 @@ def get_current_vendor_api( # CUSTOMER AUTHENTICATION (SHOP) # ============================================================================ + def get_current_customer_from_cookie_or_header( - request: Request, - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - customer_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db), + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + customer_token: Optional[str] = Cookie(None), + db: Session = Depends(get_db), ): """ Get current customer from customer_token cookie or Authorization header. @@ -338,15 +331,14 @@ def get_current_customer_from_cookie_or_header( Raises: InvalidTokenException: If no token or invalid token """ - from models.database.customer import Customer - from jose import jwt, JWTError from datetime import datetime, timezone + from jose import JWTError, jwt + + from models.database.customer import Customer + token, source = _get_token_from_request( - credentials, - customer_token, - "customer_token", - str(request.url.path) + credentials, customer_token, "customer_token", str(request.url.path) ) if not token: @@ -356,9 +348,7 @@ def get_current_customer_from_cookie_or_header( # Decode and validate customer JWT token try: payload = jwt.decode( - token, - auth_manager.secret_key, - algorithms=[auth_manager.algorithm] + token, auth_manager.secret_key, algorithms=[auth_manager.algorithm] ) # Verify this is a customer token @@ -375,7 +365,9 @@ def get_current_customer_from_cookie_or_header( # Verify token hasn't expired exp = payload.get("exp") - if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc): + if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now( + timezone.utc + ): logger.warning(f"Expired customer token for customer_id={customer_id}") raise InvalidTokenException("Token has expired") @@ -400,8 +392,8 @@ def get_current_customer_from_cookie_or_header( def get_current_customer_api( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db), + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), ) -> User: """ Get current customer user from Authorization header ONLY. @@ -445,9 +437,10 @@ def get_current_customer_api( # GENERIC AUTHENTICATION (for mixed-use endpoints) # ============================================================================ + def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db) + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), ) -> User: """ Get current authenticated user from Authorization header only. @@ -475,10 +468,11 @@ def get_current_user( # VENDOR OWNERSHIP VERIFICATION # ============================================================================ + def get_user_vendor( - vendor_code: str, - current_user: User = Depends(get_current_vendor_from_cookie_or_header), - db: Session = Depends(get_db), + vendor_code: str, + current_user: User = Depends(get_current_vendor_from_cookie_or_header), + db: Session = Depends(get_db), ) -> Vendor: """ Get vendor and verify user ownership/membership. @@ -500,9 +494,7 @@ def get_user_vendor( VendorNotFoundException: If vendor doesn't exist UnauthorizedVendorAccessException: If user doesn't have access """ - vendor = db.query(Vendor).filter( - Vendor.vendor_code == vendor_code.upper() - ).first() + vendor = db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first() if not vendor: raise VendorNotFoundException(vendor_code) @@ -517,10 +509,12 @@ def get_user_vendor( # User doesn't have access to this vendor raise UnauthorizedVendorAccessException(vendor_code, current_user.id) + # ============================================================================ # PERMISSIONS CHECKING # ============================================================================ + def require_vendor_permission(permission: str): """ Dependency factory to require a specific vendor permission. @@ -535,9 +529,9 @@ def require_vendor_permission(permission: str): """ def permission_checker( - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_vendor_from_cookie_or_header), + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ) -> User: # Get vendor from request state (set by middleware) vendor = getattr(request.state, "vendor", None) @@ -557,9 +551,9 @@ def require_vendor_permission(permission: str): def require_vendor_owner( - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_vendor_from_cookie_or_header), + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ) -> User: """ Dependency to require vendor owner role. @@ -600,9 +594,9 @@ def require_any_vendor_permission(*permissions: str): """ def permission_checker( - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_vendor_from_cookie_or_header), + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ) -> User: vendor = getattr(request.state, "vendor", None) if not vendor: @@ -610,8 +604,7 @@ def require_any_vendor_permission(*permissions: str): # Check if user has ANY of the required permissions has_permission = any( - current_user.has_vendor_permission(vendor.id, perm) - for perm in permissions + current_user.has_vendor_permission(vendor.id, perm) for perm in permissions ) if not has_permission: @@ -641,9 +634,9 @@ def require_all_vendor_permissions(*permissions: str): """ def permission_checker( - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_vendor_from_cookie_or_header), + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ) -> User: vendor = getattr(request.state, "vendor", None) if not vendor: @@ -651,7 +644,8 @@ def require_all_vendor_permissions(*permissions: str): # Check if user has ALL required permissions missing_permissions = [ - perm for perm in permissions + perm + for perm in permissions if not current_user.has_vendor_permission(vendor.id, perm) ] @@ -667,8 +661,8 @@ def require_all_vendor_permissions(*permissions: str): def get_user_permissions( - request: Request, - current_user: User = Depends(get_current_vendor_from_cookie_or_header), + request: Request, + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ) -> list: """ Get all permissions for current user in current vendor. @@ -682,6 +676,7 @@ def get_user_permissions( # If owner, return all permissions if current_user.is_owner_of(vendor.id): from app.core.permissions import VendorPermissions + return [p.value for p in VendorPermissions] # Get permissions from vendor membership @@ -696,11 +691,12 @@ def get_user_permissions( # OPTIONAL AUTHENTICATION (For Login Page Redirects) # ============================================================================ + def get_current_admin_optional( - request: Request, - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - admin_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db), + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + admin_token: Optional[str] = Cookie(None), + db: Session = Depends(get_db), ) -> Optional[User]: """ Get current admin user from admin_token cookie or Authorization header. @@ -723,10 +719,7 @@ def get_current_admin_optional( None: If no token, invalid token, or user is not admin """ token, source = _get_token_from_request( - credentials, - admin_token, - "admin_token", - str(request.url.path) + credentials, admin_token, "admin_token", str(request.url.path) ) if not token: @@ -747,10 +740,10 @@ def get_current_admin_optional( def get_current_vendor_optional( - request: Request, - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - vendor_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db), + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + vendor_token: Optional[str] = Cookie(None), + db: Session = Depends(get_db), ) -> Optional[User]: """ Get current vendor user from vendor_token cookie or Authorization header. @@ -773,10 +766,7 @@ def get_current_vendor_optional( None: If no token, invalid token, or user is not vendor """ token, source = _get_token_from_request( - credentials, - vendor_token, - "vendor_token", - str(request.url.path) + credentials, vendor_token, "vendor_token", str(request.url.path) ) if not token: @@ -797,10 +787,10 @@ def get_current_vendor_optional( def get_current_customer_optional( - request: Request, - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - customer_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db), + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + customer_token: Optional[str] = Cookie(None), + db: Session = Depends(get_db), ) -> Optional[User]: """ Get current customer user from customer_token cookie or Authorization header. @@ -823,10 +813,7 @@ def get_current_customer_optional( None: If no token, invalid token, or user is not customer """ token, source = _get_token_from_request( - credentials, - customer_token, - "customer_token", - str(request.url.path) + credentials, customer_token, "customer_token", str(request.url.path) ) if not token: @@ -844,5 +831,3 @@ def get_current_customer_optional( pass return None - - diff --git a/app/api/main.py b/app/api/main.py index a6997bd1..1f1c43be 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -9,7 +9,8 @@ This module provides: """ from fastapi import APIRouter -from app.api.v1 import admin, vendor, shop + +from app.api.v1 import admin, shop, vendor api_router = APIRouter() @@ -18,31 +19,18 @@ api_router = APIRouter() # Prefix: /api/v1/admin # ============================================================================ -api_router.include_router( - admin.router, - prefix="/v1/admin", - tags=["admin"] -) +api_router.include_router(admin.router, prefix="/v1/admin", tags=["admin"]) # ============================================================================ # VENDOR ROUTES (Vendor-scoped operations) # Prefix: /api/v1/vendor # ============================================================================ -api_router.include_router( - vendor.router, - prefix="/v1/vendor", - tags=["vendor"] -) +api_router.include_router(vendor.router, prefix="/v1/vendor", tags=["vendor"]) # ============================================================================ # SHOP ROUTES (Public shop frontend API) # Prefix: /api/v1/shop # ============================================================================ -api_router.include_router( - shop.router, - prefix="/v1/shop", - tags=["shop"] -) - +api_router.include_router(shop.router, prefix="/v1/shop", tags=["shop"]) diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py index 8890d586..223a83c5 100644 --- a/app/api/v1/__init__.py +++ b/app/api/v1/__init__.py @@ -3,6 +3,6 @@ API Version 1 - All endpoints """ -from . import admin, vendor, shop +from . import admin, shop, vendor -__all__ = ["admin", "vendor", "shop"] \ No newline at end of file +__all__ = ["admin", "vendor", "shop"] diff --git a/app/api/v1/admin/__init__.py b/app/api/v1/admin/__init__.py index f2e5ca10..a7b9d009 100644 --- a/app/api/v1/admin/__init__.py +++ b/app/api/v1/admin/__init__.py @@ -24,21 +24,9 @@ IMPORTANT: from fastapi import APIRouter # Import all admin routers -from . import ( - auth, - vendors, - vendor_domains, - vendor_themes, - users, - dashboard, - marketplace, - monitoring, - audit, - settings, - notifications, - content_pages, - code_quality -) +from . import (audit, auth, code_quality, content_pages, dashboard, + marketplace, monitoring, notifications, settings, users, + vendor_domains, vendor_themes, vendors) # Create admin router router = APIRouter() @@ -66,7 +54,9 @@ router.include_router(vendor_domains.router, tags=["admin-vendor-domains"]) router.include_router(vendor_themes.router, tags=["admin-vendor-themes"]) # Include content pages management endpoints -router.include_router(content_pages.router, prefix="/content-pages", tags=["admin-content-pages"]) +router.include_router( + content_pages.router, prefix="/content-pages", tags=["admin-content-pages"] +) # ============================================================================ @@ -115,7 +105,9 @@ router.include_router(notifications.router, tags=["admin-notifications"]) # ============================================================================ # Include code quality and architecture validation endpoints -router.include_router(code_quality.router, prefix="/code-quality", tags=["admin-code-quality"]) +router.include_router( + code_quality.router, prefix="/code-quality", tags=["admin-code-quality"] +) # Export the router __all__ = ["router"] diff --git a/app/api/v1/admin/audit.py b/app/api/v1/admin/audit.py index a150cbf2..d3d91511 100644 --- a/app/api/v1/admin/audit.py +++ b/app/api/v1/admin/audit.py @@ -9,8 +9,8 @@ Provides endpoints for: """ import logging -from typing import Optional from datetime import datetime +from typing import Optional from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session @@ -18,12 +18,10 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db from app.services.admin_audit_service import admin_audit_service -from models.schema.admin import ( - AdminAuditLogResponse, - AdminAuditLogFilters, - AdminAuditLogListResponse -) from models.database.user import User +from models.schema.admin import (AdminAuditLogFilters, + AdminAuditLogListResponse, + AdminAuditLogResponse) router = APIRouter(prefix="/audit") logger = logging.getLogger(__name__) @@ -31,15 +29,15 @@ logger = logging.getLogger(__name__) @router.get("/logs", response_model=AdminAuditLogListResponse) def get_audit_logs( - admin_user_id: Optional[int] = Query(None, description="Filter by admin user"), - action: Optional[str] = Query(None, description="Filter by action type"), - target_type: Optional[str] = Query(None, description="Filter by target type"), - date_from: Optional[datetime] = Query(None, description="Filter from date"), - date_to: Optional[datetime] = Query(None, description="Filter to date"), - skip: int = Query(0, ge=0, description="Number of records to skip"), - limit: int = Query(100, ge=1, le=1000, description="Maximum records to return"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + admin_user_id: Optional[int] = Query(None, description="Filter by admin user"), + action: Optional[str] = Query(None, description="Filter by action type"), + target_type: Optional[str] = Query(None, description="Filter by target type"), + date_from: Optional[datetime] = Query(None, description="Filter from date"), + date_to: Optional[datetime] = Query(None, description="Filter to date"), + skip: int = Query(0, ge=0, description="Number of records to skip"), + limit: int = Query(100, ge=1, le=1000, description="Maximum records to return"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get filtered admin audit logs. @@ -54,7 +52,7 @@ def get_audit_logs( date_from=date_from, date_to=date_to, skip=skip, - limit=limit + limit=limit, ) logs = admin_audit_service.get_audit_logs(db, filters) @@ -62,19 +60,14 @@ def get_audit_logs( logger.info(f"Admin {current_admin.username} retrieved {len(logs)} audit logs") - return AdminAuditLogListResponse( - logs=logs, - total=total, - skip=skip, - limit=limit - ) + return AdminAuditLogListResponse(logs=logs, total=total, skip=skip, limit=limit) @router.get("/logs/recent", response_model=list[AdminAuditLogResponse]) def get_recent_audit_logs( - limit: int = Query(20, ge=1, le=100), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + limit: int = Query(20, ge=1, le=100), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Get recent audit logs (last 20 by default).""" filters = AdminAuditLogFilters(limit=limit) @@ -83,25 +76,23 @@ def get_recent_audit_logs( @router.get("/logs/my-actions", response_model=list[AdminAuditLogResponse]) def get_my_actions( - limit: int = Query(50, ge=1, le=100), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + limit: int = Query(50, ge=1, le=100), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Get audit logs for current admin's actions.""" return admin_audit_service.get_recent_actions_by_admin( - db=db, - admin_user_id=current_admin.id, - limit=limit + db=db, admin_user_id=current_admin.id, limit=limit ) @router.get("/logs/target/{target_type}/{target_id}") def get_actions_by_target( - target_type: str, - target_id: str, - limit: int = Query(50, ge=1, le=100), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + target_type: str, + target_id: str, + limit: int = Query(50, ge=1, le=100), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get all actions performed on a specific target. @@ -109,8 +100,5 @@ def get_actions_by_target( Useful for tracking the history of a specific vendor, user, or entity. """ return admin_audit_service.get_actions_by_target( - db=db, - target_type=target_type, - target_id=target_id, - limit=limit + db=db, target_type=target_type, target_id=target_id, limit=limit ) diff --git a/app/api/v1/admin/auth.py b/app/api/v1/admin/auth.py index 5de07e00..a739e5cd 100644 --- a/app/api/v1/admin/auth.py +++ b/app/api/v1/admin/auth.py @@ -10,16 +10,17 @@ This prevents admin cookies from being sent to vendor routes. """ import logging + from fastapi import APIRouter, Depends, Response from sqlalchemy.orm import Session +from app.api.deps import get_current_admin_api from app.core.database import get_db from app.core.environment import should_use_secure_cookies -from app.services.auth_service import auth_service from app.exceptions import InvalidCredentialsException -from models.schema.auth import LoginResponse, UserLogin, UserResponse +from app.services.auth_service import auth_service from models.database.user import User -from app.api.deps import get_current_admin_api +from models.schema.auth import LoginResponse, UserLogin, UserResponse router = APIRouter(prefix="/auth") logger = logging.getLogger(__name__) @@ -27,9 +28,7 @@ logger = logging.getLogger(__name__) @router.post("/login", response_model=LoginResponse) def admin_login( - user_credentials: UserLogin, - response: Response, - db: Session = Depends(get_db) + user_credentials: UserLogin, response: Response, db: Session = Depends(get_db) ): """ Admin login endpoint. @@ -49,7 +48,9 @@ def admin_login( # Verify user is admin if login_result["user"].role != "admin": - logger.warning(f"Non-admin user attempted admin login: {user_credentials.email_or_username}") + logger.warning( + f"Non-admin user attempted admin login: {user_credentials.email_or_username}" + ) raise InvalidCredentialsException("Admin access required") logger.info(f"Admin login successful: {login_result['user'].username}") diff --git a/app/api/v1/admin/code_quality.py b/app/api/v1/admin/code_quality.py index de3e5674..8193638b 100644 --- a/app/api/v1/admin/code_quality.py +++ b/app/api/v1/admin/code_quality.py @@ -3,25 +3,27 @@ Code Quality API Endpoints RESTful API for architecture validation and violation management """ -from typing import Optional from datetime import datetime -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.orm import Session -from pydantic import BaseModel, Field +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from app.api.deps import get_current_admin_api from app.core.database import get_db from app.services.code_quality_service import code_quality_service -from app.api.deps import get_current_admin_api from models.database.user import User - router = APIRouter() # Pydantic Models for API + class ScanResponse(BaseModel): """Response model for a scan""" + id: int timestamp: str total_files: int @@ -38,6 +40,7 @@ class ScanResponse(BaseModel): class ViolationResponse(BaseModel): """Response model for a violation""" + id: int scan_id: int rule_id: str @@ -61,6 +64,7 @@ class ViolationResponse(BaseModel): class ViolationListResponse(BaseModel): """Response model for paginated violations list""" + violations: list[ViolationResponse] total: int page: int @@ -70,34 +74,42 @@ class ViolationListResponse(BaseModel): class ViolationDetailResponse(ViolationResponse): """Response model for single violation with relationships""" + assignments: list = [] comments: list = [] class AssignViolationRequest(BaseModel): """Request model for assigning a violation""" + user_id: int = Field(..., description="User ID to assign to") due_date: Optional[datetime] = Field(None, description="Due date for resolution") - priority: str = Field("medium", description="Priority level (low, medium, high, critical)") + priority: str = Field( + "medium", description="Priority level (low, medium, high, critical)" + ) class ResolveViolationRequest(BaseModel): """Request model for resolving a violation""" + resolution_note: str = Field(..., description="Note about the resolution") class IgnoreViolationRequest(BaseModel): """Request model for ignoring a violation""" + reason: str = Field(..., description="Reason for ignoring") class AddCommentRequest(BaseModel): """Request model for adding a comment""" + comment: str = Field(..., min_length=1, description="Comment text") class DashboardStatsResponse(BaseModel): """Response model for dashboard statistics""" + total_violations: int errors: int warnings: int @@ -116,10 +128,10 @@ class DashboardStatsResponse(BaseModel): # API Endpoints + @router.post("/scan", response_model=ScanResponse) async def trigger_scan( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api) ): """ Trigger a new architecture scan @@ -127,7 +139,9 @@ async def trigger_scan( Requires authentication. Runs the validator script and stores results. """ try: - scan = code_quality_service.run_scan(db, triggered_by=f"manual:{current_user.username}") + scan = code_quality_service.run_scan( + db, triggered_by=f"manual:{current_user.username}" + ) return ScanResponse( id=scan.id, @@ -138,7 +152,7 @@ async def trigger_scan( warnings=scan.warnings, duration_seconds=scan.duration_seconds, triggered_by=scan.triggered_by, - git_commit_hash=scan.git_commit_hash + git_commit_hash=scan.git_commit_hash, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}") @@ -148,7 +162,7 @@ async def trigger_scan( async def list_scans( limit: int = Query(30, ge=1, le=100, description="Number of scans to return"), db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Get scan history @@ -167,7 +181,7 @@ async def list_scans( warnings=scan.warnings, duration_seconds=scan.duration_seconds, triggered_by=scan.triggered_by, - git_commit_hash=scan.git_commit_hash + git_commit_hash=scan.git_commit_hash, ) for scan in scans ] @@ -175,15 +189,23 @@ async def list_scans( @router.get("/violations", response_model=ViolationListResponse) async def list_violations( - scan_id: Optional[int] = Query(None, description="Filter by scan ID (defaults to latest)"), - severity: Optional[str] = Query(None, description="Filter by severity (error, warning)"), - status: Optional[str] = Query(None, description="Filter by status (open, assigned, resolved, ignored)"), + scan_id: Optional[int] = Query( + None, description="Filter by scan ID (defaults to latest)" + ), + severity: Optional[str] = Query( + None, description="Filter by severity (error, warning)" + ), + status: Optional[str] = Query( + None, description="Filter by status (open, assigned, resolved, ignored)" + ), rule_id: Optional[str] = Query(None, description="Filter by rule ID"), - file_path: Optional[str] = Query(None, description="Filter by file path (partial match)"), + file_path: Optional[str] = Query( + None, description="Filter by file path (partial match)" + ), page: int = Query(1, ge=1, description="Page number"), page_size: int = Query(50, ge=1, le=200, description="Items per page"), db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Get violations with filtering and pagination @@ -200,7 +222,7 @@ async def list_violations( rule_id=rule_id, file_path=file_path, limit=page_size, - offset=offset + offset=offset, ) total_pages = (total + page_size - 1) // page_size @@ -223,14 +245,14 @@ async def list_violations( resolved_at=v.resolved_at.isoformat() if v.resolved_at else None, resolved_by=v.resolved_by, resolution_note=v.resolution_note, - created_at=v.created_at.isoformat() + created_at=v.created_at.isoformat(), ) for v in violations ], total=total, page=page, page_size=page_size, - total_pages=total_pages + total_pages=total_pages, ) @@ -238,7 +260,7 @@ async def list_violations( async def get_violation( violation_id: int, db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Get single violation with details @@ -253,12 +275,12 @@ async def get_violation( # Format assignments assignments = [ { - 'id': a.id, - 'user_id': a.user_id, - 'assigned_at': a.assigned_at.isoformat(), - 'assigned_by': a.assigned_by, - 'due_date': a.due_date.isoformat() if a.due_date else None, - 'priority': a.priority + "id": a.id, + "user_id": a.user_id, + "assigned_at": a.assigned_at.isoformat(), + "assigned_by": a.assigned_by, + "due_date": a.due_date.isoformat() if a.due_date else None, + "priority": a.priority, } for a in violation.assignments ] @@ -266,10 +288,10 @@ async def get_violation( # Format comments comments = [ { - 'id': c.id, - 'user_id': c.user_id, - 'comment': c.comment, - 'created_at': c.created_at.isoformat() + "id": c.id, + "user_id": c.user_id, + "comment": c.comment, + "created_at": c.created_at.isoformat(), } for c in violation.comments ] @@ -287,12 +309,14 @@ async def get_violation( suggestion=violation.suggestion, status=violation.status, assigned_to=violation.assigned_to, - resolved_at=violation.resolved_at.isoformat() if violation.resolved_at else None, + resolved_at=( + violation.resolved_at.isoformat() if violation.resolved_at else None + ), resolved_by=violation.resolved_by, resolution_note=violation.resolution_note, created_at=violation.created_at.isoformat(), assignments=assignments, - comments=comments + comments=comments, ) @@ -301,7 +325,7 @@ async def assign_violation( violation_id: int, request: AssignViolationRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Assign violation to a developer @@ -315,17 +339,19 @@ async def assign_violation( user_id=request.user_id, assigned_by=current_user.id, due_date=request.due_date, - priority=request.priority + priority=request.priority, ) return { - 'id': assignment.id, - 'violation_id': assignment.violation_id, - 'user_id': assignment.user_id, - 'assigned_at': assignment.assigned_at.isoformat(), - 'assigned_by': assignment.assigned_by, - 'due_date': assignment.due_date.isoformat() if assignment.due_date else None, - 'priority': assignment.priority + "id": assignment.id, + "violation_id": assignment.violation_id, + "user_id": assignment.user_id, + "assigned_at": assignment.assigned_at.isoformat(), + "assigned_by": assignment.assigned_by, + "due_date": ( + assignment.due_date.isoformat() if assignment.due_date else None + ), + "priority": assignment.priority, } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -336,7 +362,7 @@ async def resolve_violation( violation_id: int, request: ResolveViolationRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Mark violation as resolved @@ -348,15 +374,17 @@ async def resolve_violation( db, violation_id=violation_id, resolved_by=current_user.id, - resolution_note=request.resolution_note + resolution_note=request.resolution_note, ) return { - 'id': violation.id, - 'status': violation.status, - 'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None, - 'resolved_by': violation.resolved_by, - 'resolution_note': violation.resolution_note + "id": violation.id, + "status": violation.status, + "resolved_at": ( + violation.resolved_at.isoformat() if violation.resolved_at else None + ), + "resolved_by": violation.resolved_by, + "resolution_note": violation.resolution_note, } except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -369,7 +397,7 @@ async def ignore_violation( violation_id: int, request: IgnoreViolationRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Mark violation as ignored (won't fix) @@ -381,15 +409,17 @@ async def ignore_violation( db, violation_id=violation_id, ignored_by=current_user.id, - reason=request.reason + reason=request.reason, ) return { - 'id': violation.id, - 'status': violation.status, - 'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None, - 'resolved_by': violation.resolved_by, - 'resolution_note': violation.resolution_note + "id": violation.id, + "status": violation.status, + "resolved_at": ( + violation.resolved_at.isoformat() if violation.resolved_at else None + ), + "resolved_by": violation.resolved_by, + "resolution_note": violation.resolution_note, } except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -402,7 +432,7 @@ async def add_comment( violation_id: int, request: AddCommentRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + current_user: User = Depends(get_current_admin_api), ): """ Add comment to violation @@ -414,15 +444,15 @@ async def add_comment( db, violation_id=violation_id, user_id=current_user.id, - comment=request.comment + comment=request.comment, ) return { - 'id': comment.id, - 'violation_id': comment.violation_id, - 'user_id': comment.user_id, - 'comment': comment.comment, - 'created_at': comment.created_at.isoformat() + "id": comment.id, + "violation_id": comment.violation_id, + "user_id": comment.user_id, + "comment": comment.comment, + "created_at": comment.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -430,8 +460,7 @@ async def add_comment( @router.get("/stats", response_model=DashboardStatsResponse) async def get_dashboard_stats( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_admin_api) + db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api) ): """ Get dashboard statistics diff --git a/app/api/v1/admin/content_pages.py b/app/api/v1/admin/content_pages.py index 28d6c1b8..aceb521c 100644 --- a/app/api/v1/admin/content_pages.py +++ b/app/api/v1/admin/content_pages.py @@ -10,6 +10,7 @@ Platform administrators can: import logging from typing import List, Optional + from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -26,24 +27,43 @@ logger = logging.getLogger(__name__) # REQUEST/RESPONSE SCHEMAS # ============================================================================ + class ContentPageCreate(BaseModel): """Schema for creating a content page.""" - slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)") + + slug: str = Field( + ..., + max_length=100, + description="URL-safe identifier (about, faq, contact, etc.)", + ) title: str = Field(..., max_length=200, description="Page title") content: str = Field(..., description="HTML or Markdown content") - content_format: str = Field(default="html", description="Content format: html or markdown") - template: str = Field(default="default", max_length=50, description="Template name (default, minimal, modern)") - meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description") - meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords") + content_format: str = Field( + default="html", description="Content format: html or markdown" + ) + template: str = Field( + default="default", + max_length=50, + description="Template name (default, minimal, modern)", + ) + meta_description: Optional[str] = Field( + None, max_length=300, description="SEO meta description" + ) + meta_keywords: Optional[str] = Field( + None, max_length=300, description="SEO keywords" + ) is_published: bool = Field(default=False, description="Publish immediately") show_in_footer: bool = Field(default=True, description="Show in footer navigation") show_in_header: bool = Field(default=False, description="Show in header navigation") display_order: int = Field(default=0, description="Display order (lower = first)") - vendor_id: Optional[int] = Field(None, description="Vendor ID (None for platform default)") + vendor_id: Optional[int] = Field( + None, description="Vendor ID (None for platform default)" + ) class ContentPageUpdate(BaseModel): """Schema for updating a content page.""" + title: Optional[str] = Field(None, max_length=200) content: Optional[str] = None content_format: Optional[str] = None @@ -58,6 +78,7 @@ class ContentPageUpdate(BaseModel): class ContentPageResponse(BaseModel): """Schema for content page response.""" + id: int vendor_id: Optional[int] vendor_name: Optional[str] @@ -84,11 +105,12 @@ class ContentPageResponse(BaseModel): # PLATFORM DEFAULT PAGES (vendor_id=NULL) # ============================================================================ + @router.get("/platform", response_model=List[ContentPageResponse]) def list_platform_pages( include_unpublished: bool = Query(False, description="Include draft pages"), current_user: User = Depends(get_current_admin_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ List all platform default content pages. @@ -96,8 +118,7 @@ def list_platform_pages( These are used as fallbacks when vendors haven't created custom pages. """ pages = content_page_service.list_all_platform_pages( - db, - include_unpublished=include_unpublished + db, include_unpublished=include_unpublished ) return [page.to_dict() for page in pages] @@ -107,7 +128,7 @@ def list_platform_pages( def create_platform_page( page_data: ContentPageCreate, current_user: User = Depends(get_current_admin_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Create a new platform default content page. @@ -129,7 +150,7 @@ def create_platform_page( show_in_footer=page_data.show_in_footer, show_in_header=page_data.show_in_header, display_order=page_data.display_order, - created_by=current_user.id + created_by=current_user.id, ) return page.to_dict() @@ -139,12 +160,13 @@ def create_platform_page( # ALL CONTENT PAGES (Platform + Vendors) # ============================================================================ + @router.get("/", response_model=List[ContentPageResponse]) def list_all_pages( vendor_id: Optional[int] = Query(None, description="Filter by vendor ID"), include_unpublished: bool = Query(False, description="Include draft pages"), current_user: User = Depends(get_current_admin_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ List all content pages (platform defaults and vendor overrides). @@ -153,15 +175,14 @@ def list_all_pages( """ if vendor_id: pages = content_page_service.list_all_vendor_pages( - db, - vendor_id=vendor_id, - include_unpublished=include_unpublished + db, vendor_id=vendor_id, include_unpublished=include_unpublished ) else: # Get all pages (both platform and vendor) - from models.database.content_page import ContentPage from sqlalchemy import and_ + from models.database.content_page import ContentPage + filters = [] if not include_unpublished: filters.append(ContentPage.is_published == True) @@ -169,7 +190,9 @@ def list_all_pages( pages = ( db.query(ContentPage) .filter(and_(*filters) if filters else True) - .order_by(ContentPage.vendor_id, ContentPage.display_order, ContentPage.title) + .order_by( + ContentPage.vendor_id, ContentPage.display_order, ContentPage.title + ) .all() ) @@ -180,7 +203,7 @@ def list_all_pages( def get_page( page_id: int, current_user: User = Depends(get_current_admin_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get a specific content page by ID.""" page = content_page_service.get_page_by_id(db, page_id) @@ -196,7 +219,7 @@ def update_page( page_id: int, page_data: ContentPageUpdate, current_user: User = Depends(get_current_admin_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update a content page (platform or vendor).""" page = content_page_service.update_page( @@ -212,7 +235,7 @@ def update_page( show_in_footer=page_data.show_in_footer, show_in_header=page_data.show_in_header, display_order=page_data.display_order, - updated_by=current_user.id + updated_by=current_user.id, ) if not page: @@ -225,7 +248,7 @@ def update_page( def delete_page( page_id: int, current_user: User = Depends(get_current_admin_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Delete a content page.""" success = content_page_service.delete_page(db, page_id) diff --git a/app/api/v1/admin/dashboard.py b/app/api/v1/admin/dashboard.py index 38db202b..1ca77a50 100644 --- a/app/api/v1/admin/dashboard.py +++ b/app/api/v1/admin/dashboard.py @@ -5,6 +5,7 @@ Admin dashboard and statistics endpoints. import logging from typing import List + from fastapi import APIRouter, Depends from sqlalchemy.orm import Session @@ -87,4 +88,4 @@ def get_platform_statistics( "products": stats_service.get_product_statistics(db), "orders": stats_service.get_order_statistics(db), "imports": stats_service.get_import_statistics(db), - } \ No newline at end of file + } diff --git a/app/api/v1/admin/marketplace.py b/app/api/v1/admin/marketplace.py index 0fa5843b..b9456a53 100644 --- a/app/api/v1/admin/marketplace.py +++ b/app/api/v1/admin/marketplace.py @@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api from app.core.database import get_db from app.services.admin_service import admin_service from app.services.stats_service import stats_service -from models.schema.marketplace_import_job import MarketplaceImportJobResponse from models.database.user import User +from models.schema.marketplace_import_job import MarketplaceImportJobResponse router = APIRouter(prefix="/marketplace-import-jobs") logger = logging.getLogger(__name__) diff --git a/app/api/v1/admin/monitoring.py b/app/api/v1/admin/monitoring.py index 35f4f24b..0b5699d2 100644 --- a/app/api/v1/admin/monitoring.py +++ b/app/api/v1/admin/monitoring.py @@ -1 +1 @@ -# Platform monitoring and alerts +# Platform monitoring and alerts diff --git a/app/api/v1/admin/notifications.py b/app/api/v1/admin/notifications.py index 974c9b34..780885b7 100644 --- a/app/api/v1/admin/notifications.py +++ b/app/api/v1/admin/notifications.py @@ -16,16 +16,13 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db -from models.schema.admin import ( - AdminNotificationCreate, - AdminNotificationResponse, - AdminNotificationListResponse, - PlatformAlertCreate, - PlatformAlertResponse, - PlatformAlertListResponse, - PlatformAlertResolve -) from models.database.user import User +from models.schema.admin import (AdminNotificationCreate, + AdminNotificationListResponse, + AdminNotificationResponse, + PlatformAlertCreate, + PlatformAlertListResponse, + PlatformAlertResolve, PlatformAlertResponse) router = APIRouter(prefix="/notifications") logger = logging.getLogger(__name__) @@ -35,6 +32,7 @@ logger = logging.getLogger(__name__) # ADMIN NOTIFICATIONS # ============================================================================ + @router.get("", response_model=AdminNotificationListResponse) def get_notifications( priority: Optional[str] = Query(None, description="Filter by priority"), @@ -47,11 +45,7 @@ def get_notifications( """Get admin notifications with filtering.""" # TODO: Implement notification service return AdminNotificationListResponse( - notifications=[], - total=0, - unread_count=0, - skip=skip, - limit=limit + notifications=[], total=0, unread_count=0, skip=skip, limit=limit ) @@ -90,10 +84,13 @@ def mark_all_as_read( # PLATFORM ALERTS # ============================================================================ + @router.get("/alerts", response_model=PlatformAlertListResponse) def get_platform_alerts( severity: Optional[str] = Query(None, description="Filter by severity"), - is_resolved: Optional[bool] = Query(None, description="Filter by resolution status"), + is_resolved: Optional[bool] = Query( + None, description="Filter by resolution status" + ), skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=100), db: Session = Depends(get_db), @@ -102,12 +99,7 @@ def get_platform_alerts( """Get platform alerts with filtering.""" # TODO: Implement alert service return PlatformAlertListResponse( - alerts=[], - total=0, - active_count=0, - critical_count=0, - skip=skip, - limit=limit + alerts=[], total=0, active_count=0, critical_count=0, skip=skip, limit=limit ) @@ -147,5 +139,5 @@ def get_alert_statistics( "total_alerts": 0, "active_alerts": 0, "critical_alerts": 0, - "resolved_today": 0 + "resolved_today": 0, } diff --git a/app/api/v1/admin/settings.py b/app/api/v1/admin/settings.py index 0fd2297f..025b2a3c 100644 --- a/app/api/v1/admin/settings.py +++ b/app/api/v1/admin/settings.py @@ -16,15 +16,11 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db -from app.services.admin_settings_service import admin_settings_service from app.services.admin_audit_service import admin_audit_service -from models.schema.admin import ( - AdminSettingCreate, - AdminSettingResponse, - AdminSettingUpdate, - AdminSettingListResponse -) +from app.services.admin_settings_service import admin_settings_service from models.database.user import User +from models.schema.admin import (AdminSettingCreate, AdminSettingListResponse, + AdminSettingResponse, AdminSettingUpdate) router = APIRouter(prefix="/settings") logger = logging.getLogger(__name__) @@ -32,10 +28,10 @@ logger = logging.getLogger(__name__) @router.get("", response_model=AdminSettingListResponse) def get_all_settings( - category: Optional[str] = Query(None, description="Filter by category"), - is_public: Optional[bool] = Query(None, description="Filter by public flag"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + category: Optional[str] = Query(None, description="Filter by category"), + is_public: Optional[bool] = Query(None, description="Filter by public flag"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get all platform settings. @@ -46,16 +42,14 @@ def get_all_settings( settings = admin_settings_service.get_all_settings(db, category, is_public) return AdminSettingListResponse( - settings=settings, - total=len(settings), - category=category + settings=settings, total=len(settings), category=category ) @router.get("/categories") def get_setting_categories( - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Get list of all setting categories.""" # This could be enhanced to return counts per category @@ -66,22 +60,23 @@ def get_setting_categories( "marketplace", "notifications", "integrations", - "payments" + "payments", ] } @router.get("/{key}", response_model=AdminSettingResponse) def get_setting( - key: str, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + key: str, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Get specific setting by key.""" setting = admin_settings_service.get_setting_by_key(db, key) if not setting: from fastapi import HTTPException + raise HTTPException(status_code=404, detail=f"Setting '{key}' not found") return AdminSettingResponse.model_validate(setting) @@ -89,9 +84,9 @@ def get_setting( @router.post("", response_model=AdminSettingResponse) def create_setting( - setting_data: AdminSettingCreate, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + setting_data: AdminSettingCreate, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Create new platform setting. @@ -99,9 +94,7 @@ def create_setting( Setting keys should be lowercase with underscores (e.g., max_vendors_allowed). """ result = admin_settings_service.create_setting( - db=db, - setting_data=setting_data, - admin_user_id=current_admin.id + db=db, setting_data=setting_data, admin_user_id=current_admin.id ) # Log action @@ -111,7 +104,10 @@ def create_setting( action="create_setting", target_type="setting", target_id=setting_data.key, - details={"category": setting_data.category, "value_type": setting_data.value_type} + details={ + "category": setting_data.category, + "value_type": setting_data.value_type, + }, ) return result @@ -119,19 +115,16 @@ def create_setting( @router.put("/{key}", response_model=AdminSettingResponse) def update_setting( - key: str, - update_data: AdminSettingUpdate, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + key: str, + update_data: AdminSettingUpdate, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Update existing setting value.""" old_value = admin_settings_service.get_setting_value(db, key) result = admin_settings_service.update_setting( - db=db, - key=key, - update_data=update_data, - admin_user_id=current_admin.id + db=db, key=key, update_data=update_data, admin_user_id=current_admin.id ) # Log action @@ -141,7 +134,7 @@ def update_setting( action="update_setting", target_type="setting", target_id=key, - details={"old_value": str(old_value), "new_value": update_data.value} + details={"old_value": str(old_value), "new_value": update_data.value}, ) return result @@ -149,9 +142,9 @@ def update_setting( @router.post("/upsert", response_model=AdminSettingResponse) def upsert_setting( - setting_data: AdminSettingCreate, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + setting_data: AdminSettingCreate, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Create or update setting (upsert). @@ -159,9 +152,7 @@ def upsert_setting( If setting exists, updates its value. If not, creates new setting. """ result = admin_settings_service.upsert_setting( - db=db, - setting_data=setting_data, - admin_user_id=current_admin.id + db=db, setting_data=setting_data, admin_user_id=current_admin.id ) # Log action @@ -171,7 +162,7 @@ def upsert_setting( action="upsert_setting", target_type="setting", target_id=setting_data.key, - details={"category": setting_data.category} + details={"category": setting_data.category}, ) return result @@ -179,10 +170,10 @@ def upsert_setting( @router.delete("/{key}") def delete_setting( - key: str, - confirm: bool = Query(False, description="Must be true to confirm deletion"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + key: str, + confirm: bool = Query(False, description="Must be true to confirm deletion"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Delete platform setting. @@ -195,13 +186,11 @@ def delete_setting( if not confirm: raise HTTPException( status_code=400, - detail="Deletion requires confirmation parameter: confirm=true" + detail="Deletion requires confirmation parameter: confirm=true", ) message = admin_settings_service.delete_setting( - db=db, - key=key, - admin_user_id=current_admin.id + db=db, key=key, admin_user_id=current_admin.id ) # Log action @@ -211,7 +200,7 @@ def delete_setting( action="delete_setting", target_type="setting", target_id=key, - details={} + details={}, ) return {"message": message} diff --git a/app/api/v1/admin/users.py b/app/api/v1/admin/users.py index 5c4d968c..2cadd850 100644 --- a/app/api/v1/admin/users.py +++ b/app/api/v1/admin/users.py @@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api from app.core.database import get_db from app.services.admin_service import admin_service from app.services.stats_service import stats_service -from models.schema.auth import UserResponse from models.database.user import User +from models.schema.auth import UserResponse router = APIRouter(prefix="/users") logger = logging.getLogger(__name__) diff --git a/app/api/v1/admin/vendor_domains.py b/app/api/v1/admin/vendor_domains.py index a01604f3..c33e1786 100644 --- a/app/api/v1/admin/vendor_domains.py +++ b/app/api/v1/admin/vendor_domains.py @@ -12,24 +12,22 @@ Follows the architecture pattern: import logging from typing import List -from fastapi import APIRouter, Depends, Path, Body, Query +from fastapi import APIRouter, Body, Depends, Path, Query from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db -from app.services.vendor_domain_service import vendor_domain_service from app.exceptions import VendorNotFoundException -from models.schema.vendor_domain import ( - VendorDomainCreate, - VendorDomainUpdate, - VendorDomainResponse, - VendorDomainListResponse, - DomainVerificationInstructions, - DomainVerificationResponse, - DomainDeletionResponse, -) +from app.services.vendor_domain_service import vendor_domain_service from models.database.user import User from models.database.vendor import Vendor +from models.schema.vendor_domain import (DomainDeletionResponse, + DomainVerificationInstructions, + DomainVerificationResponse, + VendorDomainCreate, + VendorDomainListResponse, + VendorDomainResponse, + VendorDomainUpdate) router = APIRouter(prefix="/vendors") logger = logging.getLogger(__name__) @@ -57,10 +55,10 @@ def _get_vendor_by_id(db: Session, vendor_id: int) -> Vendor: @router.post("/{vendor_id}/domains", response_model=VendorDomainResponse) def add_vendor_domain( - vendor_id: int = Path(..., description="Vendor ID", gt=0), - domain_data: VendorDomainCreate = Body(...), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_id: int = Path(..., description="Vendor ID", gt=0), + domain_data: VendorDomainCreate = Body(...), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Add a custom domain to vendor (Admin only). @@ -88,9 +86,7 @@ def add_vendor_domain( - 422: Invalid domain format or reserved subdomain """ domain = vendor_domain_service.add_domain( - db=db, - vendor_id=vendor_id, - domain_data=domain_data + db=db, vendor_id=vendor_id, domain_data=domain_data ) return VendorDomainResponse( @@ -111,9 +107,9 @@ def add_vendor_domain( @router.get("/{vendor_id}/domains", response_model=VendorDomainListResponse) def list_vendor_domains( - vendor_id: int = Path(..., description="Vendor ID", gt=0), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_id: int = Path(..., description="Vendor ID", gt=0), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ List all domains for a vendor (Admin only). @@ -148,15 +144,15 @@ def list_vendor_domains( ) for d in domains ], - total=len(domains) + total=len(domains), ) @router.get("/domains/{domain_id}", response_model=VendorDomainResponse) def get_domain_details( - domain_id: int = Path(..., description="Domain ID", gt=0), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + domain_id: int = Path(..., description="Domain ID", gt=0), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get detailed information about a specific domain (Admin only). @@ -174,7 +170,9 @@ def get_domain_details( is_active=domain.is_active, is_verified=domain.is_verified, ssl_status=domain.ssl_status, - verification_token=domain.verification_token if not domain.is_verified else None, + verification_token=( + domain.verification_token if not domain.is_verified else None + ), verified_at=domain.verified_at, ssl_verified_at=domain.ssl_verified_at, created_at=domain.created_at, @@ -184,10 +182,10 @@ def get_domain_details( @router.put("/domains/{domain_id}", response_model=VendorDomainResponse) def update_vendor_domain( - domain_id: int = Path(..., description="Domain ID", gt=0), - domain_update: VendorDomainUpdate = Body(...), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + domain_id: int = Path(..., description="Domain ID", gt=0), + domain_update: VendorDomainUpdate = Body(...), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Update domain settings (Admin only). @@ -206,9 +204,7 @@ def update_vendor_domain( - 400: Cannot activate unverified domain """ domain = vendor_domain_service.update_domain( - db=db, - domain_id=domain_id, - domain_update=domain_update + db=db, domain_id=domain_id, domain_update=domain_update ) return VendorDomainResponse( @@ -229,9 +225,9 @@ def update_vendor_domain( @router.delete("/domains/{domain_id}", response_model=DomainDeletionResponse) def delete_vendor_domain( - domain_id: int = Path(..., description="Domain ID", gt=0), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + domain_id: int = Path(..., description="Domain ID", gt=0), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Delete a custom domain (Admin only). @@ -250,17 +246,15 @@ def delete_vendor_domain( message = vendor_domain_service.delete_domain(db, domain_id) return DomainDeletionResponse( - message=message, - domain=domain_name, - vendor_id=vendor_id + message=message, domain=domain_name, vendor_id=vendor_id ) @router.post("/domains/{domain_id}/verify", response_model=DomainVerificationResponse) def verify_domain_ownership( - domain_id: int = Path(..., description="Domain ID", gt=0), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + domain_id: int = Path(..., description="Domain ID", gt=0), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Verify domain ownership via DNS TXT record (Admin only). @@ -290,15 +284,18 @@ def verify_domain_ownership( message=message, domain=domain.domain, verified_at=domain.verified_at, - is_verified=domain.is_verified + is_verified=domain.is_verified, ) -@router.get("/domains/{domain_id}/verification-instructions", response_model=DomainVerificationInstructions) +@router.get( + "/domains/{domain_id}/verification-instructions", + response_model=DomainVerificationInstructions, +) def get_domain_verification_instructions( - domain_id: int = Path(..., description="Domain ID", gt=0), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + domain_id: int = Path(..., description="Domain ID", gt=0), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get DNS verification instructions for domain (Admin only). @@ -324,5 +321,5 @@ def get_domain_verification_instructions( verification_token=instructions["verification_token"], instructions=instructions["instructions"], txt_record=instructions["txt_record"], - common_registrars=instructions["common_registrars"] + common_registrars=instructions["common_registrars"], ) diff --git a/app/api/v1/admin/vendor_themes.py b/app/api/v1/admin/vendor_themes.py index 7a4b51fc..3a0fa06e 100644 --- a/app/api/v1/admin/vendor_themes.py +++ b/app/api/v1/admin/vendor_themes.py @@ -20,11 +20,8 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api, get_db from app.services.vendor_theme_service import vendor_theme_service from models.database.user import User -from models.schema.vendor_theme import ( - VendorThemeResponse, - VendorThemeUpdate, - ThemePresetListResponse -) +from models.schema.vendor_theme import (ThemePresetListResponse, + VendorThemeResponse, VendorThemeUpdate) router = APIRouter(prefix="/vendor-themes") logger = logging.getLogger(__name__) @@ -34,10 +31,9 @@ logger = logging.getLogger(__name__) # PRESET ENDPOINTS # ============================================================================ + @router.get("/presets", response_model=ThemePresetListResponse) -async def get_theme_presets( - current_admin: User = Depends(get_current_admin_api) -): +async def get_theme_presets(current_admin: User = Depends(get_current_admin_api)): """ Get all available theme presets with preview information. @@ -59,11 +55,12 @@ async def get_theme_presets( # THEME RETRIEVAL # ============================================================================ + @router.get("/{vendor_code}", response_model=VendorThemeResponse) async def get_vendor_theme( - vendor_code: str = Path(..., description="Vendor code"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api) + vendor_code: str = Path(..., description="Vendor code"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get theme configuration for a vendor. @@ -93,12 +90,13 @@ async def get_vendor_theme( # THEME UPDATE # ============================================================================ + @router.put("/{vendor_code}", response_model=VendorThemeResponse) async def update_vendor_theme( - vendor_code: str = Path(..., description="Vendor code"), - theme_data: VendorThemeUpdate = None, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api) + vendor_code: str = Path(..., description="Vendor code"), + theme_data: VendorThemeUpdate = None, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Update or create theme for a vendor. @@ -140,12 +138,13 @@ async def update_vendor_theme( # PRESET APPLICATION # ============================================================================ + @router.post("/{vendor_code}/preset/{preset_name}") async def apply_theme_preset( - vendor_code: str = Path(..., description="Vendor code"), - preset_name: str = Path(..., description="Preset name"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api) + vendor_code: str = Path(..., description="Vendor code"), + preset_name: str = Path(..., description="Preset name"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Apply a theme preset to a vendor. @@ -184,7 +183,7 @@ async def apply_theme_preset( return { "message": f"Applied {preset_name} preset successfully", - "theme": theme.to_dict() + "theme": theme.to_dict(), } @@ -192,11 +191,12 @@ async def apply_theme_preset( # THEME DELETION # ============================================================================ + @router.delete("/{vendor_code}") async def delete_vendor_theme( - vendor_code: str = Path(..., description="Vendor code"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api) + vendor_code: str = Path(..., description="Vendor code"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Delete custom theme for a vendor. diff --git a/app/api/v1/admin/vendors.py b/app/api/v1/admin/vendors.py index e2db6b88..58b1f615 100644 --- a/app/api/v1/admin/vendors.py +++ b/app/api/v1/admin/vendors.py @@ -6,28 +6,24 @@ Vendor management endpoints for admin. import logging from typing import Optional -from fastapi import APIRouter, Depends, Query, Path, Body +from fastapi import APIRouter, Body, Depends, Path, Query +from sqlalchemy import func from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db +from app.exceptions import (ConfirmationRequiredException, + VendorNotFoundException) from app.services.admin_service import admin_service from app.services.stats_service import stats_service -from app.exceptions import VendorNotFoundException, ConfirmationRequiredException -from models.schema.stats import VendorStatsResponse -from models.schema.vendor import ( - VendorListResponse, - VendorResponse, - VendorDetailResponse, - VendorCreate, - VendorCreateResponse, - VendorUpdate, - VendorTransferOwnership, - VendorTransferOwnershipResponse, -) from models.database.user import User from models.database.vendor import Vendor -from sqlalchemy import func +from models.schema.stats import VendorStatsResponse +from models.schema.vendor import (VendorCreate, VendorCreateResponse, + VendorDetailResponse, VendorListResponse, + VendorResponse, VendorTransferOwnership, + VendorTransferOwnershipResponse, + VendorUpdate) router = APIRouter(prefix="/vendors") logger = logging.getLogger(__name__) @@ -60,9 +56,11 @@ def _get_vendor_by_identifier(db: Session, identifier: str) -> Vendor: pass # Try as vendor_code (case-insensitive) - vendor = db.query(Vendor).filter( - func.upper(Vendor.vendor_code) == identifier.upper() - ).first() + vendor = ( + db.query(Vendor) + .filter(func.upper(Vendor.vendor_code) == identifier.upper()) + .first() + ) if not vendor: raise VendorNotFoundException(identifier, identifier_type="code") @@ -72,9 +70,9 @@ def _get_vendor_by_identifier(db: Session, identifier: str) -> Vendor: @router.post("", response_model=VendorCreateResponse) def create_vendor_with_owner( - vendor_data: VendorCreate, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_data: VendorCreate, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Create a new vendor with owner user account (Admin only). @@ -93,8 +91,7 @@ def create_vendor_with_owner( Returns vendor details with owner credentials. """ vendor, owner_user, temp_password = admin_service.create_vendor_with_owner( - db=db, - vendor_data=vendor_data + db=db, vendor_data=vendor_data ) return VendorCreateResponse( @@ -121,19 +118,19 @@ def create_vendor_with_owner( owner_email=owner_user.email, owner_username=owner_user.username, temporary_password=temp_password, - login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login" + login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login", ) @router.get("", response_model=VendorListResponse) def get_all_vendors_admin( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - search: Optional[str] = Query(None, description="Search by name or vendor code"), - is_active: Optional[bool] = Query(None), - is_verified: Optional[bool] = Query(None), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + search: Optional[str] = Query(None, description="Search by name or vendor code"), + is_active: Optional[bool] = Query(None), + is_verified: Optional[bool] = Query(None), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Get all vendors with filtering (Admin only).""" vendors, total = admin_service.get_all_vendors( @@ -142,15 +139,15 @@ def get_all_vendors_admin( limit=limit, search=search, is_active=is_active, - is_verified=is_verified + is_verified=is_verified, ) return VendorListResponse(vendors=vendors, total=total, skip=skip, limit=limit) @router.get("/stats", response_model=VendorStatsResponse) def get_vendor_statistics_endpoint( - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """Get vendor statistics for admin dashboard (Admin only).""" stats = stats_service.get_vendor_statistics(db) @@ -165,9 +162,9 @@ def get_vendor_statistics_endpoint( @router.get("/{vendor_identifier}", response_model=VendorDetailResponse) def get_vendor_details( - vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Get detailed vendor information including owner details (Admin only). @@ -208,10 +205,10 @@ def get_vendor_details( @router.put("/{vendor_identifier}", response_model=VendorDetailResponse) def update_vendor( - vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), - vendor_update: VendorUpdate = Body(...), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), + vendor_update: VendorUpdate = Body(...), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Update vendor information (Admin only). @@ -257,12 +254,15 @@ def update_vendor( ) -@router.post("/{vendor_identifier}/transfer-ownership", response_model=VendorTransferOwnershipResponse) +@router.post( + "/{vendor_identifier}/transfer-ownership", + response_model=VendorTransferOwnershipResponse, +) def transfer_vendor_ownership( - vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), - transfer_data: VendorTransferOwnership = Body(...), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), + transfer_data: VendorTransferOwnership = Body(...), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Transfer vendor ownership to another user (Admin only). @@ -311,10 +311,10 @@ def transfer_vendor_ownership( @router.put("/{vendor_identifier}/verification", response_model=VendorDetailResponse) def toggle_vendor_verification( - vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), - verification_data: dict = Body(..., example={"is_verified": True}), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), + verification_data: dict = Body(..., example={"is_verified": True}), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Toggle vendor verification status (Admin only). @@ -362,10 +362,10 @@ def toggle_vendor_verification( @router.put("/{vendor_identifier}/status", response_model=VendorDetailResponse) def toggle_vendor_status( - vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), - status_data: dict = Body(..., example={"is_active": True}), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), + status_data: dict = Body(..., example={"is_active": True}), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Toggle vendor active status (Admin only). @@ -413,10 +413,10 @@ def toggle_vendor_status( @router.delete("/{vendor_identifier}") def delete_vendor( - vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), - confirm: bool = Query(False, description="Must be true to confirm deletion"), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_api), + vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), + confirm: bool = Query(False, description="Must be true to confirm deletion"), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_api), ): """ Delete vendor and all associated data (Admin only). @@ -436,7 +436,7 @@ def delete_vendor( if not confirm: raise ConfirmationRequiredException( operation="delete_vendor", - message="Deletion requires confirmation parameter: confirm=true" + message="Deletion requires confirmation parameter: confirm=true", ) vendor = _get_vendor_by_identifier(db, vendor_identifier) diff --git a/app/api/v1/shared/health.py b/app/api/v1/shared/health.py index 069c2bac..0b438ce2 100644 --- a/app/api/v1/shared/health.py +++ b/app/api/v1/shared/health.py @@ -1 +1 @@ -# Health checks +# Health checks diff --git a/app/api/v1/shared/uploads.py b/app/api/v1/shared/uploads.py index 9a4e1aa0..e573f17b 100644 --- a/app/api/v1/shared/uploads.py +++ b/app/api/v1/shared/uploads.py @@ -1 +1 @@ -# File upload handling +# File upload handling diff --git a/app/api/v1/shop/__init__.py b/app/api/v1/shop/__init__.py index f0e93ed4..23d483bd 100644 --- a/app/api/v1/shop/__init__.py +++ b/app/api/v1/shop/__init__.py @@ -21,7 +21,7 @@ Authentication: from fastapi import APIRouter # Import shop routers -from . import products, cart, orders, auth, content_pages +from . import auth, cart, content_pages, orders, products # Create shop router router = APIRouter() @@ -43,6 +43,8 @@ router.include_router(cart.router, tags=["shop-cart"]) router.include_router(orders.router, tags=["shop-orders"]) # Content pages (public) -router.include_router(content_pages.router, prefix="/content-pages", tags=["shop-content-pages"]) +router.include_router( + content_pages.router, prefix="/content-pages", tags=["shop-content-pages"] +) __all__ = ["router"] diff --git a/app/api/v1/shop/auth.py b/app/api/v1/shop/auth.py index f9d50b43..0c3365d8 100644 --- a/app/api/v1/shop/auth.py +++ b/app/api/v1/shop/auth.py @@ -15,15 +15,16 @@ This prevents: """ import logging -from fastapi import APIRouter, Depends, Response, Request, HTTPException + +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from pydantic import BaseModel from sqlalchemy.orm import Session from app.core.database import get_db +from app.core.environment import should_use_secure_cookies from app.services.customer_service import customer_service from models.schema.auth import UserLogin from models.schema.customer import CustomerRegister, CustomerResponse -from app.core.environment import should_use_secure_cookies -from pydantic import BaseModel router = APIRouter() logger = logging.getLogger(__name__) @@ -32,6 +33,7 @@ logger = logging.getLogger(__name__) # Response model for customer login class CustomerLoginResponse(BaseModel): """Customer login response with token and customer data.""" + access_token: str token_type: str expires_in: int @@ -40,9 +42,7 @@ class CustomerLoginResponse(BaseModel): @router.post("/auth/register", response_model=CustomerResponse) def register_customer( - request: Request, - customer_data: CustomerRegister, - db: Session = Depends(get_db) + request: Request, customer_data: CustomerRegister, db: Session = Depends(get_db) ): """ Register a new customer for current vendor. @@ -59,12 +59,12 @@ def register_customer( - phone: Customer phone number (optional) """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -73,14 +73,12 @@ def register_customer( "vendor_id": vendor.id, "vendor_code": vendor.subdomain, "email": customer_data.email, - } + }, ) # Create customer account customer = customer_service.register_customer( - db=db, - vendor_id=vendor.id, - customer_data=customer_data + db=db, vendor_id=vendor.id, customer_data=customer_data ) logger.info( @@ -89,7 +87,7 @@ def register_customer( "customer_id": customer.id, "vendor_id": vendor.id, "email": customer.email, - } + }, ) return CustomerResponse.model_validate(customer) @@ -100,7 +98,7 @@ def customer_login( request: Request, user_credentials: UserLogin, response: Response, - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Customer login for current vendor. @@ -121,12 +119,12 @@ def customer_login( - password: Customer password """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -135,33 +133,39 @@ def customer_login( "vendor_id": vendor.id, "vendor_code": vendor.subdomain, "email_or_username": user_credentials.email_or_username, - } + }, ) # Authenticate customer login_result = customer_service.login_customer( - db=db, - vendor_id=vendor.id, - credentials=user_credentials + db=db, vendor_id=vendor.id, credentials=user_credentials ) logger.info( f"Customer login successful: {login_result['customer'].email} for vendor {vendor.subdomain}", extra={ - "customer_id": login_result['customer'].id, + "customer_id": login_result["customer"].id, "vendor_id": vendor.id, - "email": login_result['customer'].email, - } + "email": login_result["customer"].email, + }, ) # Calculate cookie path based on vendor access method - vendor_context = getattr(request.state, 'vendor_context', None) - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + vendor_context = getattr(request.state, "vendor_context", None) + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) cookie_path = "/shop" # Default for domain/subdomain access if access_method == "path": # For path-based access like /vendors/wizamart/shop - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) cookie_path = f"{full_prefix}{vendor.subdomain}/shop" # Set HTTP-only cookie for browser navigation @@ -180,10 +184,10 @@ def customer_login( f"Set customer_token cookie with {login_result['token_data']['expires_in']}s expiry " f"(path={cookie_path}, httponly=True, secure={should_use_secure_cookies()})", extra={ - "expires_in": login_result['token_data']['expires_in'], + "expires_in": login_result["token_data"]["expires_in"], "secure": should_use_secure_cookies(), "cookie_path": cookie_path, - } + }, ) # Return full login response @@ -196,10 +200,7 @@ def customer_login( @router.post("/auth/logout") -def customer_logout( - request: Request, - response: Response -): +def customer_logout(request: Request, response: Response): """ Customer logout for current vendor. @@ -208,24 +209,32 @@ def customer_logout( Client should also remove token from localStorage. """ # Get vendor from middleware (for logging) - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) logger.info( f"Customer logout for vendor {vendor.subdomain if vendor else 'unknown'}", extra={ "vendor_id": vendor.id if vendor else None, "vendor_code": vendor.subdomain if vendor else None, - } + }, ) # Calculate cookie path based on vendor access method (must match login) - vendor_context = getattr(request.state, 'vendor_context', None) - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + vendor_context = getattr(request.state, "vendor_context", None) + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) cookie_path = "/shop" # Default for domain/subdomain access if access_method == "path" and vendor: # For path-based access like /vendors/wizamart/shop - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) cookie_path = f"{full_prefix}{vendor.subdomain}/shop" # Clear the cookie (must match path used when setting) @@ -240,11 +249,7 @@ def customer_logout( @router.post("/auth/forgot-password") -def forgot_password( - request: Request, - email: str, - db: Session = Depends(get_db) -): +def forgot_password(request: Request, email: str, db: Session = Depends(get_db)): """ Request password reset for customer. @@ -255,12 +260,12 @@ def forgot_password( - email: Customer email address """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -269,7 +274,7 @@ def forgot_password( "vendor_id": vendor.id, "vendor_code": vendor.subdomain, "email": email, - } + }, ) # TODO: Implement password reset functionality @@ -278,9 +283,7 @@ def forgot_password( # - Send reset email to customer # - Return success message (don't reveal if email exists) - logger.info( - f"Password reset requested for {email} (vendor: {vendor.subdomain})" - ) + logger.info(f"Password reset requested for {email} (vendor: {vendor.subdomain})") return { "message": "If an account exists with this email, a password reset link has been sent." @@ -289,10 +292,7 @@ def forgot_password( @router.post("/auth/reset-password") def reset_password( - request: Request, - reset_token: str, - new_password: str, - db: Session = Depends(get_db) + request: Request, reset_token: str, new_password: str, db: Session = Depends(get_db) ): """ Reset customer password using reset token. @@ -304,12 +304,12 @@ def reset_password( - new_password: New password """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -317,7 +317,7 @@ def reset_password( extra={ "vendor_id": vendor.id, "vendor_code": vendor.subdomain, - } + }, ) # TODO: Implement password reset @@ -327,9 +327,7 @@ def reset_password( # - Invalidate reset token # - Return success - logger.info( - f"Password reset completed (vendor: {vendor.subdomain})" - ) + logger.info(f"Password reset completed (vendor: {vendor.subdomain})") return { "message": "Password reset successfully. You can now log in with your new password." diff --git a/app/api/v1/shop/cart.py b/app/api/v1/shop/cart.py index 35520670..ac51667c 100644 --- a/app/api/v1/shop/cart.py +++ b/app/api/v1/shop/cart.py @@ -8,18 +8,15 @@ No authentication required - uses session ID for cart tracking. """ import logging -from fastapi import APIRouter, Depends, Path, Body, Request, HTTPException + +from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request from sqlalchemy.orm import Session from app.core.database import get_db from app.services.cart_service import cart_service -from models.schema.cart import ( - AddToCartRequest, - UpdateCartItemRequest, - CartResponse, - CartOperationResponse, - ClearCartResponse, -) +from models.schema.cart import (AddToCartRequest, CartOperationResponse, + CartResponse, ClearCartResponse, + UpdateCartItemRequest) router = APIRouter() logger = logging.getLogger(__name__) @@ -29,6 +26,7 @@ logger = logging.getLogger(__name__) # CART ENDPOINTS # ============================================================================ + @router.get("/cart/{session_id}", response_model=CartResponse) def get_cart( request: Request, @@ -45,12 +43,12 @@ def get_cart( - session_id: Unique session identifier for the cart """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.info( @@ -59,23 +57,19 @@ def get_cart( "vendor_id": vendor.id, "vendor_code": vendor.subdomain, "session_id": session_id, - } + }, ) - cart = cart_service.get_cart( - db=db, - vendor_id=vendor.id, - session_id=session_id - ) + cart = cart_service.get_cart(db=db, vendor_id=vendor.id, session_id=session_id) logger.info( f"[SHOP_API] get_cart result: {len(cart.get('items', []))} items in cart", extra={ "session_id": session_id, "vendor_id": vendor.id, - "item_count": len(cart.get('items', [])), - "total": cart.get('total', 0), - } + "item_count": len(cart.get("items", [])), + "total": cart.get("total", 0), + }, ) return CartResponse.from_service_dict(cart) @@ -102,12 +96,12 @@ def add_to_cart( - quantity: Quantity to add (default: 1) """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.info( @@ -118,7 +112,7 @@ def add_to_cart( "session_id": session_id, "product_id": cart_data.product_id, "quantity": cart_data.quantity, - } + }, ) result = cart_service.add_to_cart( @@ -126,7 +120,7 @@ def add_to_cart( vendor_id=vendor.id, session_id=session_id, product_id=cart_data.product_id, - quantity=cart_data.quantity + quantity=cart_data.quantity, ) logger.info( @@ -134,13 +128,15 @@ def add_to_cart( extra={ "session_id": session_id, "result": result, - } + }, ) return CartOperationResponse(**result) -@router.put("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse) +@router.put( + "/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse +) def update_cart_item( request: Request, session_id: str = Path(..., description="Shopping session ID"), @@ -162,12 +158,12 @@ def update_cart_item( - quantity: New quantity (must be >= 1) """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -178,7 +174,7 @@ def update_cart_item( "session_id": session_id, "product_id": product_id, "quantity": cart_data.quantity, - } + }, ) result = cart_service.update_cart_item( @@ -186,13 +182,15 @@ def update_cart_item( vendor_id=vendor.id, session_id=session_id, product_id=product_id, - quantity=cart_data.quantity + quantity=cart_data.quantity, ) return CartOperationResponse(**result) -@router.delete("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse) +@router.delete( + "/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse +) def remove_from_cart( request: Request, session_id: str = Path(..., description="Shopping session ID"), @@ -210,12 +208,12 @@ def remove_from_cart( - product_id: ID of product to remove """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -225,14 +223,11 @@ def remove_from_cart( "vendor_code": vendor.subdomain, "session_id": session_id, "product_id": product_id, - } + }, ) result = cart_service.remove_from_cart( - db=db, - vendor_id=vendor.id, - session_id=session_id, - product_id=product_id + db=db, vendor_id=vendor.id, session_id=session_id, product_id=product_id ) return CartOperationResponse(**result) @@ -254,12 +249,12 @@ def clear_cart( - session_id: Unique session identifier for the cart """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -268,13 +263,9 @@ def clear_cart( "vendor_id": vendor.id, "vendor_code": vendor.subdomain, "session_id": session_id, - } + }, ) - result = cart_service.clear_cart( - db=db, - vendor_id=vendor.id, - session_id=session_id - ) + result = cart_service.clear_cart(db=db, vendor_id=vendor.id, session_id=session_id) return ClearCartResponse(**result) diff --git a/app/api/v1/shop/content_pages.py b/app/api/v1/shop/content_pages.py index 1c228911..48e19d64 100644 --- a/app/api/v1/shop/content_pages.py +++ b/app/api/v1/shop/content_pages.py @@ -8,6 +8,7 @@ No authentication required. import logging from typing import List + from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import BaseModel from sqlalchemy.orm import Session @@ -23,8 +24,10 @@ logger = logging.getLogger(__name__) # RESPONSE SCHEMAS # ============================================================================ + class PublicContentPageResponse(BaseModel): """Public content page response (no internal IDs).""" + slug: str title: str content: str @@ -36,6 +39,7 @@ class PublicContentPageResponse(BaseModel): class ContentPageListItem(BaseModel): """Content page list item for navigation.""" + slug: str title: str show_in_footer: bool @@ -47,25 +51,21 @@ class ContentPageListItem(BaseModel): # PUBLIC ENDPOINTS # ============================================================================ + @router.get("/navigation", response_model=List[ContentPageListItem]) -def get_navigation_pages( - request: Request, - db: Session = Depends(get_db) -): +def get_navigation_pages(request: Request, db: Session = Depends(get_db)): """ Get list of content pages for navigation (footer/header). Uses vendor from request.state (set by middleware). Returns vendor overrides + platform defaults. """ - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) vendor_id = vendor.id if vendor else None # Get all published pages for this vendor pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=vendor_id, - include_unpublished=False + db, vendor_id=vendor_id, include_unpublished=False ) return [ @@ -81,25 +81,21 @@ def get_navigation_pages( @router.get("/{slug}", response_model=PublicContentPageResponse) -def get_content_page( - slug: str, - request: Request, - db: Session = Depends(get_db) -): +def get_content_page(slug: str, request: Request, db: Session = Depends(get_db)): """ Get a specific content page by slug. Uses vendor from request.state (set by middleware). Returns vendor override if exists, otherwise platform default. """ - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) vendor_id = vendor.id if vendor else None page = content_page_service.get_page_for_vendor( db, slug=slug, vendor_id=vendor_id, - include_unpublished=False # Only show published pages + include_unpublished=False, # Only show published pages ) if not page: diff --git a/app/api/v1/shop/orders.py b/app/api/v1/shop/orders.py index 3523e16e..74ac25bf 100644 --- a/app/api/v1/shop/orders.py +++ b/app/api/v1/shop/orders.py @@ -10,31 +10,23 @@ Requires customer authentication for most operations. import logging from typing import Optional -from fastapi import APIRouter, Depends, Path, Query, Request, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request from sqlalchemy.orm import Session -from app.core.database import get_db from app.api.deps import get_current_customer_api -from app.services.order_service import order_service +from app.core.database import get_db from app.services.customer_service import customer_service -from models.schema.order import ( - OrderCreate, - OrderResponse, - OrderDetailResponse, - OrderListResponse -) -from models.database.user import User +from app.services.order_service import order_service from models.database.customer import Customer +from models.database.user import User +from models.schema.order import (OrderCreate, OrderDetailResponse, + OrderListResponse, OrderResponse) router = APIRouter() logger = logging.getLogger(__name__) -def get_customer_from_user( - request: Request, - user: User, - db: Session -) -> Customer: +def get_customer_from_user(request: Request, user: User, db: Session) -> Customer: """ Helper to get Customer record from authenticated User. @@ -49,25 +41,22 @@ def get_customer_from_user( Raises: HTTPException: If customer not found or vendor mismatch """ - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) # Find customer record for this user and vendor customer = customer_service.get_customer_by_user_id( - db=db, - vendor_id=vendor.id, - user_id=user.id + db=db, vendor_id=vendor.id, user_id=user.id ) if not customer: raise HTTPException( - status_code=404, - detail="Customer account not found for current vendor" + status_code=404, detail="Customer account not found for current vendor" ) return customer @@ -91,12 +80,12 @@ def place_order( - Order data including shipping address, payment method, etc. """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) # Get customer record @@ -109,14 +98,12 @@ def place_order( "vendor_code": vendor.subdomain, "customer_id": customer.id, "user_id": current_user.id, - } + }, ) # Create order order = order_service.create_order( - db=db, - vendor_id=vendor.id, - order_data=order_data + db=db, vendor_id=vendor.id, order_data=order_data ) logger.info( @@ -127,7 +114,7 @@ def place_order( "order_number": order.order_number, "customer_id": customer.id, "total_amount": float(order.total_amount), - } + }, ) # TODO: Update customer stats @@ -156,12 +143,12 @@ def get_my_orders( - limit: Maximum number of orders to return """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) # Get customer record @@ -175,23 +162,19 @@ def get_my_orders( "customer_id": customer.id, "skip": skip, "limit": limit, - } + }, ) # Get orders orders, total = order_service.get_customer_orders( - db=db, - vendor_id=vendor.id, - customer_id=customer.id, - skip=skip, - limit=limit + db=db, vendor_id=vendor.id, customer_id=customer.id, skip=skip, limit=limit ) return OrderListResponse( orders=[OrderResponse.model_validate(o) for o in orders], total=total, skip=skip, - limit=limit + limit=limit, ) @@ -212,12 +195,12 @@ def get_order_details( - order_id: ID of the order to retrieve """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) # Get customer record @@ -230,19 +213,16 @@ def get_order_details( "vendor_code": vendor.subdomain, "customer_id": customer.id, "order_id": order_id, - } + }, ) # Get order - order = order_service.get_order( - db=db, - vendor_id=vendor.id, - order_id=order_id - ) + order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id) # Verify order belongs to customer if order.customer_id != customer.id: from app.exceptions import OrderNotFoundException + raise OrderNotFoundException(str(order_id)) return OrderDetailResponse.model_validate(order) diff --git a/app/api/v1/shop/products.py b/app/api/v1/shop/products.py index 60bdcb83..37836bec 100644 --- a/app/api/v1/shop/products.py +++ b/app/api/v1/shop/products.py @@ -10,12 +10,13 @@ No authentication required. import logging from typing import Optional -from fastapi import APIRouter, Depends, Query, Path, Request, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request from sqlalchemy.orm import Session from app.core.database import get_db from app.services.product_service import product_service -from models.schema.product import ProductResponse, ProductDetailResponse, ProductListResponse +from models.schema.product import (ProductDetailResponse, ProductListResponse, + ProductResponse) router = APIRouter() logger = logging.getLogger(__name__) @@ -27,7 +28,9 @@ def get_product_catalog( skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), search: Optional[str] = Query(None, description="Search products by name"), - is_featured: Optional[bool] = Query(None, description="Filter by featured products"), + is_featured: Optional[bool] = Query( + None, description="Filter by featured products" + ), db: Session = Depends(get_db), ): """ @@ -44,12 +47,12 @@ def get_product_catalog( - is_featured: Filter by featured products only """ # Get vendor from middleware (injected by VendorContextMiddleware) - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -61,7 +64,7 @@ def get_product_catalog( "limit": limit, "search": search, "is_featured": is_featured, - } + }, ) # Get only active products for public view @@ -71,14 +74,14 @@ def get_product_catalog( skip=skip, limit=limit, is_active=True, # Only show active products to customers - is_featured=is_featured + is_featured=is_featured, ) return ProductListResponse( products=[ProductResponse.model_validate(p) for p in products], total=total, skip=skip, - limit=limit + limit=limit, ) @@ -98,12 +101,12 @@ def get_product_details( - product_id: ID of the product to retrieve """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -112,18 +115,17 @@ def get_product_details( "vendor_id": vendor.id, "vendor_code": vendor.subdomain, "product_id": product_id, - } + }, ) product = product_service.get_product( - db=db, - vendor_id=vendor.id, - product_id=product_id + db=db, vendor_id=vendor.id, product_id=product_id ) # Check if product is active if not product.is_active: from app.exceptions import ProductNotActiveException + raise ProductNotActiveException(str(product_id)) return ProductDetailResponse.model_validate(product) @@ -150,12 +152,12 @@ def search_products( - limit: Maximum number of results to return """ # Get vendor from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException( status_code=404, - detail="Vendor not found. Please access via vendor domain/subdomain/path." + detail="Vendor not found. Please access via vendor domain/subdomain/path.", ) logger.debug( @@ -166,22 +168,18 @@ def search_products( "query": q, "skip": skip, "limit": limit, - } + }, ) # TODO: Implement full-text search functionality # For now, return filtered products products, total = product_service.get_vendor_products( - db=db, - vendor_id=vendor.id, - skip=skip, - limit=limit, - is_active=True + db=db, vendor_id=vendor.id, skip=skip, limit=limit, is_active=True ) return ProductListResponse( products=[ProductResponse.model_validate(p) for p in products], total=total, skip=skip, - limit=limit + limit=limit, ) diff --git a/app/api/v1/vendor/__init__.py b/app/api/v1/vendor/__init__.py index 9eb0cba8..b984af4f 100644 --- a/app/api/v1/vendor/__init__.py +++ b/app/api/v1/vendor/__init__.py @@ -13,25 +13,9 @@ IMPORTANT: from fastapi import APIRouter # Import all sub-routers (JSON API only) -from . import ( - info, - auth, - dashboard, - profile, - settings, - products, - orders, - customers, - team, - inventory, - marketplace, - payments, - media, - notifications, - analytics, - content_pages, -) - +from . import (analytics, auth, content_pages, customers, dashboard, info, + inventory, marketplace, media, notifications, orders, payments, + products, profile, settings, team) # Create vendor router router = APIRouter() @@ -68,7 +52,11 @@ router.include_router(notifications.router, tags=["vendor-notifications"]) router.include_router(analytics.router, tags=["vendor-analytics"]) # Content pages management -router.include_router(content_pages.router, prefix="/{vendor_code}/content-pages", tags=["vendor-content-pages"]) +router.include_router( + content_pages.router, + prefix="/{vendor_code}/content-pages", + tags=["vendor-content-pages"], +) # Vendor info endpoint - MUST BE LAST! Has catch-all GET /{vendor_code} router.include_router(info.router, tags=["vendor-info"]) diff --git a/app/api/v1/vendor/analytics.py b/app/api/v1/vendor/analytics.py index 4e99270c..7d92ac5b 100644 --- a/app/api/v1/vendor/analytics.py +++ b/app/api/v1/vendor/analytics.py @@ -4,13 +4,14 @@ Vendor analytics and reporting endpoints. """ import logging + from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.stats_service import stats_service +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor @@ -20,10 +21,10 @@ logger = logging.getLogger(__name__) @router.get("") def get_vendor_analytics( - period: str = Query("30d", description="Time period: 7d, 30d, 90d, 1y"), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + period: str = Query("30d", description="Time period: 7d, 30d, 90d, 1y"), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get vendor analytics data for specified time period.""" return stats_service.get_vendor_analytics(db, vendor.id, period) diff --git a/app/api/v1/vendor/auth.py b/app/api/v1/vendor/auth.py index 127355bc..8f8bac33 100644 --- a/app/api/v1/vendor/auth.py +++ b/app/api/v1/vendor/auth.py @@ -13,19 +13,20 @@ This prevents: """ import logging + from fastapi import APIRouter, Depends, Request, Response +from pydantic import BaseModel from sqlalchemy.orm import Session -from app.core.database import get_db -from app.services.auth_service import auth_service -from app.exceptions import InvalidCredentialsException -from middleware.vendor_context import get_current_vendor -from models.schema.auth import UserLogin -from models.database.vendor import Vendor, VendorUser, Role -from models.database.user import User -from pydantic import BaseModel from app.api.deps import get_current_vendor_api +from app.core.database import get_db from app.core.environment import should_use_secure_cookies +from app.exceptions import InvalidCredentialsException +from app.services.auth_service import auth_service +from middleware.vendor_context import get_current_vendor +from models.database.user import User +from models.database.vendor import Role, Vendor, VendorUser +from models.schema.auth import UserLogin router = APIRouter(prefix="/auth") logger = logging.getLogger(__name__) @@ -43,10 +44,10 @@ class VendorLoginResponse(BaseModel): @router.post("/login", response_model=VendorLoginResponse) def vendor_login( - user_credentials: UserLogin, - request: Request, - response: Response, - db: Session = Depends(get_db) + user_credentials: UserLogin, + request: Request, + response: Response, + db: Session = Depends(get_db), ): """ Vendor team member login. @@ -64,13 +65,16 @@ def vendor_login( vendor = get_current_vendor(request) # If no vendor from middleware, try to get from request body - if not vendor and hasattr(user_credentials, 'vendor_code'): - vendor_code = getattr(user_credentials, 'vendor_code', None) + if not vendor and hasattr(user_credentials, "vendor_code"): + vendor_code = getattr(user_credentials, "vendor_code", None) if vendor_code: - vendor = db.query(Vendor).filter( - Vendor.vendor_code == vendor_code.upper(), - Vendor.is_active == True - ).first() + vendor = ( + db.query(Vendor) + .filter( + Vendor.vendor_code == vendor_code.upper(), Vendor.is_active == True + ) + .first() + ) # Authenticate user login_result = auth_service.login_user(db=db, user_credentials=user_credentials) @@ -79,7 +83,9 @@ def vendor_login( # CRITICAL: Prevent admin users from using vendor login if user.role == "admin": logger.warning(f"Admin user attempted vendor login: {user.username}") - raise InvalidCredentialsException("Admins cannot access vendor portal. Please use admin portal.") + raise InvalidCredentialsException( + "Admins cannot access vendor portal. Please use admin portal." + ) # Determine vendor and role vendor_role = "Member" @@ -92,11 +98,16 @@ def vendor_login( vendor_role = "Owner" else: # Check if user is team member - vendor_user = db.query(VendorUser).join(Role).filter( - VendorUser.user_id == user.id, - VendorUser.vendor_id == vendor.id, - VendorUser.is_active == True - ).first() + vendor_user = ( + db.query(VendorUser) + .join(Role) + .filter( + VendorUser.user_id == user.id, + VendorUser.vendor_id == vendor.id, + VendorUser.is_active == True, + ) + .first() + ) if vendor_user: vendor_role = vendor_user.role.name @@ -117,17 +128,14 @@ def vendor_login( # Check vendor memberships elif user.vendor_memberships: active_membership = next( - (vm for vm in user.vendor_memberships if vm.is_active), - None + (vm for vm in user.vendor_memberships if vm.is_active), None ) if active_membership: vendor = active_membership.vendor vendor_role = active_membership.role.name if not vendor: - raise InvalidCredentialsException( - "User is not associated with any vendor" - ) + raise InvalidCredentialsException("User is not associated with any vendor") logger.info( f"Vendor team login successful: {user.username} " @@ -161,7 +169,7 @@ def vendor_login( "username": user.username, "email": user.email, "role": user.role, - "is_active": user.is_active + "is_active": user.is_active, }, vendor={ "id": vendor.id, @@ -169,9 +177,9 @@ def vendor_login( "subdomain": vendor.subdomain, "name": vendor.name, "is_active": vendor.is_active, - "is_verified": vendor.is_verified + "is_verified": vendor.is_verified, }, - vendor_role=vendor_role + vendor_role=vendor_role, ) @@ -198,8 +206,7 @@ def vendor_logout(response: Response): @router.get("/me") def get_current_vendor_user( - user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + user: User = Depends(get_current_vendor_api), db: Session = Depends(get_db) ): """ Get current authenticated vendor user. @@ -212,5 +219,5 @@ def get_current_vendor_user( "username": user.username, "email": user.email, "role": user.role, - "is_active": user.is_active + "is_active": user.is_active, } diff --git a/app/api/v1/vendor/content_pages.py b/app/api/v1/vendor/content_pages.py index b9af5302..4a552898 100644 --- a/app/api/v1/vendor/content_pages.py +++ b/app/api/v1/vendor/content_pages.py @@ -10,6 +10,7 @@ Vendors can: import logging from typing import List, Optional + from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -26,14 +27,26 @@ logger = logging.getLogger(__name__) # REQUEST/RESPONSE SCHEMAS # ============================================================================ + class VendorContentPageCreate(BaseModel): """Schema for creating a vendor content page.""" - slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)") + + slug: str = Field( + ..., + max_length=100, + description="URL-safe identifier (about, faq, contact, etc.)", + ) title: str = Field(..., max_length=200, description="Page title") content: str = Field(..., description="HTML or Markdown content") - content_format: str = Field(default="html", description="Content format: html or markdown") - meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description") - meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords") + content_format: str = Field( + default="html", description="Content format: html or markdown" + ) + meta_description: Optional[str] = Field( + None, max_length=300, description="SEO meta description" + ) + meta_keywords: Optional[str] = Field( + None, max_length=300, description="SEO keywords" + ) is_published: bool = Field(default=False, description="Publish immediately") show_in_footer: bool = Field(default=True, description="Show in footer navigation") show_in_header: bool = Field(default=False, description="Show in header navigation") @@ -42,6 +55,7 @@ class VendorContentPageCreate(BaseModel): class VendorContentPageUpdate(BaseModel): """Schema for updating a vendor content page.""" + title: Optional[str] = Field(None, max_length=200) content: Optional[str] = None content_format: Optional[str] = None @@ -55,6 +69,7 @@ class VendorContentPageUpdate(BaseModel): class ContentPageResponse(BaseModel): """Schema for content page response.""" + id: int vendor_id: Optional[int] vendor_name: Optional[str] @@ -81,11 +96,12 @@ class ContentPageResponse(BaseModel): # VENDOR CONTENT PAGES # ============================================================================ + @router.get("/", response_model=List[ContentPageResponse]) def list_vendor_pages( include_unpublished: bool = Query(False, description="Include draft pages"), current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ List all content pages available for this vendor. @@ -93,12 +109,12 @@ def list_vendor_pages( Returns vendor-specific overrides + platform defaults (vendor overrides take precedence). """ if not current_user.vendor_id: - raise HTTPException(status_code=403, detail="User is not associated with a vendor") + raise HTTPException( + status_code=403, detail="User is not associated with a vendor" + ) pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=current_user.vendor_id, - include_unpublished=include_unpublished + db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished ) return [page.to_dict() for page in pages] @@ -108,7 +124,7 @@ def list_vendor_pages( def list_vendor_overrides( include_unpublished: bool = Query(False, description="Include draft pages"), current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ List only vendor-specific content pages (no platform defaults). @@ -116,12 +132,12 @@ def list_vendor_overrides( Shows what the vendor has customized. """ if not current_user.vendor_id: - raise HTTPException(status_code=403, detail="User is not associated with a vendor") + raise HTTPException( + status_code=403, detail="User is not associated with a vendor" + ) pages = content_page_service.list_all_vendor_pages( - db, - vendor_id=current_user.vendor_id, - include_unpublished=include_unpublished + db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished ) return [page.to_dict() for page in pages] @@ -132,7 +148,7 @@ def get_page( slug: str, include_unpublished: bool = Query(False, description="Include draft pages"), current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get a specific content page by slug. @@ -140,13 +156,15 @@ def get_page( Returns vendor override if exists, otherwise platform default. """ if not current_user.vendor_id: - raise HTTPException(status_code=403, detail="User is not associated with a vendor") + raise HTTPException( + status_code=403, detail="User is not associated with a vendor" + ) page = content_page_service.get_page_for_vendor( db, slug=slug, vendor_id=current_user.vendor_id, - include_unpublished=include_unpublished + include_unpublished=include_unpublished, ) if not page: @@ -159,7 +177,7 @@ def get_page( def create_vendor_page( page_data: VendorContentPageCreate, current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Create a vendor-specific content page override. @@ -167,7 +185,9 @@ def create_vendor_page( This will be shown instead of the platform default for this vendor. """ if not current_user.vendor_id: - raise HTTPException(status_code=403, detail="User is not associated with a vendor") + raise HTTPException( + status_code=403, detail="User is not associated with a vendor" + ) page = content_page_service.create_page( db, @@ -182,7 +202,7 @@ def create_vendor_page( show_in_footer=page_data.show_in_footer, show_in_header=page_data.show_in_header, display_order=page_data.display_order, - created_by=current_user.id + created_by=current_user.id, ) return page.to_dict() @@ -193,7 +213,7 @@ def update_vendor_page( page_id: int, page_data: VendorContentPageUpdate, current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Update a vendor-specific content page. @@ -201,7 +221,9 @@ def update_vendor_page( Can only update pages owned by this vendor. """ if not current_user.vendor_id: - raise HTTPException(status_code=403, detail="User is not associated with a vendor") + raise HTTPException( + status_code=403, detail="User is not associated with a vendor" + ) # Verify ownership existing_page = content_page_service.get_page_by_id(db, page_id) @@ -209,7 +231,9 @@ def update_vendor_page( raise HTTPException(status_code=404, detail="Content page not found") if existing_page.vendor_id != current_user.vendor_id: - raise HTTPException(status_code=403, detail="Cannot edit pages from other vendors") + raise HTTPException( + status_code=403, detail="Cannot edit pages from other vendors" + ) # Update page = content_page_service.update_page( @@ -224,7 +248,7 @@ def update_vendor_page( show_in_footer=page_data.show_in_footer, show_in_header=page_data.show_in_header, display_order=page_data.display_order, - updated_by=current_user.id + updated_by=current_user.id, ) return page.to_dict() @@ -234,7 +258,7 @@ def update_vendor_page( def delete_vendor_page( page_id: int, current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Delete a vendor-specific content page. @@ -243,7 +267,9 @@ def delete_vendor_page( After deletion, platform default will be shown (if exists). """ if not current_user.vendor_id: - raise HTTPException(status_code=403, detail="User is not associated with a vendor") + raise HTTPException( + status_code=403, detail="User is not associated with a vendor" + ) # Verify ownership existing_page = content_page_service.get_page_by_id(db, page_id) @@ -251,7 +277,9 @@ def delete_vendor_page( raise HTTPException(status_code=404, detail="Content page not found") if existing_page.vendor_id != current_user.vendor_id: - raise HTTPException(status_code=403, detail="Cannot delete pages from other vendors") + raise HTTPException( + status_code=403, detail="Cannot delete pages from other vendors" + ) # Delete content_page_service.delete_page(db, page_id) diff --git a/app/api/v1/vendor/customers.py b/app/api/v1/vendor/customers.py index 2727a017..3a267c17 100644 --- a/app/api/v1/vendor/customers.py +++ b/app/api/v1/vendor/customers.py @@ -6,6 +6,7 @@ Vendor customer management endpoints. import logging from typing import Optional + from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session @@ -21,13 +22,13 @@ logger = logging.getLogger(__name__) @router.get("") def get_vendor_customers( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - search: Optional[str] = Query(None), - is_active: Optional[bool] = Query(None), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + search: Optional[str] = Query(None), + is_active: Optional[bool] = Query(None), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get all customers for this vendor. @@ -43,16 +44,16 @@ def get_vendor_customers( "total": 0, "skip": skip, "limit": limit, - "message": "Customer management coming in Slice 4" + "message": "Customer management coming in Slice 4", } @router.get("/{customer_id}") def get_customer_details( - customer_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + customer_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get detailed customer information. @@ -63,17 +64,15 @@ def get_customer_details( - Include order history - Include total spent, etc. """ - return { - "message": "Customer details coming in Slice 4" - } + return {"message": "Customer details coming in Slice 4"} @router.get("/{customer_id}/orders") def get_customer_orders( - customer_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + customer_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get order history for a specific customer. @@ -83,19 +82,16 @@ def get_customer_orders( - Filter by vendor_id - Return order details """ - return { - "orders": [], - "message": "Customer orders coming in Slice 5" - } + return {"orders": [], "message": "Customer orders coming in Slice 5"} @router.put("/{customer_id}") def update_customer( - customer_id: int, - customer_data: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + customer_id: int, + customer_data: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Update customer information. @@ -105,17 +101,15 @@ def update_customer( - Verify customer belongs to vendor - Update customer preferences """ - return { - "message": "Customer update coming in Slice 4" - } + return {"message": "Customer update coming in Slice 4"} @router.put("/{customer_id}/status") def toggle_customer_status( - customer_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + customer_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Activate/deactivate customer account. @@ -125,17 +119,15 @@ def toggle_customer_status( - Verify customer belongs to vendor - Log the change """ - return { - "message": "Customer status toggle coming in Slice 4" - } + return {"message": "Customer status toggle coming in Slice 4"} @router.get("/{customer_id}/stats") def get_customer_statistics( - customer_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + customer_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get customer statistics and metrics. @@ -151,6 +143,5 @@ def get_customer_statistics( "total_spent": 0.0, "average_order_value": 0.0, "last_order_date": None, - "message": "Customer statistics coming in Slice 4" + "message": "Customer statistics coming in Slice 4", } - diff --git a/app/api/v1/vendor/dashboard.py b/app/api/v1/vendor/dashboard.py index 9e8cc066..b49351a6 100644 --- a/app/api/v1/vendor/dashboard.py +++ b/app/api/v1/vendor/dashboard.py @@ -4,13 +4,14 @@ Vendor dashboard and statistics endpoints. """ import logging + from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.stats_service import stats_service +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor @@ -20,9 +21,9 @@ logger = logging.getLogger(__name__) @router.get("/stats") def get_vendor_dashboard_stats( - request: Request, - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + request: Request, + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get vendor-specific dashboard statistics. @@ -38,24 +39,23 @@ def get_vendor_dashboard_stats( """ # Get vendor from authenticated user's vendor_user record from models.database.vendor import VendorUser - vendor_user = db.query(VendorUser).filter( - VendorUser.user_id == current_user.id - ).first() + + vendor_user = ( + db.query(VendorUser).filter(VendorUser.user_id == current_user.id).first() + ) if not vendor_user: from fastapi import HTTPException + raise HTTPException( - status_code=403, - detail="User is not associated with any vendor" + status_code=403, detail="User is not associated with any vendor" ) vendor = vendor_user.vendor if not vendor or not vendor.is_active: from fastapi import HTTPException - raise HTTPException( - status_code=404, - detail="Vendor not found or inactive" - ) + + raise HTTPException(status_code=404, detail="Vendor not found or inactive") # Get vendor-scoped statistics stats_data = stats_service.get_vendor_stats(db=db, vendor_id=vendor.id) @@ -82,5 +82,5 @@ def get_vendor_dashboard_stats( "revenue": { "total": stats_data.get("total_revenue", 0), "this_month": stats_data.get("revenue_this_month", 0), - } + }, } diff --git a/app/api/v1/vendor/info.py b/app/api/v1/vendor/info.py index fffcebb7..e84d0a8d 100644 --- a/app/api/v1/vendor/info.py +++ b/app/api/v1/vendor/info.py @@ -8,14 +8,15 @@ This module provides: """ import logging -from fastapi import APIRouter, Path, Depends -from sqlalchemy.orm import Session + +from fastapi import APIRouter, Depends, Path from sqlalchemy import func +from sqlalchemy.orm import Session from app.core.database import get_db from app.exceptions import VendorNotFoundException -from models.schema.vendor import VendorResponse, VendorDetailResponse from models.database.vendor import Vendor +from models.schema.vendor import VendorDetailResponse, VendorResponse router = APIRouter() logger = logging.getLogger(__name__) @@ -35,10 +36,14 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor: Raises: VendorNotFoundException: If vendor not found or inactive """ - vendor = db.query(Vendor).filter( - func.upper(Vendor.vendor_code) == vendor_code.upper(), - Vendor.is_active == True - ).first() + vendor = ( + db.query(Vendor) + .filter( + func.upper(Vendor.vendor_code) == vendor_code.upper(), + Vendor.is_active == True, + ) + .first() + ) if not vendor: logger.warning(f"Vendor not found or inactive: {vendor_code}") @@ -49,8 +54,8 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor: @router.get("/{vendor_code}", response_model=VendorDetailResponse) def get_vendor_info( - vendor_code: str = Path(..., description="Vendor code"), - db: Session = Depends(get_db) + vendor_code: str = Path(..., description="Vendor code"), + db: Session = Depends(get_db), ): """ Get public vendor information by vendor code. diff --git a/app/api/v1/vendor/inventory.py b/app/api/v1/vendor/inventory.py index 65b2d331..0cb49c2c 100644 --- a/app/api/v1/vendor/inventory.py +++ b/app/api/v1/vendor/inventory.py @@ -7,19 +7,14 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.inventory_service import inventory_service -from models.schema.inventory import ( - InventoryCreate, - InventoryAdjust, - InventoryUpdate, - InventoryReserve, - InventoryResponse, - ProductInventorySummary, - InventoryListResponse -) +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor +from models.schema.inventory import (InventoryAdjust, InventoryCreate, + InventoryListResponse, InventoryReserve, + InventoryResponse, InventoryUpdate, + ProductInventorySummary) router = APIRouter() logger = logging.getLogger(__name__) @@ -27,10 +22,10 @@ logger = logging.getLogger(__name__) @router.post("/inventory/set", response_model=InventoryResponse) def set_inventory( - inventory: InventoryCreate, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + inventory: InventoryCreate, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Set exact inventory quantity (replaces existing).""" return inventory_service.set_inventory(db, vendor.id, inventory) @@ -38,10 +33,10 @@ def set_inventory( @router.post("/inventory/adjust", response_model=InventoryResponse) def adjust_inventory( - adjustment: InventoryAdjust, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + adjustment: InventoryAdjust, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Adjust inventory (positive to add, negative to remove).""" return inventory_service.adjust_inventory(db, vendor.id, adjustment) @@ -49,10 +44,10 @@ def adjust_inventory( @router.post("/inventory/reserve", response_model=InventoryResponse) def reserve_inventory( - reservation: InventoryReserve, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + reservation: InventoryReserve, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Reserve inventory for an order.""" return inventory_service.reserve_inventory(db, vendor.id, reservation) @@ -60,10 +55,10 @@ def reserve_inventory( @router.post("/inventory/release", response_model=InventoryResponse) def release_reservation( - reservation: InventoryReserve, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + reservation: InventoryReserve, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Release reserved inventory (cancel order).""" return inventory_service.release_reservation(db, vendor.id, reservation) @@ -71,10 +66,10 @@ def release_reservation( @router.post("/inventory/fulfill", response_model=InventoryResponse) def fulfill_reservation( - reservation: InventoryReserve, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + reservation: InventoryReserve, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Fulfill reservation (complete order, remove from stock).""" return inventory_service.fulfill_reservation(db, vendor.id, reservation) @@ -82,10 +77,10 @@ def fulfill_reservation( @router.get("/inventory/product/{product_id}", response_model=ProductInventorySummary) def get_product_inventory( - product_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get inventory summary for a product.""" return inventory_service.get_product_inventory(db, vendor.id, product_id) @@ -93,13 +88,13 @@ def get_product_inventory( @router.get("/inventory", response_model=InventoryListResponse) def get_vendor_inventory( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - location: Optional[str] = Query(None), - low_stock: Optional[int] = Query(None, ge=0), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + location: Optional[str] = Query(None), + low_stock: Optional[int] = Query(None, ge=0), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get all inventory for vendor.""" inventories = inventory_service.get_vendor_inventory( @@ -110,31 +105,30 @@ def get_vendor_inventory( total = len(inventories) # You might want a separate count query for large datasets return InventoryListResponse( - inventories=inventories, - total=total, - skip=skip, - limit=limit + inventories=inventories, total=total, skip=skip, limit=limit ) @router.put("/inventory/{inventory_id}", response_model=InventoryResponse) def update_inventory( - inventory_id: int, - inventory_update: InventoryUpdate, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + inventory_id: int, + inventory_update: InventoryUpdate, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Update inventory entry.""" - return inventory_service.update_inventory(db, vendor.id, inventory_id, inventory_update) + return inventory_service.update_inventory( + db, vendor.id, inventory_id, inventory_update + ) @router.delete("/inventory/{inventory_id}") def delete_inventory( - inventory_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + inventory_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Delete inventory entry.""" inventory_service.delete_inventory(db, vendor.id, inventory_id) diff --git a/app/api/v1/vendor/marketplace.py b/app/api/v1/vendor/marketplace.py index cb98d980..9084bfd2 100644 --- a/app/api/v1/vendor/marketplace.py +++ b/app/api/v1/vendor/marketplace.py @@ -12,16 +12,15 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context # IMPORTANT -from app.services.marketplace_import_job_service import marketplace_import_job_service +from app.services.marketplace_import_job_service import \ + marketplace_import_job_service from app.tasks.background_tasks import process_marketplace_import from middleware.decorators import rate_limit -from models.schema.marketplace_import_job import ( - MarketplaceImportJobResponse, - MarketplaceImportJobRequest -) +from middleware.vendor_context import require_vendor_context # IMPORTANT from models.database.user import User from models.database.vendor import Vendor +from models.schema.marketplace_import_job import (MarketplaceImportJobRequest, + MarketplaceImportJobResponse) router = APIRouter() logger = logging.getLogger(__name__) @@ -30,11 +29,11 @@ logger = logging.getLogger(__name__) @router.post("/import", response_model=MarketplaceImportJobResponse) @rate_limit(max_requests=10, window_seconds=3600) async def import_products_from_marketplace( - request: MarketplaceImportJobRequest, - background_tasks: BackgroundTasks, - vendor: Vendor = Depends(require_vendor_context()), # ADDED: Vendor from middleware - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + request: MarketplaceImportJobRequest, + background_tasks: BackgroundTasks, + vendor: Vendor = Depends(require_vendor_context()), # ADDED: Vendor from middleware + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Import products from marketplace CSV with background processing (Protected).""" logger.info( @@ -66,7 +65,7 @@ async def import_products_from_marketplace( vendor_name=vendor.name, # FIXED: from vendor object source_url=request.source_url, message=f"Marketplace import started from {request.marketplace}. " - f"Check status with /import-status/{import_job.id}", + f"Check status with /import-status/{import_job.id}", imported=0, updated=0, total_processed=0, @@ -77,10 +76,10 @@ async def import_products_from_marketplace( @router.get("/imports/{job_id}", response_model=MarketplaceImportJobResponse) def get_marketplace_import_status( - job_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + job_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get status of marketplace import job (Protected).""" job = marketplace_import_job_service.get_import_job_by_id(db, job_id, current_user) @@ -88,6 +87,7 @@ def get_marketplace_import_status( # Verify job belongs to current vendor if job.vendor_id != vendor.id: from app.exceptions import UnauthorizedVendorAccessException + raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id) return marketplace_import_job_service.convert_to_response_model(job) @@ -95,12 +95,12 @@ def get_marketplace_import_status( @router.get("/imports", response_model=List[MarketplaceImportJobResponse]) def get_marketplace_import_jobs( - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - skip: int = Query(0, ge=0), - limit: int = Query(50, ge=1, le=100), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=100), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get marketplace import jobs for current vendor (Protected).""" jobs = marketplace_import_job_service.get_import_jobs( @@ -112,4 +112,6 @@ def get_marketplace_import_jobs( limit=limit, ) - return [marketplace_import_job_service.convert_to_response_model(job) for job in jobs] + return [ + marketplace_import_job_service.convert_to_response_model(job) for job in jobs + ] diff --git a/app/api/v1/vendor/media.py b/app/api/v1/vendor/media.py index d90ee1fe..900bbb6e 100644 --- a/app/api/v1/vendor/media.py +++ b/app/api/v1/vendor/media.py @@ -6,7 +6,8 @@ Vendor media and file management endpoints. import logging from typing import Optional -from fastapi import APIRouter, Depends, Query, UploadFile, File + +from fastapi import APIRouter, Depends, File, Query, UploadFile from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api @@ -21,13 +22,13 @@ logger = logging.getLogger(__name__) @router.get("") def get_media_library( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - media_type: Optional[str] = Query(None, description="image, video, document"), - search: Optional[str] = Query(None), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + media_type: Optional[str] = Query(None, description="image, video, document"), + search: Optional[str] = Query(None), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get vendor media library. @@ -44,17 +45,17 @@ def get_media_library( "total": 0, "skip": skip, "limit": limit, - "message": "Media library coming in Slice 3" + "message": "Media library coming in Slice 3", } @router.post("/upload") async def upload_media( - file: UploadFile = File(...), - folder: Optional[str] = Query(None, description="products, general, etc."), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + file: UploadFile = File(...), + folder: Optional[str] = Query(None, description="products, general, etc."), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Upload media file. @@ -70,17 +71,17 @@ async def upload_media( return { "file_url": None, "thumbnail_url": None, - "message": "Media upload coming in Slice 3" + "message": "Media upload coming in Slice 3", } @router.post("/upload/multiple") async def upload_multiple_media( - files: list[UploadFile] = File(...), - folder: Optional[str] = Query(None), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + files: list[UploadFile] = File(...), + folder: Optional[str] = Query(None), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Upload multiple media files at once. @@ -94,16 +95,16 @@ async def upload_multiple_media( return { "uploaded_files": [], "failed_files": [], - "message": "Multiple upload coming in Slice 3" + "message": "Multiple upload coming in Slice 3", } @router.get("/{media_id}") def get_media_details( - media_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + media_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get media file details. @@ -113,18 +114,16 @@ def get_media_details( - Return file URL - Return usage information (which products use this file) """ - return { - "message": "Media details coming in Slice 3" - } + return {"message": "Media details coming in Slice 3"} @router.put("/{media_id}") def update_media_metadata( - media_id: int, - metadata: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + media_id: int, + metadata: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Update media file metadata. @@ -135,17 +134,15 @@ def update_media_metadata( - Update tags/categories - Update description """ - return { - "message": "Media update coming in Slice 3" - } + return {"message": "Media update coming in Slice 3"} @router.delete("/{media_id}") def delete_media( - media_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + media_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Delete media file. @@ -157,17 +154,15 @@ def delete_media( - Delete database record - Return success/error """ - return { - "message": "Media deletion coming in Slice 3" - } + return {"message": "Media deletion coming in Slice 3"} @router.get("/{media_id}/usage") def get_media_usage( - media_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + media_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get where this media file is being used. @@ -180,16 +175,16 @@ def get_media_usage( return { "products": [], "other_usage": [], - "message": "Media usage tracking coming in Slice 3" + "message": "Media usage tracking coming in Slice 3", } @router.post("/optimize/{media_id}") def optimize_media( - media_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + media_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Optimize media file (compress, resize, etc.). @@ -200,7 +195,4 @@ def optimize_media( - Keep original - Update database with new versions """ - return { - "message": "Media optimization coming in Slice 3" - } - + return {"message": "Media optimization coming in Slice 3"} diff --git a/app/api/v1/vendor/notifications.py b/app/api/v1/vendor/notifications.py index 2889a9e7..2b6a8205 100644 --- a/app/api/v1/vendor/notifications.py +++ b/app/api/v1/vendor/notifications.py @@ -1,4 +1,4 @@ -# Notification management +# Notification management # app/api/v1/vendor/notifications.py """ Vendor notification management endpoints. @@ -6,6 +6,7 @@ Vendor notification management endpoints. import logging from typing import Optional + from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session @@ -21,12 +22,12 @@ logger = logging.getLogger(__name__) @router.get("") def get_notifications( - skip: int = Query(0, ge=0), - limit: int = Query(50, ge=1, le=100), - unread_only: Optional[bool] = Query(False), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=100), + unread_only: Optional[bool] = Query(False), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get vendor notifications. @@ -41,15 +42,15 @@ def get_notifications( "notifications": [], "total": 0, "unread_count": 0, - "message": "Notifications coming in Slice 5" + "message": "Notifications coming in Slice 5", } @router.get("/unread-count") def get_unread_count( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get count of unread notifications. @@ -58,18 +59,15 @@ def get_unread_count( - Count unread notifications for vendor - Used for notification badge """ - return { - "unread_count": 0, - "message": "Unread count coming in Slice 5" - } + return {"unread_count": 0, "message": "Unread count coming in Slice 5"} @router.put("/{notification_id}/read") def mark_as_read( - notification_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + notification_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Mark notification as read. @@ -78,16 +76,14 @@ def mark_as_read( - Mark single notification as read - Update read timestamp """ - return { - "message": "Mark as read coming in Slice 5" - } + return {"message": "Mark as read coming in Slice 5"} @router.put("/mark-all-read") def mark_all_as_read( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Mark all notifications as read. @@ -96,17 +92,15 @@ def mark_all_as_read( - Mark all vendor notifications as read - Update timestamps """ - return { - "message": "Mark all as read coming in Slice 5" - } + return {"message": "Mark all as read coming in Slice 5"} @router.delete("/{notification_id}") def delete_notification( - notification_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + notification_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Delete notification. @@ -115,16 +109,14 @@ def delete_notification( - Delete single notification - Verify notification belongs to vendor """ - return { - "message": "Notification deletion coming in Slice 5" - } + return {"message": "Notification deletion coming in Slice 5"} @router.get("/settings") def get_notification_settings( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get notification preferences. @@ -138,16 +130,16 @@ def get_notification_settings( "email_notifications": True, "in_app_notifications": True, "notification_types": {}, - "message": "Notification settings coming in Slice 5" + "message": "Notification settings coming in Slice 5", } @router.put("/settings") def update_notification_settings( - settings: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + settings: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Update notification preferences. @@ -157,16 +149,14 @@ def update_notification_settings( - Update in-app notification settings - Enable/disable specific notification types """ - return { - "message": "Notification settings update coming in Slice 5" - } + return {"message": "Notification settings update coming in Slice 5"} @router.get("/templates") def get_notification_templates( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get notification email templates. @@ -176,19 +166,16 @@ def get_notification_templates( - Include: order confirmation, shipping notification, etc. - Return template details """ - return { - "templates": [], - "message": "Notification templates coming in Slice 5" - } + return {"templates": [], "message": "Notification templates coming in Slice 5"} @router.put("/templates/{template_id}") def update_notification_template( - template_id: int, - template_data: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + template_id: int, + template_data: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Update notification email template. @@ -199,17 +186,15 @@ def update_notification_template( - Validate template variables - Preview template """ - return { - "message": "Template update coming in Slice 5" - } + return {"message": "Template update coming in Slice 5"} @router.post("/test") def send_test_notification( - notification_data: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + notification_data: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Send test notification. @@ -219,6 +204,4 @@ def send_test_notification( - Use specified template - Send to current user's email """ - return { - "message": "Test notification coming in Slice 5" - } + return {"message": "Test notification coming in Slice 5"} diff --git a/app/api/v1/vendor/orders.py b/app/api/v1/vendor/orders.py index 49b5a360..a32b4253 100644 --- a/app/api/v1/vendor/orders.py +++ b/app/api/v1/vendor/orders.py @@ -6,21 +6,17 @@ Vendor order management endpoints. import logging from typing import Optional -from fastapi import APIRouter, Depends, Query, Request, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query, Request from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.order_service import order_service -from models.schema.order import ( - OrderResponse, - OrderDetailResponse, - OrderListResponse, - OrderUpdate -) +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor, VendorUser +from models.schema.order import (OrderDetailResponse, OrderListResponse, + OrderResponse, OrderUpdate) router = APIRouter(prefix="/orders") logger = logging.getLogger(__name__) @@ -28,13 +24,13 @@ logger = logging.getLogger(__name__) @router.get("", response_model=OrderListResponse) def get_vendor_orders( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - status: Optional[str] = Query(None, description="Filter by order status"), - customer_id: Optional[int] = Query(None, description="Filter by customer"), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + status: Optional[str] = Query(None, description="Filter by order status"), + customer_id: Optional[int] = Query(None, description="Filter by customer"), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get all orders for vendor. @@ -51,45 +47,41 @@ def get_vendor_orders( skip=skip, limit=limit, status=status, - customer_id=customer_id + customer_id=customer_id, ) return OrderListResponse( orders=[OrderResponse.model_validate(o) for o in orders], total=total, skip=skip, - limit=limit + limit=limit, ) @router.get("/{order_id}", response_model=OrderDetailResponse) def get_order_details( - order_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + order_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get detailed order information including items and addresses. Requires Authorization header (API endpoint). """ - order = order_service.get_order( - db=db, - vendor_id=vendor.id, - order_id=order_id - ) + order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id) return OrderDetailResponse.model_validate(order) @router.put("/{order_id}/status", response_model=OrderResponse) def update_order_status( - order_id: int, - order_update: OrderUpdate, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + order_id: int, + order_update: OrderUpdate, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Update order status and tracking information. @@ -105,10 +97,7 @@ def update_order_status( Requires Authorization header (API endpoint). """ order = order_service.update_order_status( - db=db, - vendor_id=vendor.id, - order_id=order_id, - order_update=order_update + db=db, vendor_id=vendor.id, order_id=order_id, order_update=order_update ) logger.info( diff --git a/app/api/v1/vendor/payments.py b/app/api/v1/vendor/payments.py index a25c7b69..a29672f6 100644 --- a/app/api/v1/vendor/payments.py +++ b/app/api/v1/vendor/payments.py @@ -1,10 +1,11 @@ -# Payment configuration and processing +# Payment configuration and processing # app/api/v1/vendor/payments.py """ Vendor payment configuration and processing endpoints. """ import logging + from fastapi import APIRouter, Depends from sqlalchemy.orm import Session @@ -20,9 +21,9 @@ logger = logging.getLogger(__name__) @router.get("/config") def get_payment_configuration( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get vendor payment configuration. @@ -38,16 +39,16 @@ def get_payment_configuration( "accepted_methods": [], "currency": "EUR", "stripe_connected": False, - "message": "Payment configuration coming in Slice 5" + "message": "Payment configuration coming in Slice 5", } @router.put("/config") def update_payment_configuration( - payment_config: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + payment_config: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Update vendor payment configuration. @@ -58,17 +59,15 @@ def update_payment_configuration( - Update accepted payment methods - Validate configuration before saving """ - return { - "message": "Payment configuration update coming in Slice 5" - } + return {"message": "Payment configuration update coming in Slice 5"} @router.post("/stripe/connect") def connect_stripe_account( - stripe_data: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + stripe_data: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Connect Stripe account for payment processing. @@ -79,16 +78,14 @@ def connect_stripe_account( - Verify Stripe account is active - Enable payment processing """ - return { - "message": "Stripe connection coming in Slice 5" - } + return {"message": "Stripe connection coming in Slice 5"} @router.delete("/stripe/disconnect") def disconnect_stripe_account( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Disconnect Stripe account. @@ -98,16 +95,14 @@ def disconnect_stripe_account( - Disable payment processing - Warn about pending payments """ - return { - "message": "Stripe disconnection coming in Slice 5" - } + return {"message": "Stripe disconnection coming in Slice 5"} @router.get("/methods") def get_payment_methods( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get accepted payment methods for vendor. @@ -116,17 +111,14 @@ def get_payment_methods( - Return list of enabled payment methods - Include: credit card, PayPal, bank transfer, etc. """ - return { - "methods": [], - "message": "Payment methods coming in Slice 5" - } + return {"methods": [], "message": "Payment methods coming in Slice 5"} @router.get("/transactions") def get_payment_transactions( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get payment transaction history. @@ -140,15 +132,15 @@ def get_payment_transactions( return { "transactions": [], "total": 0, - "message": "Payment transactions coming in Slice 5" + "message": "Payment transactions coming in Slice 5", } @router.get("/balance") def get_payment_balance( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get vendor payment balance and payout information. @@ -164,17 +156,17 @@ def get_payment_balance( "pending_balance": 0.0, "currency": "EUR", "next_payout_date": None, - "message": "Payment balance coming in Slice 5" + "message": "Payment balance coming in Slice 5", } @router.post("/refund/{payment_id}") def refund_payment( - payment_id: int, - refund_data: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + payment_id: int, + refund_data: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Process payment refund. @@ -185,6 +177,4 @@ def refund_payment( - Update order status - Send refund notification to customer """ - return { - "message": "Payment refund coming in Slice 5" - } + return {"message": "Payment refund coming in Slice 5"} diff --git a/app/api/v1/vendor/products.py b/app/api/v1/vendor/products.py index d229ca57..f2263ef3 100644 --- a/app/api/v1/vendor/products.py +++ b/app/api/v1/vendor/products.py @@ -11,17 +11,13 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.product_service import product_service -from models.schema.product import ( - ProductCreate, - ProductUpdate, - ProductResponse, - ProductDetailResponse, - ProductListResponse -) +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor +from models.schema.product import (ProductCreate, ProductDetailResponse, + ProductListResponse, ProductResponse, + ProductUpdate) router = APIRouter(prefix="/products") logger = logging.getLogger(__name__) @@ -29,13 +25,13 @@ logger = logging.getLogger(__name__) @router.get("", response_model=ProductListResponse) def get_vendor_products( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - is_active: Optional[bool] = Query(None), - is_featured: Optional[bool] = Query(None), - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + is_active: Optional[bool] = Query(None), + is_featured: Optional[bool] = Query(None), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Get all products in vendor catalog. @@ -50,29 +46,27 @@ def get_vendor_products( skip=skip, limit=limit, is_active=is_active, - is_featured=is_featured + is_featured=is_featured, ) return ProductListResponse( products=[ProductResponse.model_validate(p) for p in products], total=total, skip=skip, - limit=limit + limit=limit, ) @router.get("/{product_id}", response_model=ProductDetailResponse) def get_product_details( - product_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get detailed product information including inventory.""" product = product_service.get_product( - db=db, - vendor_id=vendor.id, - product_id=product_id + db=db, vendor_id=vendor.id, product_id=product_id ) return ProductDetailResponse.model_validate(product) @@ -80,10 +74,10 @@ def get_product_details( @router.post("", response_model=ProductResponse) def add_product_to_catalog( - product_data: ProductCreate, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_data: ProductCreate, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Add a product from marketplace to vendor catalog. @@ -91,9 +85,7 @@ def add_product_to_catalog( This publishes a MarketplaceProduct to the vendor's public catalog. """ product = product_service.create_product( - db=db, - vendor_id=vendor.id, - product_data=product_data + db=db, vendor_id=vendor.id, product_data=product_data ) logger.info( @@ -106,18 +98,15 @@ def add_product_to_catalog( @router.put("/{product_id}", response_model=ProductResponse) def update_product( - product_id: int, - product_data: ProductUpdate, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_id: int, + product_data: ProductUpdate, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Update product in vendor catalog.""" product = product_service.update_product( - db=db, - vendor_id=vendor.id, - product_id=product_id, - product_update=product_data + db=db, vendor_id=vendor.id, product_id=product_id, product_update=product_data ) logger.info( @@ -130,17 +119,13 @@ def update_product( @router.delete("/{product_id}") def remove_product_from_catalog( - product_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Remove product from vendor catalog.""" - product_service.delete_product( - db=db, - vendor_id=vendor.id, - product_id=product_id - ) + product_service.delete_product(db=db, vendor_id=vendor.id, product_id=product_id) logger.info( f"Product {product_id} removed from catalog by user {current_user.username} " @@ -152,10 +137,10 @@ def remove_product_from_catalog( @router.post("/from-import/{marketplace_product_id}", response_model=ProductResponse) def publish_from_marketplace( - marketplace_product_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + marketplace_product_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """ Publish a marketplace product to vendor catalog. @@ -163,14 +148,11 @@ def publish_from_marketplace( Shortcut endpoint for publishing directly from marketplace import. """ product_data = ProductCreate( - marketplace_product_id=marketplace_product_id, - is_active=True + marketplace_product_id=marketplace_product_id, is_active=True ) product = product_service.create_product( - db=db, - vendor_id=vendor.id, - product_data=product_data + db=db, vendor_id=vendor.id, product_data=product_data ) logger.info( @@ -183,10 +165,10 @@ def publish_from_marketplace( @router.put("/{product_id}/toggle-active") def toggle_product_active( - product_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Toggle product active status.""" product = product_service.get_product(db, vendor.id, product_id) @@ -198,18 +180,15 @@ def toggle_product_active( status = "activated" if product.is_active else "deactivated" logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}") - return { - "message": f"Product {status}", - "is_active": product.is_active - } + return {"message": f"Product {status}", "is_active": product.is_active} @router.put("/{product_id}/toggle-featured") def toggle_product_featured( - product_id: int, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + product_id: int, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Toggle product featured status.""" product = product_service.get_product(db, vendor.id, product_id) @@ -221,7 +200,4 @@ def toggle_product_featured( status = "featured" if product.is_featured else "unfeatured" logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}") - return { - "message": f"Product {status}", - "is_featured": product.is_featured - } + return {"message": f"Product {status}", "is_featured": product.is_featured} diff --git a/app/api/v1/vendor/profile.py b/app/api/v1/vendor/profile.py index b04f88bf..2161b493 100644 --- a/app/api/v1/vendor/profile.py +++ b/app/api/v1/vendor/profile.py @@ -4,16 +4,17 @@ Vendor profile management endpoints. """ import logging + from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.vendor_service import vendor_service -from models.schema.vendor import VendorUpdate, VendorResponse +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor +from models.schema.vendor import VendorResponse, VendorUpdate router = APIRouter(prefix="/profile") logger = logging.getLogger(__name__) @@ -21,9 +22,9 @@ logger = logging.getLogger(__name__) @router.get("", response_model=VendorResponse) def get_vendor_profile( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get current vendor profile information.""" return vendor @@ -31,10 +32,10 @@ def get_vendor_profile( @router.put("", response_model=VendorResponse) def update_vendor_profile( - vendor_update: VendorUpdate, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor_update: VendorUpdate, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Update vendor profile information.""" # Verify user has permission to update vendor diff --git a/app/api/v1/vendor/settings.py b/app/api/v1/vendor/settings.py index 7391bd91..0ab97b5c 100644 --- a/app/api/v1/vendor/settings.py +++ b/app/api/v1/vendor/settings.py @@ -4,13 +4,14 @@ Vendor settings and configuration endpoints. """ import logging + from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from middleware.vendor_context import require_vendor_context from app.services.vendor_service import vendor_service +from middleware.vendor_context import require_vendor_context from models.database.user import User from models.database.vendor import Vendor @@ -20,9 +21,9 @@ logger = logging.getLogger(__name__) @router.get("") def get_vendor_settings( - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Get vendor settings and configuration.""" return { @@ -44,10 +45,10 @@ def get_vendor_settings( @router.put("/marketplace") def update_marketplace_settings( - marketplace_config: dict, - vendor: Vendor = Depends(require_vendor_context()), - current_user: User = Depends(get_current_vendor_api), - db: Session = Depends(get_db), + marketplace_config: dict, + vendor: Vendor = Depends(require_vendor_context()), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), ): """Update marketplace integration settings.""" # Verify permissions diff --git a/app/api/v1/vendor/team.py b/app/api/v1/vendor/team.py index ad4108e2..fe84c858 100644 --- a/app/api/v1/vendor/team.py +++ b/app/api/v1/vendor/team.py @@ -12,35 +12,24 @@ Implements complete team management with: import logging from typing import List + from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session +from app.api.deps import (get_current_vendor_api, get_user_permissions, + require_vendor_owner, require_vendor_permission) from app.core.database import get_db from app.core.permissions import VendorPermissions -from app.api.deps import ( - get_current_vendor_api, - require_vendor_owner, - require_vendor_permission, - get_user_permissions -) from app.services.vendor_team_service import vendor_team_service from models.database.user import User from models.database.vendor import Vendor -from models.schema.team import ( - TeamMemberInvite, - TeamMemberUpdate, - TeamMemberResponse, - TeamMemberListResponse, - InvitationAccept, - InvitationResponse, - InvitationAcceptResponse, - RoleResponse, - RoleListResponse, - UserPermissionsResponse, - TeamStatistics, - BulkRemoveRequest, - BulkRemoveResponse, -) +from models.schema.team import (BulkRemoveRequest, BulkRemoveResponse, + InvitationAccept, InvitationAcceptResponse, + InvitationResponse, RoleListResponse, + RoleResponse, TeamMemberInvite, + TeamMemberListResponse, TeamMemberResponse, + TeamMemberUpdate, TeamStatistics, + UserPermissionsResponse) router = APIRouter(prefix="/team") logger = logging.getLogger(__name__) @@ -50,14 +39,15 @@ logger = logging.getLogger(__name__) # Team Member Routes # ============================================================================ + @router.get("/members", response_model=TeamMemberListResponse) def list_team_members( - request: Request, - include_inactive: bool = False, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_permission( - VendorPermissions.TEAM_VIEW.value - )) + request: Request, + include_inactive: bool = False, + db: Session = Depends(get_db), + current_user: User = Depends( + require_vendor_permission(VendorPermissions.TEAM_VIEW.value) + ), ): """ Get all team members for current vendor. @@ -74,9 +64,7 @@ def list_team_members( vendor = request.state.vendor members = vendor_team_service.get_team_members( - db=db, - vendor=vendor, - include_inactive=include_inactive + db=db, vendor=vendor, include_inactive=include_inactive ) # Calculate statistics @@ -90,19 +78,16 @@ def list_team_members( ) return TeamMemberListResponse( - members=members, - total=total, - active_count=active, - pending_invitations=pending + members=members, total=total, active_count=active, pending_invitations=pending ) @router.post("/invite", response_model=InvitationResponse) def invite_team_member( - invitation: TeamMemberInvite, - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_owner) # Owner only + invitation: TeamMemberInvite, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(require_vendor_owner), # Owner only ): """ Invite a new team member to the vendor. @@ -135,7 +120,7 @@ def invite_team_member( vendor=vendor, inviter=current_user, email=invitation.email, - role_id=invitation.role_id + role_id=invitation.role_id, ) elif invitation.role_name: # Use role name with optional custom permissions @@ -145,7 +130,7 @@ def invite_team_member( inviter=current_user, email=invitation.email, role_name=invitation.role_name, - custom_permissions=invitation.custom_permissions + custom_permissions=invitation.custom_permissions, ) else: # Default to Staff role @@ -154,7 +139,7 @@ def invite_team_member( vendor=vendor, inviter=current_user, email=invitation.email, - role_name="staff" + role_name="staff", ) logger.info( @@ -166,15 +151,12 @@ def invite_team_member( message="Invitation sent successfully", email=result["email"], role=result["role"], - invitation_sent=True + invitation_sent=True, ) @router.post("/accept-invitation", response_model=InvitationAcceptResponse) -def accept_invitation( - acceptance: InvitationAccept, - db: Session = Depends(get_db) -): +def accept_invitation(acceptance: InvitationAccept, db: Session = Depends(get_db)): """ Accept a team invitation and activate account. @@ -196,7 +178,7 @@ def accept_invitation( invitation_token=acceptance.invitation_token, password=acceptance.password, first_name=acceptance.first_name, - last_name=acceptance.last_name + last_name=acceptance.last_name, ) logger.info( @@ -210,26 +192,26 @@ def accept_invitation( "id": result["vendor"].id, "vendor_code": result["vendor"].vendor_code, "name": result["vendor"].name, - "subdomain": result["vendor"].subdomain + "subdomain": result["vendor"].subdomain, }, user={ "id": result["user"].id, "email": result["user"].email, "username": result["user"].username, - "full_name": result["user"].full_name + "full_name": result["user"].full_name, }, - role=result["role"] + role=result["role"], ) @router.get("/members/{user_id}", response_model=TeamMemberResponse) def get_team_member( - user_id: int, - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_permission( - VendorPermissions.TEAM_VIEW.value - )) + user_id: int, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends( + require_vendor_permission(VendorPermissions.TEAM_VIEW.value) + ), ): """ Get details of a specific team member. @@ -239,14 +221,13 @@ def get_team_member( vendor = request.state.vendor members = vendor_team_service.get_team_members( - db=db, - vendor=vendor, - include_inactive=True + db=db, vendor=vendor, include_inactive=True ) member = next((m for m in members if m["id"] == user_id), None) if not member: from app.exceptions import UserNotFoundException + raise UserNotFoundException(str(user_id)) return TeamMemberResponse(**member) @@ -254,11 +235,11 @@ def get_team_member( @router.put("/members/{user_id}", response_model=TeamMemberResponse) def update_team_member( - user_id: int, - update_data: TeamMemberUpdate, - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_owner) # Owner only + user_id: int, + update_data: TeamMemberUpdate, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(require_vendor_owner), # Owner only ): """ Update a team member's role or status. @@ -280,7 +261,7 @@ def update_team_member( vendor=vendor, user_id=user_id, new_role_id=update_data.role_id, - is_active=update_data.is_active + is_active=update_data.is_active, ) logger.info( @@ -297,10 +278,10 @@ def update_team_member( @router.delete("/members/{user_id}") def remove_team_member( - user_id: int, - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_owner) # Owner only + user_id: int, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(require_vendor_owner), # Owner only ): """ Remove a team member from the vendor. @@ -316,29 +297,22 @@ def remove_team_member( """ vendor = request.state.vendor - vendor_team_service.remove_team_member( - db=db, - vendor=vendor, - user_id=user_id - ) + vendor_team_service.remove_team_member(db=db, vendor=vendor, user_id=user_id) logger.info( f"Team member removed: {user_id} from {vendor.vendor_code} " f"by {current_user.username}" ) - return { - "message": "Team member removed successfully", - "user_id": user_id - } + return {"message": "Team member removed successfully", "user_id": user_id} @router.post("/members/bulk-remove", response_model=BulkRemoveResponse) def bulk_remove_team_members( - bulk_remove: BulkRemoveRequest, - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_owner) + bulk_remove: BulkRemoveRequest, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(require_vendor_owner), ): """ Remove multiple team members at once. @@ -354,17 +328,12 @@ def bulk_remove_team_members( for user_id in bulk_remove.user_ids: try: vendor_team_service.remove_team_member( - db=db, - vendor=vendor, - user_id=user_id + db=db, vendor=vendor, user_id=user_id ) success_count += 1 except Exception as e: failed_count += 1 - errors.append({ - "user_id": user_id, - "error": str(e) - }) + errors.append({"user_id": user_id, "error": str(e)}) logger.info( f"Bulk remove completed: {success_count} removed, {failed_count} failed " @@ -372,9 +341,7 @@ def bulk_remove_team_members( ) return BulkRemoveResponse( - success_count=success_count, - failed_count=failed_count, - errors=errors + success_count=success_count, failed_count=failed_count, errors=errors ) @@ -382,13 +349,14 @@ def bulk_remove_team_members( # Role Management Routes # ============================================================================ + @router.get("/roles", response_model=RoleListResponse) def list_roles( - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_permission( - VendorPermissions.TEAM_VIEW.value - )) + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends( + require_vendor_permission(VendorPermissions.TEAM_VIEW.value) + ), ): """ Get all available roles for the vendor. @@ -403,21 +371,19 @@ def list_roles( roles = vendor_team_service.get_vendor_roles(db=db, vendor_id=vendor.id) - return RoleListResponse( - roles=roles, - total=len(roles) - ) + return RoleListResponse(roles=roles, total=len(roles)) # ============================================================================ # Permission Routes # ============================================================================ + @router.get("/me/permissions", response_model=UserPermissionsResponse) def get_my_permissions( - request: Request, - permissions: List[str] = Depends(get_user_permissions), - current_user: User = Depends(get_current_vendor_api) + request: Request, + permissions: List[str] = Depends(get_user_permissions), + current_user: User = Depends(get_current_vendor_api), ): """ Get current user's permissions in this vendor. @@ -443,7 +409,7 @@ def get_my_permissions( permissions=permissions, permission_count=len(permissions), is_owner=is_owner, - role_name=role_name + role_name=role_name, ) @@ -451,13 +417,14 @@ def get_my_permissions( # Statistics Routes # ============================================================================ + @router.get("/statistics", response_model=TeamStatistics) def get_team_statistics( - request: Request, - db: Session = Depends(get_db), - current_user: User = Depends(require_vendor_permission( - VendorPermissions.TEAM_VIEW.value - )) + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends( + require_vendor_permission(VendorPermissions.TEAM_VIEW.value) + ), ): """ Get team statistics for the vendor. @@ -474,9 +441,7 @@ def get_team_statistics( vendor = request.state.vendor members = vendor_team_service.get_team_members( - db=db, - vendor=vendor, - include_inactive=True + db=db, vendor=vendor, include_inactive=True ) # Calculate statistics @@ -500,5 +465,5 @@ def get_team_statistics( pending_invitations=pending, owners=owners, team_members=team_members, - roles_breakdown=roles_breakdown + roles_breakdown=roles_breakdown, ) diff --git a/app/core/config.py b/app/core/config.py index 59a5793b..ed0b2641 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -14,6 +14,7 @@ This module focuses purely on configuration storage and validation. """ from typing import List, Optional + from pydantic_settings import BaseSettings @@ -137,13 +138,9 @@ settings = Settings() # ENVIRONMENT UTILITIES - Module-level functions # ============================================================================= # Import environment detection utilities -from app.core.environment import ( - get_environment, - is_development, - is_production, - is_staging, - should_use_secure_cookies -) +from app.core.environment import (get_environment, is_development, + is_production, is_staging, + should_use_secure_cookies) def get_current_environment() -> str: @@ -190,6 +187,7 @@ def is_staging_environment() -> bool: # VALIDATION FUNCTIONS # ============================================================================= + def validate_production_settings() -> List[str]: """ Validate settings for production environment. @@ -243,22 +241,19 @@ def print_environment_info(): # ============================================================================= __all__ = [ # Settings singleton - 'settings', - + "settings", # Environment detection (re-exported from app.core.environment) - 'get_environment', - 'is_development', - 'is_production', - 'is_staging', - 'should_use_secure_cookies', - + "get_environment", + "is_development", + "is_production", + "is_staging", + "should_use_secure_cookies", # Convenience functions - 'get_current_environment', - 'is_production_environment', - 'is_development_environment', - 'is_staging_environment', - + "get_current_environment", + "is_production_environment", + "is_development_environment", + "is_staging_environment", # Validation - 'validate_production_settings', - 'print_environment_info', + "validate_production_settings", + "print_environment_info", ] diff --git a/app/core/environment.py b/app/core/environment.py index e34e6fc4..0a430311 100644 --- a/app/core/environment.py +++ b/app/core/environment.py @@ -15,7 +15,7 @@ EnvironmentType = Literal["development", "staging", "production"] def get_environment() -> EnvironmentType: """ Detect current environment automatically. - + Detection logic: 1. Check ENV environment variable if set 2. Check ENVIRONMENT environment variable if set @@ -23,7 +23,7 @@ def get_environment() -> EnvironmentType: - localhost, 127.0.0.1 → development - Contains 'staging' → staging - Otherwise → production (safe default) - + Returns: str: 'development', 'staging', or 'production' """ @@ -35,7 +35,7 @@ def get_environment() -> EnvironmentType: return "staging" elif env in ["production", "prod"]: return "production" - + # Priority 2: ENVIRONMENT variable env = os.getenv("ENVIRONMENT", "").lower() if env in ["development", "dev", "local"]: @@ -44,22 +44,25 @@ def get_environment() -> EnvironmentType: return "staging" elif env in ["production", "prod"]: return "production" - + # Priority 3: Auto-detect from common indicators - + # Check if running in debug mode (common in development) if os.getenv("DEBUG", "").lower() in ["true", "1", "yes"]: return "development" - + # Check common development indicators hostname = os.getenv("HOSTNAME", "").lower() - if any(dev_indicator in hostname for dev_indicator in ["local", "dev", "laptop", "desktop"]): + if any( + dev_indicator in hostname + for dev_indicator in ["local", "dev", "laptop", "desktop"] + ): return "development" - + # Check for staging indicators if "staging" in hostname or "stage" in hostname: return "staging" - + # Default to development for safety (HTTPS not required in dev) # Change this to "production" if you prefer secure-by-default return "development" @@ -83,7 +86,7 @@ def is_production() -> bool: def should_use_secure_cookies() -> bool: """ Determine if cookies should have secure flag (HTTPS only). - + Returns: bool: True if production or staging, False if development """ @@ -97,7 +100,7 @@ _cached_environment: EnvironmentType | None = None def get_cached_environment() -> EnvironmentType: """ Get environment with caching. - + Environment is detected once and cached for performance. Useful if you call this frequently. """ diff --git a/app/core/lifespan.py b/app/core/lifespan.py index 6c0731c6..f26a8363 100644 --- a/app/core/lifespan.py +++ b/app/core/lifespan.py @@ -15,11 +15,13 @@ from fastapi import FastAPI from sqlalchemy import text from middleware.auth import AuthManager -# Remove this import if not needed: from models.database.base import Base from .database import SessionLocal, engine from .logging import setup_logging +# Remove this import if not needed: from models.database.base import Base + + logger = logging.getLogger(__name__) auth_manager = AuthManager() @@ -46,7 +48,9 @@ def check_database_ready(): try: with engine.connect() as conn: # Try to query a table that should exist - result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1")) + result = conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1") + ) tables = result.fetchall() return len(tables) > 0 except Exception: @@ -93,6 +97,7 @@ def verify_startup_requirements(): logger.info("[OK] Startup verification passed") return True + # You can call this in your main.py if desired: # if not verify_startup_requirements(): # raise RuntimeError("Application startup requirements not met") diff --git a/app/core/permissions.py b/app/core/permissions.py index 5651abf9..2a04968c 100644 --- a/app/core/permissions.py +++ b/app/core/permissions.py @@ -17,6 +17,7 @@ class VendorPermissions(str, Enum): Naming convention: RESOURCE_ACTION """ + # Dashboard DASHBOARD_VIEW = "dashboard.view" @@ -166,17 +167,23 @@ class PermissionChecker: return required_permission in permissions @staticmethod - def has_any_permission(permissions: List[str], required_permissions: List[str]) -> bool: + def has_any_permission( + permissions: List[str], required_permissions: List[str] + ) -> bool: """Check if a permission list contains ANY of the required permissions.""" return any(perm in permissions for perm in required_permissions) @staticmethod - def has_all_permissions(permissions: List[str], required_permissions: List[str]) -> bool: + def has_all_permissions( + permissions: List[str], required_permissions: List[str] + ) -> bool: """Check if a permission list contains ALL of the required permissions.""" return all(perm in permissions for perm in required_permissions) @staticmethod - def get_missing_permissions(permissions: List[str], required_permissions: List[str]) -> List[str]: + def get_missing_permissions( + permissions: List[str], required_permissions: List[str] + ) -> List[str]: """Get list of missing permissions.""" return [perm for perm in required_permissions if perm not in permissions] diff --git a/app/core/theme_presets.py b/app/core/theme_presets.py index 7b091ad3..20eea4a4 100644 --- a/app/core/theme_presets.py +++ b/app/core/theme_presets.py @@ -16,19 +16,11 @@ THEME_PRESETS = { "accent": "#ec4899", # Pink "background": "#ffffff", # White "text": "#1f2937", # Gray-800 - "border": "#e5e7eb" # Gray-200 + "border": "#e5e7eb", # Gray-200 }, - "fonts": { - "heading": "Inter, sans-serif", - "body": "Inter, sans-serif" - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "modern" - } + "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"}, + "layout": {"style": "grid", "header": "fixed", "product_card": "modern"}, }, - "modern": { "colors": { "primary": "#6366f1", # Indigo - Modern tech look @@ -36,19 +28,11 @@ THEME_PRESETS = { "accent": "#ec4899", # Pink "background": "#ffffff", # White "text": "#1f2937", # Gray-800 - "border": "#e5e7eb" # Gray-200 + "border": "#e5e7eb", # Gray-200 }, - "fonts": { - "heading": "Inter, sans-serif", - "body": "Inter, sans-serif" - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "modern" - } + "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"}, + "layout": {"style": "grid", "header": "fixed", "product_card": "modern"}, }, - "classic": { "colors": { "primary": "#1e40af", # Dark blue - Traditional @@ -56,19 +40,11 @@ THEME_PRESETS = { "accent": "#dc2626", # Red "background": "#ffffff", # White "text": "#1f2937", # Gray-800 - "border": "#d1d5db" # Gray-300 + "border": "#d1d5db", # Gray-300 }, - "fonts": { - "heading": "Georgia, serif", - "body": "Arial, sans-serif" - }, - "layout": { - "style": "list", - "header": "static", - "product_card": "classic" - } + "fonts": {"heading": "Georgia, serif", "body": "Arial, sans-serif"}, + "layout": {"style": "list", "header": "static", "product_card": "classic"}, }, - "minimal": { "colors": { "primary": "#000000", # Black - Ultra minimal @@ -76,19 +52,11 @@ THEME_PRESETS = { "accent": "#666666", # Medium gray "background": "#ffffff", # White "text": "#000000", # Black - "border": "#e5e7eb" # Light gray + "border": "#e5e7eb", # Light gray }, - "fonts": { - "heading": "Helvetica, sans-serif", - "body": "Helvetica, sans-serif" - }, - "layout": { - "style": "grid", - "header": "transparent", - "product_card": "minimal" - } + "fonts": {"heading": "Helvetica, sans-serif", "body": "Helvetica, sans-serif"}, + "layout": {"style": "grid", "header": "transparent", "product_card": "minimal"}, }, - "vibrant": { "colors": { "primary": "#f59e0b", # Orange - Bold & energetic @@ -96,19 +64,11 @@ THEME_PRESETS = { "accent": "#8b5cf6", # Purple "background": "#ffffff", # White "text": "#1f2937", # Gray-800 - "border": "#fbbf24" # Yellow + "border": "#fbbf24", # Yellow }, - "fonts": { - "heading": "Poppins, sans-serif", - "body": "Open Sans, sans-serif" - }, - "layout": { - "style": "masonry", - "header": "fixed", - "product_card": "modern" - } + "fonts": {"heading": "Poppins, sans-serif", "body": "Open Sans, sans-serif"}, + "layout": {"style": "masonry", "header": "fixed", "product_card": "modern"}, }, - "elegant": { "colors": { "primary": "#6b7280", # Gray - Sophisticated @@ -116,19 +76,11 @@ THEME_PRESETS = { "accent": "#d97706", # Amber "background": "#ffffff", # White "text": "#1f2937", # Gray-800 - "border": "#e5e7eb" # Gray-200 + "border": "#e5e7eb", # Gray-200 }, - "fonts": { - "heading": "Playfair Display, serif", - "body": "Lato, sans-serif" - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "classic" - } + "fonts": {"heading": "Playfair Display, serif", "body": "Lato, sans-serif"}, + "layout": {"style": "grid", "header": "fixed", "product_card": "classic"}, }, - "nature": { "colors": { "primary": "#059669", # Green - Natural & eco @@ -136,18 +88,11 @@ THEME_PRESETS = { "accent": "#f59e0b", # Amber "background": "#ffffff", # White "text": "#1f2937", # Gray-800 - "border": "#d1fae5" # Light green + "border": "#d1fae5", # Light green }, - "fonts": { - "heading": "Montserrat, sans-serif", - "body": "Open Sans, sans-serif" - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "modern" - } - } + "fonts": {"heading": "Montserrat, sans-serif", "body": "Open Sans, sans-serif"}, + "layout": {"style": "grid", "header": "fixed", "product_card": "modern"}, + }, } @@ -243,7 +188,7 @@ def get_preset_preview(preset_name: str) -> dict: "minimal": "Ultra-clean black and white aesthetic", "vibrant": "Bold and energetic with bright accent colors", "elegant": "Sophisticated gray tones with refined typography", - "nature": "Fresh and eco-friendly green color palette" + "nature": "Fresh and eco-friendly green color palette", } return { @@ -259,10 +204,7 @@ def get_preset_preview(preset_name: str) -> dict: def create_custom_preset( - colors: dict, - fonts: dict, - layout: dict, - name: str = "custom" + colors: dict, fonts: dict, layout: dict, name: str = "custom" ) -> dict: """ Create a custom preset from provided settings. @@ -304,8 +246,4 @@ def create_custom_preset( if "product_card" not in layout: layout["product_card"] = "modern" - return { - "colors": colors, - "fonts": fonts, - "layout": layout - } + return {"colors": colors, "fonts": fonts, "layout": layout} diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py index 19f5fd2c..fa7f17f1 100644 --- a/app/exceptions/__init__.py +++ b/app/exceptions/__init__.py @@ -6,179 +6,109 @@ This module provides frontend-friendly exceptions with consistent error codes, messages, and HTTP status mappings. """ -# Base exceptions -from .base import ( - WizamartException, - ValidationException, - AuthenticationException, - AuthorizationException, - ResourceNotFoundException, - ConflictException, - BusinessLogicException, - ExternalServiceException, - RateLimitException, - ServiceUnavailableException, -) - -# Authentication exceptions -from .auth import ( - InvalidCredentialsException, - TokenExpiredException, - InvalidTokenException, - InsufficientPermissionsException, - UserNotActiveException, - AdminRequiredException, - UserAlreadyExistsException -) - # Admin exceptions -from .admin import ( - UserNotFoundException, - UserStatusChangeException, - VendorVerificationException, - AdminOperationException, - CannotModifyAdminException, - CannotModifySelfException, - InvalidAdminActionException, - BulkOperationException, - ConfirmationRequiredException, -) - -# Marketplace import job exceptions -from .marketplace_import_job import ( - MarketplaceImportException, - ImportJobNotFoundException, - ImportJobNotOwnedException, - InvalidImportDataException, - ImportJobCannotBeCancelledException, - ImportJobCannotBeDeletedException, - MarketplaceConnectionException, - MarketplaceDataParsingException, - ImportRateLimitException, - InvalidMarketplaceException, - ImportJobAlreadyProcessingException, -) - -# Marketplace product exceptions -from .marketplace_product import ( - MarketplaceProductNotFoundException, - MarketplaceProductAlreadyExistsException, - InvalidMarketplaceProductDataException, - MarketplaceProductValidationException, - InvalidGTINException, - MarketplaceProductCSVImportException, -) - -# Inventory exceptions -from .inventory import ( - InventoryNotFoundException, - InsufficientInventoryException, - InvalidInventoryOperationException, - InventoryValidationException, - NegativeInventoryException, - InvalidQuantityException, - LocationNotFoundException -) - -# Vendor exceptions -from .vendor import ( - VendorNotFoundException, - VendorAlreadyExistsException, - VendorNotActiveException, - VendorNotVerifiedException, - UnauthorizedVendorAccessException, - InvalidVendorDataException, - MaxVendorsReachedException, - VendorValidationException, -) - -# Vendor domain exceptions -from .vendor_domain import ( - VendorDomainNotFoundException, - VendorDomainAlreadyExistsException, - InvalidDomainFormatException, - ReservedDomainException, - DomainNotVerifiedException, - DomainVerificationFailedException, - DomainAlreadyVerifiedException, - MultiplePrimaryDomainsException, - DNSVerificationException, - MaxDomainsReachedException, - UnauthorizedDomainAccessException, -) - -# Vendor theme exceptions -from .vendor_theme import ( - VendorThemeNotFoundException, - InvalidThemeDataException, - ThemePresetNotFoundException, - ThemeValidationException, - ThemePresetAlreadyAppliedException, - InvalidColorFormatException, - InvalidFontFamilyException, - ThemeOperationException, -) - -# Customer exceptions -from .customer import ( - CustomerNotFoundException, - CustomerAlreadyExistsException, - DuplicateCustomerEmailException, - CustomerNotActiveException, - InvalidCustomerCredentialsException, - CustomerValidationException, - CustomerAuthorizationException, -) - -# Team exceptions -from .team import ( - TeamMemberNotFoundException, - TeamMemberAlreadyExistsException, - TeamInvitationNotFoundException, - TeamInvitationExpiredException, - TeamInvitationAlreadyAcceptedException, - UnauthorizedTeamActionException, - CannotRemoveOwnerException, - CannotModifyOwnRoleException, - RoleNotFoundException, - InvalidRoleException, - InsufficientTeamPermissionsException, - MaxTeamMembersReachedException, - TeamValidationException, - InvalidInvitationDataException, - InvalidInvitationTokenException, -) - -# Product exceptions -from .product import ( - ProductNotFoundException, - ProductAlreadyExistsException, - ProductNotInCatalogException, - ProductNotActiveException, - InvalidProductDataException, - ProductValidationException, - CannotDeleteProductWithInventoryException, - CannotDeleteProductWithOrdersException, -) - -# Order exceptions -from .order import ( - OrderNotFoundException, - OrderAlreadyExistsException, - OrderValidationException, - InvalidOrderStatusException, - OrderCannotBeCancelledException, -) - +from .admin import (AdminOperationException, BulkOperationException, + CannotModifyAdminException, CannotModifySelfException, + ConfirmationRequiredException, InvalidAdminActionException, + UserNotFoundException, UserStatusChangeException, + VendorVerificationException) +# Authentication exceptions +from .auth import (AdminRequiredException, InsufficientPermissionsException, + InvalidCredentialsException, InvalidTokenException, + TokenExpiredException, UserAlreadyExistsException, + UserNotActiveException) +# Base exceptions +from .base import (AuthenticationException, AuthorizationException, + BusinessLogicException, ConflictException, + ExternalServiceException, RateLimitException, + ResourceNotFoundException, ServiceUnavailableException, + ValidationException, WizamartException) # Cart exceptions -from .cart import ( - CartItemNotFoundException, - EmptyCartException, - CartValidationException, - InsufficientInventoryForCartException, - InvalidCartQuantityException, - ProductNotAvailableForCartException, -) +from .cart import (CartItemNotFoundException, CartValidationException, + EmptyCartException, InsufficientInventoryForCartException, + InvalidCartQuantityException, + ProductNotAvailableForCartException) +# Customer exceptions +from .customer import (CustomerAlreadyExistsException, + CustomerAuthorizationException, + CustomerNotActiveException, CustomerNotFoundException, + CustomerValidationException, + DuplicateCustomerEmailException, + InvalidCustomerCredentialsException) +# Inventory exceptions +from .inventory import (InsufficientInventoryException, + InvalidInventoryOperationException, + InvalidQuantityException, InventoryNotFoundException, + InventoryValidationException, + LocationNotFoundException, NegativeInventoryException) +# Marketplace import job exceptions +from .marketplace_import_job import (ImportJobAlreadyProcessingException, + ImportJobCannotBeCancelledException, + ImportJobCannotBeDeletedException, + ImportJobNotFoundException, + ImportJobNotOwnedException, + ImportRateLimitException, + InvalidImportDataException, + InvalidMarketplaceException, + MarketplaceConnectionException, + MarketplaceDataParsingException, + MarketplaceImportException) +# Marketplace product exceptions +from .marketplace_product import (InvalidGTINException, + InvalidMarketplaceProductDataException, + MarketplaceProductAlreadyExistsException, + MarketplaceProductCSVImportException, + MarketplaceProductNotFoundException, + MarketplaceProductValidationException) +# Order exceptions +from .order import (InvalidOrderStatusException, OrderAlreadyExistsException, + OrderCannotBeCancelledException, OrderNotFoundException, + OrderValidationException) +# Product exceptions +from .product import (CannotDeleteProductWithInventoryException, + CannotDeleteProductWithOrdersException, + InvalidProductDataException, + ProductAlreadyExistsException, ProductNotActiveException, + ProductNotFoundException, ProductNotInCatalogException, + ProductValidationException) +# Team exceptions +from .team import (CannotModifyOwnRoleException, CannotRemoveOwnerException, + InsufficientTeamPermissionsException, + InvalidInvitationDataException, + InvalidInvitationTokenException, InvalidRoleException, + MaxTeamMembersReachedException, RoleNotFoundException, + TeamInvitationAlreadyAcceptedException, + TeamInvitationExpiredException, + TeamInvitationNotFoundException, + TeamMemberAlreadyExistsException, + TeamMemberNotFoundException, TeamValidationException, + UnauthorizedTeamActionException) +# Vendor exceptions +from .vendor import (InvalidVendorDataException, MaxVendorsReachedException, + UnauthorizedVendorAccessException, + VendorAlreadyExistsException, VendorNotActiveException, + VendorNotFoundException, VendorNotVerifiedException, + VendorValidationException) +# Vendor domain exceptions +from .vendor_domain import (DNSVerificationException, + DomainAlreadyVerifiedException, + DomainNotVerifiedException, + DomainVerificationFailedException, + InvalidDomainFormatException, + MaxDomainsReachedException, + MultiplePrimaryDomainsException, + ReservedDomainException, + UnauthorizedDomainAccessException, + VendorDomainAlreadyExistsException, + VendorDomainNotFoundException) +# Vendor theme exceptions +from .vendor_theme import (InvalidColorFormatException, + InvalidFontFamilyException, + InvalidThemeDataException, ThemeOperationException, + ThemePresetAlreadyAppliedException, + ThemePresetNotFoundException, + ThemeValidationException, + VendorThemeNotFoundException) __all__ = [ # Base exceptions @@ -192,7 +122,6 @@ __all__ = [ "ExternalServiceException", "RateLimitException", "ServiceUnavailableException", - # Auth exceptions "InvalidCredentialsException", "TokenExpiredException", @@ -201,7 +130,6 @@ __all__ = [ "UserNotActiveException", "AdminRequiredException", "UserAlreadyExistsException", - # Customer exceptions "CustomerNotFoundException", "CustomerAlreadyExistsException", @@ -210,7 +138,6 @@ __all__ = [ "InvalidCustomerCredentialsException", "CustomerValidationException", "CustomerAuthorizationException", - # Team exceptions "TeamMemberNotFoundException", "TeamMemberAlreadyExistsException", @@ -227,7 +154,6 @@ __all__ = [ "TeamValidationException", "InvalidInvitationDataException", "InvalidInvitationTokenException", - # Inventory exceptions "InventoryNotFoundException", "InsufficientInventoryException", @@ -236,7 +162,6 @@ __all__ = [ "NegativeInventoryException", "InvalidQuantityException", "LocationNotFoundException", - # Vendor exceptions "VendorNotFoundException", "VendorAlreadyExistsException", @@ -246,7 +171,6 @@ __all__ = [ "InvalidVendorDataException", "MaxVendorsReachedException", "VendorValidationException", - # Vendor Domain "VendorDomainNotFoundException", "VendorDomainAlreadyExistsException", @@ -259,7 +183,6 @@ __all__ = [ "DNSVerificationException", "MaxDomainsReachedException", "UnauthorizedDomainAccessException", - # Vendor Theme "VendorThemeNotFoundException", "InvalidThemeDataException", @@ -269,7 +192,6 @@ __all__ = [ "InvalidColorFormatException", "InvalidFontFamilyException", "ThemeOperationException", - # Product exceptions "ProductNotFoundException", "ProductAlreadyExistsException", @@ -279,14 +201,12 @@ __all__ = [ "ProductValidationException", "CannotDeleteProductWithInventoryException", "CannotDeleteProductWithOrdersException", - # Order exceptions "OrderNotFoundException", "OrderAlreadyExistsException", "OrderValidationException", "InvalidOrderStatusException", "OrderCannotBeCancelledException", - # Cart exceptions "CartItemNotFoundException", "EmptyCartException", @@ -294,7 +214,6 @@ __all__ = [ "InsufficientInventoryForCartException", "InvalidCartQuantityException", "ProductNotAvailableForCartException", - # MarketplaceProduct exceptions "MarketplaceProductNotFoundException", "MarketplaceProductAlreadyExistsException", @@ -302,7 +221,6 @@ __all__ = [ "MarketplaceProductValidationException", "InvalidGTINException", "MarketplaceProductCSVImportException", - # Marketplace import exceptions "MarketplaceImportException", "ImportJobNotFoundException", @@ -315,7 +233,6 @@ __all__ = [ "ImportRateLimitException", "InvalidMarketplaceException", "ImportJobAlreadyProcessingException", - # Admin exceptions "UserNotFoundException", "UserStatusChangeException", diff --git a/app/exceptions/admin.py b/app/exceptions/admin.py index 8df4d62e..c98abdb0 100644 --- a/app/exceptions/admin.py +++ b/app/exceptions/admin.py @@ -4,12 +4,9 @@ Admin operations specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - BusinessLogicException, - AuthorizationException, - ValidationException -) + +from .base import (AuthorizationException, BusinessLogicException, + ResourceNotFoundException, ValidationException) class UserNotFoundException(ResourceNotFoundException): @@ -35,11 +32,11 @@ class UserStatusChangeException(BusinessLogicException): """Raised when user status cannot be changed.""" def __init__( - self, - user_id: int, - current_status: str, - attempted_action: str, - reason: Optional[str] = None, + self, + user_id: int, + current_status: str, + attempted_action: str, + reason: Optional[str] = None, ): message = f"Cannot {attempted_action} user {user_id} (current status: {current_status})" if reason: @@ -61,10 +58,10 @@ class ShopVerificationException(BusinessLogicException): """Raised when shop verification fails.""" def __init__( - self, - shop_id: int, - reason: str, - current_verification_status: Optional[bool] = None, + self, + shop_id: int, + reason: str, + current_verification_status: Optional[bool] = None, ): details = { "shop_id": shop_id, @@ -85,11 +82,11 @@ class AdminOperationException(BusinessLogicException): """Raised when admin operation fails.""" def __init__( - self, - operation: str, - reason: str, - target_type: Optional[str] = None, - target_id: Optional[str] = None, + self, + operation: str, + reason: str, + target_type: Optional[str] = None, + target_id: Optional[str] = None, ): message = f"Admin operation '{operation}' failed: {reason}" @@ -142,10 +139,10 @@ class InvalidAdminActionException(ValidationException): """Raised when admin action is invalid.""" def __init__( - self, - action: str, - reason: str, - valid_actions: Optional[list] = None, + self, + action: str, + reason: str, + valid_actions: Optional[list] = None, ): details = { "action": action, @@ -166,11 +163,11 @@ class BulkOperationException(BusinessLogicException): """Raised when bulk admin operation fails.""" def __init__( - self, - operation: str, - total_items: int, - failed_items: int, - errors: Optional[Dict[str, Any]] = None, + self, + operation: str, + total_items: int, + failed_items: int, + errors: Optional[Dict[str, Any]] = None, ): message = f"Bulk {operation} completed with errors: {failed_items}/{total_items} failed" @@ -195,10 +192,10 @@ class ConfirmationRequiredException(BusinessLogicException): """Raised when a destructive operation requires explicit confirmation.""" def __init__( - self, - operation: str, - message: Optional[str] = None, - confirmation_param: str = "confirm" + self, + operation: str, + message: Optional[str] = None, + confirmation_param: str = "confirm", ): if not message: message = f"Operation '{operation}' requires confirmation parameter: {confirmation_param}=true" @@ -217,10 +214,10 @@ class VendorVerificationException(BusinessLogicException): """Raised when vendor verification fails.""" def __init__( - self, - vendor_id: int, - reason: str, - current_verification_status: Optional[bool] = None, + self, + vendor_id: int, + reason: str, + current_verification_status: Optional[bool] = None, ): details = { "vendor_id": vendor_id, diff --git a/app/exceptions/auth.py b/app/exceptions/auth.py index d78bf32b..48c96c15 100644 --- a/app/exceptions/auth.py +++ b/app/exceptions/auth.py @@ -4,7 +4,9 @@ Authentication and authorization specific exceptions. """ from typing import Optional -from .base import AuthenticationException, AuthorizationException, ConflictException + +from .base import (AuthenticationException, AuthorizationException, + ConflictException) class InvalidCredentialsException(AuthenticationException): @@ -41,9 +43,9 @@ class InsufficientPermissionsException(AuthorizationException): """Raised when user lacks required permissions for an action.""" def __init__( - self, - message: str = "Insufficient permissions for this action", - required_permission: Optional[str] = None, + self, + message: str = "Insufficient permissions for this action", + required_permission: Optional[str] = None, ): details = {} if required_permission: @@ -80,9 +82,9 @@ class UserAlreadyExistsException(ConflictException): """Raised when trying to register with existing username/email.""" def __init__( - self, - message: str = "User already exists", - field: Optional[str] = None, + self, + message: str = "User already exists", + field: Optional[str] = None, ): details = {} if field: diff --git a/app/exceptions/backup.py b/app/exceptions/backup.py index f8e0394d..5d0229cb 100644 --- a/app/exceptions/backup.py +++ b/app/exceptions/backup.py @@ -1 +1 @@ -# Backup/recovery exceptions +# Backup/recovery exceptions diff --git a/app/exceptions/base.py b/app/exceptions/base.py index 79fc589d..5cfe05a3 100644 --- a/app/exceptions/base.py +++ b/app/exceptions/base.py @@ -39,8 +39,6 @@ class WizamartException(Exception): return result - - class ValidationException(WizamartException): """Raised when request validation fails.""" @@ -62,8 +60,6 @@ class ValidationException(WizamartException): ) - - class AuthenticationException(WizamartException): """Raised when authentication fails.""" @@ -97,6 +93,7 @@ class AuthorizationException(WizamartException): details=details, ) + class ResourceNotFoundException(WizamartException): """Raised when a requested resource is not found.""" @@ -122,6 +119,7 @@ class ResourceNotFoundException(WizamartException): }, ) + class ConflictException(WizamartException): """Raised when a resource conflict occurs.""" @@ -138,6 +136,7 @@ class ConflictException(WizamartException): details=details, ) + class BusinessLogicException(WizamartException): """Raised when business logic rules are violated.""" @@ -196,6 +195,7 @@ class RateLimitException(WizamartException): details=rate_limit_details, ) + class ServiceUnavailableException(WizamartException): """Raised when service is unavailable.""" @@ -206,6 +206,7 @@ class ServiceUnavailableException(WizamartException): status_code=503, ) + # Note: Domain-specific exceptions like VendorNotFoundException, UserNotFoundException, etc. # are defined in their respective domain modules (vendor.py, admin.py, etc.) # to keep domain-specific logic separate from base exceptions. diff --git a/app/exceptions/cart.py b/app/exceptions/cart.py index 1a56e527..6a980730 100644 --- a/app/exceptions/cart.py +++ b/app/exceptions/cart.py @@ -4,11 +4,9 @@ Shopping cart specific exceptions. """ from typing import Optional -from .base import ( - ResourceNotFoundException, - ValidationException, - BusinessLogicException -) + +from .base import (BusinessLogicException, ResourceNotFoundException, + ValidationException) class CartItemNotFoundException(ResourceNotFoundException): @@ -19,22 +17,16 @@ class CartItemNotFoundException(ResourceNotFoundException): resource_type="CartItem", identifier=str(product_id), message=f"Product {product_id} not found in cart", - error_code="CART_ITEM_NOT_FOUND" + error_code="CART_ITEM_NOT_FOUND", ) - self.details.update({ - "product_id": product_id, - "session_id": session_id - }) + self.details.update({"product_id": product_id, "session_id": session_id}) class EmptyCartException(ValidationException): """Raised when trying to perform operations on an empty cart.""" def __init__(self, session_id: str): - super().__init__( - message="Cart is empty", - details={"session_id": session_id} - ) + super().__init__(message="Cart is empty", details={"session_id": session_id}) self.error_code = "CART_EMPTY" @@ -82,7 +74,9 @@ class InsufficientInventoryForCartException(BusinessLogicException): class InvalidCartQuantityException(ValidationException): """Raised when cart quantity is invalid.""" - def __init__(self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None): + def __init__( + self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None + ): if quantity < min_quantity: message = f"Quantity must be at least {min_quantity}" elif max_quantity and quantity > max_quantity: diff --git a/app/exceptions/customer.py b/app/exceptions/customer.py index ca3d68b0..a9f50ace 100644 --- a/app/exceptions/customer.py +++ b/app/exceptions/customer.py @@ -4,13 +4,10 @@ Customer management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - ConflictException, - ValidationException, - AuthenticationException, - BusinessLogicException -) + +from .base import (AuthenticationException, BusinessLogicException, + ConflictException, ResourceNotFoundException, + ValidationException) class CustomerNotFoundException(ResourceNotFoundException): @@ -21,7 +18,7 @@ class CustomerNotFoundException(ResourceNotFoundException): resource_type="Customer", identifier=customer_identifier, message=f"Customer '{customer_identifier}' not found", - error_code="CUSTOMER_NOT_FOUND" + error_code="CUSTOMER_NOT_FOUND", ) @@ -32,7 +29,7 @@ class CustomerAlreadyExistsException(ConflictException): super().__init__( message=f"Customer with email '{email}' already exists", error_code="CUSTOMER_ALREADY_EXISTS", - details={"email": email} + details={"email": email}, ) @@ -43,10 +40,7 @@ class DuplicateCustomerEmailException(ConflictException): super().__init__( message=f"Email '{email}' is already registered for this vendor", error_code="DUPLICATE_CUSTOMER_EMAIL", - details={ - "email": email, - "vendor_code": vendor_code - } + details={"email": email, "vendor_code": vendor_code}, ) @@ -57,7 +51,7 @@ class CustomerNotActiveException(BusinessLogicException): super().__init__( message=f"Customer account '{email}' is not active", error_code="CUSTOMER_NOT_ACTIVE", - details={"email": email} + details={"email": email}, ) @@ -67,7 +61,7 @@ class InvalidCustomerCredentialsException(AuthenticationException): def __init__(self): super().__init__( message="Invalid email or password", - error_code="INVALID_CUSTOMER_CREDENTIALS" + error_code="INVALID_CUSTOMER_CREDENTIALS", ) @@ -78,13 +72,9 @@ class CustomerValidationException(ValidationException): self, message: str = "Customer validation failed", field: Optional[str] = None, - details: Optional[Dict[str, Any]] = None + details: Optional[Dict[str, Any]] = None, ): - super().__init__( - message=message, - field=field, - details=details - ) + super().__init__(message=message, field=field, details=details) self.error_code = "CUSTOMER_VALIDATION_FAILED" @@ -95,8 +85,5 @@ class CustomerAuthorizationException(BusinessLogicException): super().__init__( message=f"Customer '{customer_email}' not authorized for: {operation}", error_code="CUSTOMER_NOT_AUTHORIZED", - details={ - "customer_email": customer_email, - "operation": operation - } + details={"customer_email": customer_email, "operation": operation}, ) diff --git a/app/exceptions/error_renderer.py b/app/exceptions/error_renderer.py index ac574cd4..cc804722 100644 --- a/app/exceptions/error_renderer.py +++ b/app/exceptions/error_renderer.py @@ -7,7 +7,7 @@ Handles fallback logic and context-specific customization. """ import logging from pathlib import Path -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from fastapi import Request from fastapi.responses import HTMLResponse @@ -114,7 +114,7 @@ class ErrorPageRenderer: "error_code": error_code, "context": context_type.value, "template": template_path, - } + }, ) try: @@ -129,8 +129,7 @@ class ErrorPageRenderer: ) except Exception as e: logger.error( - f"Failed to render error template {template_path}: {e}", - exc_info=True + f"Failed to render error template {template_path}: {e}", exc_info=True ) # Return basic HTML as absolute fallback return ErrorPageRenderer._render_basic_html_fallback( @@ -228,7 +227,9 @@ class ErrorPageRenderer: } @staticmethod - def _get_context_data(request: Request, context_type: RequestContext) -> Dict[str, Any]: + def _get_context_data( + request: Request, context_type: RequestContext + ) -> Dict[str, Any]: """Get context-specific data for error templates.""" data = {} @@ -261,11 +262,19 @@ class ErrorPageRenderer: # Calculate base_url for shop links vendor_context = getattr(request.state, "vendor_context", None) - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) base_url = "/" if access_method == "path" and vendor: # Use the full_prefix from vendor_context to determine which pattern was used - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) base_url = f"{full_prefix}{vendor.subdomain}/" data["base_url"] = base_url diff --git a/app/exceptions/handler.py b/app/exceptions/handler.py index 2ea2690e..c1b38edf 100644 --- a/app/exceptions/handler.py +++ b/app/exceptions/handler.py @@ -13,13 +13,14 @@ This module provides classes and functions for: import logging from typing import Union -from fastapi import Request, HTTPException +from fastapi import HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, RedirectResponse +from middleware.context import RequestContext, get_request_context + from .base import WizamartException from .error_renderer import ErrorPageRenderer -from middleware.context import RequestContext, get_request_context logger = logging.getLogger(__name__) @@ -38,8 +39,8 @@ def setup_exception_handlers(app): extra={ "path": request.url.path, "accept": request.headers.get("accept", ""), - "method": request.method - } + "method": request.method, + }, ) # Redirect to appropriate login page based on context @@ -56,15 +57,12 @@ def setup_exception_handlers(app): "url": str(request.url), "method": request.method, "exception_type": type(exc).__name__, - } + }, ) # Check if this is an API request if _is_api_request(request): - return JSONResponse( - status_code=exc.status_code, - content=exc.to_dict() - ) + return JSONResponse(status_code=exc.status_code, content=exc.to_dict()) # Check if this is an HTML page request if _is_html_page_request(request): @@ -78,10 +76,7 @@ def setup_exception_handlers(app): ) # Default to JSON for unknown request types - return JSONResponse( - status_code=exc.status_code, - content=exc.to_dict() - ) + return JSONResponse(status_code=exc.status_code, content=exc.to_dict()) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): @@ -96,7 +91,7 @@ def setup_exception_handlers(app): "url": str(request.url), "method": request.method, "exception_type": "HTTPException", - } + }, ) # Check if this is an API request @@ -107,7 +102,7 @@ def setup_exception_handlers(app): "error_code": f"HTTP_{exc.status_code}", "message": exc.detail, "status_code": exc.status_code, - } + }, ) # Check if this is an HTML page request @@ -128,11 +123,13 @@ def setup_exception_handlers(app): "error_code": f"HTTP_{exc.status_code}", "message": exc.detail, "status_code": exc.status_code, - } + }, ) @app.exception_handler(RequestValidationError) - async def validation_exception_handler(request: Request, exc: RequestValidationError): + async def validation_exception_handler( + request: Request, exc: RequestValidationError + ): """Handle Pydantic validation errors with consistent format.""" # Sanitize errors to remove sensitive data from logs @@ -140,8 +137,8 @@ def setup_exception_handlers(app): for error in exc.errors(): sanitized_error = error.copy() # Remove 'input' field which may contain passwords - if 'input' in sanitized_error: - sanitized_error['input'] = '' + if "input" in sanitized_error: + sanitized_error["input"] = "" sanitized_errors.append(sanitized_error) logger.error( @@ -151,7 +148,7 @@ def setup_exception_handlers(app): "url": str(request.url), "method": request.method, "exception_type": "RequestValidationError", - } + }, ) # Clean up validation errors to ensure JSON serializability @@ -159,15 +156,17 @@ def setup_exception_handlers(app): for error in exc.errors(): clean_error = {} for key, value in error.items(): - if key == 'input' and isinstance(value, bytes): + if key == "input" and isinstance(value, bytes): # Convert bytes to string representation for JSON serialization clean_error[key] = f"" - elif key == 'ctx' and isinstance(value, dict): + elif key == "ctx" and isinstance(value, dict): # Handle the 'ctx' field that contains ValueError objects clean_ctx = {} for ctx_key, ctx_value in value.items(): if isinstance(ctx_value, Exception): - clean_ctx[ctx_key] = str(ctx_value) # Convert exception to string + clean_ctx[ctx_key] = str( + ctx_value + ) # Convert exception to string else: clean_ctx[ctx_key] = ctx_value clean_error[key] = clean_ctx @@ -186,10 +185,8 @@ def setup_exception_handlers(app): "error_code": "VALIDATION_ERROR", "message": "Request validation failed", "status_code": 422, - "details": { - "validation_errors": clean_errors - } - } + "details": {"validation_errors": clean_errors}, + }, ) # Check if this is an HTML page request @@ -210,10 +207,8 @@ def setup_exception_handlers(app): "error_code": "VALIDATION_ERROR", "message": "Request validation failed", "status_code": 422, - "details": { - "validation_errors": clean_errors - } - } + "details": {"validation_errors": clean_errors}, + }, ) @app.exception_handler(Exception) @@ -227,7 +222,7 @@ def setup_exception_handlers(app): "url": str(request.url), "method": request.method, "exception_type": type(exc).__name__, - } + }, ) # Check if this is an API request @@ -238,7 +233,7 @@ def setup_exception_handlers(app): "error_code": "INTERNAL_SERVER_ERROR", "message": "Internal server error", "status_code": 500, - } + }, ) # Check if this is an HTML page request @@ -259,7 +254,7 @@ def setup_exception_handlers(app): "error_code": "INTERNAL_SERVER_ERROR", "message": "Internal server error", "status_code": 500, - } + }, ) @app.exception_handler(404) @@ -275,11 +270,8 @@ def setup_exception_handlers(app): "error_code": "ENDPOINT_NOT_FOUND", "message": f"Endpoint not found: {request.url.path}", "status_code": 404, - "details": { - "path": request.url.path, - "method": request.method - } - } + "details": {"path": request.url.path, "method": request.method}, + }, ) # Check if this is an HTML page request @@ -300,11 +292,8 @@ def setup_exception_handlers(app): "error_code": "ENDPOINT_NOT_FOUND", "message": f"Endpoint not found: {request.url.path}", "status_code": 404, - "details": { - "path": request.url.path, - "method": request.method - } - } + "details": {"path": request.url.path, "method": request.method}, + }, ) @@ -332,8 +321,8 @@ def _is_html_page_request(request: Request) -> bool: extra={ "path": request.url.path, "method": request.method, - "accept": request.headers.get("accept", "") - } + "accept": request.headers.get("accept", ""), + }, ) # Don't redirect API calls @@ -354,7 +343,9 @@ def _is_html_page_request(request: Request) -> bool: # MUST explicitly accept HTML (strict check) accept_header = request.headers.get("accept", "") if "text/html" not in accept_header: - logger.debug(f"Not HTML page: Accept header doesn't include text/html: {accept_header}") + logger.debug( + f"Not HTML page: Accept header doesn't include text/html: {accept_header}" + ) return False logger.debug("IS HTML page request") @@ -379,13 +370,21 @@ def _redirect_to_login(request: Request) -> RedirectResponse: elif context_type == RequestContext.SHOP: # For shop context, redirect to shop login (customer login) # Calculate base_url for proper routing (supports domain, subdomain, and path-based access) - vendor = getattr(request.state, 'vendor', None) - vendor_context = getattr(request.state, 'vendor_context', None) - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + vendor = getattr(request.state, "vendor", None) + vendor_context = getattr(request.state, "vendor_context", None) + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) base_url = "/" if access_method == "path" and vendor: - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) base_url = f"{full_prefix}{vendor.subdomain}/" login_url = f"{base_url}shop/account/login" @@ -401,22 +400,28 @@ def _redirect_to_login(request: Request) -> RedirectResponse: def raise_not_found(resource_type: str, identifier: str) -> None: """Convenience function to raise ResourceNotFoundException.""" from .base import ResourceNotFoundException + raise ResourceNotFoundException(resource_type, identifier) -def raise_validation_error(message: str, field: str = None, details: dict = None) -> None: +def raise_validation_error( + message: str, field: str = None, details: dict = None +) -> None: """Convenience function to raise ValidationException.""" from .base import ValidationException + raise ValidationException(message, field, details) def raise_auth_error(message: str = "Authentication failed") -> None: """Convenience function to raise AuthenticationException.""" from .base import AuthenticationException + raise AuthenticationException(message) def raise_permission_error(message: str = "Access denied") -> None: """Convenience function to raise AuthorizationException.""" from .base import AuthorizationException + raise AuthorizationException(message) diff --git a/app/exceptions/inventory.py b/app/exceptions/inventory.py index 86a26427..c2c73fd7 100644 --- a/app/exceptions/inventory.py +++ b/app/exceptions/inventory.py @@ -4,7 +4,9 @@ Inventory management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ResourceNotFoundException, ValidationException, BusinessLogicException + +from .base import (BusinessLogicException, ResourceNotFoundException, + ValidationException) class InventoryNotFoundException(ResourceNotFoundException): @@ -14,7 +16,9 @@ class InventoryNotFoundException(ResourceNotFoundException): if identifier_type.lower() == "gtin": message = f"No inventory found for GTIN '{identifier}'" else: - message = f"Inventory record with {identifier_type} '{identifier}' not found" + message = ( + f"Inventory record with {identifier_type} '{identifier}' not found" + ) super().__init__( resource_type="Inventory", @@ -28,11 +32,11 @@ class InsufficientInventoryException(BusinessLogicException): """Raised when trying to remove more inventory than available.""" def __init__( - self, - gtin: str, - location: str, - requested: int, - available: int, + self, + gtin: str, + location: str, + requested: int, + available: int, ): message = f"Insufficient inventory for GTIN '{gtin}' at '{location}'. Requested: {requested}, Available: {available}" @@ -52,10 +56,10 @@ class InvalidInventoryOperationException(ValidationException): """Raised when inventory operation is invalid.""" def __init__( - self, - message: str, - operation: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str, + operation: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): if not details: details = {} @@ -74,10 +78,10 @@ class InventoryValidationException(ValidationException): """Raised when inventory data validation fails.""" def __init__( - self, - message: str = "Inventory validation failed", - field: Optional[str] = None, - validation_errors: Optional[Dict[str, str]] = None, + self, + message: str = "Inventory validation failed", + field: Optional[str] = None, + validation_errors: Optional[Dict[str, str]] = None, ): details = {} if validation_errors: diff --git a/app/exceptions/marketplace.py b/app/exceptions/marketplace.py index 1a92b6ab..28f96110 100644 --- a/app/exceptions/marketplace.py +++ b/app/exceptions/marketplace.py @@ -1 +1 @@ -# Import/marketplace exceptions +# Import/marketplace exceptions diff --git a/app/exceptions/marketplace_import_job.py b/app/exceptions/marketplace_import_job.py index c383579e..c442e1f9 100644 --- a/app/exceptions/marketplace_import_job.py +++ b/app/exceptions/marketplace_import_job.py @@ -4,24 +4,21 @@ Marketplace import specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - ValidationException, - BusinessLogicException, - AuthorizationException, - ExternalServiceException -) + +from .base import (AuthorizationException, BusinessLogicException, + ExternalServiceException, ResourceNotFoundException, + ValidationException) class MarketplaceImportException(BusinessLogicException): """Base exception for marketplace import operations.""" def __init__( - self, - message: str, - error_code: str = "MARKETPLACE_IMPORT_ERROR", - marketplace: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str, + error_code: str = "MARKETPLACE_IMPORT_ERROR", + marketplace: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): if not details: details = {} @@ -67,11 +64,11 @@ class InvalidImportDataException(ValidationException): """Raised when import data is invalid.""" def __init__( - self, - message: str = "Invalid import data", - field: Optional[str] = None, - row_number: Optional[int] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str = "Invalid import data", + field: Optional[str] = None, + row_number: Optional[int] = None, + details: Optional[Dict[str, Any]] = None, ): if not details: details = {} @@ -118,7 +115,9 @@ class ImportJobCannotBeDeletedException(BusinessLogicException): class MarketplaceConnectionException(ExternalServiceException): """Raised when marketplace connection fails.""" - def __init__(self, marketplace: str, message: str = "Failed to connect to marketplace"): + def __init__( + self, marketplace: str, message: str = "Failed to connect to marketplace" + ): super().__init__( service=marketplace, message=f"{message}: {marketplace}", @@ -130,10 +129,10 @@ class MarketplaceDataParsingException(ValidationException): """Raised when marketplace data cannot be parsed.""" def __init__( - self, - marketplace: str, - message: str = "Failed to parse marketplace data", - details: Optional[Dict[str, Any]] = None, + self, + marketplace: str, + message: str = "Failed to parse marketplace data", + details: Optional[Dict[str, Any]] = None, ): if not details: details = {} @@ -150,10 +149,10 @@ class ImportRateLimitException(BusinessLogicException): """Raised when import rate limit is exceeded.""" def __init__( - self, - max_imports: int, - time_window: str, - retry_after: Optional[int] = None, + self, + max_imports: int, + time_window: str, + retry_after: Optional[int] = None, ): details = { "max_imports": max_imports, diff --git a/app/exceptions/marketplace_product.py b/app/exceptions/marketplace_product.py index a153a717..8fbc3e71 100644 --- a/app/exceptions/marketplace_product.py +++ b/app/exceptions/marketplace_product.py @@ -4,7 +4,9 @@ MarketplaceProduct management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ResourceNotFoundException, ConflictException, ValidationException, BusinessLogicException + +from .base import (BusinessLogicException, ConflictException, + ResourceNotFoundException, ValidationException) class MarketplaceProductNotFoundException(ResourceNotFoundException): @@ -34,10 +36,10 @@ class InvalidMarketplaceProductDataException(ValidationException): """Raised when product data is invalid.""" def __init__( - self, - message: str = "Invalid product data", - field: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str = "Invalid product data", + field: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): super().__init__( message=message, @@ -51,10 +53,10 @@ class MarketplaceProductValidationException(ValidationException): """Raised when product validation fails.""" def __init__( - self, - message: str, - field: Optional[str] = None, - validation_errors: Optional[Dict[str, str]] = None, + self, + message: str, + field: Optional[str] = None, + validation_errors: Optional[Dict[str, str]] = None, ): details = {} if validation_errors: @@ -84,10 +86,10 @@ class MarketplaceProductCSVImportException(BusinessLogicException): """Raised when product CSV import fails.""" def __init__( - self, - message: str = "MarketplaceProduct CSV import failed", - row_number: Optional[int] = None, - errors: Optional[Dict[str, Any]] = None, + self, + message: str = "MarketplaceProduct CSV import failed", + row_number: Optional[int] = None, + errors: Optional[Dict[str, Any]] = None, ): details = {} if row_number: diff --git a/app/exceptions/media.py b/app/exceptions/media.py index a56df71b..d7663627 100644 --- a/app/exceptions/media.py +++ b/app/exceptions/media.py @@ -1 +1 @@ -# Media/file management exceptions +# Media/file management exceptions diff --git a/app/exceptions/monitoring.py b/app/exceptions/monitoring.py index 70f33c86..d16c3949 100644 --- a/app/exceptions/monitoring.py +++ b/app/exceptions/monitoring.py @@ -1 +1 @@ -# Monitoring exceptions +# Monitoring exceptions diff --git a/app/exceptions/notification.py b/app/exceptions/notification.py index c1772546..8dfb9341 100644 --- a/app/exceptions/notification.py +++ b/app/exceptions/notification.py @@ -1 +1 @@ -# Notification exceptions +# Notification exceptions diff --git a/app/exceptions/order.py b/app/exceptions/order.py index 3052f206..d5b74705 100644 --- a/app/exceptions/order.py +++ b/app/exceptions/order.py @@ -4,11 +4,9 @@ Order management specific exceptions. """ from typing import Optional -from .base import ( - ResourceNotFoundException, - ValidationException, - BusinessLogicException -) + +from .base import (BusinessLogicException, ResourceNotFoundException, + ValidationException) class OrderNotFoundException(ResourceNotFoundException): @@ -19,7 +17,7 @@ class OrderNotFoundException(ResourceNotFoundException): resource_type="Order", identifier=order_identifier, message=f"Order '{order_identifier}' not found", - error_code="ORDER_NOT_FOUND" + error_code="ORDER_NOT_FOUND", ) @@ -30,7 +28,7 @@ class OrderAlreadyExistsException(ValidationException): super().__init__( message=f"Order with number '{order_number}' already exists", error_code="ORDER_ALREADY_EXISTS", - details={"order_number": order_number} + details={"order_number": order_number}, ) @@ -39,9 +37,7 @@ class OrderValidationException(ValidationException): def __init__(self, message: str, details: Optional[dict] = None): super().__init__( - message=message, - error_code="ORDER_VALIDATION_FAILED", - details=details + message=message, error_code="ORDER_VALIDATION_FAILED", details=details ) @@ -52,10 +48,7 @@ class InvalidOrderStatusException(BusinessLogicException): super().__init__( message=f"Cannot change order status from '{current_status}' to '{new_status}'", error_code="INVALID_ORDER_STATUS_CHANGE", - details={ - "current_status": current_status, - "new_status": new_status - } + details={"current_status": current_status, "new_status": new_status}, ) @@ -66,8 +59,5 @@ class OrderCannotBeCancelledException(BusinessLogicException): super().__init__( message=f"Order '{order_number}' cannot be cancelled: {reason}", error_code="ORDER_CANNOT_BE_CANCELLED", - details={ - "order_number": order_number, - "reason": reason - } + details={"order_number": order_number, "reason": reason}, ) diff --git a/app/exceptions/payment.py b/app/exceptions/payment.py index c3fbe15b..2c7f0f89 100644 --- a/app/exceptions/payment.py +++ b/app/exceptions/payment.py @@ -1 +1 @@ -# Payment processing exceptions +# Payment processing exceptions diff --git a/app/exceptions/product.py b/app/exceptions/product.py index fcd984ba..72a3d8ba 100644 --- a/app/exceptions/product.py +++ b/app/exceptions/product.py @@ -4,12 +4,9 @@ Product (vendor catalog) specific exceptions. """ from typing import Optional -from .base import ( - ResourceNotFoundException, - ConflictException, - ValidationException, - BusinessLogicException -) + +from .base import (BusinessLogicException, ConflictException, + ResourceNotFoundException, ValidationException) class ProductNotFoundException(ResourceNotFoundException): diff --git a/app/exceptions/search.py b/app/exceptions/search.py index 5db85736..ed380099 100644 --- a/app/exceptions/search.py +++ b/app/exceptions/search.py @@ -1 +1 @@ -# Search exceptions +# Search exceptions diff --git a/app/exceptions/team.py b/app/exceptions/team.py index 4578a690..21283db1 100644 --- a/app/exceptions/team.py +++ b/app/exceptions/team.py @@ -4,13 +4,10 @@ Team management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - ConflictException, - ValidationException, - AuthorizationException, - BusinessLogicException -) + +from .base import (AuthorizationException, BusinessLogicException, + ConflictException, ResourceNotFoundException, + ValidationException) class TeamMemberNotFoundException(ResourceNotFoundException): @@ -20,7 +17,9 @@ class TeamMemberNotFoundException(ResourceNotFoundException): details = {"user_id": user_id} if vendor_id: details["vendor_id"] = vendor_id - message = f"Team member with user ID '{user_id}' not found in vendor {vendor_id}" + message = ( + f"Team member with user ID '{user_id}' not found in vendor {vendor_id}" + ) else: message = f"Team member with user ID '{user_id}' not found" @@ -84,7 +83,12 @@ class TeamInvitationAlreadyAcceptedException(ConflictException): class UnauthorizedTeamActionException(AuthorizationException): """Raised when user tries to perform team action without permission.""" - def __init__(self, action: str, user_id: Optional[int] = None, required_permission: Optional[str] = None): + def __init__( + self, + action: str, + user_id: Optional[int] = None, + required_permission: Optional[str] = None, + ): details = {"action": action} if user_id: details["user_id"] = user_id @@ -147,10 +151,10 @@ class InvalidRoleException(ValidationException): """Raised when role data is invalid.""" def __init__( - self, - message: str = "Invalid role data", - field: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str = "Invalid role data", + field: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): super().__init__( message=message, @@ -164,10 +168,10 @@ class InsufficientTeamPermissionsException(AuthorizationException): """Raised when user lacks required team permissions for an action.""" def __init__( - self, - required_permission: str, - user_id: Optional[int] = None, - action: Optional[str] = None, + self, + required_permission: str, + user_id: Optional[int] = None, + action: Optional[str] = None, ): details = {"required_permission": required_permission} if user_id: @@ -202,10 +206,10 @@ class TeamValidationException(ValidationException): """Raised when team operation validation fails.""" def __init__( - self, - message: str = "Team operation validation failed", - field: Optional[str] = None, - validation_errors: Optional[Dict[str, str]] = None, + self, + message: str = "Team operation validation failed", + field: Optional[str] = None, + validation_errors: Optional[Dict[str, str]] = None, ): details = {} if validation_errors: @@ -223,10 +227,10 @@ class InvalidInvitationDataException(ValidationException): """Raised when team invitation data is invalid.""" def __init__( - self, - message: str = "Invalid invitation data", - field: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str = "Invalid invitation data", + field: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): super().__init__( message=message, @@ -240,6 +244,7 @@ class InvalidInvitationDataException(ValidationException): # NEW: Add InvalidInvitationTokenException # ============================================================================ + class InvalidInvitationTokenException(ValidationException): """Raised when invitation token is invalid, expired, or already used. @@ -248,9 +253,9 @@ class InvalidInvitationTokenException(ValidationException): """ def __init__( - self, - message: str = "Invalid or expired invitation token", - invitation_token: Optional[str] = None + self, + message: str = "Invalid or expired invitation token", + invitation_token: Optional[str] = None, ): details = {} if invitation_token: diff --git a/app/exceptions/vendor.py b/app/exceptions/vendor.py index 5563488f..44380e67 100644 --- a/app/exceptions/vendor.py +++ b/app/exceptions/vendor.py @@ -4,13 +4,10 @@ Vendor management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - ConflictException, - ValidationException, - AuthorizationException, - BusinessLogicException -) + +from .base import (AuthorizationException, BusinessLogicException, + ConflictException, ResourceNotFoundException, + ValidationException) class VendorNotFoundException(ResourceNotFoundException): @@ -82,10 +79,10 @@ class InvalidVendorDataException(ValidationException): """Raised when vendor data is invalid or incomplete.""" def __init__( - self, - message: str = "Invalid vendor data", - field: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str = "Invalid vendor data", + field: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): super().__init__( message=message, @@ -99,10 +96,10 @@ class VendorValidationException(ValidationException): """Raised when vendor validation fails.""" def __init__( - self, - message: str = "Vendor validation failed", - field: Optional[str] = None, - validation_errors: Optional[Dict[str, str]] = None, + self, + message: str = "Vendor validation failed", + field: Optional[str] = None, + validation_errors: Optional[Dict[str, str]] = None, ): details = {} if validation_errors: @@ -120,9 +117,9 @@ class IncompleteVendorDataException(ValidationException): """Raised when vendor data is missing required fields.""" def __init__( - self, - vendor_code: str, - missing_fields: list, + self, + vendor_code: str, + missing_fields: list, ): super().__init__( message=f"Vendor '{vendor_code}' is missing required fields: {', '.join(missing_fields)}", diff --git a/app/exceptions/vendor_domain.py b/app/exceptions/vendor_domain.py index cb04d883..8a8e704b 100644 --- a/app/exceptions/vendor_domain.py +++ b/app/exceptions/vendor_domain.py @@ -4,13 +4,10 @@ Vendor domain management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - ConflictException, - ValidationException, - BusinessLogicException, - ExternalServiceException -) + +from .base import (BusinessLogicException, ConflictException, + ExternalServiceException, ResourceNotFoundException, + ValidationException) class VendorDomainNotFoundException(ResourceNotFoundException): @@ -64,10 +61,7 @@ class ReservedDomainException(ValidationException): super().__init__( message=f"Domain cannot use reserved subdomain: {reserved_part}", field="domain", - details={ - "domain": domain, - "reserved_part": reserved_part - }, + details={"domain": domain, "reserved_part": reserved_part}, ) self.error_code = "RESERVED_DOMAIN" @@ -79,10 +73,7 @@ class DomainNotVerifiedException(BusinessLogicException): super().__init__( message=f"Domain '{domain}' must be verified before activation", error_code="DOMAIN_NOT_VERIFIED", - details={ - "domain_id": domain_id, - "domain": domain - }, + details={"domain_id": domain_id, "domain": domain}, ) @@ -93,10 +84,7 @@ class DomainVerificationFailedException(BusinessLogicException): super().__init__( message=f"Domain verification failed for '{domain}': {reason}", error_code="DOMAIN_VERIFICATION_FAILED", - details={ - "domain": domain, - "reason": reason - }, + details={"domain": domain, "reason": reason}, ) @@ -107,10 +95,7 @@ class DomainAlreadyVerifiedException(BusinessLogicException): super().__init__( message=f"Domain '{domain}' is already verified", error_code="DOMAIN_ALREADY_VERIFIED", - details={ - "domain_id": domain_id, - "domain": domain - }, + details={"domain_id": domain_id, "domain": domain}, ) @@ -133,10 +118,7 @@ class DNSVerificationException(ExternalServiceException): service_name="DNS", message=f"DNS verification failed for '{domain}': {reason}", error_code="DNS_VERIFICATION_ERROR", - details={ - "domain": domain, - "reason": reason - }, + details={"domain": domain, "reason": reason}, ) @@ -147,10 +129,7 @@ class MaxDomainsReachedException(BusinessLogicException): super().__init__( message=f"Maximum number of domains reached ({max_domains})", error_code="MAX_DOMAINS_REACHED", - details={ - "vendor_id": vendor_id, - "max_domains": max_domains - }, + details={"vendor_id": vendor_id, "max_domains": max_domains}, ) @@ -161,8 +140,5 @@ class UnauthorizedDomainAccessException(BusinessLogicException): super().__init__( message=f"Unauthorized access to domain {domain_id}", error_code="UNAUTHORIZED_DOMAIN_ACCESS", - details={ - "domain_id": domain_id, - "vendor_id": vendor_id - }, + details={"domain_id": domain_id, "vendor_id": vendor_id}, ) diff --git a/app/exceptions/vendor_theme.py b/app/exceptions/vendor_theme.py index b38622c3..08000432 100644 --- a/app/exceptions/vendor_theme.py +++ b/app/exceptions/vendor_theme.py @@ -4,12 +4,9 @@ Vendor theme management specific exceptions. """ from typing import Any, Dict, Optional -from .base import ( - ResourceNotFoundException, - ConflictException, - ValidationException, - BusinessLogicException -) + +from .base import (BusinessLogicException, ConflictException, + ResourceNotFoundException, ValidationException) class VendorThemeNotFoundException(ResourceNotFoundException): @@ -28,10 +25,10 @@ class InvalidThemeDataException(ValidationException): """Raised when theme data is invalid.""" def __init__( - self, - message: str = "Invalid theme data", - field: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + self, + message: str = "Invalid theme data", + field: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, ): super().__init__( message=message, @@ -62,10 +59,10 @@ class ThemeValidationException(ValidationException): """Raised when theme validation fails.""" def __init__( - self, - message: str = "Theme validation failed", - field: Optional[str] = None, - validation_errors: Optional[Dict[str, str]] = None, + self, + message: str = "Theme validation failed", + field: Optional[str] = None, + validation_errors: Optional[Dict[str, str]] = None, ): details = {} if validation_errors: @@ -86,10 +83,7 @@ class ThemePresetAlreadyAppliedException(BusinessLogicException): super().__init__( message=f"Preset '{preset_name}' is already applied to vendor '{vendor_code}'", error_code="THEME_PRESET_ALREADY_APPLIED", - details={ - "preset_name": preset_name, - "vendor_code": vendor_code - }, + details={"preset_name": preset_name, "vendor_code": vendor_code}, ) @@ -120,18 +114,13 @@ class InvalidFontFamilyException(ValidationException): class ThemeOperationException(BusinessLogicException): """Raised when theme operation fails.""" - def __init__( - self, - operation: str, - vendor_code: str, - reason: str - ): + def __init__(self, operation: str, vendor_code: str, reason: str): super().__init__( message=f"Theme operation '{operation}' failed for vendor '{vendor_code}': {reason}", error_code="THEME_OPERATION_FAILED", details={ "operation": operation, "vendor_code": vendor_code, - "reason": reason + "reason": reason, }, ) diff --git a/app/models/architecture_scan.py b/app/models/architecture_scan.py index 39b861ec..403fbcbf 100644 --- a/app/models/architecture_scan.py +++ b/app/models/architecture_scan.py @@ -3,18 +3,23 @@ Architecture Scan Models Database models for tracking code quality scans and violations """ -from sqlalchemy import Column, Integer, String, Float, DateTime, Text, Boolean, ForeignKey, JSON +from sqlalchemy import (JSON, Boolean, Column, DateTime, Float, ForeignKey, + Integer, String, Text) from sqlalchemy.orm import relationship from sqlalchemy.sql import func + from app.core.database import Base class ArchitectureScan(Base): """Represents a single run of the architecture validator""" + __tablename__ = "architecture_scans" id = Column(Integer, primary_key=True, index=True) - timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True) + timestamp = Column( + DateTime(timezone=True), server_default=func.now(), nullable=False, index=True + ) total_files = Column(Integer, default=0) total_violations = Column(Integer, default=0) errors = Column(Integer, default=0) @@ -24,7 +29,9 @@ class ArchitectureScan(Base): git_commit_hash = Column(String(40)) # Relationship to violations - violations = relationship("ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan") + violations = relationship( + "ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan" + ) def __repr__(self): return f"" @@ -32,31 +39,48 @@ class ArchitectureScan(Base): class ArchitectureViolation(Base): """Represents a single architectural violation found during a scan""" + __tablename__ = "architecture_violations" id = Column(Integer, primary_key=True, index=True) - scan_id = Column(Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True) + scan_id = Column( + Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True + ) rule_id = Column(String(20), nullable=False, index=True) # e.g., 'API-001' rule_name = Column(String(200), nullable=False) - severity = Column(String(10), nullable=False, index=True) # 'error', 'warning', 'info' + severity = Column( + String(10), nullable=False, index=True + ) # 'error', 'warning', 'info' file_path = Column(String(500), nullable=False, index=True) line_number = Column(Integer, nullable=False) message = Column(Text, nullable=False) context = Column(Text) # Code snippet suggestion = Column(Text) - status = Column(String(20), default='open', index=True) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt' + status = Column( + String(20), default="open", index=True + ) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt' assigned_to = Column(Integer, ForeignKey("users.id")) resolved_at = Column(DateTime(timezone=True)) resolved_by = Column(Integer, ForeignKey("users.id")) resolution_note = Column(Text) - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + created_at = Column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) # Relationships scan = relationship("ArchitectureScan", back_populates="violations") - assigned_user = relationship("User", foreign_keys=[assigned_to], backref="assigned_violations") - resolver = relationship("User", foreign_keys=[resolved_by], backref="resolved_violations") - assignments = relationship("ViolationAssignment", back_populates="violation", cascade="all, delete-orphan") - comments = relationship("ViolationComment", back_populates="violation", cascade="all, delete-orphan") + assigned_user = relationship( + "User", foreign_keys=[assigned_to], backref="assigned_violations" + ) + resolver = relationship( + "User", foreign_keys=[resolved_by], backref="resolved_violations" + ) + assignments = relationship( + "ViolationAssignment", back_populates="violation", cascade="all, delete-orphan" + ) + comments = relationship( + "ViolationComment", back_populates="violation", cascade="all, delete-orphan" + ) def __repr__(self): return f"" @@ -64,18 +88,30 @@ class ArchitectureViolation(Base): class ArchitectureRule(Base): """Architecture rules configuration (from YAML with database overrides)""" + __tablename__ = "architecture_rules" id = Column(Integer, primary_key=True, index=True) - rule_id = Column(String(20), unique=True, nullable=False, index=True) # e.g., 'API-001' - category = Column(String(50), nullable=False) # 'api_endpoint', 'service_layer', etc. + rule_id = Column( + String(20), unique=True, nullable=False, index=True + ) # e.g., 'API-001' + category = Column( + String(50), nullable=False + ) # 'api_endpoint', 'service_layer', etc. name = Column(String(200), nullable=False) description = Column(Text) severity = Column(String(10), nullable=False) # Can override default from YAML enabled = Column(Boolean, default=True, nullable=False) custom_config = Column(JSON) # For rule-specific settings - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + created_at = Column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at = Column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) def __repr__(self): return f"" @@ -83,20 +119,29 @@ class ArchitectureRule(Base): class ViolationAssignment(Base): """Tracks assignment of violations to developers""" + __tablename__ = "violation_assignments" id = Column(Integer, primary_key=True, index=True) - violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True) + violation_id = Column( + Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True + ) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - assigned_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + assigned_at = Column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) assigned_by = Column(Integer, ForeignKey("users.id")) due_date = Column(DateTime(timezone=True)) - priority = Column(String(10), default='medium') # 'low', 'medium', 'high', 'critical' + priority = Column( + String(10), default="medium" + ) # 'low', 'medium', 'high', 'critical' # Relationships violation = relationship("ArchitectureViolation", back_populates="assignments") user = relationship("User", foreign_keys=[user_id], backref="violation_assignments") - assigner = relationship("User", foreign_keys=[assigned_by], backref="assigned_by_me") + assigner = relationship( + "User", foreign_keys=[assigned_by], backref="assigned_by_me" + ) def __repr__(self): return f"" @@ -104,13 +149,18 @@ class ViolationAssignment(Base): class ViolationComment(Base): """Comments on violations for collaboration""" + __tablename__ = "violation_comments" id = Column(Integer, primary_key=True, index=True) - violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True) + violation_id = Column( + Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True + ) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) comment = Column(Text, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + created_at = Column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) # Relationships violation = relationship("ArchitectureViolation", back_populates="comments") diff --git a/app/routes/admin_pages.py b/app/routes/admin_pages.py index 49b10873..736b27db 100644 --- a/app/routes/admin_pages.py +++ b/app/routes/admin_pages.py @@ -30,17 +30,15 @@ Routes: - GET /code-quality/violations/{violation_id} → Violation details (auth required) """ -from fastapi import APIRouter, Request, Depends, Path +from typing import Optional + +from fastapi import APIRouter, Depends, Path, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session -from typing import Optional -from app.api.deps import ( - get_current_admin_from_cookie_or_header, - get_current_admin_optional, - get_db -) +from app.api.deps import (get_current_admin_from_cookie_or_header, + get_current_admin_optional, get_db) from models.database.user import User router = APIRouter() @@ -51,9 +49,10 @@ templates = Jinja2Templates(directory="app/templates") # PUBLIC ROUTES (No Authentication Required) # ============================================================================ + @router.get("/", response_class=RedirectResponse, include_in_schema=False) async def admin_root( - current_user: Optional[User] = Depends(get_current_admin_optional) + current_user: Optional[User] = Depends(get_current_admin_optional), ): """ Redirect /admin/ based on authentication status. @@ -70,8 +69,7 @@ async def admin_root( @router.get("/login", response_class=HTMLResponse, include_in_schema=False) async def admin_login_page( - request: Request, - current_user: Optional[User] = Depends(get_current_admin_optional) + request: Request, current_user: Optional[User] = Depends(get_current_admin_optional) ): """ Render admin login page. @@ -83,21 +81,19 @@ async def admin_login_page( # User is already logged in as admin, redirect to dashboard return RedirectResponse(url="/admin/dashboard", status_code=302) - return templates.TemplateResponse( - "admin/login.html", - {"request": request} - ) + return templates.TemplateResponse("admin/login.html", {"request": request}) # ============================================================================ # AUTHENTICATED ROUTES (Admin Only) # ============================================================================ + @router.get("/dashboard", response_class=HTMLResponse, include_in_schema=False) async def admin_dashboard_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render admin dashboard page. @@ -108,7 +104,7 @@ async def admin_dashboard_page( { "request": request, "user": current_user, - } + }, ) @@ -116,11 +112,12 @@ async def admin_dashboard_page( # VENDOR MANAGEMENT ROUTES # ============================================================================ + @router.get("/vendors", response_class=HTMLResponse, include_in_schema=False) async def admin_vendors_list_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendors management page. @@ -131,15 +128,15 @@ async def admin_vendors_list_page( { "request": request, "user": current_user, - } + }, ) @router.get("/vendors/create", response_class=HTMLResponse, include_in_schema=False) async def admin_vendor_create_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendor creation form. @@ -149,16 +146,18 @@ async def admin_vendor_create_page( { "request": request, "user": current_user, - } + }, ) -@router.get("/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False +) async def admin_vendor_detail_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendor detail page. @@ -170,16 +169,18 @@ async def admin_vendor_detail_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) -@router.get("/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False +) async def admin_vendor_edit_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendor edit form. @@ -190,7 +191,7 @@ async def admin_vendor_edit_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -198,12 +199,17 @@ async def admin_vendor_edit_page( # VENDOR DOMAINS ROUTES # ============================================================================ -@router.get("/vendors/{vendor_code}/domains", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/vendors/{vendor_code}/domains", + response_class=HTMLResponse, + include_in_schema=False, +) async def admin_vendor_domains_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendor domains management page. @@ -215,7 +221,7 @@ async def admin_vendor_domains_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -223,12 +229,15 @@ async def admin_vendor_domains_page( # VENDOR THEMES ROUTES # ============================================================================ -@router.get("/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False +) async def admin_vendor_theme_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendor theme customization page. @@ -240,7 +249,7 @@ async def admin_vendor_theme_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -248,11 +257,12 @@ async def admin_vendor_theme_page( # USER MANAGEMENT ROUTES # ============================================================================ + @router.get("/users", response_class=HTMLResponse, include_in_schema=False) async def admin_users_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render users management page. @@ -263,7 +273,7 @@ async def admin_users_page( { "request": request, "user": current_user, - } + }, ) @@ -271,11 +281,12 @@ async def admin_users_page( # IMPORT MANAGEMENT ROUTES # ============================================================================ + @router.get("/imports", response_class=HTMLResponse, include_in_schema=False) async def admin_imports_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render imports management page. @@ -286,7 +297,7 @@ async def admin_imports_page( { "request": request, "user": current_user, - } + }, ) @@ -294,11 +305,12 @@ async def admin_imports_page( # SETTINGS ROUTES # ============================================================================ + @router.get("/settings", response_class=HTMLResponse, include_in_schema=False) async def admin_settings_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render admin settings page. @@ -309,7 +321,7 @@ async def admin_settings_page( { "request": request, "user": current_user, - } + }, ) @@ -317,11 +329,12 @@ async def admin_settings_page( # CONTENT MANAGEMENT SYSTEM (CMS) ROUTES # ============================================================================ + @router.get("/platform-homepage", response_class=HTMLResponse, include_in_schema=False) async def admin_platform_homepage_manager( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render platform homepage manager. @@ -332,15 +345,15 @@ async def admin_platform_homepage_manager( { "request": request, "user": current_user, - } + }, ) @router.get("/content-pages", response_class=HTMLResponse, include_in_schema=False) async def admin_content_pages_list( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render content pages list. @@ -351,15 +364,17 @@ async def admin_content_pages_list( { "request": request, "user": current_user, - } + }, ) -@router.get("/content-pages/create", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/content-pages/create", response_class=HTMLResponse, include_in_schema=False +) async def admin_content_page_create( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render create content page form. @@ -371,16 +386,20 @@ async def admin_content_page_create( "request": request, "user": current_user, "page_id": None, # Indicates this is a create operation - } + }, ) -@router.get("/content-pages/{page_id}/edit", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/content-pages/{page_id}/edit", + response_class=HTMLResponse, + include_in_schema=False, +) async def admin_content_page_edit( - request: Request, - page_id: int = Path(..., description="Content page ID"), - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + page_id: int = Path(..., description="Content page ID"), + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render edit content page form. @@ -392,7 +411,7 @@ async def admin_content_page_edit( "request": request, "user": current_user, "page_id": page_id, - } + }, ) @@ -400,11 +419,12 @@ async def admin_content_page_edit( # DEVELOPER TOOLS - COMPONENTS & TESTING # ============================================================================ + @router.get("/components", response_class=HTMLResponse, include_in_schema=False) async def admin_components_page( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render UI components library page. @@ -415,7 +435,7 @@ async def admin_components_page( { "request": request, "user": current_user, - } + }, ) @@ -423,7 +443,7 @@ async def admin_components_page( async def admin_icons_page( request: Request, current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Render icons browser page. @@ -434,15 +454,15 @@ async def admin_icons_page( { "request": request, "user": current_user, - } + }, ) @router.get("/testing", response_class=HTMLResponse, include_in_schema=False) async def admin_testing_hub( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render testing hub page. @@ -453,15 +473,15 @@ async def admin_testing_hub( { "request": request, "user": current_user, - } + }, ) @router.get("/test/auth-flow", response_class=HTMLResponse, include_in_schema=False) async def admin_test_auth_flow( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render authentication flow testing page. @@ -472,15 +492,19 @@ async def admin_test_auth_flow( { "request": request, "user": current_user, - } + }, ) -@router.get("/test/vendors-users-migration", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/test/vendors-users-migration", + response_class=HTMLResponse, + include_in_schema=False, +) async def admin_test_vendors_users_migration( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render vendors and users migration testing page. @@ -491,7 +515,7 @@ async def admin_test_vendors_users_migration( { "request": request, "user": current_user, - } + }, ) @@ -499,11 +523,12 @@ async def admin_test_vendors_users_migration( # CODE QUALITY & ARCHITECTURE ROUTES # ============================================================================ + @router.get("/code-quality", response_class=HTMLResponse, include_in_schema=False) async def admin_code_quality_dashboard( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render code quality dashboard. @@ -514,15 +539,17 @@ async def admin_code_quality_dashboard( { "request": request, "user": current_user, - } + }, ) -@router.get("/code-quality/violations", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/code-quality/violations", response_class=HTMLResponse, include_in_schema=False +) async def admin_code_quality_violations( - request: Request, - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render violations list page. @@ -533,16 +560,20 @@ async def admin_code_quality_violations( { "request": request, "user": current_user, - } + }, ) -@router.get("/code-quality/violations/{violation_id}", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/code-quality/violations/{violation_id}", + response_class=HTMLResponse, + include_in_schema=False, +) async def admin_code_quality_violation_detail( - request: Request, - violation_id: int = Path(..., description="Violation ID"), - current_user: User = Depends(get_current_admin_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + violation_id: int = Path(..., description="Violation ID"), + current_user: User = Depends(get_current_admin_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render violation detail page. @@ -554,5 +585,5 @@ async def admin_code_quality_violation_detail( "request": request, "user": current_user, "violation_id": violation_id, - } + }, ) diff --git a/app/routes/shop_pages.py b/app/routes/shop_pages.py index 569113d9..8103d809 100644 --- a/app/routes/shop_pages.py +++ b/app/routes/shop_pages.py @@ -31,7 +31,8 @@ Routes (all mounted at /shop/* or /vendors/{code}/shop/* prefix): """ import logging -from fastapi import APIRouter, Request, Depends, Path + +from fastapi import APIRouter, Depends, Path, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session @@ -50,6 +51,7 @@ logger = logging.getLogger(__name__) # HELPER: Build Shop Template Context # ============================================================================ + def get_shop_context(request: Request, db: Session = None, **extra_context) -> dict: """ Build template context for shop pages. @@ -76,13 +78,17 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d get_shop_context(request, db=db, user=current_user, product_id=123) """ # Extract from middleware state - vendor = getattr(request.state, 'vendor', None) - theme = getattr(request.state, 'theme', None) - clean_path = getattr(request.state, 'clean_path', request.url.path) - vendor_context = getattr(request.state, 'vendor_context', None) + vendor = getattr(request.state, "vendor", None) + theme = getattr(request.state, "theme", None) + clean_path = getattr(request.state, "clean_path", request.url.path) + vendor_context = getattr(request.state, "vendor_context", None) # Get detection method from vendor_context - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) if vendor is None: logger.warning( @@ -91,7 +97,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d "path": request.url.path, "host": request.headers.get("host", ""), "has_vendor": False, - } + }, ) # Calculate base URL for links @@ -100,7 +106,11 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d base_url = "/" if access_method == "path" and vendor: # Use the full_prefix from vendor_context to determine which pattern was used - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) base_url = f"{full_prefix}{vendor.subdomain}/" # Load footer navigation pages from CMS if db session provided @@ -111,22 +121,16 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d vendor_id = vendor.id # Get pages configured to show in footer footer_pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=vendor_id, - footer_only=True, - include_unpublished=False + db, vendor_id=vendor_id, footer_only=True, include_unpublished=False ) # Get pages configured to show in header header_pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=vendor_id, - header_only=True, - include_unpublished=False + db, vendor_id=vendor_id, header_only=True, include_unpublished=False ) except Exception as e: logger.error( f"[SHOP_CONTEXT] Failed to load navigation pages", - extra={"error": str(e), "vendor_id": vendor.id if vendor else None} + extra={"error": str(e), "vendor_id": vendor.id if vendor else None}, ) context = { @@ -156,7 +160,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d "footer_pages_count": len(footer_pages), "header_pages_count": len(header_pages), "extra_keys": list(extra_context.keys()) if extra_context else [], - } + }, ) return context @@ -166,6 +170,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d # PUBLIC SHOP ROUTES (No Authentication Required) # ============================================================================ + @router.get("/", response_class=HTMLResponse, include_in_schema=False) @router.get("/products", response_class=HTMLResponse, include_in_schema=False) async def shop_products_page(request: Request, db: Session = Depends(get_db)): @@ -177,21 +182,21 @@ async def shop_products_page(request: Request, db: Session = Depends(get_db)): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/products.html", - get_shop_context(request, db=db) + "shop/products.html", get_shop_context(request, db=db) ) -@router.get("/products/{product_id}", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/products/{product_id}", response_class=HTMLResponse, include_in_schema=False +) async def shop_product_detail_page( - request: Request, - product_id: int = Path(..., description="Product ID") + request: Request, product_id: int = Path(..., description="Product ID") ): """ Render product detail page. @@ -201,21 +206,21 @@ async def shop_product_detail_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/product.html", - get_shop_context(request, product_id=product_id) + "shop/product.html", get_shop_context(request, product_id=product_id) ) -@router.get("/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False +) async def shop_category_page( - request: Request, - category_slug: str = Path(..., description="Category slug") + request: Request, category_slug: str = Path(..., description="Category slug") ): """ Render category products page. @@ -225,14 +230,13 @@ async def shop_category_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/category.html", - get_shop_context(request, category_slug=category_slug) + "shop/category.html", get_shop_context(request, category_slug=category_slug) ) @@ -246,15 +250,12 @@ async def shop_cart_page(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) - return templates.TemplateResponse( - "shop/cart.html", - get_shop_context(request) - ) + return templates.TemplateResponse("shop/cart.html", get_shop_context(request)) @router.get("/checkout", response_class=HTMLResponse, include_in_schema=False) @@ -267,15 +268,12 @@ async def shop_checkout_page(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) - return templates.TemplateResponse( - "shop/checkout.html", - get_shop_context(request) - ) + return templates.TemplateResponse("shop/checkout.html", get_shop_context(request)) @router.get("/search", response_class=HTMLResponse, include_in_schema=False) @@ -288,21 +286,19 @@ async def shop_search_page(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) - return templates.TemplateResponse( - "shop/search.html", - get_shop_context(request) - ) + return templates.TemplateResponse("shop/search.html", get_shop_context(request)) # ============================================================================ # CUSTOMER ACCOUNT - PUBLIC ROUTES (No Authentication) # ============================================================================ + @router.get("/account/register", response_class=HTMLResponse, include_in_schema=False) async def shop_register_page(request: Request): """ @@ -313,14 +309,13 @@ async def shop_register_page(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/register.html", - get_shop_context(request) + "shop/account/register.html", get_shop_context(request) ) @@ -334,18 +329,19 @@ async def shop_login_page(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/login.html", - get_shop_context(request) + "shop/account/login.html", get_shop_context(request) ) -@router.get("/account/forgot-password", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/account/forgot-password", response_class=HTMLResponse, include_in_schema=False +) async def shop_forgot_password_page(request: Request): """ Render forgot password page. @@ -355,14 +351,13 @@ async def shop_forgot_password_page(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/forgot-password.html", - get_shop_context(request) + "shop/account/forgot-password.html", get_shop_context(request) ) @@ -370,6 +365,7 @@ async def shop_forgot_password_page(request: Request): # CUSTOMER ACCOUNT - AUTHENTICATED ROUTES # ============================================================================ + @router.get("/account", response_class=RedirectResponse, include_in_schema=False) @router.get("/account/", response_class=RedirectResponse, include_in_schema=False) async def shop_account_root(request: Request): @@ -380,19 +376,27 @@ async def shop_account_root(request: Request): f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) # Get base_url from context for proper redirect - vendor = getattr(request.state, 'vendor', None) - vendor_context = getattr(request.state, 'vendor_context', None) - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + vendor = getattr(request.state, "vendor", None) + vendor_context = getattr(request.state, "vendor_context", None) + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) base_url = "/" if access_method == "path" and vendor: - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) base_url = f"{full_prefix}{vendor.subdomain}/" return RedirectResponse(url=f"{base_url}shop/account/dashboard", status_code=302) @@ -400,9 +404,9 @@ async def shop_account_root(request: Request): @router.get("/account/dashboard", response_class=HTMLResponse, include_in_schema=False) async def shop_account_dashboard_page( - request: Request, - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer account dashboard. @@ -413,22 +417,21 @@ async def shop_account_dashboard_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/dashboard.html", - get_shop_context(request, user=current_customer) + "shop/account/dashboard.html", get_shop_context(request, user=current_customer) ) @router.get("/account/orders", response_class=HTMLResponse, include_in_schema=False) async def shop_orders_page( - request: Request, - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer orders history page. @@ -439,23 +442,24 @@ async def shop_orders_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/orders.html", - get_shop_context(request, user=current_customer) + "shop/account/orders.html", get_shop_context(request, user=current_customer) ) -@router.get("/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False +) async def shop_order_detail_page( - request: Request, - order_id: int = Path(..., description="Order ID"), - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + order_id: int = Path(..., description="Order ID"), + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer order detail page. @@ -466,22 +470,22 @@ async def shop_order_detail_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( "shop/account/order-detail.html", - get_shop_context(request, user=current_customer, order_id=order_id) + get_shop_context(request, user=current_customer, order_id=order_id), ) @router.get("/account/profile", response_class=HTMLResponse, include_in_schema=False) async def shop_profile_page( - request: Request, - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer profile page. @@ -492,22 +496,21 @@ async def shop_profile_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/profile.html", - get_shop_context(request, user=current_customer) + "shop/account/profile.html", get_shop_context(request, user=current_customer) ) @router.get("/account/addresses", response_class=HTMLResponse, include_in_schema=False) async def shop_addresses_page( - request: Request, - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer addresses management page. @@ -518,22 +521,21 @@ async def shop_addresses_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/addresses.html", - get_shop_context(request, user=current_customer) + "shop/account/addresses.html", get_shop_context(request, user=current_customer) ) @router.get("/account/wishlist", response_class=HTMLResponse, include_in_schema=False) async def shop_wishlist_page( - request: Request, - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer wishlist page. @@ -544,22 +546,21 @@ async def shop_wishlist_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/wishlist.html", - get_shop_context(request, user=current_customer) + "shop/account/wishlist.html", get_shop_context(request, user=current_customer) ) @router.get("/account/settings", response_class=HTMLResponse, include_in_schema=False) async def shop_settings_page( - request: Request, - current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), - db: Session = Depends(get_db) + request: Request, + current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), + db: Session = Depends(get_db), ): """ Render customer account settings page. @@ -570,14 +571,13 @@ async def shop_settings_page( f"[SHOP_HANDLER] shop_products_page REACHED", extra={ "path": request.url.path, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) return templates.TemplateResponse( - "shop/account/settings.html", - get_shop_context(request, user=current_customer) + "shop/account/settings.html", get_shop_context(request, user=current_customer) ) @@ -585,11 +585,12 @@ async def shop_settings_page( # DYNAMIC CONTENT PAGES (CMS) # ============================================================================ + @router.get("/{slug}", response_class=HTMLResponse, include_in_schema=False) async def generic_content_page( request: Request, slug: str = Path(..., description="Content page slug"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Generic content page handler (CMS). @@ -612,20 +613,17 @@ async def generic_content_page( extra={ "path": request.url.path, "slug": slug, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) vendor_id = vendor.id if vendor else None # Load content page from database (vendor override → platform default) page = content_page_service.get_page_for_vendor( - db, - slug=slug, - vendor_id=vendor_id, - include_unpublished=False + db, slug=slug, vendor_id=vendor_id, include_unpublished=False ) if not page: @@ -635,7 +633,7 @@ async def generic_content_page( "slug": slug, "vendor_id": vendor_id, "vendor_name": vendor.name if vendor else None, - } + }, ) raise HTTPException(status_code=404, detail=f"Page not found: {slug}") @@ -647,12 +645,11 @@ async def generic_content_page( "page_title": page.title, "is_vendor_override": page.vendor_id is not None, "vendor_id": vendor_id, - } + }, ) return templates.TemplateResponse( - "shop/content-page.html", - get_shop_context(request, page=page) + "shop/content-page.html", get_shop_context(request, page=page) ) @@ -660,6 +657,7 @@ async def generic_content_page( # DEBUG ENDPOINTS - For troubleshooting context issues # ============================================================================ + @router.get("/debug/context", response_class=HTMLResponse, include_in_schema=False) async def debug_context(request: Request): """ @@ -670,8 +668,8 @@ async def debug_context(request: Request): URL: /shop/debug/context """ - vendor = getattr(request.state, 'vendor', None) - theme = getattr(request.state, 'theme', None) + vendor = getattr(request.state, "vendor", None) + theme = getattr(request.state, "theme", None) debug_info = { "path": request.url.path, @@ -687,12 +685,13 @@ async def debug_context(request: Request): "found": theme is not None, "name": theme.get("theme_name") if theme else None, }, - "clean_path": getattr(request.state, 'clean_path', 'NOT SET'), - "context_type": str(getattr(request.state, 'context_type', 'NOT SET')), + "clean_path": getattr(request.state, "clean_path", "NOT SET"), + "context_type": str(getattr(request.state, "context_type", "NOT SET")), } # Return as JSON-like HTML for easy reading import json + html_content = f""" diff --git a/app/routes/vendor_pages.py b/app/routes/vendor_pages.py index a6b24ad2..41dd702d 100644 --- a/app/routes/vendor_pages.py +++ b/app/routes/vendor_pages.py @@ -21,18 +21,16 @@ Routes: - GET /vendor/{vendor_code}/settings → Vendor settings """ -from fastapi import APIRouter, Request, Depends, Path, HTTPException +import logging +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Path, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session -from typing import Optional -import logging -from app.api.deps import ( - get_current_vendor_from_cookie_or_header, - get_current_vendor_optional, - get_db -) +from app.api.deps import (get_current_vendor_from_cookie_or_header, + get_current_vendor_optional, get_db) from app.services.content_page_service import content_page_service from models.database.user import User @@ -46,6 +44,7 @@ templates = Jinja2Templates(directory="app/templates") # PUBLIC ROUTES (No Authentication Required) # ============================================================================ + @router.get("/{vendor_code}", response_class=RedirectResponse, include_in_schema=False) async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor code")): """ @@ -57,8 +56,8 @@ async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor @router.get("/{vendor_code}/", response_class=RedirectResponse, include_in_schema=False) async def vendor_root( - vendor_code: str = Path(..., description="Vendor code"), - current_user: Optional[User] = Depends(get_current_vendor_optional) + vendor_code: str = Path(..., description="Vendor code"), + current_user: Optional[User] = Depends(get_current_vendor_optional), ): """ Redirect /vendor/{code}/ based on authentication status. @@ -73,11 +72,13 @@ async def vendor_root( return RedirectResponse(url=f"/vendor/{vendor_code}/login", status_code=302) -@router.get("/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False +) async def vendor_login_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: Optional[User] = Depends(get_current_vendor_optional) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: Optional[User] = Depends(get_current_vendor_optional), ): """ Render vendor login page. @@ -99,7 +100,7 @@ async def vendor_login_page( { "request": request, "vendor_code": vendor_code, - } + }, ) @@ -107,11 +108,14 @@ async def vendor_login_page( # AUTHENTICATED ROUTES (Vendor Users Only) # ============================================================================ -@router.get("/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False +) async def vendor_dashboard_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render vendor dashboard. @@ -128,7 +132,7 @@ async def vendor_dashboard_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -136,11 +140,14 @@ async def vendor_dashboard_page( # PRODUCT MANAGEMENT # ============================================================================ -@router.get("/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False +) async def vendor_products_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render products management page. @@ -152,7 +159,7 @@ async def vendor_products_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -160,11 +167,14 @@ async def vendor_products_page( # ORDER MANAGEMENT # ============================================================================ -@router.get("/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False +) async def vendor_orders_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render orders management page. @@ -176,7 +186,7 @@ async def vendor_orders_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -184,11 +194,14 @@ async def vendor_orders_page( # CUSTOMER MANAGEMENT # ============================================================================ -@router.get("/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False +) async def vendor_customers_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render customers management page. @@ -200,7 +213,7 @@ async def vendor_customers_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -208,11 +221,14 @@ async def vendor_customers_page( # INVENTORY MANAGEMENT # ============================================================================ -@router.get("/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False +) async def vendor_inventory_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render inventory management page. @@ -224,7 +240,7 @@ async def vendor_inventory_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -232,11 +248,14 @@ async def vendor_inventory_page( # MARKETPLACE IMPORTS # ============================================================================ -@router.get("/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False +) async def vendor_marketplace_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render marketplace import page. @@ -248,7 +267,7 @@ async def vendor_marketplace_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -256,11 +275,12 @@ async def vendor_marketplace_page( # TEAM MANAGEMENT # ============================================================================ + @router.get("/{vendor_code}/team", response_class=HTMLResponse, include_in_schema=False) async def vendor_team_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render team management page. @@ -272,7 +292,7 @@ async def vendor_team_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -280,11 +300,14 @@ async def vendor_team_page( # PROFILE & SETTINGS # ============================================================================ -@router.get("/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False +) async def vendor_profile_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render vendor profile page. @@ -296,15 +319,17 @@ async def vendor_profile_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) -@router.get("/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False) +@router.get( + "/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False +) async def vendor_settings_page( - request: Request, - vendor_code: str = Path(..., description="Vendor code"), - current_user: User = Depends(get_current_vendor_from_cookie_or_header) + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), ): """ Render vendor settings page. @@ -316,7 +341,7 @@ async def vendor_settings_page( "request": request, "user": current_user, "vendor_code": vendor_code, - } + }, ) @@ -324,12 +349,15 @@ async def vendor_settings_page( # DYNAMIC CONTENT PAGES (CMS) # ============================================================================ -@router.get("/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False) + +@router.get( + "/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False +) async def vendor_content_page( request: Request, vendor_code: str = Path(..., description="Vendor code"), slug: str = Path(..., description="Content page slug"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Generic content page handler for vendor shop (CMS). @@ -351,20 +379,17 @@ async def vendor_content_page( "path": request.url.path, "vendor_code": vendor_code, "slug": slug, - "vendor": getattr(request.state, 'vendor', 'NOT SET'), - "context": getattr(request.state, 'context_type', 'NOT SET'), - } + "vendor": getattr(request.state, "vendor", "NOT SET"), + "context": getattr(request.state, "context_type", "NOT SET"), + }, ) - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) vendor_id = vendor.id if vendor else None # Load content page from database (vendor override → platform default) page = content_page_service.get_page_for_vendor( - db, - slug=slug, - vendor_id=vendor_id, - include_unpublished=False + db, slug=slug, vendor_id=vendor_id, include_unpublished=False ) if not page: @@ -374,7 +399,7 @@ async def vendor_content_page( "slug": slug, "vendor_code": vendor_code, "vendor_id": vendor_id, - } + }, ) raise HTTPException(status_code=404, detail="Page not found") @@ -385,7 +410,7 @@ async def vendor_content_page( "page_id": page.id, "is_vendor_override": page.vendor_id is not None, "vendor_id": vendor_id, - } + }, ) return templates.TemplateResponse( @@ -394,5 +419,5 @@ async def vendor_content_page( "request": request, "page": page, "vendor_code": vendor_code, - } + }, ) diff --git a/app/services/admin_audit_service.py b/app/services/admin_audit_service.py index ee9da385..503a2c32 100644 --- a/app/services/admin_audit_service.py +++ b/app/services/admin_audit_service.py @@ -10,15 +10,15 @@ This module provides functions for: import logging from datetime import datetime, timezone -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional -from sqlalchemy.orm import Session from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session +from app.exceptions import AdminOperationException from models.database.admin import AdminAuditLog from models.database.user import User from models.schema.admin import AdminAuditLogFilters, AdminAuditLogResponse -from app.exceptions import AdminOperationException logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class AdminAuditService: details: Optional[Dict[str, Any]] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, - request_id: Optional[str] = None + request_id: Optional[str] = None, ) -> AdminAuditLog: """ Log an admin action to the audit trail. @@ -63,7 +63,7 @@ class AdminAuditService: details=details or {}, ip_address=ip_address, user_agent=user_agent, - request_id=request_id + request_id=request_id, ) db.add(audit_log) @@ -84,9 +84,7 @@ class AdminAuditService: return None def get_audit_logs( - self, - db: Session, - filters: AdminAuditLogFilters + self, db: Session, filters: AdminAuditLogFilters ) -> List[AdminAuditLogResponse]: """ Get filtered admin audit logs with pagination. @@ -98,7 +96,9 @@ class AdminAuditService: List of audit log responses """ try: - query = db.query(AdminAuditLog).join(User, AdminAuditLog.admin_user_id == User.id) + query = db.query(AdminAuditLog).join( + User, AdminAuditLog.admin_user_id == User.id + ) # Apply filters conditions = [] @@ -123,8 +123,7 @@ class AdminAuditService: # Execute query with pagination logs = ( - query - .order_by(AdminAuditLog.created_at.desc()) + query.order_by(AdminAuditLog.created_at.desc()) .offset(filters.skip) .limit(filters.limit) .all() @@ -143,7 +142,7 @@ class AdminAuditService: ip_address=log.ip_address, user_agent=log.user_agent, request_id=log.request_id, - created_at=log.created_at + created_at=log.created_at, ) for log in logs ] @@ -151,15 +150,10 @@ class AdminAuditService: except Exception as e: logger.error(f"Failed to retrieve audit logs: {str(e)}") raise AdminOperationException( - operation="get_audit_logs", - reason="Database query failed" + operation="get_audit_logs", reason="Database query failed" ) - def get_audit_logs_count( - self, - db: Session, - filters: AdminAuditLogFilters - ) -> int: + def get_audit_logs_count(self, db: Session, filters: AdminAuditLogFilters) -> int: """Get total count of audit logs matching filters.""" try: query = db.query(AdminAuditLog) @@ -192,24 +186,14 @@ class AdminAuditService: return 0 def get_recent_actions_by_admin( - self, - db: Session, - admin_user_id: int, - limit: int = 10 + self, db: Session, admin_user_id: int, limit: int = 10 ) -> List[AdminAuditLogResponse]: """Get recent actions by a specific admin.""" - filters = AdminAuditLogFilters( - admin_user_id=admin_user_id, - limit=limit - ) + filters = AdminAuditLogFilters(admin_user_id=admin_user_id, limit=limit) return self.get_audit_logs(db, filters) def get_actions_by_target( - self, - db: Session, - target_type: str, - target_id: str, - limit: int = 50 + self, db: Session, target_type: str, target_id: str, limit: int = 50 ) -> List[AdminAuditLogResponse]: """Get all actions performed on a specific target.""" try: @@ -218,7 +202,7 @@ class AdminAuditService: .filter( and_( AdminAuditLog.target_type == target_type, - AdminAuditLog.target_id == str(target_id) + AdminAuditLog.target_id == str(target_id), ) ) .order_by(AdminAuditLog.created_at.desc()) @@ -236,7 +220,7 @@ class AdminAuditService: target_id=log.target_id, details=log.details, ip_address=log.ip_address, - created_at=log.created_at + created_at=log.created_at, ) for log in logs ] @@ -247,4 +231,4 @@ class AdminAuditService: # Create service instance -admin_audit_service = AdminAuditService() \ No newline at end of file +admin_audit_service = AdminAuditService() diff --git a/app/services/admin_service.py b/app/services/admin_service.py index f6529725..65637df8 100644 --- a/app/services/admin_service.py +++ b/app/services/admin_service.py @@ -16,24 +16,19 @@ import string from datetime import datetime, timezone from typing import List, Optional, Tuple -from sqlalchemy.orm import Session from sqlalchemy import func, or_ +from sqlalchemy.orm import Session -from app.exceptions import ( - UserNotFoundException, - UserStatusChangeException, - CannotModifySelfException, - VendorNotFoundException, - VendorAlreadyExistsException, - VendorVerificationException, - AdminOperationException, - ValidationException, -) +from app.exceptions import (AdminOperationException, CannotModifySelfException, + UserNotFoundException, UserStatusChangeException, + ValidationException, VendorAlreadyExistsException, + VendorNotFoundException, + VendorVerificationException) +from models.database.marketplace_import_job import MarketplaceImportJob +from models.database.user import User +from models.database.vendor import Role, Vendor, VendorUser from models.schema.marketplace_import_job import MarketplaceImportJobResponse from models.schema.vendor import VendorCreate -from models.database.marketplace_import_job import MarketplaceImportJob -from models.database.vendor import Vendor, Role, VendorUser -from models.database.user import User logger = logging.getLogger(__name__) @@ -52,12 +47,11 @@ class AdminService: except Exception as e: logger.error(f"Failed to retrieve users: {str(e)}") raise AdminOperationException( - operation="get_all_users", - reason="Database query failed" + operation="get_all_users", reason="Database query failed" ) def toggle_user_status( - self, db: Session, user_id: int, current_admin_id: int + self, db: Session, user_id: int, current_admin_id: int ) -> Tuple[User, str]: """Toggle user active status.""" user = self._get_user_by_id_or_raise(db, user_id) @@ -72,7 +66,7 @@ class AdminService: user_id=user_id, current_status="admin", attempted_action="toggle status", - reason="Cannot modify another admin user" + reason="Cannot modify another admin user", ) try: @@ -95,7 +89,7 @@ class AdminService: user_id=user_id, current_status="active" if original_status else "inactive", attempted_action="toggle status", - reason="Database update failed" + reason="Database update failed", ) # ============================================================================ @@ -103,7 +97,7 @@ class AdminService: # ============================================================================ def create_vendor_with_owner( - self, db: Session, vendor_data: VendorCreate + self, db: Session, vendor_data: VendorCreate ) -> Tuple[Vendor, User, str]: """ Create vendor with owner user account. @@ -118,17 +112,23 @@ class AdminService: """ try: # Check if vendor code already exists - existing_vendor = db.query(Vendor).filter( - func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper() - ).first() + existing_vendor = ( + db.query(Vendor) + .filter( + func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper() + ) + .first() + ) if existing_vendor: raise VendorAlreadyExistsException(vendor_data.vendor_code) # Check if subdomain already exists - existing_subdomain = db.query(Vendor).filter( - func.lower(Vendor.subdomain) == vendor_data.subdomain.lower() - ).first() + existing_subdomain = ( + db.query(Vendor) + .filter(func.lower(Vendor.subdomain) == vendor_data.subdomain.lower()) + .first() + ) if existing_subdomain: raise ValidationException( @@ -140,15 +140,14 @@ class AdminService: # Create owner user with owner_email from middleware.auth import AuthManager + auth_manager = AuthManager() owner_username = f"{vendor_data.subdomain}_owner" owner_email = vendor_data.owner_email # ✅ For User authentication # Check if user with this email already exists - existing_user = db.query(User).filter( - User.email == owner_email - ).first() + existing_user = db.query(User).filter(User.email == owner_email).first() if existing_user: # Use existing user as owner @@ -215,17 +214,17 @@ class AdminService: logger.error(f"Failed to create vendor: {str(e)}") raise AdminOperationException( operation="create_vendor_with_owner", - reason=f"Failed to create vendor: {str(e)}" + reason=f"Failed to create vendor: {str(e)}", ) def get_all_vendors( - self, - db: Session, - skip: int = 0, - limit: int = 100, - search: Optional[str] = None, - is_active: Optional[bool] = None, - is_verified: Optional[bool] = None + self, + db: Session, + skip: int = 0, + limit: int = 100, + search: Optional[str] = None, + is_active: Optional[bool] = None, + is_verified: Optional[bool] = None, ) -> Tuple[List[Vendor], int]: """Get paginated list of all vendors with filtering.""" try: @@ -238,7 +237,7 @@ class AdminService: or_( Vendor.name.ilike(search_term), Vendor.vendor_code.ilike(search_term), - Vendor.subdomain.ilike(search_term) + Vendor.subdomain.ilike(search_term), ) ) @@ -255,8 +254,7 @@ class AdminService: except Exception as e: logger.error(f"Failed to retrieve vendors: {str(e)}") raise AdminOperationException( - operation="get_all_vendors", - reason="Database query failed" + operation="get_all_vendors", reason="Database query failed" ) def get_vendor_by_id(self, db: Session, vendor_id: int) -> Vendor: @@ -290,7 +288,7 @@ class AdminService: raise VendorVerificationException( vendor_id=vendor_id, reason="Database update failed", - current_verification_status=original_status + current_verification_status=original_status, ) def toggle_vendor_status(self, db: Session, vendor_id: int) -> Tuple[Vendor, str]: @@ -317,7 +315,7 @@ class AdminService: operation="toggle_vendor_status", reason="Database update failed", target_type="vendor", - target_id=str(vendor_id) + target_id=str(vendor_id), ) def delete_vendor(self, db: Session, vendor_id: int) -> str: @@ -345,15 +343,11 @@ class AdminService: db.rollback() logger.error(f"Failed to delete vendor {vendor_id}: {str(e)}") raise AdminOperationException( - operation="delete_vendor", - reason="Database deletion failed" + operation="delete_vendor", reason="Database deletion failed" ) def update_vendor( - self, - db: Session, - vendor_id: int, - vendor_update # VendorUpdate schema + self, db: Session, vendor_id: int, vendor_update # VendorUpdate schema ) -> Vendor: """ Update vendor information (Admin only). @@ -387,11 +381,18 @@ class AdminService: update_data = vendor_update.model_dump(exclude_unset=True) # Check subdomain uniqueness if changing - if 'subdomain' in update_data and update_data['subdomain'] != vendor.subdomain: - existing = db.query(Vendor).filter( - Vendor.subdomain == update_data['subdomain'], - Vendor.id != vendor_id - ).first() + if ( + "subdomain" in update_data + and update_data["subdomain"] != vendor.subdomain + ): + existing = ( + db.query(Vendor) + .filter( + Vendor.subdomain == update_data["subdomain"], + Vendor.id != vendor_id, + ) + .first() + ) if existing: raise ValidationException( f"Subdomain '{update_data['subdomain']}' is already taken" @@ -419,17 +420,16 @@ class AdminService: db.rollback() logger.error(f"Failed to update vendor {vendor_id}: {str(e)}") raise AdminOperationException( - operation="update_vendor", - reason=f"Database update failed: {str(e)}" + operation="update_vendor", reason=f"Database update failed: {str(e)}" ) # Add this NEW method for transferring ownership: def transfer_vendor_ownership( - self, - db: Session, - vendor_id: int, - transfer_data # VendorTransferOwnership schema + self, + db: Session, + vendor_id: int, + transfer_data, # VendorTransferOwnership schema ) -> Tuple[Vendor, User, User]: """ Transfer vendor ownership to another user. @@ -466,9 +466,9 @@ class AdminService: old_owner = vendor.owner # Get new owner - new_owner = db.query(User).filter( - User.id == transfer_data.new_owner_user_id - ).first() + new_owner = ( + db.query(User).filter(User.id == transfer_data.new_owner_user_id).first() + ) if not new_owner: raise UserNotFoundException(str(transfer_data.new_owner_user_id)) @@ -487,26 +487,32 @@ class AdminService: try: # Get Owner role for this vendor - owner_role = db.query(Role).filter( - Role.vendor_id == vendor_id, - Role.name == "Owner" - ).first() + owner_role = ( + db.query(Role) + .filter(Role.vendor_id == vendor_id, Role.name == "Owner") + .first() + ) if not owner_role: raise ValidationException("Owner role not found for vendor") # Get Manager role (to demote old owner) - manager_role = db.query(Role).filter( - Role.vendor_id == vendor_id, - Role.name == "Manager" - ).first() + manager_role = ( + db.query(Role) + .filter(Role.vendor_id == vendor_id, Role.name == "Manager") + .first() + ) # Remove old owner from Owner role - old_owner_link = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor_id, - VendorUser.user_id == old_owner.id, - VendorUser.role_id == owner_role.id - ).first() + old_owner_link = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor_id, + VendorUser.user_id == old_owner.id, + VendorUser.role_id == owner_role.id, + ) + .first() + ) if old_owner_link: if manager_role: @@ -525,10 +531,14 @@ class AdminService: ) # Check if new owner already has a vendor_user link - new_owner_link = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor_id, - VendorUser.user_id == new_owner.id - ).first() + new_owner_link = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor_id, + VendorUser.user_id == new_owner.id, + ) + .first() + ) if new_owner_link: # Update existing link to Owner role @@ -540,7 +550,7 @@ class AdminService: vendor_id=vendor_id, user_id=new_owner.id, role_id=owner_role.id, - is_active=True + is_active=True, ) db.add(new_owner_link) @@ -568,10 +578,12 @@ class AdminService: raise except Exception as e: db.rollback() - logger.error(f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}") + logger.error( + f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}" + ) raise AdminOperationException( operation="transfer_vendor_ownership", - reason=f"Ownership transfer failed: {str(e)}" + reason=f"Ownership transfer failed: {str(e)}", ) # ============================================================================ @@ -579,13 +591,13 @@ class AdminService: # ============================================================================ def get_marketplace_import_jobs( - self, - db: Session, - marketplace: Optional[str] = None, - vendor_name: Optional[str] = None, - status: Optional[str] = None, - skip: int = 0, - limit: int = 100, + self, + db: Session, + marketplace: Optional[str] = None, + vendor_name: Optional[str] = None, + status: Optional[str] = None, + skip: int = 0, + limit: int = 100, ) -> List[MarketplaceImportJobResponse]: """Get filtered and paginated marketplace import jobs.""" try: @@ -596,7 +608,9 @@ class AdminService: MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") ) if vendor_name: - query = query.filter(MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%")) + query = query.filter( + MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%") + ) if status: query = query.filter(MarketplaceImportJob.status == status) @@ -612,8 +626,7 @@ class AdminService: except Exception as e: logger.error(f"Failed to retrieve marketplace import jobs: {str(e)}") raise AdminOperationException( - operation="get_marketplace_import_jobs", - reason="Database query failed" + operation="get_marketplace_import_jobs", reason="Database query failed" ) # ============================================================================ @@ -624,10 +637,7 @@ class AdminService: """Get recently created vendors.""" try: vendors = ( - db.query(Vendor) - .order_by(Vendor.created_at.desc()) - .limit(limit) - .all() + db.query(Vendor).order_by(Vendor.created_at.desc()).limit(limit).all() ) return [ @@ -638,7 +648,7 @@ class AdminService: "subdomain": v.subdomain, "is_active": v.is_active, "is_verified": v.is_verified, - "created_at": v.created_at + "created_at": v.created_at, } for v in vendors ] @@ -663,7 +673,7 @@ class AdminService: "vendor_name": j.vendor_name, "status": j.status, "total_processed": j.total_processed or 0, - "created_at": j.created_at + "created_at": j.created_at, } for j in jobs ] @@ -692,47 +702,53 @@ class AdminService: def _generate_temp_password(self, length: int = 12) -> str: """Generate secure temporary password.""" alphabet = string.ascii_letters + string.digits + "!@#$%^&*" - return ''.join(secrets.choice(alphabet) for _ in range(length)) + return "".join(secrets.choice(alphabet) for _ in range(length)) def _create_default_roles(self, db: Session, vendor_id: int): """Create default roles for a new vendor.""" default_roles = [ - { - "name": "Owner", - "permissions": ["*"] # Full access - }, + {"name": "Owner", "permissions": ["*"]}, # Full access { "name": "Manager", "permissions": [ - "products.*", "orders.*", "customers.view", - "inventory.*", "team.view" - ] + "products.*", + "orders.*", + "customers.view", + "inventory.*", + "team.view", + ], }, { "name": "Editor", "permissions": [ - "products.view", "products.edit", - "orders.view", "inventory.view" - ] + "products.view", + "products.edit", + "orders.view", + "inventory.view", + ], }, { "name": "Viewer", "permissions": [ - "products.view", "orders.view", - "customers.view", "inventory.view" - ] - } + "products.view", + "orders.view", + "customers.view", + "inventory.view", + ], + }, ] for role_data in default_roles: role = Role( vendor_id=vendor_id, name=role_data["name"], - permissions=role_data["permissions"] + permissions=role_data["permissions"], ) db.add(role) - def _convert_job_to_response(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse: + def _convert_job_to_response( + self, job: MarketplaceImportJob + ) -> MarketplaceImportJobResponse: """Convert database model to response schema.""" return MarketplaceImportJobResponse( job_id=job.id, diff --git a/app/services/admin_settings_service.py b/app/services/admin_settings_service.py index dec031af..fa4565f3 100644 --- a/app/services/admin_settings_service.py +++ b/app/services/admin_settings_service.py @@ -8,25 +8,19 @@ This module provides functions for: - Encrypting sensitive settings """ -import logging import json -from typing import Optional, List, Any, Dict +import logging from datetime import datetime, timezone +from typing import Any, Dict, List, Optional -from sqlalchemy.orm import Session from sqlalchemy import func +from sqlalchemy.orm import Session +from app.exceptions import (AdminOperationException, ResourceNotFoundException, + ValidationException) from models.database.admin import AdminSetting -from models.schema.admin import ( - AdminSettingCreate, - AdminSettingResponse, - AdminSettingUpdate -) -from app.exceptions import ( - AdminOperationException, - ValidationException, - ResourceNotFoundException -) +from models.schema.admin import (AdminSettingCreate, AdminSettingResponse, + AdminSettingUpdate) logger = logging.getLogger(__name__) @@ -34,26 +28,19 @@ logger = logging.getLogger(__name__) class AdminSettingsService: """Service for managing platform-wide settings.""" - def get_setting_by_key( - self, - db: Session, - key: str - ) -> Optional[AdminSetting]: + def get_setting_by_key(self, db: Session, key: str) -> Optional[AdminSetting]: """Get setting by key.""" try: - return db.query(AdminSetting).filter( - func.lower(AdminSetting.key) == key.lower() - ).first() + return ( + db.query(AdminSetting) + .filter(func.lower(AdminSetting.key) == key.lower()) + .first() + ) except Exception as e: logger.error(f"Failed to get setting {key}: {str(e)}") return None - def get_setting_value( - self, - db: Session, - key: str, - default: Any = None - ) -> Any: + def get_setting_value(self, db: Session, key: str, default: Any = None) -> Any: """ Get setting value with type conversion. @@ -76,7 +63,7 @@ class AdminSettingsService: elif setting.value_type == "float": return float(setting.value) elif setting.value_type == "boolean": - return setting.value.lower() in ('true', '1', 'yes') + return setting.value.lower() in ("true", "1", "yes") elif setting.value_type == "json": return json.loads(setting.value) else: @@ -86,10 +73,10 @@ class AdminSettingsService: return default def get_all_settings( - self, - db: Session, - category: Optional[str] = None, - is_public: Optional[bool] = None + self, + db: Session, + category: Optional[str] = None, + is_public: Optional[bool] = None, ) -> List[AdminSettingResponse]: """Get all settings with optional filtering.""" try: @@ -104,22 +91,16 @@ class AdminSettingsService: settings = query.order_by(AdminSetting.category, AdminSetting.key).all() return [ - AdminSettingResponse.model_validate(setting) - for setting in settings + AdminSettingResponse.model_validate(setting) for setting in settings ] except Exception as e: logger.error(f"Failed to get settings: {str(e)}") raise AdminOperationException( - operation="get_all_settings", - reason="Database query failed" + operation="get_all_settings", reason="Database query failed" ) - def get_settings_by_category( - self, - db: Session, - category: str - ) -> Dict[str, Any]: + def get_settings_by_category(self, db: Session, category: str) -> Dict[str, Any]: """ Get all settings in a category as a dictionary. @@ -136,7 +117,7 @@ class AdminSettingsService: elif setting.value_type == "float": result[setting.key] = float(setting.value) elif setting.value_type == "boolean": - result[setting.key] = setting.value.lower() in ('true', '1', 'yes') + result[setting.key] = setting.value.lower() in ("true", "1", "yes") elif setting.value_type == "json": result[setting.key] = json.loads(setting.value) else: @@ -145,10 +126,7 @@ class AdminSettingsService: return result def create_setting( - self, - db: Session, - setting_data: AdminSettingCreate, - admin_user_id: int + self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int ) -> AdminSettingResponse: """Create new setting.""" try: @@ -176,7 +154,7 @@ class AdminSettingsService: description=setting_data.description, is_encrypted=setting_data.is_encrypted, is_public=setting_data.is_public, - last_modified_by_user_id=admin_user_id + last_modified_by_user_id=admin_user_id, ) db.add(setting) @@ -194,25 +172,17 @@ class AdminSettingsService: db.rollback() logger.error(f"Failed to create setting: {str(e)}") raise AdminOperationException( - operation="create_setting", - reason="Database operation failed" + operation="create_setting", reason="Database operation failed" ) def update_setting( - self, - db: Session, - key: str, - update_data: AdminSettingUpdate, - admin_user_id: int + self, db: Session, key: str, update_data: AdminSettingUpdate, admin_user_id: int ) -> AdminSettingResponse: """Update existing setting.""" setting = self.get_setting_by_key(db, key) if not setting: - raise ResourceNotFoundException( - resource_type="setting", - identifier=key - ) + raise ResourceNotFoundException(resource_type="setting", identifier=key) try: # Validate new value @@ -244,42 +214,29 @@ class AdminSettingsService: db.rollback() logger.error(f"Failed to update setting {key}: {str(e)}") raise AdminOperationException( - operation="update_setting", - reason="Database operation failed" + operation="update_setting", reason="Database operation failed" ) def upsert_setting( - self, - db: Session, - setting_data: AdminSettingCreate, - admin_user_id: int + self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int ) -> AdminSettingResponse: """Create or update setting (upsert).""" existing = self.get_setting_by_key(db, setting_data.key) if existing: update_data = AdminSettingUpdate( - value=setting_data.value, - description=setting_data.description + value=setting_data.value, description=setting_data.description ) return self.update_setting(db, setting_data.key, update_data, admin_user_id) else: return self.create_setting(db, setting_data, admin_user_id) - def delete_setting( - self, - db: Session, - key: str, - admin_user_id: int - ) -> str: + def delete_setting(self, db: Session, key: str, admin_user_id: int) -> str: """Delete setting.""" setting = self.get_setting_by_key(db, key) if not setting: - raise ResourceNotFoundException( - resource_type="setting", - identifier=key - ) + raise ResourceNotFoundException(resource_type="setting", identifier=key) try: db.delete(setting) @@ -293,8 +250,7 @@ class AdminSettingsService: db.rollback() logger.error(f"Failed to delete setting {key}: {str(e)}") raise AdminOperationException( - operation="delete_setting", - reason="Database operation failed" + operation="delete_setting", reason="Database operation failed" ) # ============================================================================ @@ -309,7 +265,7 @@ class AdminSettingsService: elif value_type == "float": float(value) elif value_type == "boolean": - if value.lower() not in ('true', 'false', '1', '0', 'yes', 'no'): + if value.lower() not in ("true", "false", "1", "0", "yes", "no"): raise ValueError("Invalid boolean value") elif value_type == "json": json.loads(value) diff --git a/app/services/audit_service.py b/app/services/audit_service.py index ea03eb12..51cb387c 100644 --- a/app/services/audit_service.py +++ b/app/services/audit_service.py @@ -1 +1 @@ -# Audit logging services +# Audit logging services diff --git a/app/services/auth_service.py b/app/services/auth_service.py index c2c6c400..f05da054 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -13,15 +13,12 @@ from typing import Any, Dict, Optional from sqlalchemy.orm import Session -from app.exceptions import ( - UserAlreadyExistsException, - InvalidCredentialsException, - UserNotActiveException, - ValidationException, -) +from app.exceptions import (InvalidCredentialsException, + UserAlreadyExistsException, UserNotActiveException, + ValidationException) from middleware.auth import AuthManager -from models.schema.auth import UserLogin, UserRegister from models.database.user import User +from models.schema.auth import UserLogin, UserRegister logger = logging.getLogger(__name__) @@ -51,11 +48,15 @@ class AuthService: try: # Check if email already exists if self._email_exists(db, user_data.email): - raise UserAlreadyExistsException("Email already registered", field="email") + raise UserAlreadyExistsException( + "Email already registered", field="email" + ) # Check if username already exists if self._username_exists(db, user_data.username): - raise UserAlreadyExistsException("Username already taken", field="username") + raise UserAlreadyExistsException( + "Username already taken", field="username" + ) # Hash password and create user hashed_password = self.auth_manager.hash_password(user_data.password) @@ -182,7 +183,9 @@ class AuthService: Dictionary with access_token, token_type, and expires_in """ from datetime import datetime, timedelta, timezone + from jose import jwt + from app.core.config import settings try: @@ -217,6 +220,5 @@ class AuthService: return db.query(User).filter(User.username == username).first() is not None - # Create service instance following the same pattern as other services auth_service = AuthService() diff --git a/app/services/backup_service.py b/app/services/backup_service.py index b4d8b0d4..0759bea8 100644 --- a/app/services/backup_service.py +++ b/app/services/backup_service.py @@ -1 +1 @@ -# Backup and recovery services +# Backup and recovery services diff --git a/app/services/cache_service.py b/app/services/cache_service.py index b0c83d5f..63e64747 100644 --- a/app/services/cache_service.py +++ b/app/services/cache_service.py @@ -1 +1 @@ -# Caching services +# Caching services diff --git a/app/services/cart_service.py b/app/services/cart_service.py index d566b244..b3c8ddd1 100644 --- a/app/services/cart_service.py +++ b/app/services/cart_service.py @@ -9,23 +9,20 @@ This module provides: """ import logging -from typing import Dict, List, Optional from datetime import datetime, timezone +from typing import Dict, List, Optional -from sqlalchemy.orm import Session from sqlalchemy import and_ +from sqlalchemy.orm import Session +from app.exceptions import (CartItemNotFoundException, CartValidationException, + InsufficientInventoryForCartException, + InvalidCartQuantityException, + ProductNotAvailableForCartException, + ProductNotFoundException) +from models.database.cart import CartItem from models.database.product import Product from models.database.vendor import Vendor -from models.database.cart import CartItem -from app.exceptions import ( - ProductNotFoundException, - CartItemNotFoundException, - CartValidationException, - InsufficientInventoryForCartException, - InvalidCartQuantityException, - ProductNotAvailableForCartException, -) logger = logging.getLogger(__name__) @@ -33,12 +30,7 @@ logger = logging.getLogger(__name__) class CartService: """Service for managing shopping carts.""" - def get_cart( - self, - db: Session, - vendor_id: int, - session_id: str - ) -> Dict: + def get_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict: """ Get cart contents for a session. @@ -55,20 +47,21 @@ class CartService: extra={ "vendor_id": vendor_id, "session_id": session_id, - } + }, ) # Fetch cart items from database - cart_items = db.query(CartItem).filter( - and_( - CartItem.vendor_id == vendor_id, - CartItem.session_id == session_id + cart_items = ( + db.query(CartItem) + .filter( + and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id) ) - ).all() + .all() + ) logger.info( f"[CART_SERVICE] Found {len(cart_items)} items in database", - extra={"item_count": len(cart_items)} + extra={"item_count": len(cart_items)}, ) # Build response @@ -79,14 +72,20 @@ class CartService: product = cart_item.product line_total = cart_item.line_total - items.append({ - "product_id": product.id, - "product_name": product.marketplace_product.title, - "quantity": cart_item.quantity, - "price": cart_item.price_at_add, - "line_total": line_total, - "image_url": product.marketplace_product.image_link if product.marketplace_product else None, - }) + items.append( + { + "product_id": product.id, + "product_name": product.marketplace_product.title, + "quantity": cart_item.quantity, + "price": cart_item.price_at_add, + "line_total": line_total, + "image_url": ( + product.marketplace_product.image_link + if product.marketplace_product + else None + ), + } + ) subtotal += line_total @@ -95,23 +94,23 @@ class CartService: "session_id": session_id, "items": items, "subtotal": subtotal, - "total": subtotal # Could add tax/shipping later + "total": subtotal, # Could add tax/shipping later } logger.info( f"[CART_SERVICE] get_cart returning: {len(cart_data['items'])} items, total: {cart_data['total']}", - extra={"cart": cart_data} + extra={"cart": cart_data}, ) return cart_data def add_to_cart( - self, - db: Session, - vendor_id: int, - session_id: str, - product_id: int, - quantity: int = 1 + self, + db: Session, + vendor_id: int, + session_id: str, + product_id: int, + quantity: int = 1, ) -> Dict: """ Add product to cart. @@ -136,23 +135,27 @@ class CartService: "vendor_id": vendor_id, "session_id": session_id, "product_id": product_id, - "quantity": quantity - } + "quantity": quantity, + }, ) # Verify product exists and belongs to vendor - product = db.query(Product).filter( - and_( - Product.id == product_id, - Product.vendor_id == vendor_id, - Product.is_active == True + product = ( + db.query(Product) + .filter( + and_( + Product.id == product_id, + Product.vendor_id == vendor_id, + Product.is_active == True, + ) ) - ).first() + .first() + ) if not product: logger.error( f"[CART_SERVICE] Product not found", - extra={"product_id": product_id, "vendor_id": vendor_id} + extra={"product_id": product_id, "vendor_id": vendor_id}, ) raise ProductNotFoundException(product_id=product_id, vendor_id=vendor_id) @@ -161,21 +164,25 @@ class CartService: extra={ "product_id": product_id, "product_name": product.marketplace_product.title, - "available_inventory": product.available_inventory - } + "available_inventory": product.available_inventory, + }, ) # Get current price (use sale_price if available, otherwise regular price) current_price = product.sale_price if product.sale_price else product.price # Check if item already exists in cart - existing_item = db.query(CartItem).filter( - and_( - CartItem.vendor_id == vendor_id, - CartItem.session_id == session_id, - CartItem.product_id == product_id + existing_item = ( + db.query(CartItem) + .filter( + and_( + CartItem.vendor_id == vendor_id, + CartItem.session_id == session_id, + CartItem.product_id == product_id, + ) ) - ).first() + .first() + ) if existing_item: # Update quantity @@ -190,14 +197,14 @@ class CartService: "current_in_cart": existing_item.quantity, "adding": quantity, "requested_total": new_quantity, - "available": product.available_inventory - } + "available": product.available_inventory, + }, ) raise InsufficientInventoryForCartException( product_id=product_id, product_name=product.marketplace_product.title, requested=new_quantity, - available=product.available_inventory + available=product.available_inventory, ) existing_item.quantity = new_quantity @@ -206,16 +213,13 @@ class CartService: logger.info( f"[CART_SERVICE] Updated existing cart item", - extra={ - "cart_item_id": existing_item.id, - "new_quantity": new_quantity - } + extra={"cart_item_id": existing_item.id, "new_quantity": new_quantity}, ) return { "message": "Product quantity updated in cart", "product_id": product_id, - "quantity": new_quantity + "quantity": new_quantity, } else: # Check inventory for new item @@ -225,14 +229,14 @@ class CartService: extra={ "product_id": product_id, "requested": quantity, - "available": product.available_inventory - } + "available": product.available_inventory, + }, ) raise InsufficientInventoryForCartException( product_id=product_id, product_name=product.marketplace_product.title, requested=quantity, - available=product.available_inventory + available=product.available_inventory, ) # Create new cart item @@ -241,7 +245,7 @@ class CartService: session_id=session_id, product_id=product_id, quantity=quantity, - price_at_add=current_price + price_at_add=current_price, ) db.add(cart_item) db.commit() @@ -252,23 +256,23 @@ class CartService: extra={ "cart_item_id": cart_item.id, "quantity": quantity, - "price": current_price - } + "price": current_price, + }, ) return { "message": "Product added to cart", "product_id": product_id, - "quantity": quantity + "quantity": quantity, } def update_cart_item( - self, - db: Session, - vendor_id: int, - session_id: str, - product_id: int, - quantity: int + self, + db: Session, + vendor_id: int, + session_id: str, + product_id: int, + quantity: int, ) -> Dict: """ Update quantity of item in cart. @@ -292,25 +296,35 @@ class CartService: raise InvalidCartQuantityException(quantity=quantity, min_quantity=1) # Find cart item - cart_item = db.query(CartItem).filter( - and_( - CartItem.vendor_id == vendor_id, - CartItem.session_id == session_id, - CartItem.product_id == product_id + cart_item = ( + db.query(CartItem) + .filter( + and_( + CartItem.vendor_id == vendor_id, + CartItem.session_id == session_id, + CartItem.product_id == product_id, + ) ) - ).first() + .first() + ) if not cart_item: - raise CartItemNotFoundException(product_id=product_id, session_id=session_id) + raise CartItemNotFoundException( + product_id=product_id, session_id=session_id + ) # Verify product still exists and is active - product = db.query(Product).filter( - and_( - Product.id == product_id, - Product.vendor_id == vendor_id, - Product.is_active == True + product = ( + db.query(Product) + .filter( + and_( + Product.id == product_id, + Product.vendor_id == vendor_id, + Product.is_active == True, + ) ) - ).first() + .first() + ) if not product: raise ProductNotFoundException(str(product_id)) @@ -321,7 +335,7 @@ class CartService: product_id=product_id, product_name=product.marketplace_product.title, requested=quantity, - available=product.available_inventory + available=product.available_inventory, ) # Update quantity @@ -334,22 +348,18 @@ class CartService: extra={ "cart_item_id": cart_item.id, "product_id": product_id, - "new_quantity": quantity - } + "new_quantity": quantity, + }, ) return { "message": "Cart updated", "product_id": product_id, - "quantity": quantity + "quantity": quantity, } def remove_from_cart( - self, - db: Session, - vendor_id: int, - session_id: str, - product_id: int + self, db: Session, vendor_id: int, session_id: str, product_id: int ) -> Dict: """ Remove item from cart. @@ -367,16 +377,22 @@ class CartService: ProductNotFoundException: If product not in cart """ # Find and delete cart item - cart_item = db.query(CartItem).filter( - and_( - CartItem.vendor_id == vendor_id, - CartItem.session_id == session_id, - CartItem.product_id == product_id + cart_item = ( + db.query(CartItem) + .filter( + and_( + CartItem.vendor_id == vendor_id, + CartItem.session_id == session_id, + CartItem.product_id == product_id, + ) ) - ).first() + .first() + ) if not cart_item: - raise CartItemNotFoundException(product_id=product_id, session_id=session_id) + raise CartItemNotFoundException( + product_id=product_id, session_id=session_id + ) db.delete(cart_item) db.commit() @@ -386,21 +402,13 @@ class CartService: extra={ "cart_item_id": cart_item.id, "product_id": product_id, - "session_id": session_id - } + "session_id": session_id, + }, ) - return { - "message": "Item removed from cart", - "product_id": product_id - } + return {"message": "Item removed from cart", "product_id": product_id} - def clear_cart( - self, - db: Session, - vendor_id: int, - session_id: str - ) -> Dict: + def clear_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict: """ Clear all items from cart. @@ -413,12 +421,13 @@ class CartService: Success message with count of items removed """ # Delete all cart items for this session - deleted_count = db.query(CartItem).filter( - and_( - CartItem.vendor_id == vendor_id, - CartItem.session_id == session_id + deleted_count = ( + db.query(CartItem) + .filter( + and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id) ) - ).delete() + .delete() + ) db.commit() @@ -427,14 +436,11 @@ class CartService: extra={ "session_id": session_id, "vendor_id": vendor_id, - "items_removed": deleted_count - } + "items_removed": deleted_count, + }, ) - return { - "message": "Cart cleared", - "items_removed": deleted_count - } + return {"message": "Cart cleared", "items_removed": deleted_count} # Create service instance diff --git a/app/services/code_quality_service.py b/app/services/code_quality_service.py index d129ff2b..ff058e78 100644 --- a/app/services/code_quality_service.py +++ b/app/services/code_quality_service.py @@ -3,22 +3,20 @@ Code Quality Service Business logic for managing architecture scans and violations """ -import subprocess import json import logging +import subprocess from datetime import datetime -from typing import List, Tuple, Optional, Dict from pathlib import Path -from sqlalchemy.orm import Session -from sqlalchemy import func, desc +from typing import Dict, List, Optional, Tuple -from app.models.architecture_scan import ( - ArchitectureScan, - ArchitectureViolation, - ArchitectureRule, - ViolationAssignment, - ViolationComment -) +from sqlalchemy import desc, func +from sqlalchemy.orm import Session + +from app.models.architecture_scan import (ArchitectureRule, ArchitectureScan, + ArchitectureViolation, + ViolationAssignment, + ViolationComment) logger = logging.getLogger(__name__) @@ -26,7 +24,7 @@ logger = logging.getLogger(__name__) class CodeQualityService: """Service for managing code quality scans and violations""" - def run_scan(self, db: Session, triggered_by: str = 'manual') -> ArchitectureScan: + def run_scan(self, db: Session, triggered_by: str = "manual") -> ArchitectureScan: """ Run architecture validator and store results in database @@ -49,10 +47,10 @@ class CodeQualityService: start_time = datetime.now() try: result = subprocess.run( - ['python', 'scripts/validate_architecture.py', '--json'], + ["python", "scripts/validate_architecture.py", "--json"], capture_output=True, text=True, - timeout=300 # 5 minute timeout + timeout=300, # 5 minute timeout ) except subprocess.TimeoutExpired: logger.error("Architecture scan timed out after 5 minutes") @@ -63,17 +61,17 @@ class CodeQualityService: # Parse JSON output (get only the JSON part, skip progress messages) try: # Find the JSON part in stdout - lines = result.stdout.strip().split('\n') + lines = result.stdout.strip().split("\n") json_start = -1 for i, line in enumerate(lines): - if line.strip().startswith('{'): + if line.strip().startswith("{"): json_start = i break if json_start == -1: raise ValueError("No JSON output found") - json_output = '\n'.join(lines[json_start:]) + json_output = "\n".join(lines[json_start:]) data = json.loads(json_output) except (json.JSONDecodeError, ValueError) as e: logger.error(f"Failed to parse validator output: {e}") @@ -84,33 +82,33 @@ class CodeQualityService: # Create scan record scan = ArchitectureScan( timestamp=datetime.now(), - total_files=data.get('files_checked', 0), - total_violations=data.get('total_violations', 0), - errors=data.get('errors', 0), - warnings=data.get('warnings', 0), + total_files=data.get("files_checked", 0), + total_violations=data.get("total_violations", 0), + errors=data.get("errors", 0), + warnings=data.get("warnings", 0), duration_seconds=duration, triggered_by=triggered_by, - git_commit_hash=git_commit + git_commit_hash=git_commit, ) db.add(scan) db.flush() # Get scan.id # Create violation records - violations_data = data.get('violations', []) + violations_data = data.get("violations", []) logger.info(f"Creating {len(violations_data)} violation records") for v in violations_data: violation = ArchitectureViolation( scan_id=scan.id, - rule_id=v['rule_id'], - rule_name=v['rule_name'], - severity=v['severity'], - file_path=v['file_path'], - line_number=v['line_number'], - message=v['message'], - context=v.get('context', ''), - suggestion=v.get('suggestion', ''), - status='open' + rule_id=v["rule_id"], + rule_name=v["rule_name"], + severity=v["severity"], + file_path=v["file_path"], + line_number=v["line_number"], + message=v["message"], + context=v.get("context", ""), + suggestion=v.get("suggestion", ""), + status="open", ) db.add(violation) @@ -122,7 +120,11 @@ class CodeQualityService: def get_latest_scan(self, db: Session) -> Optional[ArchitectureScan]: """Get the most recent scan""" - return db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)).first() + return ( + db.query(ArchitectureScan) + .order_by(desc(ArchitectureScan.timestamp)) + .first() + ) def get_scan_by_id(self, db: Session, scan_id: int) -> Optional[ArchitectureScan]: """Get scan by ID""" @@ -139,10 +141,12 @@ class CodeQualityService: Returns: List of ArchitectureScan objects, newest first """ - return db.query(ArchitectureScan)\ - .order_by(desc(ArchitectureScan.timestamp))\ - .limit(limit)\ + return ( + db.query(ArchitectureScan) + .order_by(desc(ArchitectureScan.timestamp)) + .limit(limit) .all() + ) def get_violations( self, @@ -153,7 +157,7 @@ class CodeQualityService: rule_id: str = None, file_path: str = None, limit: int = 100, - offset: int = 0 + offset: int = 0, ) -> Tuple[List[ArchitectureViolation], int]: """ Get violations with filtering and pagination @@ -194,24 +198,32 @@ class CodeQualityService: query = query.filter(ArchitectureViolation.rule_id == rule_id) if file_path: - query = query.filter(ArchitectureViolation.file_path.like(f'%{file_path}%')) + query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%")) # Get total count total = query.count() # Get page of results - violations = query.order_by( - ArchitectureViolation.severity.desc(), - ArchitectureViolation.file_path - ).limit(limit).offset(offset).all() + violations = ( + query.order_by( + ArchitectureViolation.severity.desc(), ArchitectureViolation.file_path + ) + .limit(limit) + .offset(offset) + .all() + ) return violations, total - def get_violation_by_id(self, db: Session, violation_id: int) -> Optional[ArchitectureViolation]: + def get_violation_by_id( + self, db: Session, violation_id: int + ) -> Optional[ArchitectureViolation]: """Get single violation with details""" - return db.query(ArchitectureViolation).filter( - ArchitectureViolation.id == violation_id - ).first() + return ( + db.query(ArchitectureViolation) + .filter(ArchitectureViolation.id == violation_id) + .first() + ) def assign_violation( self, @@ -220,7 +232,7 @@ class CodeQualityService: user_id: int, assigned_by: int, due_date: datetime = None, - priority: str = 'medium' + priority: str = "medium", ) -> ViolationAssignment: """ Assign violation to a developer @@ -239,7 +251,7 @@ class CodeQualityService: # Update violation status violation = self.get_violation_by_id(db, violation_id) if violation: - violation.status = 'assigned' + violation.status = "assigned" violation.assigned_to = user_id # Create assignment record @@ -248,7 +260,7 @@ class CodeQualityService: user_id=user_id, assigned_by=assigned_by, due_date=due_date, - priority=priority + priority=priority, ) db.add(assignment) db.commit() @@ -257,11 +269,7 @@ class CodeQualityService: return assignment def resolve_violation( - self, - db: Session, - violation_id: int, - resolved_by: int, - resolution_note: str + self, db: Session, violation_id: int, resolved_by: int, resolution_note: str ) -> ArchitectureViolation: """ Mark violation as resolved @@ -279,7 +287,7 @@ class CodeQualityService: if not violation: raise ValueError(f"Violation {violation_id} not found") - violation.status = 'resolved' + violation.status = "resolved" violation.resolved_at = datetime.now() violation.resolved_by = resolved_by violation.resolution_note = resolution_note @@ -289,11 +297,7 @@ class CodeQualityService: return violation def ignore_violation( - self, - db: Session, - violation_id: int, - ignored_by: int, - reason: str + self, db: Session, violation_id: int, ignored_by: int, reason: str ) -> ArchitectureViolation: """ Mark violation as ignored/won't fix @@ -311,7 +315,7 @@ class CodeQualityService: if not violation: raise ValueError(f"Violation {violation_id} not found") - violation.status = 'ignored' + violation.status = "ignored" violation.resolved_at = datetime.now() violation.resolved_by = ignored_by violation.resolution_note = f"Ignored: {reason}" @@ -321,11 +325,7 @@ class CodeQualityService: return violation def add_comment( - self, - db: Session, - violation_id: int, - user_id: int, - comment: str + self, db: Session, violation_id: int, user_id: int, comment: str ) -> ViolationComment: """ Add comment to violation @@ -340,9 +340,7 @@ class CodeQualityService: ViolationComment object """ comment_obj = ViolationComment( - violation_id=violation_id, - user_id=user_id, - comment=comment + violation_id=violation_id, user_id=user_id, comment=comment ) db.add(comment_obj) db.commit() @@ -360,79 +358,95 @@ class CodeQualityService: latest_scan = self.get_latest_scan(db) if not latest_scan: return { - 'total_violations': 0, - 'errors': 0, - 'warnings': 0, - 'open': 0, - 'assigned': 0, - 'resolved': 0, - 'ignored': 0, - 'technical_debt_score': 100, - 'trend': [], - 'by_severity': {}, - 'by_rule': {}, - 'by_module': {}, - 'top_files': [] + "total_violations": 0, + "errors": 0, + "warnings": 0, + "open": 0, + "assigned": 0, + "resolved": 0, + "ignored": 0, + "technical_debt_score": 100, + "trend": [], + "by_severity": {}, + "by_rule": {}, + "by_module": {}, + "top_files": [], } # Get violation counts by status - status_counts = db.query( - ArchitectureViolation.status, - func.count(ArchitectureViolation.id) - ).filter( - ArchitectureViolation.scan_id == latest_scan.id - ).group_by(ArchitectureViolation.status).all() + status_counts = ( + db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id)) + .filter(ArchitectureViolation.scan_id == latest_scan.id) + .group_by(ArchitectureViolation.status) + .all() + ) status_dict = {status: count for status, count in status_counts} # Get violations by severity - severity_counts = db.query( - ArchitectureViolation.severity, - func.count(ArchitectureViolation.id) - ).filter( - ArchitectureViolation.scan_id == latest_scan.id - ).group_by(ArchitectureViolation.severity).all() + severity_counts = ( + db.query( + ArchitectureViolation.severity, func.count(ArchitectureViolation.id) + ) + .filter(ArchitectureViolation.scan_id == latest_scan.id) + .group_by(ArchitectureViolation.severity) + .all() + ) by_severity = {sev: count for sev, count in severity_counts} # Get violations by rule - rule_counts = db.query( - ArchitectureViolation.rule_id, - func.count(ArchitectureViolation.id) - ).filter( - ArchitectureViolation.scan_id == latest_scan.id - ).group_by(ArchitectureViolation.rule_id).all() + rule_counts = ( + db.query( + ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id) + ) + .filter(ArchitectureViolation.scan_id == latest_scan.id) + .group_by(ArchitectureViolation.rule_id) + .all() + ) - by_rule = {rule: count for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]} + by_rule = { + rule: count + for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[ + :10 + ] + } # Get top violating files - file_counts = db.query( - ArchitectureViolation.file_path, - func.count(ArchitectureViolation.id).label('count') - ).filter( - ArchitectureViolation.scan_id == latest_scan.id - ).group_by(ArchitectureViolation.file_path)\ - .order_by(desc('count'))\ - .limit(10).all() + file_counts = ( + db.query( + ArchitectureViolation.file_path, + func.count(ArchitectureViolation.id).label("count"), + ) + .filter(ArchitectureViolation.scan_id == latest_scan.id) + .group_by(ArchitectureViolation.file_path) + .order_by(desc("count")) + .limit(10) + .all() + ) - top_files = [{'file': file, 'count': count} for file, count in file_counts] + top_files = [{"file": file, "count": count} for file, count in file_counts] # Get violations by module (extract module from file path) by_module = {} - violations = db.query(ArchitectureViolation.file_path).filter( - ArchitectureViolation.scan_id == latest_scan.id - ).all() + violations = ( + db.query(ArchitectureViolation.file_path) + .filter(ArchitectureViolation.scan_id == latest_scan.id) + .all() + ) for v in violations: - path_parts = v.file_path.split('/') + path_parts = v.file_path.split("/") if len(path_parts) >= 2: - module = '/'.join(path_parts[:2]) # e.g., 'app/api' + module = "/".join(path_parts[:2]) # e.g., 'app/api' else: module = path_parts[0] by_module[module] = by_module.get(module, 0) + 1 # Sort by count and take top 10 - by_module = dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]) + by_module = dict( + sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10] + ) # Calculate technical debt score tech_debt_score = self.calculate_technical_debt_score(db, latest_scan.id) @@ -441,29 +455,29 @@ class CodeQualityService: trend_scans = self.get_scan_history(db, limit=7) trend = [ { - 'timestamp': scan.timestamp.isoformat(), - 'violations': scan.total_violations, - 'errors': scan.errors, - 'warnings': scan.warnings + "timestamp": scan.timestamp.isoformat(), + "violations": scan.total_violations, + "errors": scan.errors, + "warnings": scan.warnings, } for scan in reversed(trend_scans) # Oldest first for chart ] return { - 'total_violations': latest_scan.total_violations, - 'errors': latest_scan.errors, - 'warnings': latest_scan.warnings, - 'open': status_dict.get('open', 0), - 'assigned': status_dict.get('assigned', 0), - 'resolved': status_dict.get('resolved', 0), - 'ignored': status_dict.get('ignored', 0), - 'technical_debt_score': tech_debt_score, - 'trend': trend, - 'by_severity': by_severity, - 'by_rule': by_rule, - 'by_module': by_module, - 'top_files': top_files, - 'last_scan': latest_scan.timestamp.isoformat() if latest_scan else None + "total_violations": latest_scan.total_violations, + "errors": latest_scan.errors, + "warnings": latest_scan.warnings, + "open": status_dict.get("open", 0), + "assigned": status_dict.get("assigned", 0), + "resolved": status_dict.get("resolved", 0), + "ignored": status_dict.get("ignored", 0), + "technical_debt_score": tech_debt_score, + "trend": trend, + "by_severity": by_severity, + "by_rule": by_rule, + "by_module": by_module, + "top_files": top_files, + "last_scan": latest_scan.timestamp.isoformat() if latest_scan else None, } def calculate_technical_debt_score(self, db: Session, scan_id: int = None) -> int: @@ -497,10 +511,7 @@ class CodeQualityService: """Get current git commit hash""" try: result = subprocess.run( - ['git', 'rev-parse', 'HEAD'], - capture_output=True, - text=True, - timeout=5 + ["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: return result.stdout.strip()[:40] diff --git a/app/services/configuration_service.py b/app/services/configuration_service.py index ff83d3b4..54f872ab 100644 --- a/app/services/configuration_service.py +++ b/app/services/configuration_service.py @@ -1 +1 @@ -# Configuration management services +# Configuration management services diff --git a/app/services/content_page_service.py b/app/services/content_page_service.py index beb15fe4..af17bc7d 100644 --- a/app/services/content_page_service.py +++ b/app/services/content_page_service.py @@ -19,8 +19,9 @@ This allows: import logging from datetime import datetime, timezone from typing import List, Optional -from sqlalchemy.orm import Session + from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session from models.database.content_page import ContentPage @@ -35,7 +36,7 @@ class ContentPageService: db: Session, slug: str, vendor_id: Optional[int] = None, - include_unpublished: bool = False + include_unpublished: bool = False, ) -> Optional[ContentPage]: """ Get content page for a vendor with fallback to platform default. @@ -62,28 +63,20 @@ class ContentPageService: if vendor_id: vendor_page = ( db.query(ContentPage) - .filter( - and_( - ContentPage.vendor_id == vendor_id, - *filters - ) - ) + .filter(and_(ContentPage.vendor_id == vendor_id, *filters)) .first() ) if vendor_page: - logger.debug(f"Found vendor-specific page: {slug} for vendor_id={vendor_id}") + logger.debug( + f"Found vendor-specific page: {slug} for vendor_id={vendor_id}" + ) return vendor_page # Fallback to platform default platform_page = ( db.query(ContentPage) - .filter( - and_( - ContentPage.vendor_id == None, - *filters - ) - ) + .filter(and_(ContentPage.vendor_id == None, *filters)) .first() ) @@ -100,7 +93,7 @@ class ContentPageService: vendor_id: Optional[int] = None, include_unpublished: bool = False, footer_only: bool = False, - header_only: bool = False + header_only: bool = False, ) -> List[ContentPage]: """ List all available pages for a vendor (includes vendor overrides + platform defaults). @@ -133,12 +126,7 @@ class ContentPageService: if vendor_id: vendor_pages = ( db.query(ContentPage) - .filter( - and_( - ContentPage.vendor_id == vendor_id, - *filters - ) - ) + .filter(and_(ContentPage.vendor_id == vendor_id, *filters)) .order_by(ContentPage.display_order, ContentPage.title) .all() ) @@ -146,12 +134,7 @@ class ContentPageService: # Get platform defaults platform_pages = ( db.query(ContentPage) - .filter( - and_( - ContentPage.vendor_id == None, - *filters - ) - ) + .filter(and_(ContentPage.vendor_id == None, *filters)) .order_by(ContentPage.display_order, ContentPage.title) .all() ) @@ -159,8 +142,7 @@ class ContentPageService: # Merge: vendor overrides take precedence vendor_slugs = {page.slug for page in vendor_pages} all_pages = vendor_pages + [ - page for page in platform_pages - if page.slug not in vendor_slugs + page for page in platform_pages if page.slug not in vendor_slugs ] # Sort by display_order @@ -183,7 +165,7 @@ class ContentPageService: show_in_footer: bool = True, show_in_header: bool = False, display_order: int = 0, - created_by: Optional[int] = None + created_by: Optional[int] = None, ) -> ContentPage: """ Create a new content page. @@ -229,7 +211,9 @@ class ContentPageService: db.commit() db.refresh(page) - logger.info(f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})") + logger.info( + f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})" + ) return page @staticmethod @@ -246,7 +230,7 @@ class ContentPageService: show_in_footer: Optional[bool] = None, show_in_header: Optional[bool] = None, display_order: Optional[int] = None, - updated_by: Optional[int] = None + updated_by: Optional[int] = None, ) -> Optional[ContentPage]: """ Update an existing content page. @@ -338,9 +322,7 @@ class ContentPageService: @staticmethod def list_all_vendor_pages( - db: Session, - vendor_id: int, - include_unpublished: bool = False + db: Session, vendor_id: int, include_unpublished: bool = False ) -> List[ContentPage]: """ List only vendor-specific pages (no platform defaults). @@ -367,8 +349,7 @@ class ContentPageService: @staticmethod def list_all_platform_pages( - db: Session, - include_unpublished: bool = False + db: Session, include_unpublished: bool = False ) -> List[ContentPage]: """ List only platform default pages. diff --git a/app/services/customer_service.py b/app/services/customer_service.py index 53e70ee6..d76078b6 100644 --- a/app/services/customer_service.py +++ b/app/services/customer_service.py @@ -8,24 +8,24 @@ with complete vendor isolation. import logging from datetime import datetime, timedelta -from typing import Optional, Dict, Any -from sqlalchemy.orm import Session -from sqlalchemy import and_ +from typing import Any, Dict, Optional +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from app.exceptions.customer import (CustomerAlreadyExistsException, + CustomerNotActiveException, + CustomerNotFoundException, + CustomerValidationException, + DuplicateCustomerEmailException, + InvalidCustomerCredentialsException) +from app.exceptions.vendor import (VendorNotActiveException, + VendorNotFoundException) +from app.services.auth_service import AuthService from models.database.customer import Customer, CustomerAddress from models.database.vendor import Vendor -from models.schema.customer import CustomerRegister, CustomerUpdate from models.schema.auth import UserLogin -from app.exceptions.customer import ( - CustomerNotFoundException, - CustomerAlreadyExistsException, - CustomerNotActiveException, - InvalidCustomerCredentialsException, - CustomerValidationException, - DuplicateCustomerEmailException -) -from app.exceptions.vendor import VendorNotFoundException, VendorNotActiveException -from app.services.auth_service import AuthService +from models.schema.customer import CustomerRegister, CustomerUpdate logger = logging.getLogger(__name__) @@ -37,10 +37,7 @@ class CustomerService: self.auth_service = AuthService() def register_customer( - self, - db: Session, - vendor_id: int, - customer_data: CustomerRegister + self, db: Session, vendor_id: int, customer_data: CustomerRegister ) -> Customer: """ Register a new customer for a specific vendor. @@ -68,18 +65,26 @@ class CustomerService: raise VendorNotActiveException(vendor.vendor_code) # Check if email already exists for this vendor - existing_customer = db.query(Customer).filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == customer_data.email.lower() + existing_customer = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == customer_data.email.lower(), + ) ) - ).first() + .first() + ) if existing_customer: - raise DuplicateCustomerEmailException(customer_data.email, vendor.vendor_code) + raise DuplicateCustomerEmailException( + customer_data.email, vendor.vendor_code + ) # Generate unique customer number for this vendor - customer_number = self._generate_customer_number(db, vendor_id, vendor.vendor_code) + customer_number = self._generate_customer_number( + db, vendor_id, vendor.vendor_code + ) # Hash password hashed_password = self.auth_service.hash_password(customer_data.password) @@ -93,8 +98,12 @@ class CustomerService: last_name=customer_data.last_name, phone=customer_data.phone, customer_number=customer_number, - marketing_consent=customer_data.marketing_consent if hasattr(customer_data, 'marketing_consent') else False, - is_active=True + marketing_consent=( + customer_data.marketing_consent + if hasattr(customer_data, "marketing_consent") + else False + ), + is_active=True, ) try: @@ -114,15 +123,11 @@ class CustomerService: db.rollback() logger.error(f"Error registering customer: {str(e)}") raise CustomerValidationException( - message="Failed to register customer", - details={"error": str(e)} + message="Failed to register customer", details={"error": str(e)} ) def login_customer( - self, - db: Session, - vendor_id: int, - credentials: UserLogin + self, db: Session, vendor_id: int, credentials: UserLogin ) -> Dict[str, Any]: """ Authenticate customer and generate JWT token. @@ -146,20 +151,23 @@ class CustomerService: raise VendorNotFoundException(str(vendor_id), identifier_type="id") # Find customer by email (vendor-scoped) - customer = db.query(Customer).filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == credentials.email_or_username.lower() + customer = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == credentials.email_or_username.lower(), + ) ) - ).first() + .first() + ) if not customer: raise InvalidCustomerCredentialsException() # Verify password using auth_manager directly if not self.auth_service.auth_manager.verify_password( - credentials.password, - customer.hashed_password + credentials.password, customer.hashed_password ): raise InvalidCustomerCredentialsException() @@ -170,6 +178,7 @@ class CustomerService: # Generate JWT token with customer context # Use auth_manager directly since Customer is not a User model from datetime import datetime, timedelta, timezone + from jose import jwt auth_manager = self.auth_service.auth_manager @@ -185,7 +194,9 @@ class CustomerService: "iat": datetime.now(timezone.utc), } - token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm) + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) token_data = { "access_token": token, @@ -198,17 +209,9 @@ class CustomerService: f"for vendor {vendor.vendor_code}" ) - return { - "customer": customer, - "token_data": token_data - } + return {"customer": customer, "token_data": token_data} - def get_customer( - self, - db: Session, - vendor_id: int, - customer_id: int - ) -> Customer: + def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer: """ Get customer by ID with vendor isolation. @@ -223,12 +226,11 @@ class CustomerService: Raises: CustomerNotFoundException: If customer not found """ - customer = db.query(Customer).filter( - and_( - Customer.id == customer_id, - Customer.vendor_id == vendor_id - ) - ).first() + customer = ( + db.query(Customer) + .filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id)) + .first() + ) if not customer: raise CustomerNotFoundException(str(customer_id)) @@ -236,10 +238,7 @@ class CustomerService: return customer def get_customer_by_email( - self, - db: Session, - vendor_id: int, - email: str + self, db: Session, vendor_id: int, email: str ) -> Optional[Customer]: """ Get customer by email (vendor-scoped). @@ -252,19 +251,20 @@ class CustomerService: Returns: Optional[Customer]: Customer object or None """ - return db.query(Customer).filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == email.lower() + return ( + db.query(Customer) + .filter( + and_(Customer.vendor_id == vendor_id, Customer.email == email.lower()) ) - ).first() + .first() + ) def update_customer( - self, - db: Session, - vendor_id: int, - customer_id: int, - customer_data: CustomerUpdate + self, + db: Session, + vendor_id: int, + customer_id: int, + customer_data: CustomerUpdate, ) -> Customer: """ Update customer profile. @@ -290,13 +290,17 @@ class CustomerService: for field, value in update_data.items(): if field == "email" and value: # Check if new email already exists for this vendor - existing = db.query(Customer).filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == value.lower(), - Customer.id != customer_id + existing = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == value.lower(), + Customer.id != customer_id, + ) ) - ).first() + .first() + ) if existing: raise DuplicateCustomerEmailException(value, "vendor") @@ -317,15 +321,11 @@ class CustomerService: db.rollback() logger.error(f"Error updating customer: {str(e)}") raise CustomerValidationException( - message="Failed to update customer", - details={"error": str(e)} + message="Failed to update customer", details={"error": str(e)} ) def deactivate_customer( - self, - db: Session, - vendor_id: int, - customer_id: int + self, db: Session, vendor_id: int, customer_id: int ) -> Customer: """ Deactivate customer account. @@ -352,10 +352,7 @@ class CustomerService: return customer def update_customer_stats( - self, - db: Session, - customer_id: int, - order_total: float + self, db: Session, customer_id: int, order_total: float ) -> None: """ Update customer statistics after order. @@ -377,10 +374,7 @@ class CustomerService: logger.debug(f"Updated stats for customer {customer.email}") def _generate_customer_number( - self, - db: Session, - vendor_id: int, - vendor_code: str + self, db: Session, vendor_id: int, vendor_code: str ) -> str: """ Generate unique customer number for vendor. @@ -397,21 +391,23 @@ class CustomerService: str: Unique customer number """ # Get count of customers for this vendor - count = db.query(Customer).filter( - Customer.vendor_id == vendor_id - ).count() + count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count() # Generate number with padding sequence = str(count + 1).zfill(5) customer_number = f"{vendor_code.upper()}-CUST-{sequence}" # Ensure uniqueness (in case of deletions) - while db.query(Customer).filter( + while ( + db.query(Customer) + .filter( and_( Customer.vendor_id == vendor_id, - Customer.customer_number == customer_number + Customer.customer_number == customer_number, ) - ).first(): + ) + .first() + ): count += 1 sequence = str(count + 1).zfill(5) customer_number = f"{vendor_code.upper()}-CUST-{sequence}" diff --git a/app/services/inventory_service.py b/app/services/inventory_service.py index 79d0a673..b7f945d2 100644 --- a/app/services/inventory_service.py +++ b/app/services/inventory_service.py @@ -5,27 +5,20 @@ from typing import List, Optional from sqlalchemy.orm import Session -from app.exceptions import ( - InventoryNotFoundException, - InsufficientInventoryException, - InvalidInventoryOperationException, - InventoryValidationException, - NegativeInventoryException, - InvalidQuantityException, - ValidationException, - ProductNotFoundException, -) -from models.schema.inventory import ( - InventoryCreate, - InventoryAdjust, - InventoryUpdate, - InventoryReserve, - InventoryLocationResponse, - ProductInventorySummary -) +from app.exceptions import (InsufficientInventoryException, + InvalidInventoryOperationException, + InvalidQuantityException, + InventoryNotFoundException, + InventoryValidationException, + NegativeInventoryException, + ProductNotFoundException, ValidationException) from models.database.inventory import Inventory from models.database.product import Product from models.database.vendor import Vendor +from models.schema.inventory import (InventoryAdjust, InventoryCreate, + InventoryLocationResponse, + InventoryReserve, InventoryUpdate, + ProductInventorySummary) logger = logging.getLogger(__name__) @@ -34,7 +27,7 @@ class InventoryService: """Service for inventory operations with vendor isolation.""" def set_inventory( - self, db: Session, vendor_id: int, inventory_data: InventoryCreate + self, db: Session, vendor_id: int, inventory_data: InventoryCreate ) -> Inventory: """ Set exact inventory quantity for a product at a location (replaces existing). @@ -93,7 +86,11 @@ class InventoryService: ) return new_inventory - except (ProductNotFoundException, InvalidQuantityException, InventoryValidationException): + except ( + ProductNotFoundException, + InvalidQuantityException, + InventoryValidationException, + ): db.rollback() raise except Exception as e: @@ -102,7 +99,7 @@ class InventoryService: raise ValidationException("Failed to set inventory") def adjust_inventory( - self, db: Session, vendor_id: int, inventory_data: InventoryAdjust + self, db: Session, vendor_id: int, inventory_data: InventoryAdjust ) -> Inventory: """ Adjust inventory by adding or removing quantity. @@ -124,7 +121,9 @@ class InventoryService: location = self._validate_location(inventory_data.location) # Check if inventory exists - existing = self._get_inventory_entry(db, inventory_data.product_id, location) + existing = self._get_inventory_entry( + db, inventory_data.product_id, location + ) if not existing: # Create new if adding, error if removing @@ -173,8 +172,12 @@ class InventoryService: ) return existing - except (ProductNotFoundException, InventoryNotFoundException, - InsufficientInventoryException, InventoryValidationException): + except ( + ProductNotFoundException, + InventoryNotFoundException, + InsufficientInventoryException, + InventoryValidationException, + ): db.rollback() raise except Exception as e: @@ -183,7 +186,7 @@ class InventoryService: raise ValidationException("Failed to adjust inventory") def reserve_inventory( - self, db: Session, vendor_id: int, reserve_data: InventoryReserve + self, db: Session, vendor_id: int, reserve_data: InventoryReserve ) -> Inventory: """ Reserve inventory for an order (increases reserved_quantity). @@ -231,8 +234,12 @@ class InventoryService: ) return inventory - except (ProductNotFoundException, InventoryNotFoundException, - InsufficientInventoryException, InvalidQuantityException): + except ( + ProductNotFoundException, + InventoryNotFoundException, + InsufficientInventoryException, + InvalidQuantityException, + ): db.rollback() raise except Exception as e: @@ -241,7 +248,7 @@ class InventoryService: raise ValidationException("Failed to reserve inventory") def release_reservation( - self, db: Session, vendor_id: int, reserve_data: InventoryReserve + self, db: Session, vendor_id: int, reserve_data: InventoryReserve ) -> Inventory: """ Release reserved inventory (decreases reserved_quantity). @@ -287,7 +294,11 @@ class InventoryService: ) return inventory - except (ProductNotFoundException, InventoryNotFoundException, InvalidQuantityException): + except ( + ProductNotFoundException, + InventoryNotFoundException, + InvalidQuantityException, + ): db.rollback() raise except Exception as e: @@ -296,7 +307,7 @@ class InventoryService: raise ValidationException("Failed to release reservation") def fulfill_reservation( - self, db: Session, vendor_id: int, reserve_data: InventoryReserve + self, db: Session, vendor_id: int, reserve_data: InventoryReserve ) -> Inventory: """ Fulfill a reservation (decreases both quantity and reserved_quantity). @@ -349,8 +360,12 @@ class InventoryService: ) return inventory - except (ProductNotFoundException, InventoryNotFoundException, - InsufficientInventoryException, InvalidQuantityException): + except ( + ProductNotFoundException, + InventoryNotFoundException, + InsufficientInventoryException, + InvalidQuantityException, + ): db.rollback() raise except Exception as e: @@ -359,7 +374,7 @@ class InventoryService: raise ValidationException("Failed to fulfill reservation") def get_product_inventory( - self, db: Session, vendor_id: int, product_id: int + self, db: Session, vendor_id: int, product_id: int ) -> ProductInventorySummary: """ Get inventory summary for a product across all locations. @@ -376,9 +391,7 @@ class InventoryService: product = self._get_vendor_product(db, vendor_id, product_id) inventory_entries = ( - db.query(Inventory) - .filter(Inventory.product_id == product_id) - .all() + db.query(Inventory).filter(Inventory.product_id == product_id).all() ) if not inventory_entries: @@ -425,8 +438,13 @@ class InventoryService: raise ValidationException("Failed to retrieve product inventory") def get_vendor_inventory( - self, db: Session, vendor_id: int, skip: int = 0, limit: int = 100, - location: Optional[str] = None, low_stock_threshold: Optional[int] = None + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 100, + location: Optional[str] = None, + low_stock_threshold: Optional[int] = None, ) -> List[Inventory]: """ Get all inventory for a vendor with filtering. @@ -458,8 +476,11 @@ class InventoryService: raise ValidationException("Failed to retrieve vendor inventory") def update_inventory( - self, db: Session, vendor_id: int, inventory_id: int, - inventory_update: InventoryUpdate + self, + db: Session, + vendor_id: int, + inventory_id: int, + inventory_update: InventoryUpdate, ) -> Inventory: """Update inventory entry.""" try: @@ -475,7 +496,9 @@ class InventoryService: inventory.quantity = inventory_update.quantity if inventory_update.reserved_quantity is not None: - self._validate_quantity(inventory_update.reserved_quantity, allow_zero=True) + self._validate_quantity( + inventory_update.reserved_quantity, allow_zero=True + ) inventory.reserved_quantity = inventory_update.reserved_quantity if inventory_update.location: @@ -488,7 +511,11 @@ class InventoryService: logger.info(f"Updated inventory {inventory_id}") return inventory - except (InventoryNotFoundException, InvalidQuantityException, InventoryValidationException): + except ( + InventoryNotFoundException, + InvalidQuantityException, + InventoryValidationException, + ): db.rollback() raise except Exception as e: @@ -496,9 +523,7 @@ class InventoryService: logger.error(f"Error updating inventory: {str(e)}") raise ValidationException("Failed to update inventory") - def delete_inventory( - self, db: Session, vendor_id: int, inventory_id: int - ) -> bool: + def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool: """Delete inventory entry.""" try: inventory = self._get_inventory_by_id(db, inventory_id) @@ -521,28 +546,30 @@ class InventoryService: raise ValidationException("Failed to delete inventory") # Private helper methods - def _get_vendor_product(self, db: Session, vendor_id: int, product_id: int) -> Product: + def _get_vendor_product( + self, db: Session, vendor_id: int, product_id: int + ) -> Product: """Get product and verify it belongs to vendor.""" - product = db.query(Product).filter( - Product.id == product_id, - Product.vendor_id == vendor_id - ).first() + product = ( + db.query(Product) + .filter(Product.id == product_id, Product.vendor_id == vendor_id) + .first() + ) if not product: - raise ProductNotFoundException(f"Product {product_id} not found in your catalog") + raise ProductNotFoundException( + f"Product {product_id} not found in your catalog" + ) return product def _get_inventory_entry( - self, db: Session, product_id: int, location: str + self, db: Session, product_id: int, location: str ) -> Optional[Inventory]: """Get inventory entry by product and location.""" return ( db.query(Inventory) - .filter( - Inventory.product_id == product_id, - Inventory.location == location - ) + .filter(Inventory.product_id == product_id, Inventory.location == location) .first() ) diff --git a/app/services/marketplace_import_job_service.py b/app/services/marketplace_import_job_service.py index b6a696d1..2007cd2b 100644 --- a/app/services/marketplace_import_job_service.py +++ b/app/services/marketplace_import_job_service.py @@ -5,20 +5,15 @@ from typing import List, Optional from sqlalchemy.orm import Session -from app.exceptions import ( - ImportJobNotFoundException, - ImportJobNotOwnedException, - ImportJobCannotBeCancelledException, - ImportJobCannotBeDeletedException, - ValidationException, -) -from models.schema.marketplace_import_job import ( - MarketplaceImportJobResponse, - MarketplaceImportJobRequest -) +from app.exceptions import (ImportJobCannotBeCancelledException, + ImportJobCannotBeDeletedException, + ImportJobNotFoundException, + ImportJobNotOwnedException, ValidationException) from models.database.marketplace_import_job import MarketplaceImportJob -from models.database.vendor import Vendor from models.database.user import User +from models.database.vendor import Vendor +from models.schema.marketplace_import_job import (MarketplaceImportJobRequest, + MarketplaceImportJobResponse) logger = logging.getLogger(__name__) @@ -31,7 +26,7 @@ class MarketplaceImportJobService: db: Session, request: MarketplaceImportJobRequest, vendor: Vendor, # CHANGED: Vendor object from middleware - user: User + user: User, ) -> MarketplaceImportJob: """ Create a new marketplace import job. @@ -147,7 +142,9 @@ class MarketplaceImportJobService: marketplace=job.marketplace, vendor_id=job.vendor_id, vendor_code=job.vendor.vendor_code if job.vendor else None, # FIXED - vendor_name=job.vendor.name if job.vendor else None, # FIXED: from relationship + vendor_name=( + job.vendor.name if job.vendor else None + ), # FIXED: from relationship source_url=job.source_url, imported=job.imported_count or 0, updated=job.updated_count or 0, diff --git a/app/services/marketplace_product_service.py b/app/services/marketplace_product_service.py index 7c220160..fd572b32 100644 --- a/app/services/marketplace_product_service.py +++ b/app/services/marketplace_product_service.py @@ -17,19 +17,20 @@ from typing import Generator, List, Optional, Tuple from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from app.exceptions import ( - MarketplaceProductNotFoundException, - MarketplaceProductAlreadyExistsException, - InvalidMarketplaceProductDataException, - MarketplaceProductValidationException, - ValidationException, -) -from app.services.marketplace_import_job_service import marketplace_import_job_service -from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate -from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse -from models.database.marketplace_product import MarketplaceProduct -from models.database.inventory import Inventory +from app.exceptions import (InvalidMarketplaceProductDataException, + MarketplaceProductAlreadyExistsException, + MarketplaceProductNotFoundException, + MarketplaceProductValidationException, + ValidationException) +from app.services.marketplace_import_job_service import \ + marketplace_import_job_service from app.utils.data_processing import GTINProcessor, PriceProcessor +from models.database.inventory import Inventory +from models.database.marketplace_product import MarketplaceProduct +from models.schema.inventory import (InventoryLocationResponse, + InventorySummaryResponse) +from models.schema.marketplace_product import (MarketplaceProductCreate, + MarketplaceProductUpdate) logger = logging.getLogger(__name__) @@ -42,14 +43,18 @@ class MarketplaceProductService: self.gtin_processor = GTINProcessor() self.price_processor = PriceProcessor() - def create_product(self, db: Session, product_data: MarketplaceProductCreate) -> MarketplaceProduct: + def create_product( + self, db: Session, product_data: MarketplaceProductCreate + ) -> MarketplaceProduct: """Create a new product with validation.""" try: # Process and validate GTIN if provided if product_data.gtin: normalized_gtin = self.gtin_processor.normalize(product_data.gtin) if not normalized_gtin: - raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin") + raise InvalidMarketplaceProductDataException( + "Invalid GTIN format", field="gtin" + ) product_data.gtin = normalized_gtin # Process price if provided @@ -70,11 +75,18 @@ class MarketplaceProductService: product_data.marketplace = "Letzshop" # Validate required fields - if not product_data.marketplace_product_id or not product_data.marketplace_product_id.strip(): - raise MarketplaceProductValidationException("MarketplaceProduct ID is required", field="marketplace_product_id") + if ( + not product_data.marketplace_product_id + or not product_data.marketplace_product_id.strip() + ): + raise MarketplaceProductValidationException( + "MarketplaceProduct ID is required", field="marketplace_product_id" + ) if not product_data.title or not product_data.title.strip(): - raise MarketplaceProductValidationException("MarketplaceProduct title is required", field="title") + raise MarketplaceProductValidationException( + "MarketplaceProduct title is required", field="title" + ) db_product = MarketplaceProduct(**product_data.model_dump()) db.add(db_product) @@ -84,30 +96,47 @@ class MarketplaceProductService: logger.info(f"Created product {db_product.marketplace_product_id}") return db_product - except (InvalidMarketplaceProductDataException, MarketplaceProductValidationException): + except ( + InvalidMarketplaceProductDataException, + MarketplaceProductValidationException, + ): db.rollback() raise # Re-raise custom exceptions except IntegrityError as e: db.rollback() logger.error(f"Database integrity error: {str(e)}") if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower(): - raise MarketplaceProductAlreadyExistsException(product_data.marketplace_product_id) + raise MarketplaceProductAlreadyExistsException( + product_data.marketplace_product_id + ) else: - raise MarketplaceProductValidationException("Data integrity constraint violation") + raise MarketplaceProductValidationException( + "Data integrity constraint violation" + ) except Exception as e: db.rollback() logger.error(f"Error creating product: {str(e)}") raise ValidationException("Failed to create product") - def get_product_by_id(self, db: Session, marketplace_product_id: str) -> Optional[MarketplaceProduct]: + def get_product_by_id( + self, db: Session, marketplace_product_id: str + ) -> Optional[MarketplaceProduct]: """Get a product by its ID.""" try: - return db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first() + return ( + db.query(MarketplaceProduct) + .filter( + MarketplaceProduct.marketplace_product_id == marketplace_product_id + ) + .first() + ) except Exception as e: logger.error(f"Error getting product {marketplace_product_id}: {str(e)}") return None - def get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct: + def get_product_by_id_or_raise( + self, db: Session, marketplace_product_id: str + ) -> MarketplaceProduct: """ Get a product by its ID or raise exception. @@ -127,16 +156,16 @@ class MarketplaceProductService: return product def get_products_with_filters( - self, - db: Session, - skip: int = 0, - limit: int = 100, - brand: Optional[str] = None, - category: Optional[str] = None, - availability: Optional[str] = None, - marketplace: Optional[str] = None, - vendor_name: Optional[str] = None, - search: Optional[str] = None, + self, + db: Session, + skip: int = 0, + limit: int = 100, + brand: Optional[str] = None, + category: Optional[str] = None, + availability: Optional[str] = None, + marketplace: Optional[str] = None, + vendor_name: Optional[str] = None, + search: Optional[str] = None, ) -> Tuple[List[MarketplaceProduct], int]: """ Get products with filtering and pagination. @@ -162,13 +191,19 @@ class MarketplaceProductService: if brand: query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%")) if category: - query = query.filter(MarketplaceProduct.google_product_category.ilike(f"%{category}%")) + query = query.filter( + MarketplaceProduct.google_product_category.ilike(f"%{category}%") + ) if availability: query = query.filter(MarketplaceProduct.availability == availability) if marketplace: - query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")) + query = query.filter( + MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") + ) if vendor_name: - query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")) + query = query.filter( + MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") + ) if search: # Search in title, description, marketplace, and name search_term = f"%{search}%" @@ -188,7 +223,12 @@ class MarketplaceProductService: logger.error(f"Error getting products with filters: {str(e)}") raise ValidationException("Failed to retrieve products") - def update_product(self, db: Session, marketplace_product_id: str, product_update: MarketplaceProductUpdate) -> MarketplaceProduct: + def update_product( + self, + db: Session, + marketplace_product_id: str, + product_update: MarketplaceProductUpdate, + ) -> MarketplaceProduct: """Update product with validation.""" try: product = self.get_product_by_id_or_raise(db, marketplace_product_id) @@ -200,7 +240,9 @@ class MarketplaceProductService: if "gtin" in update_data and update_data["gtin"]: normalized_gtin = self.gtin_processor.normalize(update_data["gtin"]) if not normalized_gtin: - raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin") + raise InvalidMarketplaceProductDataException( + "Invalid GTIN format", field="gtin" + ) update_data["gtin"] = normalized_gtin # Process price if being updated @@ -217,8 +259,12 @@ class MarketplaceProductService: raise InvalidMarketplaceProductDataException(str(e), field="price") # Validate required fields if being updated - if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()): - raise MarketplaceProductValidationException("MarketplaceProduct title cannot be empty", field="title") + if "title" in update_data and ( + not update_data["title"] or not update_data["title"].strip() + ): + raise MarketplaceProductValidationException( + "MarketplaceProduct title cannot be empty", field="title" + ) for key, value in update_data.items(): setattr(product, key, value) @@ -230,7 +276,11 @@ class MarketplaceProductService: logger.info(f"Updated product {marketplace_product_id}") return product - except (MarketplaceProductNotFoundException, InvalidMarketplaceProductDataException, MarketplaceProductValidationException): + except ( + MarketplaceProductNotFoundException, + InvalidMarketplaceProductDataException, + MarketplaceProductValidationException, + ): db.rollback() raise # Re-raise custom exceptions except Exception as e: @@ -272,7 +322,9 @@ class MarketplaceProductService: logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}") raise ValidationException("Failed to delete product") - def get_inventory_info(self, db: Session, gtin: str) -> Optional[InventorySummaryResponse]: + def get_inventory_info( + self, db: Session, gtin: str + ) -> Optional[InventorySummaryResponse]: """ Get inventory information for a product by GTIN. @@ -290,7 +342,9 @@ class MarketplaceProductService: total_quantity = sum(entry.quantity for entry in inventory_entries) locations = [ - InventoryLocationResponse(location=entry.location, quantity=entry.quantity) + InventoryLocationResponse( + location=entry.location, quantity=entry.quantity + ) for entry in inventory_entries ] @@ -305,13 +359,14 @@ class MarketplaceProductService: import csv from io import StringIO from typing import Generator, Optional + from sqlalchemy.orm import Session def generate_csv_export( - self, - db: Session, - marketplace: Optional[str] = None, - vendor_name: Optional[str] = None, + self, + db: Session, + marketplace: Optional[str] = None, + vendor_name: Optional[str] = None, ) -> Generator[str, None, None]: """ Generate CSV export with streaming for memory efficiency and proper CSV escaping. @@ -331,9 +386,18 @@ class MarketplaceProductService: # Write header row headers = [ - "marketplace_product_id", "title", "description", "link", "image_link", - "availability", "price", "currency", "brand", "gtin", - "marketplace", "name" + "marketplace_product_id", + "title", + "description", + "link", + "image_link", + "availability", + "price", + "currency", + "brand", + "gtin", + "marketplace", + "name", ] writer.writerow(headers) yield output.getvalue() @@ -350,9 +414,13 @@ class MarketplaceProductService: # Apply marketplace filters if marketplace: - query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")) + query = query.filter( + MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") + ) if vendor_name: - query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")) + query = query.filter( + MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") + ) products = query.offset(offset).limit(batch_size).all() if not products: @@ -392,8 +460,12 @@ class MarketplaceProductService: """Check if product exists by ID.""" try: return ( - db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first() - is not None + db.query(MarketplaceProduct) + .filter( + MarketplaceProduct.marketplace_product_id == marketplace_product_id + ) + .first() + is not None ) except Exception as e: logger.error(f"Error checking if product exists: {str(e)}") @@ -402,18 +474,27 @@ class MarketplaceProductService: # Private helper methods def _validate_product_data(self, product_data: dict) -> None: """Validate product data structure.""" - required_fields = ['marketplace_product_id', 'title'] + required_fields = ["marketplace_product_id", "title"] for field in required_fields: if field not in product_data or not product_data[field]: - raise MarketplaceProductValidationException(f"{field} is required", field=field) + raise MarketplaceProductValidationException( + f"{field} is required", field=field + ) def _normalize_product_data(self, product_data: dict) -> dict: """Normalize and clean product data.""" normalized = product_data.copy() # Trim whitespace from string fields - string_fields = ['marketplace_product_id', 'title', 'description', 'brand', 'marketplace', 'name'] + string_fields = [ + "marketplace_product_id", + "title", + "description", + "brand", + "marketplace", + "name", + ] for field in string_fields: if field in normalized and normalized[field]: normalized[field] = normalized[field].strip() diff --git a/app/services/media_service.py b/app/services/media_service.py index 7d008007..0c726681 100644 --- a/app/services/media_service.py +++ b/app/services/media_service.py @@ -1 +1 @@ -# File and media management services +# File and media management services diff --git a/app/services/monitoring_service.py b/app/services/monitoring_service.py index f1d92af4..9d5ec783 100644 --- a/app/services/monitoring_service.py +++ b/app/services/monitoring_service.py @@ -1 +1 @@ -# Application monitoring services +# Application monitoring services diff --git a/app/services/notification_service.py b/app/services/notification_service.py index c25cb0d9..4eaa423b 100644 --- a/app/services/notification_service.py +++ b/app/services/notification_service.py @@ -1 +1 @@ -# Email/notification services +# Email/notification services diff --git a/app/services/order_service.py b/app/services/order_service.py index b4606c43..9e1005fd 100644 --- a/app/services/order_service.py +++ b/app/services/order_service.py @@ -9,24 +9,21 @@ This module provides: """ import logging -from datetime import datetime, timezone -from typing import List, Optional, Tuple import random import string +from datetime import datetime, timezone +from typing import List, Optional, Tuple -from sqlalchemy.orm import Session from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session -from models.database.order import Order, OrderItem +from app.exceptions import (CustomerNotFoundException, + InsufficientInventoryException, + OrderNotFoundException, ValidationException) from models.database.customer import Customer, CustomerAddress +from models.database.order import Order, OrderItem from models.database.product import Product -from models.schema.order import OrderCreate, OrderUpdate, OrderAddressCreate -from app.exceptions import ( - OrderNotFoundException, - ValidationException, - InsufficientInventoryException, - CustomerNotFoundException -) +from models.schema.order import OrderAddressCreate, OrderCreate, OrderUpdate logger = logging.getLogger(__name__) @@ -42,23 +39,27 @@ class OrderService: Example: ORD-1-20250110-A1B2C3 """ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d") - random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) + random_suffix = "".join( + random.choices(string.ascii_uppercase + string.digits, k=6) + ) order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" # Ensure uniqueness while db.query(Order).filter(Order.order_number == order_number).first(): - random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) + random_suffix = "".join( + random.choices(string.ascii_uppercase + string.digits, k=6) + ) order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" return order_number def _create_customer_address( - self, - db: Session, - vendor_id: int, - customer_id: int, - address_data: OrderAddressCreate, - address_type: str + self, + db: Session, + vendor_id: int, + customer_id: int, + address_data: OrderAddressCreate, + address_type: str, ) -> CustomerAddress: """Create a customer address for order.""" address = CustomerAddress( @@ -73,17 +74,14 @@ class OrderService: city=address_data.city, postal_code=address_data.postal_code, country=address_data.country, - is_default=False + is_default=False, ) db.add(address) db.flush() # Get ID without committing return address def create_order( - self, - db: Session, - vendor_id: int, - order_data: OrderCreate + self, db: Session, vendor_id: int, order_data: OrderCreate ) -> Order: """ Create a new order. @@ -104,12 +102,15 @@ class OrderService: # Validate customer exists if provided customer_id = order_data.customer_id if customer_id: - customer = db.query(Customer).filter( - and_( - Customer.id == customer_id, - Customer.vendor_id == vendor_id + customer = ( + db.query(Customer) + .filter( + and_( + Customer.id == customer_id, Customer.vendor_id == vendor_id + ) ) - ).first() + .first() + ) if not customer: raise CustomerNotFoundException(str(customer_id)) @@ -124,7 +125,7 @@ class OrderService: vendor_id=vendor_id, customer_id=customer_id, address_data=order_data.shipping_address, - address_type="shipping" + address_type="shipping", ) # Create billing address (use shipping if not provided) @@ -134,7 +135,7 @@ class OrderService: vendor_id=vendor_id, customer_id=customer_id, address_data=order_data.billing_address, - address_type="billing" + address_type="billing", ) else: billing_address = shipping_address @@ -145,23 +146,29 @@ class OrderService: for item_data in order_data.items: # Get product - product = db.query(Product).filter( - and_( - Product.id == item_data.product_id, - Product.vendor_id == vendor_id, - Product.is_active == True + product = ( + db.query(Product) + .filter( + and_( + Product.id == item_data.product_id, + Product.vendor_id == vendor_id, + Product.is_active == True, + ) ) - ).first() + .first() + ) if not product: - raise ValidationException(f"Product {item_data.product_id} not found") + raise ValidationException( + f"Product {item_data.product_id} not found" + ) # Check inventory if product.available_inventory < item_data.quantity: raise InsufficientInventoryException( product_id=product.id, requested=item_data.quantity, - available=product.available_inventory + available=product.available_inventory, ) # Calculate item total @@ -172,14 +179,16 @@ class OrderService: item_total = unit_price * item_data.quantity subtotal += item_total - order_items_data.append({ - "product_id": product.id, - "product_name": product.marketplace_product.title, - "product_sku": product.product_id, - "quantity": item_data.quantity, - "unit_price": unit_price, - "total_price": item_total - }) + order_items_data.append( + { + "product_id": product.id, + "product_name": product.marketplace_product.title, + "product_sku": product.product_id, + "quantity": item_data.quantity, + "unit_price": unit_price, + "total_price": item_total, + } + ) # Calculate tax and shipping (simple implementation) tax_amount = 0.0 # TODO: Implement tax calculation @@ -205,7 +214,7 @@ class OrderService: shipping_address_id=shipping_address.id, billing_address_id=billing_address.id, shipping_method=order_data.shipping_method, - customer_notes=order_data.customer_notes + customer_notes=order_data.customer_notes, ) db.add(order) @@ -213,10 +222,7 @@ class OrderService: # Create order items for item_data in order_items_data: - order_item = OrderItem( - order_id=order.id, - **item_data - ) + order_item = OrderItem(order_id=order.id, **item_data) db.add(order_item) db.commit() @@ -229,7 +235,11 @@ class OrderService: return order - except (ValidationException, InsufficientInventoryException, CustomerNotFoundException): + except ( + ValidationException, + InsufficientInventoryException, + CustomerNotFoundException, + ): db.rollback() raise except Exception as e: @@ -237,19 +247,13 @@ class OrderService: logger.error(f"Error creating order: {str(e)}") raise ValidationException(f"Failed to create order: {str(e)}") - def get_order( - self, - db: Session, - vendor_id: int, - order_id: int - ) -> Order: + def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order: """Get order by ID.""" - order = db.query(Order).filter( - and_( - Order.id == order_id, - Order.vendor_id == vendor_id - ) - ).first() + order = ( + db.query(Order) + .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id)) + .first() + ) if not order: raise OrderNotFoundException(str(order_id)) @@ -257,13 +261,13 @@ class OrderService: return order def get_vendor_orders( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 100, - status: Optional[str] = None, - customer_id: Optional[int] = None + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 100, + status: Optional[str] = None, + customer_id: Optional[int] = None, ) -> Tuple[List[Order], int]: """ Get orders for vendor with filtering. @@ -296,28 +300,20 @@ class OrderService: return orders, total def get_customer_orders( - self, - db: Session, - vendor_id: int, - customer_id: int, - skip: int = 0, - limit: int = 100 + self, + db: Session, + vendor_id: int, + customer_id: int, + skip: int = 0, + limit: int = 100, ) -> Tuple[List[Order], int]: """Get orders for a specific customer.""" return self.get_vendor_orders( - db=db, - vendor_id=vendor_id, - skip=skip, - limit=limit, - customer_id=customer_id + db=db, vendor_id=vendor_id, skip=skip, limit=limit, customer_id=customer_id ) def update_order_status( - self, - db: Session, - vendor_id: int, - order_id: int, - order_update: OrderUpdate + self, db: Session, vendor_id: int, order_id: int, order_update: OrderUpdate ) -> Order: """ Update order status and tracking information. diff --git a/app/services/payment_service.py b/app/services/payment_service.py index ec9ec61e..5aef72d7 100644 --- a/app/services/payment_service.py +++ b/app/services/payment_service.py @@ -1 +1 @@ -# Payment processing services +# Payment processing services diff --git a/app/services/product_service.py b/app/services/product_service.py index 52db4e09..da7f46ba 100644 --- a/app/services/product_service.py +++ b/app/services/product_service.py @@ -14,14 +14,11 @@ from typing import List, Optional, Tuple from sqlalchemy.orm import Session -from app.exceptions import ( - ProductNotFoundException, - ProductAlreadyExistsException, - ValidationException, -) -from models.schema.product import ProductCreate, ProductUpdate -from models.database.product import Product +from app.exceptions import (ProductAlreadyExistsException, + ProductNotFoundException, ValidationException) from models.database.marketplace_product import MarketplaceProduct +from models.database.product import Product +from models.schema.product import ProductCreate, ProductUpdate logger = logging.getLogger(__name__) @@ -45,10 +42,11 @@ class ProductService: ProductNotFoundException: If product not found """ try: - product = db.query(Product).filter( - Product.id == product_id, - Product.vendor_id == vendor_id - ).first() + product = ( + db.query(Product) + .filter(Product.id == product_id, Product.vendor_id == vendor_id) + .first() + ) if not product: raise ProductNotFoundException(f"Product {product_id} not found") @@ -62,7 +60,7 @@ class ProductService: raise ValidationException("Failed to retrieve product") def create_product( - self, db: Session, vendor_id: int, product_data: ProductCreate + self, db: Session, vendor_id: int, product_data: ProductCreate ) -> Product: """ Add a product from marketplace to vendor catalog. @@ -81,10 +79,14 @@ class ProductService: """ try: # Verify marketplace product exists and belongs to vendor - marketplace_product = db.query(MarketplaceProduct).filter( - MarketplaceProduct.id == product_data.marketplace_product_id, - MarketplaceProduct.vendor_id == vendor_id - ).first() + marketplace_product = ( + db.query(MarketplaceProduct) + .filter( + MarketplaceProduct.id == product_data.marketplace_product_id, + MarketplaceProduct.vendor_id == vendor_id, + ) + .first() + ) if not marketplace_product: raise ValidationException( @@ -92,10 +94,15 @@ class ProductService: ) # Check if already in catalog - existing = db.query(Product).filter( - Product.vendor_id == vendor_id, - Product.marketplace_product_id == product_data.marketplace_product_id - ).first() + existing = ( + db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.marketplace_product_id + == product_data.marketplace_product_id, + ) + .first() + ) if existing: raise ProductAlreadyExistsException( @@ -122,9 +129,7 @@ class ProductService: db.commit() db.refresh(product) - logger.info( - f"Added product {product.id} to vendor {vendor_id} catalog" - ) + logger.info(f"Added product {product.id} to vendor {vendor_id} catalog") return product except (ProductAlreadyExistsException, ValidationException): @@ -136,7 +141,11 @@ class ProductService: raise ValidationException("Failed to create product") def update_product( - self, db: Session, vendor_id: int, product_id: int, product_update: ProductUpdate + self, + db: Session, + vendor_id: int, + product_id: int, + product_update: ProductUpdate, ) -> Product: """ Update product in vendor catalog. @@ -202,13 +211,13 @@ class ProductService: raise ValidationException("Failed to delete product") def get_vendor_products( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - is_featured: Optional[bool] = None, + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 100, + is_active: Optional[bool] = None, + is_featured: Optional[bool] = None, ) -> Tuple[List[Product], int]: """ Get products in vendor catalog with filtering. diff --git a/app/services/search_service.py b/app/services/search_service.py index f2f192b8..58830348 100644 --- a/app/services/search_service.py +++ b/app/services/search_service.py @@ -1 +1 @@ -# Search and indexing services +# Search and indexing services diff --git a/app/services/stats_service.py b/app/services/stats_service.py index beabc8f1..ba8f381b 100644 --- a/app/services/stats_service.py +++ b/app/services/stats_service.py @@ -10,25 +10,21 @@ This module provides: """ import logging -from typing import Any, Dict, List from datetime import datetime, timedelta +from typing import Any, Dict, List from sqlalchemy import func from sqlalchemy.orm import Session -from app.exceptions import ( - VendorNotFoundException, - AdminOperationException, -) - -from models.database.marketplace_product import MarketplaceProduct -from models.database.product import Product -from models.database.inventory import Inventory -from models.database.vendor import Vendor -from models.database.order import Order -from models.database.user import User +from app.exceptions import AdminOperationException, VendorNotFoundException from models.database.customer import Customer +from models.database.inventory import Inventory from models.database.marketplace_import_job import MarketplaceImportJob +from models.database.marketplace_product import MarketplaceProduct +from models.database.order import Order +from models.database.product import Product +from models.database.user import User +from models.database.vendor import Vendor logger = logging.getLogger(__name__) @@ -62,63 +58,77 @@ class StatsService: try: # Catalog statistics - total_catalog_products = db.query(Product).filter( - Product.vendor_id == vendor_id, - Product.is_active == True - ).count() + total_catalog_products = ( + db.query(Product) + .filter(Product.vendor_id == vendor_id, Product.is_active == True) + .count() + ) - featured_products = db.query(Product).filter( - Product.vendor_id == vendor_id, - Product.is_featured == True, - Product.is_active == True - ).count() + featured_products = ( + db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.is_featured == True, + Product.is_active == True, + ) + .count() + ) # Staging statistics # TODO: This is fragile - MarketplaceProduct uses vendor_name (string) not vendor_id # Should add vendor_id foreign key to MarketplaceProduct for robust querying # For now, matching by vendor name which could fail if names don't match exactly - staging_products = db.query(MarketplaceProduct).filter( - MarketplaceProduct.vendor_name == vendor.name - ).count() + staging_products = ( + db.query(MarketplaceProduct) + .filter(MarketplaceProduct.vendor_name == vendor.name) + .count() + ) # Inventory statistics - total_inventory = db.query( - func.sum(Inventory.quantity) - ).filter( - Inventory.vendor_id == vendor_id - ).scalar() or 0 + total_inventory = ( + db.query(func.sum(Inventory.quantity)) + .filter(Inventory.vendor_id == vendor_id) + .scalar() + or 0 + ) - reserved_inventory = db.query( - func.sum(Inventory.reserved_quantity) - ).filter( - Inventory.vendor_id == vendor_id - ).scalar() or 0 + reserved_inventory = ( + db.query(func.sum(Inventory.reserved_quantity)) + .filter(Inventory.vendor_id == vendor_id) + .scalar() + or 0 + ) - inventory_locations = db.query( - func.count(func.distinct(Inventory.location)) - ).filter( - Inventory.vendor_id == vendor_id - ).scalar() or 0 + inventory_locations = ( + db.query(func.count(func.distinct(Inventory.location))) + .filter(Inventory.vendor_id == vendor_id) + .scalar() + or 0 + ) # Import statistics - total_imports = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.vendor_id == vendor_id - ).count() + total_imports = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.vendor_id == vendor_id) + .count() + ) - successful_imports = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.vendor_id == vendor_id, - MarketplaceImportJob.status == "completed" - ).count() + successful_imports = ( + db.query(MarketplaceImportJob) + .filter( + MarketplaceImportJob.vendor_id == vendor_id, + MarketplaceImportJob.status == "completed", + ) + .count() + ) # Orders - total_orders = db.query(Order).filter( - Order.vendor_id == vendor_id - ).count() + total_orders = db.query(Order).filter(Order.vendor_id == vendor_id).count() # Customers - total_customers = db.query(Customer).filter( - Customer.vendor_id == vendor_id - ).count() + total_customers = ( + db.query(Customer).filter(Customer.vendor_id == vendor_id).count() + ) return { "catalog": { @@ -138,7 +148,11 @@ class StatsService: "imports": { "total_imports": total_imports, "successful_imports": successful_imports, - "success_rate": (successful_imports / total_imports * 100) if total_imports > 0 else 0, + "success_rate": ( + (successful_imports / total_imports * 100) + if total_imports > 0 + else 0 + ), }, "orders": { "total_orders": total_orders, @@ -151,16 +165,18 @@ class StatsService: except VendorNotFoundException: raise except Exception as e: - logger.error(f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}") + logger.error( + f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}" + ) raise AdminOperationException( operation="get_vendor_stats", reason=f"Database query failed: {str(e)}", target_type="vendor", - target_id=str(vendor_id) + target_id=str(vendor_id), ) def get_vendor_analytics( - self, db: Session, vendor_id: int, period: str = "30d" + self, db: Session, vendor_id: int, period: str = "30d" ) -> Dict[str, Any]: """ Get a specific vendor analytics for a time period. @@ -188,21 +204,28 @@ class StatsService: start_date = datetime.utcnow() - timedelta(days=days) # Import activity - recent_imports = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.vendor_id == vendor_id, - MarketplaceImportJob.created_at >= start_date - ).count() + recent_imports = ( + db.query(MarketplaceImportJob) + .filter( + MarketplaceImportJob.vendor_id == vendor_id, + MarketplaceImportJob.created_at >= start_date, + ) + .count() + ) # Products added to catalog - products_added = db.query(Product).filter( - Product.vendor_id == vendor_id, - Product.created_at >= start_date - ).count() + products_added = ( + db.query(Product) + .filter( + Product.vendor_id == vendor_id, Product.created_at >= start_date + ) + .count() + ) # Inventory changes - inventory_entries = db.query(Inventory).filter( - Inventory.vendor_id == vendor_id - ).count() + inventory_entries = ( + db.query(Inventory).filter(Inventory.vendor_id == vendor_id).count() + ) return { "period": period, @@ -221,12 +244,14 @@ class StatsService: except VendorNotFoundException: raise except Exception as e: - logger.error(f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}") + logger.error( + f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}" + ) raise AdminOperationException( operation="get_vendor_analytics", reason=f"Database query failed: {str(e)}", target_type="vendor", - target_id=str(vendor_id) + target_id=str(vendor_id), ) def get_vendor_statistics(self, db: Session) -> dict: @@ -234,7 +259,9 @@ class StatsService: try: total_vendors = db.query(Vendor).count() active_vendors = db.query(Vendor).filter(Vendor.is_active == True).count() - verified_vendors = db.query(Vendor).filter(Vendor.is_verified == True).count() + verified_vendors = ( + db.query(Vendor).filter(Vendor.is_verified == True).count() + ) inactive_vendors = total_vendors - active_vendors return { @@ -242,13 +269,14 @@ class StatsService: "active_vendors": active_vendors, "inactive_vendors": inactive_vendors, "verified_vendors": verified_vendors, - "verification_rate": (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0 + "verification_rate": ( + (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0 + ), } except Exception as e: logger.error(f"Failed to get vendor statistics: {str(e)}") raise AdminOperationException( - operation="get_vendor_statistics", - reason="Database query failed" + operation="get_vendor_statistics", reason="Database query failed" ) # ======================================================================== @@ -302,7 +330,7 @@ class StatsService: logger.error(f"Failed to retrieve comprehensive statistics: {str(e)}") raise AdminOperationException( operation="get_comprehensive_stats", - reason=f"Database query failed: {str(e)}" + reason=f"Database query failed: {str(e)}", ) def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]: @@ -323,8 +351,12 @@ class StatsService: db.query( MarketplaceProduct.marketplace, func.count(MarketplaceProduct.id).label("total_products"), - func.count(func.distinct(MarketplaceProduct.vendor_name)).label("unique_vendors"), - func.count(func.distinct(MarketplaceProduct.brand)).label("unique_brands"), + func.count(func.distinct(MarketplaceProduct.vendor_name)).label( + "unique_vendors" + ), + func.count(func.distinct(MarketplaceProduct.brand)).label( + "unique_brands" + ), ) .filter(MarketplaceProduct.marketplace.isnot(None)) .group_by(MarketplaceProduct.marketplace) @@ -342,10 +374,12 @@ class StatsService: ] except Exception as e: - logger.error(f"Failed to retrieve marketplace breakdown statistics: {str(e)}") + logger.error( + f"Failed to retrieve marketplace breakdown statistics: {str(e)}" + ) raise AdminOperationException( operation="get_marketplace_breakdown_stats", - reason=f"Database query failed: {str(e)}" + reason=f"Database query failed: {str(e)}", ) def get_user_statistics(self, db: Session) -> Dict[str, Any]: @@ -372,13 +406,14 @@ class StatsService: "active_users": active_users, "inactive_users": inactive_users, "admin_users": admin_users, - "activation_rate": (active_users / total_users * 100) if total_users > 0 else 0 + "activation_rate": ( + (active_users / total_users * 100) if total_users > 0 else 0 + ), } except Exception as e: logger.error(f"Failed to get user statistics: {str(e)}") raise AdminOperationException( - operation="get_user_statistics", - reason="Database query failed" + operation="get_user_statistics", reason="Database query failed" ) def get_import_statistics(self, db: Session) -> Dict[str, Any]: @@ -396,18 +431,22 @@ class StatsService: """ try: total = db.query(MarketplaceImportJob).count() - completed = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.status == "completed" - ).count() - failed = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.status == "failed" - ).count() + completed = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.status == "completed") + .count() + ) + failed = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.status == "failed") + .count() + ) return { "total_imports": total, "completed_imports": completed, "failed_imports": failed, - "success_rate": (completed / total * 100) if total > 0 else 0 + "success_rate": (completed / total * 100) if total > 0 else 0, } except Exception as e: logger.error(f"Failed to get import statistics: {str(e)}") @@ -415,7 +454,7 @@ class StatsService: "total_imports": 0, "completed_imports": 0, "failed_imports": 0, - "success_rate": 0 + "success_rate": 0, } def get_order_statistics(self, db: Session) -> Dict[str, Any]: @@ -431,11 +470,7 @@ class StatsService: Note: TODO: Implement when Order model is fully available """ - return { - "total_orders": 0, - "pending_orders": 0, - "completed_orders": 0 - } + return {"total_orders": 0, "pending_orders": 0, "completed_orders": 0} def get_product_statistics(self, db: Session) -> Dict[str, Any]: """ @@ -450,11 +485,7 @@ class StatsService: Note: TODO: Implement when Product model is fully available """ - return { - "total_products": 0, - "active_products": 0, - "out_of_stock": 0 - } + return {"total_products": 0, "active_products": 0, "out_of_stock": 0} # ======================================================================== # PRIVATE HELPER METHODS @@ -491,8 +522,7 @@ class StatsService: return ( db.query(MarketplaceProduct.brand) .filter( - MarketplaceProduct.brand.isnot(None), - MarketplaceProduct.brand != "" + MarketplaceProduct.brand.isnot(None), MarketplaceProduct.brand != "" ) .distinct() .count() diff --git a/app/services/team_service.py b/app/services/team_service.py index 543df373..84799d22 100644 --- a/app/services/team_service.py +++ b/app/services/team_service.py @@ -9,17 +9,15 @@ This module provides: """ import logging -from typing import List, Dict, Any from datetime import datetime, timezone +from typing import Any, Dict, List from sqlalchemy.orm import Session -from app.exceptions import ( - ValidationException, - UnauthorizedVendorAccessException, -) -from models.database.vendor import VendorUser, Role +from app.exceptions import (UnauthorizedVendorAccessException, + ValidationException) from models.database.user import User +from models.database.vendor import Role, VendorUser logger = logging.getLogger(__name__) @@ -28,7 +26,7 @@ class TeamService: """Service for team management operations.""" def get_team_members( - self, db: Session, vendor_id: int, current_user: User + self, db: Session, vendor_id: int, current_user: User ) -> List[Dict[str, Any]]: """ Get all team members for vendor. @@ -42,23 +40,26 @@ class TeamService: List of team members """ try: - vendor_users = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor_id, - VendorUser.is_active == True - ).all() + vendor_users = ( + db.query(VendorUser) + .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) + .all() + ) members = [] for vu in vendor_users: - members.append({ - "id": vu.user_id, - "email": vu.user.email, - "first_name": vu.user.first_name, - "last_name": vu.user.last_name, - "role": vu.role.name, - "role_id": vu.role_id, - "is_active": vu.is_active, - "joined_at": vu.created_at, - }) + members.append( + { + "id": vu.user_id, + "email": vu.user.email, + "first_name": vu.user.first_name, + "last_name": vu.user.last_name, + "role": vu.role.name, + "role_id": vu.role_id, + "is_active": vu.is_active, + "joined_at": vu.created_at, + } + ) return members @@ -67,7 +68,7 @@ class TeamService: raise ValidationException("Failed to retrieve team members") def invite_team_member( - self, db: Session, vendor_id: int, invitation_data: dict, current_user: User + self, db: Session, vendor_id: int, invitation_data: dict, current_user: User ) -> Dict[str, Any]: """ Invite a new team member. @@ -95,12 +96,12 @@ class TeamService: raise ValidationException("Failed to invite team member") def update_team_member( - self, - db: Session, - vendor_id: int, - user_id: int, - update_data: dict, - current_user: User + self, + db: Session, + vendor_id: int, + user_id: int, + update_data: dict, + current_user: User, ) -> Dict[str, Any]: """ Update team member role or status. @@ -116,10 +117,13 @@ class TeamService: Updated member info """ try: - vendor_user = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor_id, - VendorUser.user_id == user_id - ).first() + vendor_user = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id + ) + .first() + ) if not vendor_user: raise ValidationException("Team member not found") @@ -146,7 +150,7 @@ class TeamService: raise ValidationException("Failed to update team member") def remove_team_member( - self, db: Session, vendor_id: int, user_id: int, current_user: User + self, db: Session, vendor_id: int, user_id: int, current_user: User ) -> bool: """ Remove team member from vendor. @@ -161,10 +165,13 @@ class TeamService: True if removed """ try: - vendor_user = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor_id, - VendorUser.user_id == user_id - ).first() + vendor_user = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id + ) + .first() + ) if not vendor_user: raise ValidationException("Team member not found") diff --git a/app/services/vendor_domain_service.py b/app/services/vendor_domain_service.py index 73fd9da3..1d1776fb 100644 --- a/app/services/vendor_domain_service.py +++ b/app/services/vendor_domain_service.py @@ -12,30 +12,28 @@ This module provides classes and functions for: import logging import secrets -from typing import List, Tuple, Optional from datetime import datetime, timezone +from typing import List, Optional, Tuple -from sqlalchemy.orm import Session from sqlalchemy import and_ +from sqlalchemy.orm import Session -from app.exceptions import ( - VendorNotFoundException, - VendorDomainNotFoundException, - VendorDomainAlreadyExistsException, - InvalidDomainFormatException, - ReservedDomainException, - DomainNotVerifiedException, - DomainVerificationFailedException, - DomainAlreadyVerifiedException, - MultiplePrimaryDomainsException, - DNSVerificationException, - MaxDomainsReachedException, - UnauthorizedDomainAccessException, - ValidationException, -) -from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate +from app.exceptions import (DNSVerificationException, + DomainAlreadyVerifiedException, + DomainNotVerifiedException, + DomainVerificationFailedException, + InvalidDomainFormatException, + MaxDomainsReachedException, + MultiplePrimaryDomainsException, + ReservedDomainException, + UnauthorizedDomainAccessException, + ValidationException, + VendorDomainAlreadyExistsException, + VendorDomainNotFoundException, + VendorNotFoundException) from models.database.vendor import Vendor from models.database.vendor_domain import VendorDomain +from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate logger = logging.getLogger(__name__) @@ -45,13 +43,19 @@ class VendorDomainService: def __init__(self): self.max_domains_per_vendor = 10 # Configure as needed - self.reserved_subdomains = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail'] + self.reserved_subdomains = [ + "www", + "admin", + "api", + "mail", + "smtp", + "ftp", + "cpanel", + "webmail", + ] def add_domain( - self, - db: Session, - vendor_id: int, - domain_data: VendorDomainCreate + self, db: Session, vendor_id: int, domain_data: VendorDomainCreate ) -> VendorDomain: """ Add a custom domain to vendor. @@ -85,12 +89,14 @@ class VendorDomainService: # Check if domain already exists if self._domain_exists(db, normalized_domain): - existing_domain = db.query(VendorDomain).filter( - VendorDomain.domain == normalized_domain - ).first() + existing_domain = ( + db.query(VendorDomain) + .filter(VendorDomain.domain == normalized_domain) + .first() + ) raise VendorDomainAlreadyExistsException( normalized_domain, - existing_domain.vendor_id if existing_domain else None + existing_domain.vendor_id if existing_domain else None, ) # If setting as primary, unset other primary domains @@ -104,8 +110,8 @@ class VendorDomainService: is_primary=domain_data.is_primary, verification_token=secrets.token_urlsafe(32), is_verified=False, # Requires DNS verification - is_active=False, # Cannot be active until verified - ssl_status="pending" + is_active=False, # Cannot be active until verified + ssl_status="pending", ) db.add(new_domain) @@ -120,7 +126,7 @@ class VendorDomainService: VendorDomainAlreadyExistsException, MaxDomainsReachedException, InvalidDomainFormatException, - ReservedDomainException + ReservedDomainException, ): db.rollback() raise @@ -129,11 +135,7 @@ class VendorDomainService: logger.error(f"Error adding domain: {str(e)}") raise ValidationException("Failed to add domain") - def get_vendor_domains( - self, - db: Session, - vendor_id: int - ) -> List[VendorDomain]: + def get_vendor_domains(self, db: Session, vendor_id: int) -> List[VendorDomain]: """ Get all domains for a vendor. @@ -151,12 +153,14 @@ class VendorDomainService: # Verify vendor exists self._get_vendor_by_id_or_raise(db, vendor_id) - domains = db.query(VendorDomain).filter( - VendorDomain.vendor_id == vendor_id - ).order_by( - VendorDomain.is_primary.desc(), - VendorDomain.created_at.desc() - ).all() + domains = ( + db.query(VendorDomain) + .filter(VendorDomain.vendor_id == vendor_id) + .order_by( + VendorDomain.is_primary.desc(), VendorDomain.created_at.desc() + ) + .all() + ) return domains @@ -166,11 +170,7 @@ class VendorDomainService: logger.error(f"Error getting vendor domains: {str(e)}") raise ValidationException("Failed to retrieve domains") - def get_domain_by_id( - self, - db: Session, - domain_id: int - ) -> VendorDomain: + def get_domain_by_id(self, db: Session, domain_id: int) -> VendorDomain: """ Get domain by ID. @@ -190,10 +190,7 @@ class VendorDomainService: return domain def update_domain( - self, - db: Session, - domain_id: int, - domain_update: VendorDomainUpdate + self, db: Session, domain_id: int, domain_update: VendorDomainUpdate ) -> VendorDomain: """ Update domain settings. @@ -215,7 +212,9 @@ class VendorDomainService: # If setting as primary, unset other primary domains if domain_update.is_primary: - self._unset_primary_domains(db, domain.vendor_id, exclude_domain_id=domain_id) + self._unset_primary_domains( + db, domain.vendor_id, exclude_domain_id=domain_id + ) domain.is_primary = True # If activating, check verification @@ -240,11 +239,7 @@ class VendorDomainService: logger.error(f"Error updating domain: {str(e)}") raise ValidationException("Failed to update domain") - def delete_domain( - self, - db: Session, - domain_id: int - ) -> str: + def delete_domain(self, db: Session, domain_id: int) -> str: """ Delete a custom domain. @@ -277,11 +272,7 @@ class VendorDomainService: logger.error(f"Error deleting domain: {str(e)}") raise ValidationException("Failed to delete domain") - def verify_domain( - self, - db: Session, - domain_id: int - ) -> Tuple[VendorDomain, str]: + def verify_domain(self, db: Session, domain_id: int) -> Tuple[VendorDomain, str]: """ Verify domain ownership via DNS TXT record. @@ -313,8 +304,7 @@ class VendorDomainService: # Query DNS TXT records try: txt_records = dns.resolver.resolve( - f"_wizamart-verify.{domain.domain}", - 'TXT' + f"_wizamart-verify.{domain.domain}", "TXT" ) # Check if verification token is present @@ -332,42 +322,33 @@ class VendorDomainService: # Token not found raise DomainVerificationFailedException( - domain.domain, - "Verification token not found in DNS records" + domain.domain, "Verification token not found in DNS records" ) except dns.resolver.NXDOMAIN: raise DomainVerificationFailedException( domain.domain, - f"DNS record _wizamart-verify.{domain.domain} not found" + f"DNS record _wizamart-verify.{domain.domain} not found", ) except dns.resolver.NoAnswer: raise DomainVerificationFailedException( - domain.domain, - "No TXT records found for verification" + domain.domain, "No TXT records found for verification" ) except Exception as dns_error: - raise DNSVerificationException( - domain.domain, - str(dns_error) - ) + raise DNSVerificationException(domain.domain, str(dns_error)) except ( VendorDomainNotFoundException, DomainAlreadyVerifiedException, DomainVerificationFailedException, - DNSVerificationException + DNSVerificationException, ): raise except Exception as e: logger.error(f"Error verifying domain: {str(e)}") raise ValidationException("Failed to verify domain") - def get_verification_instructions( - self, - db: Session, - domain_id: int - ) -> dict: + def get_verification_instructions(self, db: Session, domain_id: int) -> dict: """ Get DNS verification instructions for domain. @@ -390,20 +371,20 @@ class VendorDomainService: "step1": "Go to your domain's DNS settings (at your domain registrar)", "step2": "Add a new TXT record with the following values:", "step3": "Wait for DNS propagation (5-15 minutes)", - "step4": "Click 'Verify Domain' button in admin panel" + "step4": "Click 'Verify Domain' button in admin panel", }, "txt_record": { "type": "TXT", "name": "_wizamart-verify", "value": domain.verification_token, - "ttl": 3600 + "ttl": 3600, }, "common_registrars": { "Cloudflare": "https://dash.cloudflare.com", "GoDaddy": "https://dcc.godaddy.com/manage/dns", "Namecheap": "https://www.namecheap.com/myaccount/domain-list/", - "Google Domains": "https://domains.google.com" - } + "Google Domains": "https://domains.google.com", + }, } # Private helper methods @@ -416,36 +397,33 @@ class VendorDomainService: def _check_domain_limit(self, db: Session, vendor_id: int) -> None: """Check if vendor has reached maximum domain limit.""" - domain_count = db.query(VendorDomain).filter( - VendorDomain.vendor_id == vendor_id - ).count() + domain_count = ( + db.query(VendorDomain).filter(VendorDomain.vendor_id == vendor_id).count() + ) if domain_count >= self.max_domains_per_vendor: raise MaxDomainsReachedException(vendor_id, self.max_domains_per_vendor) def _domain_exists(self, db: Session, domain: str) -> bool: """Check if domain already exists in system.""" - return db.query(VendorDomain).filter( - VendorDomain.domain == domain - ).first() is not None + return ( + db.query(VendorDomain).filter(VendorDomain.domain == domain).first() + is not None + ) def _validate_domain_format(self, domain: str) -> None: """Validate domain format and check for reserved subdomains.""" # Check for reserved subdomains - first_part = domain.split('.')[0] + first_part = domain.split(".")[0] if first_part in self.reserved_subdomains: raise ReservedDomainException(domain, first_part) def _unset_primary_domains( - self, - db: Session, - vendor_id: int, - exclude_domain_id: Optional[int] = None + self, db: Session, vendor_id: int, exclude_domain_id: Optional[int] = None ) -> None: """Unset all primary domains for vendor.""" query = db.query(VendorDomain).filter( - VendorDomain.vendor_id == vendor_id, - VendorDomain.is_primary == True + VendorDomain.vendor_id == vendor_id, VendorDomain.is_primary == True ) if exclude_domain_id: diff --git a/app/services/vendor_service.py b/app/services/vendor_service.py index 3185e56b..d686c201 100644 --- a/app/services/vendor_service.py +++ b/app/services/vendor_service.py @@ -15,22 +15,19 @@ from typing import List, Optional, Tuple from sqlalchemy import func from sqlalchemy.orm import Session -from app.exceptions import ( - VendorNotFoundException, - VendorAlreadyExistsException, - UnauthorizedVendorAccessException, - InvalidVendorDataException, - MarketplaceProductNotFoundException, - ProductAlreadyExistsException, - MaxVendorsReachedException, - ValidationException, -) -from models.schema.vendor import VendorCreate -from models.schema.product import ProductCreate +from app.exceptions import (InvalidVendorDataException, + MarketplaceProductNotFoundException, + MaxVendorsReachedException, + ProductAlreadyExistsException, + UnauthorizedVendorAccessException, + ValidationException, VendorAlreadyExistsException, + VendorNotFoundException) from models.database.marketplace_product import MarketplaceProduct -from models.database.vendor import Vendor from models.database.product import Product from models.database.user import User +from models.database.vendor import Vendor +from models.schema.product import ProductCreate +from models.schema.vendor import VendorCreate logger = logging.getLogger(__name__) @@ -39,7 +36,7 @@ class VendorService: """Service class for vendor operations following the application's service pattern.""" def create_vendor( - self, db: Session, vendor_data: VendorCreate, current_user: User + self, db: Session, vendor_data: VendorCreate, current_user: User ) -> Vendor: """ Create a new vendor. @@ -47,7 +44,7 @@ class VendorService: Args: db: Database session vendor_data: Vendor creation data - current_user: User creating the vendor + current_user: User creating the vendor Returns: Created vendor object @@ -91,7 +88,11 @@ class VendorService: ) return new_vendor - except (VendorAlreadyExistsException, MaxVendorsReachedException, InvalidVendorDataException): + except ( + VendorAlreadyExistsException, + MaxVendorsReachedException, + InvalidVendorDataException, + ): db.rollback() raise # Re-raise custom exceptions except Exception as e: @@ -100,13 +101,13 @@ class VendorService: raise ValidationException("Failed to create vendor ") def get_vendors( - self, - db: Session, - current_user: User, - skip: int = 0, - limit: int = 100, - active_only: bool = True, - verified_only: bool = False, + self, + db: Session, + current_user: User, + skip: int = 0, + limit: int = 100, + active_only: bool = True, + verified_only: bool = False, ) -> Tuple[List[Vendor], int]: """ Get vendors with filtering. @@ -129,7 +130,10 @@ class VendorService: if current_user.role != "admin": query = query.filter( (Vendor.is_active == True) - & ((Vendor.is_verified == True) | (Vendor.owner_user_id == current_user.id)) + & ( + (Vendor.is_verified == True) + | (Vendor.owner_user_id == current_user.id) + ) ) else: # Admin can apply filters @@ -147,14 +151,16 @@ class VendorService: logger.error(f"Error getting vendors: {str(e)}") raise ValidationException("Failed to retrieve vendors") - def get_vendor_by_code(self, db: Session, vendor_code: str, current_user: User) -> Vendor: + def get_vendor_by_code( + self, db: Session, vendor_code: str, current_user: User + ) -> Vendor: """ Get vendor by vendor code with access control. Args: db: Database session vendor_code: Vendor code to find - current_user: Current user requesting the vendor + current_user: Current user requesting the vendor Returns: Vendor object @@ -170,14 +176,14 @@ class VendorService: .first() ) - if not vendor : + if not vendor: raise VendorNotFoundException(vendor_code) # Check access permissions if not self._can_access_vendor(vendor, current_user): raise UnauthorizedVendorAccessException(vendor_code, current_user.id) - return vendor + return vendor except (VendorNotFoundException, UnauthorizedVendorAccessException): raise # Re-raise custom exceptions @@ -186,7 +192,7 @@ class VendorService: raise ValidationException("Failed to retrieve vendor ") def add_product_to_catalog( - self, db: Session, vendor : Vendor, product: ProductCreate + self, db: Session, vendor: Vendor, product: ProductCreate ) -> Product: """ Add existing product to vendor catalog with vendor -specific settings. @@ -201,15 +207,19 @@ class VendorService: Raises: MarketplaceProductNotFoundException: If product not found - ProductAlreadyExistsException: If product already in vendor + ProductAlreadyExistsException: If product already in vendor """ try: # Check if product exists - marketplace_product = self._get_product_by_id_or_raise(db, product.marketplace_product_id) + marketplace_product = self._get_product_by_id_or_raise( + db, product.marketplace_product_id + ) - # Check if product already in vendor + # Check if product already in vendor if self._product_in_catalog(db, vendor.id, marketplace_product.id): - raise ProductAlreadyExistsException(vendor.vendor_code, product.marketplace_product_id) + raise ProductAlreadyExistsException( + vendor.vendor_code, product.marketplace_product_id + ) # Create vendor -product association new_product = Product( @@ -225,7 +235,9 @@ class VendorService: # Load the product relationship db.refresh(new_product) - logger.info(f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}") + logger.info( + f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}" + ) return new_product except (MarketplaceProductNotFoundException, ProductAlreadyExistsException): @@ -237,14 +249,14 @@ class VendorService: raise ValidationException("Failed to add product to vendor ") def get_products( - self, - db: Session, - vendor : Vendor, - current_user: User, - skip: int = 0, - limit: int = 100, - active_only: bool = True, - featured_only: bool = False, + self, + db: Session, + vendor: Vendor, + current_user: User, + skip: int = 0, + limit: int = 100, + active_only: bool = True, + featured_only: bool = False, ) -> Tuple[List[Product], int]: """ Get products in vendor catalog with filtering. @@ -267,7 +279,9 @@ class VendorService: try: # Check access permissions if not self._can_access_vendor(vendor, current_user): - raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id) + raise UnauthorizedVendorAccessException( + vendor.vendor_code, current_user.id + ) # Query vendor products query = db.query(Product).filter(Product.vendor_id == vendor.id) @@ -292,17 +306,20 @@ class VendorService: def _validate_vendor_data(self, vendor_data: VendorCreate) -> None: """Validate vendor creation data.""" if not vendor_data.vendor_code or not vendor_data.vendor_code.strip(): - raise InvalidVendorDataException("Vendor code is required", field="vendor_code") + raise InvalidVendorDataException( + "Vendor code is required", field="vendor_code" + ) if not vendor_data.vendor_name or not vendor_data.vendor_name.strip(): raise InvalidVendorDataException("Vendor name is required", field="name") # Validate vendor code format (alphanumeric, underscores, hyphens) import re - if not re.match(r'^[A-Za-z0-9_-]+$', vendor_data.vendor_code): + + if not re.match(r"^[A-Za-z0-9_-]+$", vendor_data.vendor_code): raise InvalidVendorDataException( "Vendor code can only contain letters, numbers, underscores, and hyphens", - field="vendor_code" + field="vendor_code", ) def _check_vendor_limit(self, db: Session, user: User) -> None: @@ -310,7 +327,9 @@ class VendorService: if user.role == "admin": return # Admins have no limit - user_vendor_count = db.query(Vendor).filter(Vendor.owner_user_id == user.id).count() + user_vendor_count = ( + db.query(Vendor).filter(Vendor.owner_user_id == user.id).count() + ) max_vendors = 5 # Configure this as needed if user_vendor_count >= max_vendors: @@ -319,30 +338,40 @@ class VendorService: def _vendor_code_exists(self, db: Session, vendor_code: str) -> bool: """Check if vendor code already exists (case-insensitive).""" return ( - db.query(Vendor) - .filter(func.upper(Vendor.vendor_code) == vendor_code.upper()) - .first() is not None + db.query(Vendor) + .filter(func.upper(Vendor.vendor_code) == vendor_code.upper()) + .first() + is not None ) - def _get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct: + def _get_product_by_id_or_raise( + self, db: Session, marketplace_product_id: str + ) -> MarketplaceProduct: """Get product by ID or raise exception.""" - product = db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first() + product = ( + db.query(MarketplaceProduct) + .filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id) + .first() + ) if not product: raise MarketplaceProductNotFoundException(marketplace_product_id) return product - def _product_in_catalog(self, db: Session, vendor_id: int, marketplace_product_id: int) -> bool: + def _product_in_catalog( + self, db: Session, vendor_id: int, marketplace_product_id: int + ) -> bool: """Check if product is already in vendor.""" return ( - db.query(Product) - .filter( - Product.vendor_id == vendor_id, - Product.marketplace_product_id == marketplace_product_id - ) - .first() is not None + db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.marketplace_product_id == marketplace_product_id, + ) + .first() + is not None ) - def _can_access_vendor(self, vendor : Vendor, user: User) -> bool: + def _can_access_vendor(self, vendor: Vendor, user: User) -> bool: """Check if user can access vendor.""" # Admins and owners can always access if user.role == "admin" or vendor.owner_user_id == user.id: @@ -351,9 +380,10 @@ class VendorService: # Others can only access active and verified vendors return vendor.is_active and vendor.is_verified - def _is_vendor_owner(self, vendor : Vendor, user: User) -> bool: + def _is_vendor_owner(self, vendor: Vendor, user: User) -> bool: """Check if user is vendor owner.""" return vendor.owner_user_id == user.id + # Create service instance following the same pattern as other services vendor_service = VendorService() diff --git a/app/services/vendor_team_service.py b/app/services/vendor_team_service.py index 763ea8f1..2ed95c29 100644 --- a/app/services/vendor_team_service.py +++ b/app/services/vendor_team_service.py @@ -11,23 +11,20 @@ Handles: import logging import secrets from datetime import datetime, timedelta -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from sqlalchemy.orm import Session from app.core.permissions import get_preset_permissions -from app.exceptions import ( - TeamMemberAlreadyExistsException, - InvalidInvitationTokenException, - TeamInvitationAlreadyAcceptedException, - MaxTeamMembersReachedException, - UserNotFoundException, - VendorNotFoundException, - CannotRemoveOwnerException, -) -from models.database.user import User -from models.database.vendor import Vendor, VendorUser, VendorUserType, Role +from app.exceptions import (CannotRemoveOwnerException, + InvalidInvitationTokenException, + MaxTeamMembersReachedException, + TeamInvitationAlreadyAcceptedException, + TeamMemberAlreadyExistsException, + UserNotFoundException, VendorNotFoundException) from middleware.auth import AuthManager +from models.database.user import User +from models.database.vendor import Role, Vendor, VendorUser, VendorUserType logger = logging.getLogger(__name__) @@ -40,13 +37,13 @@ class VendorTeamService: self.max_team_members = 50 # Configure as needed def invite_team_member( - self, - db: Session, - vendor: Vendor, - inviter: User, - email: str, - role_name: str, - custom_permissions: Optional[List[str]] = None, + self, + db: Session, + vendor: Vendor, + inviter: User, + email: str, + role_name: str, + custom_permissions: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Invite a new team member to a vendor. @@ -69,10 +66,14 @@ class VendorTeamService: """ try: # Check team size limit - current_team_size = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor.id, - VendorUser.is_active == True, - ).count() + current_team_size = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor.id, + VendorUser.is_active == True, + ) + .count() + ) if current_team_size >= self.max_team_members: raise MaxTeamMembersReachedException( @@ -85,22 +86,34 @@ class VendorTeamService: if user: # Check if already a member - existing_membership = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor.id, - VendorUser.user_id == user.id, - ).first() + existing_membership = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor.id, + VendorUser.user_id == user.id, + ) + .first() + ) if existing_membership: if existing_membership.is_active: - raise TeamMemberAlreadyExistsException(email, vendor.vendor_code) + raise TeamMemberAlreadyExistsException( + email, vendor.vendor_code + ) # Reactivate old membership - existing_membership.is_active = False # Will be activated on acceptance - existing_membership.invitation_token = self._generate_invitation_token() + existing_membership.is_active = ( + False # Will be activated on acceptance + ) + existing_membership.invitation_token = ( + self._generate_invitation_token() + ) existing_membership.invitation_sent_at = datetime.utcnow() existing_membership.invitation_accepted_at = None db.commit() - logger.info(f"Re-invited user {email} to vendor {vendor.vendor_code}") + logger.info( + f"Re-invited user {email} to vendor {vendor.vendor_code}" + ) return { "invitation_token": existing_membership.invitation_token, "email": email, @@ -108,7 +121,7 @@ class VendorTeamService: } else: # Create new user account (inactive until invitation accepted) - username = email.split('@')[0] + username = email.split("@")[0] # Ensure unique username base_username = username counter = 1 @@ -179,12 +192,12 @@ class VendorTeamService: raise def accept_invitation( - self, - db: Session, - invitation_token: str, - password: str, - first_name: Optional[str] = None, - last_name: Optional[str] = None, + self, + db: Session, + invitation_token: str, + password: str, + first_name: Optional[str] = None, + last_name: Optional[str] = None, ) -> Dict[str, Any]: """ Accept a team invitation and activate account. @@ -201,9 +214,13 @@ class VendorTeamService: """ try: # Find invitation - vendor_user = db.query(VendorUser).filter( - VendorUser.invitation_token == invitation_token, - ).first() + vendor_user = ( + db.query(VendorUser) + .filter( + VendorUser.invitation_token == invitation_token, + ) + .first() + ) if not vendor_user: raise InvalidInvitationTokenException() @@ -247,7 +264,10 @@ class VendorTeamService: "role": vendor_user.role.name if vendor_user.role else "member", } - except (InvalidInvitationTokenException, TeamInvitationAlreadyAcceptedException): + except ( + InvalidInvitationTokenException, + TeamInvitationAlreadyAcceptedException, + ): raise except Exception as e: db.rollback() @@ -255,10 +275,10 @@ class VendorTeamService: raise def remove_team_member( - self, - db: Session, - vendor: Vendor, - user_id: int, + self, + db: Session, + vendor: Vendor, + user_id: int, ) -> bool: """ Remove a team member from a vendor. @@ -274,10 +294,14 @@ class VendorTeamService: True if removed """ try: - vendor_user = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor.id, - VendorUser.user_id == user_id, - ).first() + vendor_user = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor.id, + VendorUser.user_id == user_id, + ) + .first() + ) if not vendor_user: raise UserNotFoundException(str(user_id)) @@ -301,12 +325,12 @@ class VendorTeamService: raise def update_member_role( - self, - db: Session, - vendor: Vendor, - user_id: int, - new_role_name: str, - custom_permissions: Optional[List[str]] = None, + self, + db: Session, + vendor: Vendor, + user_id: int, + new_role_name: str, + custom_permissions: Optional[List[str]] = None, ) -> VendorUser: """ Update a team member's role. @@ -322,10 +346,14 @@ class VendorTeamService: Updated VendorUser """ try: - vendor_user = db.query(VendorUser).filter( - VendorUser.vendor_id == vendor.id, - VendorUser.user_id == user_id, - ).first() + vendor_user = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor.id, + VendorUser.user_id == user_id, + ) + .first() + ) if not vendor_user: raise UserNotFoundException(str(user_id)) @@ -360,10 +388,10 @@ class VendorTeamService: raise def get_team_members( - self, - db: Session, - vendor: Vendor, - include_inactive: bool = False, + self, + db: Session, + vendor: Vendor, + include_inactive: bool = False, ) -> List[Dict[str, Any]]: """ Get all team members for a vendor. @@ -387,20 +415,22 @@ class VendorTeamService: members = [] for vu in vendor_users: - members.append({ - "id": vu.user.id, - "email": vu.user.email, - "username": vu.user.username, - "full_name": vu.user.full_name, - "user_type": vu.user_type, - "role": vu.role.name if vu.role else "owner", - "permissions": vu.get_all_permissions(), - "is_active": vu.is_active, - "is_owner": vu.is_owner, - "invitation_pending": vu.is_invitation_pending, - "invited_at": vu.invitation_sent_at, - "accepted_at": vu.invitation_accepted_at, - }) + members.append( + { + "id": vu.user.id, + "email": vu.user.email, + "username": vu.user.username, + "full_name": vu.user.full_name, + "user_type": vu.user_type, + "role": vu.role.name if vu.role else "owner", + "permissions": vu.get_all_permissions(), + "is_active": vu.is_active, + "is_owner": vu.is_owner, + "invitation_pending": vu.is_invitation_pending, + "invited_at": vu.invitation_sent_at, + "accepted_at": vu.invitation_accepted_at, + } + ) return members @@ -411,18 +441,22 @@ class VendorTeamService: return secrets.token_urlsafe(32) def _get_or_create_role( - self, - db: Session, - vendor: Vendor, - role_name: str, - custom_permissions: Optional[List[str]] = None, + self, + db: Session, + vendor: Vendor, + role_name: str, + custom_permissions: Optional[List[str]] = None, ) -> Role: """Get existing role or create new one with preset/custom permissions.""" # Try to find existing role with same name - role = db.query(Role).filter( - Role.vendor_id == vendor.id, - Role.name == role_name, - ).first() + role = ( + db.query(Role) + .filter( + Role.vendor_id == vendor.id, + Role.name == role_name, + ) + .first() + ) if role and custom_permissions is None: # Use existing role diff --git a/app/services/vendor_theme_service.py b/app/services/vendor_theme_service.py index 067f022b..46cd3e3b 100644 --- a/app/services/vendor_theme_service.py +++ b/app/services/vendor_theme_service.py @@ -8,32 +8,24 @@ Handles theme CRUD operations, preset application, and validation. import logging import re -from typing import Optional, Dict, List +from typing import Dict, List, Optional + from sqlalchemy.orm import Session +from app.core.theme_presets import (THEME_PRESETS, apply_preset, + get_available_presets, get_preset_preview) +from app.exceptions.vendor import VendorNotFoundException +from app.exceptions.vendor_theme import (InvalidColorFormatException, + InvalidFontFamilyException, + InvalidThemeDataException, + ThemeOperationException, + ThemePresetAlreadyAppliedException, + ThemePresetNotFoundException, + ThemeValidationException, + VendorThemeNotFoundException) from models.database.vendor import Vendor from models.database.vendor_theme import VendorTheme -from models.schema.vendor_theme import ( - VendorThemeUpdate, - ThemePresetPreview -) -from app.exceptions.vendor import VendorNotFoundException -from app.exceptions.vendor_theme import ( - VendorThemeNotFoundException, - InvalidThemeDataException, - ThemePresetNotFoundException, - ThemeValidationException, - InvalidColorFormatException, - InvalidFontFamilyException, - ThemePresetAlreadyAppliedException, - ThemeOperationException -) -from app.core.theme_presets import ( - apply_preset, - get_available_presets, - get_preset_preview, - THEME_PRESETS -) +from models.schema.vendor_theme import ThemePresetPreview, VendorThemeUpdate logger = logging.getLogger(__name__) @@ -71,9 +63,9 @@ class VendorThemeService: Raises: VendorNotFoundException: If vendor not found """ - vendor = db.query(Vendor).filter( - Vendor.vendor_code == vendor_code.upper() - ).first() + vendor = ( + db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first() + ) if not vendor: self.logger.warning(f"Vendor not found: {vendor_code}") @@ -105,12 +97,12 @@ class VendorThemeService: vendor = self._get_vendor_by_code(db, vendor_code) # Get theme - theme = db.query(VendorTheme).filter( - VendorTheme.vendor_id == vendor.id - ).first() + theme = db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first() if not theme: - self.logger.info(f"No custom theme for vendor {vendor_code}, returning default") + self.logger.info( + f"No custom theme for vendor {vendor_code}, returning default" + ) return self._get_default_theme() return theme.to_dict() @@ -130,23 +122,16 @@ class VendorThemeService: "accent": "#ec4899", "background": "#ffffff", "text": "#1f2937", - "border": "#e5e7eb" - }, - "fonts": { - "heading": "Inter, sans-serif", - "body": "Inter, sans-serif" + "border": "#e5e7eb", }, + "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"}, "branding": { "logo": None, "logo_dark": None, "favicon": None, - "banner": None - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "modern" + "banner": None, }, + "layout": {"style": "grid", "header": "fixed", "product_card": "modern"}, "social_links": {}, "custom_css": None, "css_variables": { @@ -158,7 +143,7 @@ class VendorThemeService: "--color-border": "#e5e7eb", "--font-heading": "Inter, sans-serif", "--font-body": "Inter, sans-serif", - } + }, } # ============================================================================ @@ -166,10 +151,7 @@ class VendorThemeService: # ============================================================================ def update_theme( - self, - db: Session, - vendor_code: str, - theme_data: VendorThemeUpdate + self, db: Session, vendor_code: str, theme_data: VendorThemeUpdate ) -> VendorTheme: """ Update or create theme for vendor. @@ -194,9 +176,9 @@ class VendorThemeService: vendor = self._get_vendor_by_code(db, vendor_code) # Get or create theme - theme = db.query(VendorTheme).filter( - VendorTheme.vendor_id == vendor.id - ).first() + theme = ( + db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first() + ) if not theme: self.logger.info(f"Creating new theme for vendor {vendor_code}") @@ -224,15 +206,11 @@ class VendorThemeService: db.rollback() self.logger.error(f"Failed to update theme for vendor {vendor_code}: {e}") raise ThemeOperationException( - operation="update", - vendor_code=vendor_code, - reason=str(e) + operation="update", vendor_code=vendor_code, reason=str(e) ) def _apply_theme_updates( - self, - theme: VendorTheme, - theme_data: VendorThemeUpdate + self, theme: VendorTheme, theme_data: VendorThemeUpdate ) -> None: """ Apply theme updates to theme object. @@ -251,30 +229,30 @@ class VendorThemeService: # Update fonts if theme_data.fonts: - if theme_data.fonts.get('heading'): - theme.font_family_heading = theme_data.fonts['heading'] - if theme_data.fonts.get('body'): - theme.font_family_body = theme_data.fonts['body'] + if theme_data.fonts.get("heading"): + theme.font_family_heading = theme_data.fonts["heading"] + if theme_data.fonts.get("body"): + theme.font_family_body = theme_data.fonts["body"] # Update branding if theme_data.branding: - if theme_data.branding.get('logo') is not None: - theme.logo_url = theme_data.branding['logo'] - if theme_data.branding.get('logo_dark') is not None: - theme.logo_dark_url = theme_data.branding['logo_dark'] - if theme_data.branding.get('favicon') is not None: - theme.favicon_url = theme_data.branding['favicon'] - if theme_data.branding.get('banner') is not None: - theme.banner_url = theme_data.branding['banner'] + if theme_data.branding.get("logo") is not None: + theme.logo_url = theme_data.branding["logo"] + if theme_data.branding.get("logo_dark") is not None: + theme.logo_dark_url = theme_data.branding["logo_dark"] + if theme_data.branding.get("favicon") is not None: + theme.favicon_url = theme_data.branding["favicon"] + if theme_data.branding.get("banner") is not None: + theme.banner_url = theme_data.branding["banner"] # Update layout if theme_data.layout: - if theme_data.layout.get('style'): - theme.layout_style = theme_data.layout['style'] - if theme_data.layout.get('header'): - theme.header_style = theme_data.layout['header'] - if theme_data.layout.get('product_card'): - theme.product_card_style = theme_data.layout['product_card'] + if theme_data.layout.get("style"): + theme.layout_style = theme_data.layout["style"] + if theme_data.layout.get("header"): + theme.header_style = theme_data.layout["header"] + if theme_data.layout.get("product_card"): + theme.product_card_style = theme_data.layout["product_card"] # Update custom CSS if theme_data.custom_css is not None: @@ -289,10 +267,7 @@ class VendorThemeService: # ============================================================================ def apply_theme_preset( - self, - db: Session, - vendor_code: str, - preset_name: str + self, db: Session, vendor_code: str, preset_name: str ) -> VendorTheme: """ Apply a theme preset to vendor. @@ -322,9 +297,9 @@ class VendorThemeService: vendor = self._get_vendor_by_code(db, vendor_code) # Get or create theme - theme = db.query(VendorTheme).filter( - VendorTheme.vendor_id == vendor.id - ).first() + theme = ( + db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first() + ) if not theme: self.logger.info(f"Creating new theme for vendor {vendor_code}") @@ -338,7 +313,9 @@ class VendorThemeService: db.commit() db.refresh(theme) - self.logger.info(f"Preset '{preset_name}' applied successfully to vendor {vendor_code}") + self.logger.info( + f"Preset '{preset_name}' applied successfully to vendor {vendor_code}" + ) return theme except (VendorNotFoundException, ThemePresetNotFoundException): @@ -349,9 +326,7 @@ class VendorThemeService: db.rollback() self.logger.error(f"Failed to apply preset to vendor {vendor_code}: {e}") raise ThemeOperationException( - operation="apply_preset", - vendor_code=vendor_code, - reason=str(e) + operation="apply_preset", vendor_code=vendor_code, reason=str(e) ) def get_available_presets(self) -> List[ThemePresetPreview]: @@ -399,9 +374,9 @@ class VendorThemeService: vendor = self._get_vendor_by_code(db, vendor_code) # Get theme - theme = db.query(VendorTheme).filter( - VendorTheme.vendor_id == vendor.id - ).first() + theme = ( + db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first() + ) if not theme: raise VendorThemeNotFoundException(vendor_code) @@ -423,9 +398,7 @@ class VendorThemeService: db.rollback() self.logger.error(f"Failed to delete theme for vendor {vendor_code}: {e}") raise ThemeOperationException( - operation="delete", - vendor_code=vendor_code, - reason=str(e) + operation="delete", vendor_code=vendor_code, reason=str(e) ) # ============================================================================ @@ -459,9 +432,9 @@ class VendorThemeService: # Validate layout values if theme_data.layout: valid_layouts = { - 'style': ['grid', 'list', 'masonry'], - 'header': ['fixed', 'static', 'transparent'], - 'product_card': ['modern', 'classic', 'minimal'] + "style": ["grid", "list", "masonry"], + "header": ["fixed", "static", "transparent"], + "product_card": ["modern", "classic", "minimal"], } for layout_key, layout_value in theme_data.layout.items(): @@ -472,7 +445,7 @@ class VendorThemeService: field=layout_key, validation_errors={ layout_key: f"Must be one of: {', '.join(valid_layouts[layout_key])}" - } + }, ) def _is_valid_color(self, color: str) -> bool: @@ -489,7 +462,7 @@ class VendorThemeService: return False # Check for hex color format (#RGB or #RRGGBB) - hex_pattern = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$' + hex_pattern = r"^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$" return bool(re.match(hex_pattern, color)) def _is_valid_font(self, font: str) -> bool: diff --git a/app/tasks/background_tasks.py b/app/tasks/background_tasks.py index 7f1ae503..87ca42f4 100644 --- a/app/tasks/background_tasks.py +++ b/app/tasks/background_tasks.py @@ -3,9 +3,9 @@ import logging from datetime import datetime, timezone from app.core.database import SessionLocal +from app.utils.csv_processor import CSVProcessor from models.database.marketplace_import_job import MarketplaceImportJob from models.database.vendor import Vendor -from app.utils.csv_processor import CSVProcessor logger = logging.getLogger(__name__) @@ -15,7 +15,7 @@ async def process_marketplace_import( url: str, marketplace: str, vendor_id: int, # FIXED: Changed from vendor_name to vendor_id - batch_size: int = 1000 + batch_size: int = 1000, ): """Background task to process marketplace CSV import.""" db = SessionLocal() @@ -59,7 +59,7 @@ async def process_marketplace_import( marketplace, vendor_id, # FIXED: Pass vendor_id instead of vendor_name batch_size, - db + db, ) # Update job with results diff --git a/app/utils/csv_processor.py b/app/utils/csv_processor.py index 7bfd7300..6fb9b148 100644 --- a/app/utils/csv_processor.py +++ b/app/utils/csv_processor.py @@ -195,7 +195,7 @@ class CSVProcessor: Args: url: URL to the CSV file marketplace: Name of the marketplace (e.g., 'Letzshop', 'Amazon') - vendor_name: Name of the vendor + vendor_name: Name of the vendor batch_size: Number of rows to process in each batch db: Database session @@ -267,7 +267,9 @@ class CSVProcessor: # Validate required fields if not product_data.get("marketplace_product_id"): - logger.warning(f"Row {index}: Missing marketplace_product_id, skipping") + logger.warning( + f"Row {index}: Missing marketplace_product_id, skipping" + ) errors += 1 continue @@ -279,7 +281,10 @@ class CSVProcessor: # Check if product exists existing_product = ( db.query(MarketplaceProduct) - .filter(MarketplaceProduct.marketplace_product_id == literal(product_data["marketplace_product_id"])) + .filter( + MarketplaceProduct.marketplace_product_id + == literal(product_data["marketplace_product_id"]) + ) .first() ) diff --git a/app/utils/data_processing.py b/app/utils/data_processing.py index fe44665b..85168a39 100644 --- a/app/utils/data_processing.py +++ b/app/utils/data_processing.py @@ -109,7 +109,9 @@ class PriceProcessor: r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)), } - def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]: + def parse_price_currency( + self, price_str: any + ) -> Tuple[Optional[str], Optional[str]]: """ Parse a price string to extract the numeric value and currency. diff --git a/app/utils/database.py b/app/utils/database.py index 29c75b22..72fa3e58 100644 --- a/app/utils/database.py +++ b/app/utils/database.py @@ -7,6 +7,7 @@ This module provides utility functions and classes to interact with a database u """ import logging + from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool diff --git a/main.py b/main.py index 4f306a30..b1af97ab 100644 --- a/main.py +++ b/main.py @@ -9,13 +9,13 @@ Multi-tenant e-commerce marketplace platform with: - Middleware stack for context injection """ -import sys import io +import sys # Fix Windows console encoding issues (must be at the very top) -if sys.platform == 'win32': - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') - sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') +if sys.platform == "win32": + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8") import logging from datetime import datetime, timezone @@ -23,27 +23,25 @@ from pathlib import Path from fastapi import Depends, FastAPI, HTTPException, Request, Response from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse, RedirectResponse, FileResponse +from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlalchemy import text from sqlalchemy.orm import Session from app.api.main import api_router - -# Import page routers -from app.routes import admin_pages, vendor_pages, shop_pages from app.core.config import settings from app.core.database import get_db from app.core.lifespan import lifespan -from app.exceptions.handler import setup_exception_handlers from app.exceptions import ServiceUnavailableException - +from app.exceptions.handler import setup_exception_handlers +# Import page routers +from app.routes import admin_pages, shop_pages, vendor_pages +from middleware.context import ContextMiddleware +from middleware.logging import LoggingMiddleware +from middleware.theme_context import ThemeContextMiddleware # Import REFACTORED class-based middleware from middleware.vendor_context import VendorContextMiddleware -from middleware.context import ContextMiddleware -from middleware.theme_context import ThemeContextMiddleware -from middleware.logging import LoggingMiddleware logger = logging.getLogger(__name__) @@ -146,6 +144,7 @@ app.include_router(api_router, prefix="/api") # FAVICON ROUTES (Must be registered BEFORE page routers) # ============================================================================ + def serve_favicon() -> Response: """ Serve favicon with caching headers. @@ -164,7 +163,7 @@ def serve_favicon() -> Response: media_type="image/x-icon", headers={ "Cache-Control": "public, max-age=86400", # Cache for 1 day - } + }, ) return Response(status_code=204) @@ -194,10 +193,7 @@ logger.info("=" * 80) # Admin pages logger.info("Registering admin page routes: /admin/*") app.include_router( - admin_pages.router, - prefix="/admin", - tags=["admin-pages"], - include_in_schema=False + admin_pages.router, prefix="/admin", tags=["admin-pages"], include_in_schema=False ) # Vendor management pages (dashboard, products, orders, etc.) @@ -206,7 +202,7 @@ app.include_router( vendor_pages.router, prefix="/vendor", tags=["vendor-pages"], - include_in_schema=False + include_in_schema=False, ) # Customer shop pages - Register at TWO prefixes: @@ -217,46 +213,42 @@ logger.info(" - /shop/* (subdomain/custom domain mode)") logger.info(" - /vendors/{code}/shop/* (path-based development mode)") app.include_router( - shop_pages.router, - prefix="/shop", - tags=["shop-pages"], - include_in_schema=False + shop_pages.router, prefix="/shop", tags=["shop-pages"], include_in_schema=False ) app.include_router( shop_pages.router, prefix="/vendors/{vendor_code}/shop", tags=["shop-pages"], - include_in_schema=False + include_in_schema=False, ) + # Add handler for /vendors/{vendor_code}/ root path -@app.get("/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False) -async def vendor_root_path(vendor_code: str, request: Request, db: Session = Depends(get_db)): +@app.get( + "/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False +) +async def vendor_root_path( + vendor_code: str, request: Request, db: Session = Depends(get_db) +): """Handle vendor root path (e.g., /vendors/wizamart/)""" # Vendor should already be in request.state from middleware - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if not vendor: raise HTTPException(status_code=404, detail=f"Vendor '{vendor_code}' not found") - from app.services.content_page_service import content_page_service from app.routes.shop_pages import get_shop_context + from app.services.content_page_service import content_page_service # Try to find landing page landing_page = content_page_service.get_page_for_vendor( - db, - slug="landing", - vendor_id=vendor.id, - include_unpublished=False + db, slug="landing", vendor_id=vendor.id, include_unpublished=False ) if not landing_page: landing_page = content_page_service.get_page_for_vendor( - db, - slug="home", - vendor_id=vendor.id, - include_unpublished=False + db, slug="home", vendor_id=vendor.id, include_unpublished=False ) if landing_page: @@ -265,8 +257,7 @@ async def vendor_root_path(vendor_code: str, request: Request, db: Session = Dep template_path = f"vendor/landing-{template_name}.html" return templates.TemplateResponse( - template_path, - get_shop_context(request, db=db, page=landing_page) + template_path, get_shop_context(request, db=db, page=landing_page) ) else: # No landing page - redirect to shop @@ -298,22 +289,16 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)): db, slug="platform_homepage", vendor_id=None, # Platform-level page - include_unpublished=False + include_unpublished=False, ) # Load header and footer navigation header_pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=None, - header_only=True, - include_unpublished=False + db, vendor_id=None, header_only=True, include_unpublished=False ) footer_pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=None, - footer_only=True, - include_unpublished=False + db, vendor_id=None, footer_only=True, include_unpublished=False ) if homepage: @@ -330,7 +315,7 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)): "page": homepage, "header_pages": header_pages, "footer_pages": footer_pages, - } + }, ) else: # Fallback to default static template @@ -342,15 +327,13 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)): "request": request, "header_pages": header_pages, "footer_pages": footer_pages, - } + }, ) @app.get("/{slug}", response_class=HTMLResponse, include_in_schema=False) async def platform_content_page( - request: Request, - slug: str, - db: Session = Depends(get_db) + request: Request, slug: str, db: Session = Depends(get_db) ): """ Platform content pages: /about, /faq, /terms, /contact, etc. @@ -366,10 +349,7 @@ async def platform_content_page( # Load page from CMS page = content_page_service.get_page_for_vendor( - db, - slug=slug, - vendor_id=None, # Platform pages only - include_unpublished=False + db, slug=slug, vendor_id=None, include_unpublished=False # Platform pages only ) if not page: @@ -378,17 +358,11 @@ async def platform_content_page( # Load header and footer navigation header_pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=None, - header_only=True, - include_unpublished=False + db, vendor_id=None, header_only=True, include_unpublished=False ) footer_pages = content_page_service.list_pages_for_vendor( - db, - vendor_id=None, - footer_only=True, - include_unpublished=False + db, vendor_id=None, footer_only=True, include_unpublished=False ) logger.info(f"[PLATFORM] Rendering content page: {page.title} (/{slug})") @@ -400,7 +374,7 @@ async def platform_content_page( "page": page, "header_pages": header_pages, "footer_pages": footer_pages, - } + }, ) @@ -411,8 +385,8 @@ logger.info("=" * 80) logger.info("REGISTERED ROUTES SUMMARY") logger.info("=" * 80) for route in app.routes: - if hasattr(route, 'methods') and hasattr(route, 'path'): - methods = ', '.join(route.methods) if route.methods else 'N/A' + if hasattr(route, "methods") and hasattr(route, "path"): + methods = ", ".join(route.methods) if route.methods else "N/A" logger.info(f" {methods:<10} {route.path:<60}") logger.info("=" * 80) @@ -420,6 +394,7 @@ logger.info("=" * 80) # API ROUTES (JSON Responses) # ============================================================================ + # Public Routes (no authentication required) @app.get("/", response_class=HTMLResponse, include_in_schema=False) async def root(request: Request, db: Session = Depends(get_db)): @@ -428,7 +403,7 @@ async def root(request: Request, db: Session = Depends(get_db)): - If vendor detected (domain/subdomain): Show vendor landing page or redirect to shop - If no vendor (platform root): Redirect to documentation """ - vendor = getattr(request.state, 'vendor', None) + vendor = getattr(request.state, "vendor", None) if vendor: # Vendor context detected - serve landing page @@ -436,19 +411,13 @@ async def root(request: Request, db: Session = Depends(get_db)): # Try to find landing page (slug='landing' or 'home') landing_page = content_page_service.get_page_for_vendor( - db, - slug="landing", - vendor_id=vendor.id, - include_unpublished=False + db, slug="landing", vendor_id=vendor.id, include_unpublished=False ) if not landing_page: # Try 'home' slug as fallback landing_page = content_page_service.get_page_for_vendor( - db, - slug="home", - vendor_id=vendor.id, - include_unpublished=False + db, slug="home", vendor_id=vendor.id, include_unpublished=False ) if landing_page: @@ -459,17 +428,26 @@ async def root(request: Request, db: Session = Depends(get_db)): template_path = f"vendor/landing-{template_name}.html" return templates.TemplateResponse( - template_path, - get_shop_context(request, db=db, page=landing_page) + template_path, get_shop_context(request, db=db, page=landing_page) ) else: # No landing page - redirect to shop - vendor_context = getattr(request.state, 'vendor_context', None) - access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' + vendor_context = getattr(request.state, "vendor_context", None) + access_method = ( + vendor_context.get("detection_method", "unknown") + if vendor_context + else "unknown" + ) if access_method == "path": - full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' - return RedirectResponse(url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302) + full_prefix = ( + vendor_context.get("full_prefix", "/vendor/") + if vendor_context + else "/vendor/" + ) + return RedirectResponse( + url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302 + ) else: # Domain/subdomain return RedirectResponse(url="/shop/", status_code=302) diff --git a/middleware/auth.py b/middleware/auth.py index 81110a33..d37207ce 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -26,14 +26,10 @@ from jose import jwt from passlib.context import CryptContext from sqlalchemy.orm import Session -from app.exceptions import ( - AdminRequiredException, - InvalidTokenException, - TokenExpiredException, - UserNotActiveException, - InvalidCredentialsException, - InsufficientPermissionsException -) +from app.exceptions import (AdminRequiredException, + InsufficientPermissionsException, + InvalidCredentialsException, InvalidTokenException, + TokenExpiredException, UserNotActiveException) from models.database.user import User logger = logging.getLogger(__name__) @@ -99,7 +95,9 @@ class AuthManager: """ return pwd_context.verify(plain_password, hashed_password) - def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: + def authenticate_user( + self, db: Session, username: str, password: str + ) -> Optional[User]: """Authenticate user credentials against the database. Supports authentication using either username or email address. @@ -201,7 +199,9 @@ class AuthManager: raise InvalidTokenException("Token missing expiration") # Check if token has expired (additional check beyond jwt.decode) - if datetime.now(timezone.utc) > datetime.fromtimestamp(exp, tz=timezone.utc): + if datetime.now(timezone.utc) > datetime.fromtimestamp( + exp, tz=timezone.utc + ): raise TokenExpiredException() # Validate user identifier claim exists @@ -214,7 +214,9 @@ class AuthManager: "user_id": int(user_id), "username": payload.get("username"), "email": payload.get("email"), - "role": payload.get("role", "user"), # Default to "user" role if not specified + "role": payload.get( + "role", "user" + ), # Default to "user" role if not specified } except jwt.ExpiredSignatureError: @@ -232,7 +234,9 @@ class AuthManager: logger.error(f"Token verification error: {e}") raise InvalidTokenException("Authentication failed") - def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User: + def get_current_user( + self, db: Session, credentials: HTTPAuthorizationCredentials + ) -> User: """Extract and validate the current authenticated user from request credentials. Verifies the JWT token from the Authorization header, looks up the user @@ -286,8 +290,10 @@ class AuthManager: # This will only execute if user has "admin" role pass """ + def decorator(func): """Decorator that wraps the function with role checking.""" + def wrapper(current_user: User, *args, **kwargs): # Check if current user has the required role if current_user.role != required_role: @@ -339,8 +345,7 @@ class AuthManager: # Check if user has vendor or admin role (admins have full access) if current_user.role not in ["vendor", "admin"]: raise InsufficientPermissionsException( - message="Vendor access required", - required_permission="vendor" + message="Vendor access required", required_permission="vendor" ) return current_user @@ -363,7 +368,7 @@ class AuthManager: if current_user.role not in ["customer", "admin"]: raise InsufficientPermissionsException( message="Customer account access required", - required_permission="customer" + required_permission="customer", ) return current_user diff --git a/middleware/context.py b/middleware/context.py index 53b6e075..00b88fbd 100644 --- a/middleware/context.py +++ b/middleware/context.py @@ -17,14 +17,16 @@ Class-based middleware provides: import logging from enum import Enum -from starlette.middleware.base import BaseHTTPMiddleware + from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware logger = logging.getLogger(__name__) class RequestContext(str, Enum): """Request context types for the application.""" + API = "api" ADMIN = "admin" VENDOR_DASHBOARD = "vendor" @@ -59,7 +61,7 @@ class ContextManager: # Use clean_path if available (extracted by vendor_context_middleware) # Falls back to original path if clean_path not set # This is critical for correct context detection with path-based routing - path = getattr(request.state, 'clean_path', request.url.path) + path = getattr(request.state, "clean_path", request.url.path) host = request.headers.get("host", "") @@ -71,10 +73,10 @@ class ContextManager: f"[CONTEXT] Detecting context", extra={ "original_path": request.url.path, - "clean_path": getattr(request.state, 'clean_path', 'NOT SET'), + "clean_path": getattr(request.state, "clean_path", "NOT SET"), "path_to_check": path, "host": host, - } + }, ) # 1. API context (highest priority) @@ -84,24 +86,30 @@ class ContextManager: # 2. Admin context if ContextManager._is_admin_context(request, host, path): - logger.debug("[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host}) + logger.debug( + "[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host} + ) return RequestContext.ADMIN # 3. Vendor Dashboard context (vendor management area) # Check both clean_path and original path for vendor dashboard original_path = request.url.path - if ContextManager._is_vendor_dashboard_context(path) or \ - ContextManager._is_vendor_dashboard_context(original_path): - logger.debug("[CONTEXT] Detected as VENDOR_DASHBOARD", extra={"path": path, "original_path": original_path}) + if ContextManager._is_vendor_dashboard_context( + path + ) or ContextManager._is_vendor_dashboard_context(original_path): + logger.debug( + "[CONTEXT] Detected as VENDOR_DASHBOARD", + extra={"path": path, "original_path": original_path}, + ) return RequestContext.VENDOR_DASHBOARD # 4. Shop context (vendor storefront) # Check if vendor context exists (set by vendor_context_middleware) - if hasattr(request.state, 'vendor') and request.state.vendor: + if hasattr(request.state, "vendor") and request.state.vendor: # If we have a vendor and it's not admin or vendor dashboard, it's shop logger.debug( "[CONTEXT] Detected as SHOP (has vendor context)", - extra={"vendor": request.state.vendor.name} + extra={"vendor": request.state.vendor.name}, ) return RequestContext.SHOP @@ -173,11 +181,12 @@ class ContextMiddleware(BaseHTTPMiddleware): f"[CONTEXT_MIDDLEWARE] Context detected: {context_type.value}", extra={ "path": request.url.path, - "clean_path": getattr(request.state, 'clean_path', 'NOT SET'), + "clean_path": getattr(request.state, "clean_path", "NOT SET"), "host": request.headers.get("host", ""), "context": context_type.value, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None, - } + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, + }, ) # Continue processing diff --git a/middleware/decorators.py b/middleware/decorators.py index f068f6fe..2766301d 100644 --- a/middleware/decorators.py +++ b/middleware/decorators.py @@ -9,12 +9,14 @@ This module provides classes and functions for: """ from functools import wraps + from app.exceptions.base import RateLimitException # Add this import from middleware.rate_limiter import RateLimiter # Initialize rate limiter instance rate_limiter = RateLimiter() + def rate_limit(max_requests: int = 100, window_seconds: int = 3600): """Rate limiting decorator for FastAPI endpoints.""" @@ -26,10 +28,11 @@ def rate_limit(max_requests: int = 100, window_seconds: int = 3600): if not rate_limiter.allow_request(client_id, max_requests, window_seconds): # Use custom exception instead of HTTPException raise RateLimitException( - message="Rate limit exceeded", - retry_after=window_seconds + message="Rate limit exceeded", retry_after=window_seconds ) return await func(*args, **kwargs) + return wrapper + return decorator diff --git a/middleware/theme_context.py b/middleware/theme_context.py index a7531eae..e89e5117 100644 --- a/middleware/theme_context.py +++ b/middleware/theme_context.py @@ -11,9 +11,10 @@ Class-based middleware provides: """ import logging -from starlette.middleware.base import BaseHTTPMiddleware + from fastapi import Request from sqlalchemy.orm import Session +from starlette.middleware.base import BaseHTTPMiddleware from app.core.database import get_db from models.database.vendor_theme import VendorTheme @@ -30,10 +31,11 @@ class ThemeContextManager: Get theme configuration for vendor. Returns default theme if no custom theme is configured. """ - theme = db.query(VendorTheme).filter( - VendorTheme.vendor_id == vendor_id, - VendorTheme.is_active == True - ).first() + theme = ( + db.query(VendorTheme) + .filter(VendorTheme.vendor_id == vendor_id, VendorTheme.is_active == True) + .first() + ) if theme: return theme.to_dict() @@ -52,23 +54,16 @@ class ThemeContextManager: "accent": "#ec4899", "background": "#ffffff", "text": "#1f2937", - "border": "#e5e7eb" - }, - "fonts": { - "heading": "Inter, sans-serif", - "body": "Inter, sans-serif" + "border": "#e5e7eb", }, + "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"}, "branding": { "logo": None, "logo_dark": None, "favicon": None, - "banner": None - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "modern" + "banner": None, }, + "layout": {"style": "grid", "header": "fixed", "product_card": "modern"}, "social_links": {}, "custom_css": None, "css_variables": { @@ -80,7 +75,7 @@ class ThemeContextManager: "--color-border": "#e5e7eb", "--font-heading": "Inter, sans-serif", "--font-body": "Inter, sans-serif", - } + }, } @@ -106,7 +101,7 @@ class ThemeContextMiddleware(BaseHTTPMiddleware): Load and inject theme context. """ # Only inject theme for shop pages (not admin or API) - if hasattr(request.state, 'vendor') and request.state.vendor: + if hasattr(request.state, "vendor") and request.state.vendor: vendor = request.state.vendor # Get database session @@ -123,13 +118,13 @@ class ThemeContextMiddleware(BaseHTTPMiddleware): extra={ "vendor_id": vendor.id, "vendor_name": vendor.name, - "theme_name": theme.get('theme_name', 'default'), - } + "theme_name": theme.get("theme_name", "default"), + }, ) except Exception as e: logger.error( f"[THEME] Failed to load theme for vendor {vendor.id}: {e}", - exc_info=True + exc_info=True, ) # Fallback to default theme request.state.theme = ThemeContextManager.get_default_theme() @@ -140,7 +135,7 @@ class ThemeContextMiddleware(BaseHTTPMiddleware): request.state.theme = ThemeContextManager.get_default_theme() logger.debug( "[THEME] No vendor context, using default theme", - extra={"has_vendor": False} + extra={"has_vendor": False}, ) # Continue processing diff --git a/middleware/vendor_context.py b/middleware/vendor_context.py index 33e3aa70..982afee6 100644 --- a/middleware/vendor_context.py +++ b/middleware/vendor_context.py @@ -13,10 +13,11 @@ Also extracts clean_path for nested routing patterns. import logging from typing import Optional -from sqlalchemy.orm import Session -from sqlalchemy import func -from starlette.middleware.base import BaseHTTPMiddleware + from fastapi import Request +from sqlalchemy import func +from sqlalchemy.orm import Session +from starlette.middleware.base import BaseHTTPMiddleware from app.core.config import settings from app.core.database import get_db @@ -50,14 +51,15 @@ class VendorContextManager: # Method 1: Custom domain detection (HIGHEST PRIORITY) # Check if this is a custom domain (not platform.com and not localhost) - platform_domain = getattr(settings, 'platform_domain', 'platform.com') + platform_domain = getattr(settings, "platform_domain", "platform.com") is_custom_domain = ( - host and - not host.endswith(f".{platform_domain}") and - host != platform_domain and - host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and - not host.startswith("admin.") + host + and not host.endswith(f".{platform_domain}") + and host != platform_domain + and host + not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] + and not host.startswith("admin.") ) if is_custom_domain: @@ -66,7 +68,7 @@ class VendorContextManager: "domain": normalized_domain, "detection_method": "custom_domain", "host": host, - "original_host": request.headers.get("host", "") + "original_host": request.headers.get("host", ""), } # Method 2: Subdomain detection (vendor1.platform.com) @@ -78,7 +80,7 @@ class VendorContextManager: return { "subdomain": subdomain, "detection_method": "subdomain", - "host": host + "host": host, } # Method 3: Path-based detection (/vendor/vendorname/ or /vendors/vendorname/) @@ -96,9 +98,9 @@ class VendorContextManager: return { "subdomain": vendor_code, "detection_method": "path", - "path_prefix": path[:prefix_len + len(vendor_code)], + "path_prefix": path[: prefix_len + len(vendor_code)], "full_prefix": path[:prefix_len], # /vendor/ or /vendors/ - "host": host + "host": host, } return None @@ -136,10 +138,14 @@ class VendorContextManager: logger.warning(f"Vendor for domain {domain} is not active") return None - logger.info(f"[OK] Vendor found via custom domain: {domain} → {vendor.name}") + logger.info( + f"[OK] Vendor found via custom domain: {domain} → {vendor.name}" + ) return vendor else: - logger.warning(f"No active vendor found for custom domain: {domain}") + logger.warning( + f"No active vendor found for custom domain: {domain}" + ) return None # Method 2 & 3: Subdomain or path-based lookup @@ -154,7 +160,9 @@ class VendorContextManager: if vendor: method = context.get("detection_method", "unknown") - logger.info(f"[OK] Vendor found via {method}: {subdomain} → {vendor.name}") + logger.info( + f"[OK] Vendor found via {method}: {subdomain} → {vendor.name}" + ) else: logger.warning(f"No active vendor found for subdomain: {subdomain}") @@ -176,7 +184,7 @@ class VendorContextManager: path_prefix = vendor_context.get("path_prefix", "") if path.startswith(path_prefix): - clean_path = path[len(path_prefix):] + clean_path = path[len(path_prefix) :] return clean_path if clean_path else "/" return request.url.path @@ -232,6 +240,7 @@ class VendorContextManager: try: from urllib.parse import urlparse + parsed = urlparse(referer) referer_host = parsed.hostname or "" referer_path = parsed.path or "" @@ -246,27 +255,33 @@ class VendorContextManager: "referer": referer, "referer_host": referer_host, "referer_path": referer_path, - } + }, ) # Method 1: Path-based detection from referer path # /vendors/wizamart/shop/products → wizamart - if referer_path.startswith("/vendors/") or referer_path.startswith("/vendor/"): - prefix = "/vendors/" if referer_path.startswith("/vendors/") else "/vendor/" - path_parts = referer_path[len(prefix):].split("/") + if referer_path.startswith("/vendors/") or referer_path.startswith( + "/vendor/" + ): + prefix = ( + "/vendors/" if referer_path.startswith("/vendors/") else "/vendor/" + ) + path_parts = referer_path[len(prefix) :].split("/") if len(path_parts) >= 1 and path_parts[0]: vendor_code = path_parts[0] prefix_len = len(prefix) logger.debug( f"[VENDOR] Extracted vendor from Referer path: {vendor_code}", - extra={"vendor_code": vendor_code, "method": "referer_path"} + extra={"vendor_code": vendor_code, "method": "referer_path"}, ) # Use "path" as detection_method to be consistent with direct path detection # This allows cookie path logic to work the same way return { "subdomain": vendor_code, "detection_method": "path", # Consistent with direct path detection - "path_prefix": referer_path[:prefix_len + len(vendor_code)], # /vendor/vendor1 + "path_prefix": referer_path[ + : prefix_len + len(vendor_code) + ], # /vendor/vendor1 "full_prefix": prefix, # /vendor/ or /vendors/ "host": referer_host, "referer": referer, @@ -274,7 +289,7 @@ class VendorContextManager: # Method 2: Subdomain detection from referer host # wizamart.platform.com → wizamart - platform_domain = getattr(settings, 'platform_domain', 'platform.com') + platform_domain = getattr(settings, "platform_domain", "platform.com") if "." in referer_host: parts = referer_host.split(".") if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]: @@ -283,7 +298,10 @@ class VendorContextManager: subdomain = parts[0] logger.debug( f"[VENDOR] Extracted vendor from Referer subdomain: {subdomain}", - extra={"subdomain": subdomain, "method": "referer_subdomain"} + extra={ + "subdomain": subdomain, + "method": "referer_subdomain", + }, ) return { "subdomain": subdomain, @@ -295,19 +313,23 @@ class VendorContextManager: # Method 3: Custom domain detection from referer host # custom-shop.com → custom-shop.com is_custom_domain = ( - referer_host and - not referer_host.endswith(f".{platform_domain}") and - referer_host != platform_domain and - referer_host not in ["localhost", "127.0.0.1"] and - not referer_host.startswith("admin.") + referer_host + and not referer_host.endswith(f".{platform_domain}") + and referer_host != platform_domain + and referer_host not in ["localhost", "127.0.0.1"] + and not referer_host.startswith("admin.") ) if is_custom_domain: from models.database.vendor_domain import VendorDomain + normalized_domain = VendorDomain.normalize_domain(referer_host) logger.debug( f"[VENDOR] Extracted vendor from Referer custom domain: {normalized_domain}", - extra={"domain": normalized_domain, "method": "referer_custom_domain"} + extra={ + "domain": normalized_domain, + "method": "referer_custom_domain", + }, ) return { "domain": normalized_domain, @@ -319,7 +341,7 @@ class VendorContextManager: except Exception as e: logger.warning( f"[VENDOR] Failed to extract vendor from Referer: {e}", - extra={"referer": referer, "error": str(e)} + extra={"referer": referer, "error": str(e)}, ) return None @@ -330,12 +352,28 @@ class VendorContextManager: path = request.url.path.lower() static_extensions = ( - '.ico', '.css', '.js', '.png', '.jpg', '.jpeg', '.gif', '.svg', - '.woff', '.woff2', '.ttf', '.eot', '.webp', '.map', '.json', - '.xml', '.txt', '.pdf', '.webmanifest' + ".ico", + ".css", + ".js", + ".png", + ".jpg", + ".jpeg", + ".gif", + ".svg", + ".woff", + ".woff2", + ".ttf", + ".eot", + ".webp", + ".map", + ".json", + ".xml", + ".txt", + ".pdf", + ".webmanifest", ) - static_paths = ('/static/', '/media/', '/assets/', '/.well-known/') + static_paths = ("/static/", "/media/", "/assets/", "/.well-known/") if path.endswith(static_extensions): return True @@ -343,7 +381,7 @@ class VendorContextManager: if any(path.startswith(static_path) for static_path in static_paths): return True - if 'favicon.ico' in path: + if "favicon.ico" in path: return True return False @@ -372,13 +410,13 @@ class VendorContextMiddleware(BaseHTTPMiddleware): """ # Skip vendor detection for admin, static files, and system requests if ( - VendorContextManager.is_admin_request(request) or - VendorContextManager.is_static_file_request(request) or - request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"] + VendorContextManager.is_admin_request(request) + or VendorContextManager.is_static_file_request(request) + or request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"] ): logger.debug( f"[VENDOR] Skipping vendor detection: {request.url.path}", - extra={"path": request.url.path, "reason": "admin/static/system"} + extra={"path": request.url.path, "reason": "admin/static/system"}, ) request.state.vendor = None request.state.vendor_context = None @@ -389,7 +427,10 @@ class VendorContextMiddleware(BaseHTTPMiddleware): if VendorContextManager.is_shop_api_request(request): logger.debug( f"[VENDOR] Shop API request detected: {request.url.path}", - extra={"path": request.url.path, "referer": request.headers.get("referer", "")} + extra={ + "path": request.url.path, + "referer": request.headers.get("referer", ""), + }, ) vendor_context = VendorContextManager.extract_vendor_from_referer(request) @@ -398,7 +439,9 @@ class VendorContextMiddleware(BaseHTTPMiddleware): db_gen = get_db() db = next(db_gen) try: - vendor = VendorContextManager.get_vendor_from_context(db, vendor_context) + vendor = VendorContextManager.get_vendor_from_context( + db, vendor_context + ) if vendor: request.state.vendor = vendor @@ -411,19 +454,23 @@ class VendorContextMiddleware(BaseHTTPMiddleware): "vendor_id": vendor.id, "vendor_name": vendor.name, "vendor_subdomain": vendor.subdomain, - "detection_method": vendor_context.get("detection_method"), + "detection_method": vendor_context.get( + "detection_method" + ), "api_path": request.url.path, "referer": vendor_context.get("referer", ""), - } + }, ) else: logger.warning( f"[WARNING] Vendor context from Referer but vendor not found", extra={ "context": vendor_context, - "detection_method": vendor_context.get("detection_method"), + "detection_method": vendor_context.get( + "detection_method" + ), "api_path": request.url.path, - } + }, ) request.state.vendor = None request.state.vendor_context = vendor_context @@ -433,7 +480,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware): else: logger.warning( f"[VENDOR] Shop API request without Referer header", - extra={"path": request.url.path} + extra={"path": request.url.path}, ) request.state.vendor = None request.state.vendor_context = None @@ -445,7 +492,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware): if VendorContextManager.is_api_request(request): logger.debug( f"[VENDOR] Skipping vendor detection for non-shop API: {request.url.path}", - extra={"path": request.url.path, "reason": "api"} + extra={"path": request.url.path, "reason": "api"}, ) request.state.vendor = None request.state.vendor_context = None @@ -459,7 +506,9 @@ class VendorContextMiddleware(BaseHTTPMiddleware): db_gen = get_db() db = next(db_gen) try: - vendor = VendorContextManager.get_vendor_from_context(db, vendor_context) + vendor = VendorContextManager.get_vendor_from_context( + db, vendor_context + ) if vendor: request.state.vendor = vendor @@ -477,7 +526,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware): "detection_method": vendor_context.get("detection_method"), "original_path": request.url.path, "clean_path": request.state.clean_path, - } + }, ) else: logger.warning( @@ -485,7 +534,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware): extra={ "context": vendor_context, "detection_method": vendor_context.get("detection_method"), - } + }, ) request.state.vendor = None request.state.vendor_context = vendor_context @@ -498,7 +547,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware): extra={ "path": request.url.path, "host": request.headers.get("host", ""), - } + }, ) request.state.vendor = None request.state.vendor_context = None @@ -520,9 +569,9 @@ def require_vendor_context(): vendor = get_current_vendor(request) if not vendor: from fastapi import HTTPException + raise HTTPException( - status_code=404, - detail="Vendor not found or not active" + status_code=404, detail="Vendor not found or not active" ) return vendor diff --git a/models/__init__.py b/models/__init__.py index 8d148728..ed80aa1c 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,17 +1,16 @@ # models/__init__.py """Models package - Database and API models.""" -# Database models (SQLAlchemy) -from .database.base import Base -from .database.user import User -from .database.marketplace_product import MarketplaceProduct -from .database.inventory import Inventory -from .database.vendor import Vendor -from .database.product import Product -from .database.marketplace_import_job import MarketplaceImportJob - # API models (Pydantic) - import the modules, not all classes from . import schema +# Database models (SQLAlchemy) +from .database.base import Base +from .database.inventory import Inventory +from .database.marketplace_import_job import MarketplaceImportJob +from .database.marketplace_product import MarketplaceProduct +from .database.product import Product +from .database.user import User +from .database.vendor import Vendor # Export database models for Alembic __all__ = [ diff --git a/models/database/__init__.py b/models/database/__init__.py index 24ed1b4e..374c8e70 100644 --- a/models/database/__init__.py +++ b/models/database/__init__.py @@ -1,17 +1,18 @@ # models/database/__init__.py """Database models package.""" -from .admin import AdminAuditLog, AdminNotification, AdminSetting, PlatformAlert, AdminSession +from .admin import (AdminAuditLog, AdminNotification, AdminSession, + AdminSetting, PlatformAlert) from .base import Base from .customer import Customer, CustomerAddress -from .order import Order, OrderItem -from .user import User -from .marketplace_product import MarketplaceProduct from .inventory import Inventory -from .vendor import Vendor, Role, VendorUser +from .marketplace_import_job import MarketplaceImportJob +from .marketplace_product import MarketplaceProduct +from .order import Order, OrderItem +from .product import Product +from .user import User +from .vendor import Role, Vendor, VendorUser from .vendor_domain import VendorDomain from .vendor_theme import VendorTheme -from .product import Product -from .marketplace_import_job import MarketplaceImportJob __all__ = [ # Admin-specific models @@ -34,5 +35,5 @@ __all__ = [ "MarketplaceImportJob", "MarketplaceProduct", "VendorDomain", - "VendorTheme" + "VendorTheme", ] diff --git a/models/database/admin.py b/models/database/admin.py index 32b10ae8..45c04cee 100644 --- a/models/database/admin.py +++ b/models/database/admin.py @@ -1,4 +1,4 @@ -# Admin-specific models +# Admin-specific models # models/database/admin.py """ Admin-specific database models. @@ -10,9 +10,12 @@ This module provides models for: - Platform alerts (system-wide issues) """ -from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey +from sqlalchemy import (JSON, Boolean, Column, DateTime, ForeignKey, Integer, + String, Text) from sqlalchemy.orm import relationship + from app.core.database import Base + from .base import TimestampMixin @@ -23,12 +26,17 @@ class AdminAuditLog(Base, TimestampMixin): Separate from regular audit logs - focuses on admin-specific operations like vendor creation, user management, and system configuration changes. """ + __tablename__ = "admin_audit_logs" id = Column(Integer, primary_key=True, index=True) admin_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) - action = Column(String(100), nullable=False, index=True) # create_vendor, delete_vendor, etc. - target_type = Column(String(50), nullable=False, index=True) # vendor, user, import_job, setting + action = Column( + String(100), nullable=False, index=True + ) # create_vendor, delete_vendor, etc. + target_type = Column( + String(50), nullable=False, index=True + ) # vendor, user, import_job, setting target_id = Column(String(100), nullable=False, index=True) details = Column(JSON) # Additional context about the action ip_address = Column(String(45)) # IPv4 or IPv6 @@ -49,11 +57,16 @@ class AdminNotification(Base, TimestampMixin): Different from vendor/customer notifications - these are for platform administrators to track system health and issues requiring attention. """ + __tablename__ = "admin_notifications" id = Column(Integer, primary_key=True, index=True) - type = Column(String(50), nullable=False, index=True) # system_alert, vendor_issue, import_failure - priority = Column(String(20), default="normal", index=True) # low, normal, high, critical + type = Column( + String(50), nullable=False, index=True + ) # system_alert, vendor_issue, import_failure + priority = Column( + String(20), default="normal", index=True + ) # low, normal, high, critical title = Column(String(200), nullable=False) message = Column(Text, nullable=False) is_read = Column(Boolean, default=False, index=True) @@ -84,13 +97,16 @@ class AdminSetting(Base, TimestampMixin): - smtp_settings - stripe_api_keys (encrypted) """ + __tablename__ = "admin_settings" id = Column(Integer, primary_key=True, index=True) key = Column(String(100), unique=True, nullable=False, index=True) value = Column(Text, nullable=False) value_type = Column(String(20), default="string") # string, integer, boolean, json - category = Column(String(50), index=True) # system, security, marketplace, notifications + category = Column( + String(50), index=True + ) # system, security, marketplace, notifications description = Column(Text) is_encrypted = Column(Boolean, default=False) is_public = Column(Boolean, default=False) # Can be exposed to frontend? @@ -110,11 +126,16 @@ class PlatformAlert(Base, TimestampMixin): Tracks platform issues, performance problems, security incidents, and other system-level concerns that require admin attention. """ + __tablename__ = "platform_alerts" id = Column(Integer, primary_key=True, index=True) - alert_type = Column(String(50), nullable=False, index=True) # security, performance, capacity, integration - severity = Column(String(20), nullable=False, index=True) # info, warning, error, critical + alert_type = Column( + String(50), nullable=False, index=True + ) # security, performance, capacity, integration + severity = Column( + String(20), nullable=False, index=True + ) # info, warning, error, critical title = Column(String(200), nullable=False) description = Column(Text) affected_vendors = Column(JSON) # List of affected vendor IDs @@ -142,6 +163,7 @@ class AdminSession(Base, TimestampMixin): Helps identify suspicious login patterns, track concurrent sessions, and enforce session policies for admin users. """ + __tablename__ = "admin_sessions" id = Column(Integer, primary_key=True, index=True) diff --git a/models/database/audit.py b/models/database/audit.py index a4093794..05461a88 100644 --- a/models/database/audit.py +++ b/models/database/audit.py @@ -1 +1 @@ -# AuditLog, DataExportLog models +# AuditLog, DataExportLog models diff --git a/models/database/backup.py b/models/database/backup.py index 7efe94f4..a9ac9d64 100644 --- a/models/database/backup.py +++ b/models/database/backup.py @@ -1 +1 @@ -# BackupLog, RestoreLog models +# BackupLog, RestoreLog models diff --git a/models/database/base.py b/models/database/base.py index 5c675ba2..052c3c31 100644 --- a/models/database/base.py +++ b/models/database/base.py @@ -10,5 +10,8 @@ class TimestampMixin: created_at = Column(DateTime, default=datetime.now(timezone.utc), nullable=False) updated_at = Column( - DateTime, default=datetime.now(timezone.utc), onupdate=datetime.now(timezone.utc), nullable=False + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + nullable=False, ) diff --git a/models/database/cart.py b/models/database/cart.py index 46198d44..c1763eb5 100644 --- a/models/database/cart.py +++ b/models/database/cart.py @@ -1,7 +1,9 @@ # models/database/cart.py """Cart item database model.""" from datetime import datetime -from sqlalchemy import Column, Float, ForeignKey, Index, Integer, String, UniqueConstraint + +from sqlalchemy import (Column, Float, ForeignKey, Index, Integer, String, + UniqueConstraint) from sqlalchemy.orm import relationship from app.core.database import Base @@ -15,6 +17,7 @@ class CartItem(Base, TimestampMixin): Stores cart items per session, vendor, and product. Sessions are identified by a session_id string (from browser cookies). """ + __tablename__ = "cart_items" id = Column(Integer, primary_key=True, index=True) diff --git a/models/database/configuration.py b/models/database/configuration.py index a6d610ee..ae38c260 100644 --- a/models/database/configuration.py +++ b/models/database/configuration.py @@ -1 +1 @@ -# PlatformConfig, VendorConfig, FeatureFlag models +# PlatformConfig, VendorConfig, FeatureFlag models diff --git a/models/database/content_page.py b/models/database/content_page.py index 687dfc9f..32223bf0 100644 --- a/models/database/content_page.py +++ b/models/database/content_page.py @@ -15,7 +15,9 @@ Features: """ from datetime import datetime, timezone -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, Index + +from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, + String, Text, UniqueConstraint) from sqlalchemy.orm import relationship from app.core.database import Base @@ -34,15 +36,20 @@ class ContentPage(Base): 2. If not found, use platform default (slug only) 3. If neither exists, show 404 or default template """ + __tablename__ = "content_pages" id = Column(Integer, primary_key=True, index=True) # Vendor association (NULL = platform default) - vendor_id = Column(Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=True, index=True) + vendor_id = Column( + Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=True, index=True + ) # Page identification - slug = Column(String(100), nullable=False, index=True) # about, faq, contact, shipping, returns, etc. + slug = Column( + String(100), nullable=False, index=True + ) # about, faq, contact, shipping, returns, etc. title = Column(String(200), nullable=False) # Content @@ -68,12 +75,25 @@ class ContentPage(Base): show_in_header = Column(Boolean, default=False) # Timestamps - created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False) - updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), nullable=False) + created_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + nullable=False, + ) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + nullable=False, + ) # Author tracking (admin or vendor user who created/updated) - created_by = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True) - updated_by = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + created_by = Column( + Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + updated_by = Column( + Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) # Relationships vendor = relationship("Vendor", back_populates="content_pages") @@ -84,11 +104,10 @@ class ContentPage(Base): __table_args__ = ( # Unique combination: vendor can only have one page per slug # Platform defaults (vendor_id=NULL) can only have one page per slug - UniqueConstraint('vendor_id', 'slug', name='uq_vendor_slug'), - + UniqueConstraint("vendor_id", "slug", name="uq_vendor_slug"), # Indexes for performance - Index('idx_vendor_published', 'vendor_id', 'is_published'), - Index('idx_slug_published', 'slug', 'is_published'), + Index("idx_vendor_published", "vendor_id", "is_published"), + Index("idx_slug_published", "slug", "is_published"), ) def __repr__(self): @@ -119,7 +138,9 @@ class ContentPage(Base): "meta_description": self.meta_description, "meta_keywords": self.meta_keywords, "is_published": self.is_published, - "published_at": self.published_at.isoformat() if self.published_at else None, + "published_at": ( + self.published_at.isoformat() if self.published_at else None + ), "display_order": self.display_order, "show_in_footer": self.show_in_footer, "show_in_header": self.show_in_header, diff --git a/models/database/customer.py b/models/database/customer.py index ec4b67e0..8972f806 100644 --- a/models/database/customer.py +++ b/models/database/customer.py @@ -1,8 +1,12 @@ from datetime import datetime from decimal import Decimal -from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, JSON, Numeric + +from sqlalchemy import (JSON, Boolean, Column, DateTime, ForeignKey, Integer, + Numeric, String, Text) from sqlalchemy.orm import relationship + from app.core.database import Base + from .base import TimestampMixin @@ -11,12 +15,16 @@ class Customer(Base, TimestampMixin): id = Column(Integer, primary_key=True, index=True) vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False) - email = Column(String(255), nullable=False, index=True) # Unique within vendor scope + email = Column( + String(255), nullable=False, index=True + ) # Unique within vendor scope hashed_password = Column(String(255), nullable=False) first_name = Column(String(100)) last_name = Column(String(100)) phone = Column(String(50)) - customer_number = Column(String(100), nullable=False, index=True) # Vendor-specific ID + customer_number = Column( + String(100), nullable=False, index=True + ) # Vendor-specific ID preferences = Column(JSON, default=dict) marketing_consent = Column(Boolean, default=False) last_order_date = Column(DateTime) diff --git a/models/database/inventory.py b/models/database/inventory.py index fa22744f..9318c0d8 100644 --- a/models/database/inventory.py +++ b/models/database/inventory.py @@ -1,6 +1,8 @@ # models/database/inventory.py from datetime import datetime -from sqlalchemy import Column, ForeignKey, Index, Integer, String, UniqueConstraint + +from sqlalchemy import (Column, ForeignKey, Index, Integer, String, + UniqueConstraint) from sqlalchemy.orm import relationship from app.core.database import Base @@ -27,7 +29,9 @@ class Inventory(Base, TimestampMixin): # Constraints __table_args__ = ( - UniqueConstraint("product_id", "location", name="uq_inventory_product_location"), + UniqueConstraint( + "product_id", "location", name="uq_inventory_product_location" + ), Index("idx_inventory_vendor_product", "vendor_id", "product_id"), Index("idx_inventory_product_location", "product_id", "location"), ) diff --git a/models/database/marketplace.py b/models/database/marketplace.py index 9e5c025c..5cae6517 100644 --- a/models/database/marketplace.py +++ b/models/database/marketplace.py @@ -1 +1 @@ -# MarketplaceImportJob model +# MarketplaceImportJob model diff --git a/models/database/marketplace_import_job.py b/models/database/marketplace_import_job.py index 27412cb4..c6587d21 100644 --- a/models/database/marketplace_import_job.py +++ b/models/database/marketplace_import_job.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone -from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text +from sqlalchemy import (Column, DateTime, ForeignKey, Index, Integer, String, + Text) from sqlalchemy.orm import relationship from app.core.database import Base diff --git a/models/database/marketplace_product.py b/models/database/marketplace_product.py index 389c0112..e248c81d 100644 --- a/models/database/marketplace_product.py +++ b/models/database/marketplace_product.py @@ -55,7 +55,9 @@ class MarketplaceProduct(Base, TimestampMixin): marketplace = Column( String, index=True, nullable=True, default="Letzshop" ) # Index for marketplace filtering - vendor_name = Column(String, index=True, nullable=True) # Index for vendor filtering + vendor_name = Column( + String, index=True, nullable=True + ) # Index for vendor filtering product = relationship("Product", back_populates="marketplace_product") diff --git a/models/database/media.py b/models/database/media.py index 9e51d54d..9246810c 100644 --- a/models/database/media.py +++ b/models/database/media.py @@ -1 +1 @@ -# MediaFile, ProductMedia models +# MediaFile, ProductMedia models diff --git a/models/database/monitoring.py b/models/database/monitoring.py index ed2edb20..22938930 100644 --- a/models/database/monitoring.py +++ b/models/database/monitoring.py @@ -1 +1 @@ -# PerformanceMetric, ErrorLog, SystemAlert models +# PerformanceMetric, ErrorLog, SystemAlert models diff --git a/models/database/notification.py b/models/database/notification.py index 5224c7a9..3eb5495b 100644 --- a/models/database/notification.py +++ b/models/database/notification.py @@ -1 +1 @@ -# NotificationTemplate, NotificationQueue, NotificationLog models +# NotificationTemplate, NotificationQueue, NotificationLog models diff --git a/models/database/order.py b/models/database/order.py index 562450a6..a62febdb 100644 --- a/models/database/order.py +++ b/models/database/order.py @@ -1,6 +1,8 @@ # models/database/order.py from datetime import datetime -from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, String, Text, Boolean + +from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Integer, + String, Text) from sqlalchemy.orm import relationship from app.core.database import Base @@ -9,11 +11,14 @@ from models.database.base import TimestampMixin class Order(Base, TimestampMixin): """Customer orders.""" + __tablename__ = "orders" id = Column(Integer, primary_key=True, index=True) vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) - customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False, index=True) + customer_id = Column( + Integer, ForeignKey("customers.id"), nullable=False, index=True + ) order_number = Column(String, nullable=False, unique=True, index=True) @@ -30,8 +35,12 @@ class Order(Base, TimestampMixin): currency = Column(String, default="EUR") # Addresses (stored as IDs) - shipping_address_id = Column(Integer, ForeignKey("customer_addresses.id"), nullable=False) - billing_address_id = Column(Integer, ForeignKey("customer_addresses.id"), nullable=False) + shipping_address_id = Column( + Integer, ForeignKey("customer_addresses.id"), nullable=False + ) + billing_address_id = Column( + Integer, ForeignKey("customer_addresses.id"), nullable=False + ) # Shipping shipping_method = Column(String, nullable=True) @@ -50,8 +59,12 @@ class Order(Base, TimestampMixin): # Relationships vendor = relationship("Vendor") customer = relationship("Customer", back_populates="orders") - items = relationship("OrderItem", back_populates="order", cascade="all, delete-orphan") - shipping_address = relationship("CustomerAddress", foreign_keys=[shipping_address_id]) + items = relationship( + "OrderItem", back_populates="order", cascade="all, delete-orphan" + ) + shipping_address = relationship( + "CustomerAddress", foreign_keys=[shipping_address_id] + ) billing_address = relationship("CustomerAddress", foreign_keys=[billing_address_id]) def __repr__(self): @@ -60,6 +73,7 @@ class Order(Base, TimestampMixin): class OrderItem(Base, TimestampMixin): """Individual items in an order.""" + __tablename__ = "order_items" id = Column(Integer, primary_key=True, index=True) diff --git a/models/database/payment.py b/models/database/payment.py index 4ebd4ebd..cf4bb549 100644 --- a/models/database/payment.py +++ b/models/database/payment.py @@ -1 +1 @@ -# Payment, PaymentMethod, VendorPaymentConfig models +# Payment, PaymentMethod, VendorPaymentConfig models diff --git a/models/database/product.py b/models/database/product.py index 8e8262d9..269a6885 100644 --- a/models/database/product.py +++ b/models/database/product.py @@ -1,6 +1,8 @@ # models/database/product.py from datetime import datetime -from sqlalchemy import Boolean, Column, Float, ForeignKey, Index, Integer, String, UniqueConstraint + +from sqlalchemy import (Boolean, Column, Float, ForeignKey, Index, Integer, + String, UniqueConstraint) from sqlalchemy.orm import relationship from app.core.database import Base @@ -12,7 +14,9 @@ class Product(Base, TimestampMixin): id = Column(Integer, primary_key=True, index=True) vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False) - marketplace_product_id = Column(Integer, ForeignKey("marketplace_products.id"), nullable=False) + marketplace_product_id = Column( + Integer, ForeignKey("marketplace_products.id"), nullable=False + ) # Vendor-specific overrides product_id = Column(String) # Vendor's internal SKU @@ -34,7 +38,9 @@ class Product(Base, TimestampMixin): # Relationships vendor = relationship("Vendor", back_populates="products") marketplace_product = relationship("MarketplaceProduct", back_populates="product") - inventory_entries = relationship("Inventory", back_populates="product", cascade="all, delete-orphan") + inventory_entries = relationship( + "Inventory", back_populates="product", cascade="all, delete-orphan" + ) # Constraints __table_args__ = ( diff --git a/models/database/search.py b/models/database/search.py index e13e7439..1d8f6e66 100644 --- a/models/database/search.py +++ b/models/database/search.py @@ -1 +1 @@ -# SearchIndex, SearchQuery models +# SearchIndex, SearchQuery models diff --git a/models/database/task.py b/models/database/task.py index 34945e62..ef22426a 100644 --- a/models/database/task.py +++ b/models/database/task.py @@ -1 +1 @@ -# TaskLog model +# TaskLog model diff --git a/models/database/user.py b/models/database/user.py index 48650405..dcce9329 100644 --- a/models/database/user.py +++ b/models/database/user.py @@ -10,16 +10,18 @@ ROLE CLARIFICATION: - Vendor-specific roles (manager, staff, etc.) are stored in VendorUser.role - Customers are NOT in the User table - they use the Customer model """ -from sqlalchemy import Boolean, Column, DateTime, Integer, String, Enum -from sqlalchemy.orm import relationship import enum +from sqlalchemy import Boolean, Column, DateTime, Enum, Integer, String +from sqlalchemy.orm import relationship + from app.core.database import Base from models.database.base import TimestampMixin class UserRole(str, enum.Enum): """Platform-level user roles.""" + ADMIN = "admin" # Platform administrator VENDOR = "vendor" # Vendor owner or team member @@ -44,12 +46,12 @@ class User(Base, TimestampMixin): last_login = Column(DateTime, nullable=True) # Relationships - marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="user") + marketplace_import_jobs = relationship( + "MarketplaceImportJob", back_populates="user" + ) owned_vendors = relationship("Vendor", back_populates="owner") vendor_memberships = relationship( - "VendorUser", - foreign_keys="[VendorUser.user_id]", - back_populates="user" + "VendorUser", foreign_keys="[VendorUser.user_id]", back_populates="user" ) def __repr__(self): @@ -84,8 +86,7 @@ class User(Base, TimestampMixin): return True # Check if team member return any( - vm.vendor_id == vendor_id and vm.is_active - for vm in self.vendor_memberships + vm.vendor_id == vendor_id and vm.is_active for vm in self.vendor_memberships ) def get_vendor_role(self, vendor_id: int) -> str: diff --git a/models/database/vendor.py b/models/database/vendor.py index 813db3ce..c962e852 100644 --- a/models/database/vendor.py +++ b/models/database/vendor.py @@ -5,29 +5,37 @@ Vendor model representing entities that sell products or services. This module defines the Vendor model along with its relationships to other models such as User (owner), Product, Customer, Order, and MarketplaceImportJob. """ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, JSON, DateTime +import enum + +from sqlalchemy import (JSON, Boolean, Column, DateTime, ForeignKey, Integer, + String, Text) from sqlalchemy.orm import relationship +from app.core.config import settings # Import Base from the central database module instead of creating a new one from app.core.database import Base from models.database.base import TimestampMixin -from app.core.config import settings -import enum + class Vendor(Base, TimestampMixin): """Represents a vendor in the system.""" __tablename__ = "vendors" # Name of the table in the database - id = Column(Integer, primary_key=True, index=True) # Primary key and indexed column for vendor ID - vendor_code = Column(String, unique=True, index=True, - nullable=False) # Unique, indexed, non-nullable vendor code column - subdomain = Column(String(100), unique=True, nullable=False, - index=True) # Unique, non-nullable subdomain column with indexing + id = Column( + Integer, primary_key=True, index=True + ) # Primary key and indexed column for vendor ID + vendor_code = Column( + String, unique=True, index=True, nullable=False + ) # Unique, indexed, non-nullable vendor code column + subdomain = Column( + String(100), unique=True, nullable=False, index=True + ) # Unique, non-nullable subdomain column with indexing name = Column(String, nullable=False) # Non-nullable name column for the vendor description = Column(Text) # Optional text description column for the vendor - owner_user_id = Column(Integer, ForeignKey("users.id"), - nullable=False) # Foreign key to user ID of the vendor's owner + owner_user_id = Column( + Integer, ForeignKey("users.id"), nullable=False + ) # Foreign key to user ID of the vendor's owner # Contact information contact_email = Column(String) # Optional email column for contact information @@ -40,34 +48,46 @@ class Vendor(Base, TimestampMixin): letzshop_csv_url_de = Column(String) # URL for German CSV in Letzshop # Business information - business_address = Column(Text) # Optional text address column for business information + business_address = Column( + Text + ) # Optional text address column for business information tax_number = Column(String) # Optional tax number column for business information # Status - is_active = Column(Boolean, default=True) # Boolean to indicate if the vendor is active - is_verified = Column(Boolean, default=False) # Boolean to indicate if the vendor is verified - + is_active = Column( + Boolean, default=True + ) # Boolean to indicate if the vendor is active + is_verified = Column( + Boolean, default=False + ) # Boolean to indicate if the vendor is verified # ======================================================================== # Relationships # ======================================================================== - owner = relationship("User", back_populates="owned_vendors") # Relationship with User model for the vendor's owner - vendor_users = relationship("VendorUser", - back_populates="vendor") # Relationship with VendorUser model for users in this vendor - products = relationship("Product", - back_populates="vendor") # Relationship with Product model for products of this vendor - customers = relationship("Customer", - back_populates="vendor") # Relationship with Customer model for customers of this vendor - orders = relationship("Order", - back_populates="vendor") # Relationship with Order model for orders placed by this vendor - marketplace_import_jobs = relationship("MarketplaceImportJob", - back_populates="vendor") # Relationship with MarketplaceImportJob model for import jobs related to this vendor + owner = relationship( + "User", back_populates="owned_vendors" + ) # Relationship with User model for the vendor's owner + vendor_users = relationship( + "VendorUser", back_populates="vendor" + ) # Relationship with VendorUser model for users in this vendor + products = relationship( + "Product", back_populates="vendor" + ) # Relationship with Product model for products of this vendor + customers = relationship( + "Customer", back_populates="vendor" + ) # Relationship with Customer model for customers of this vendor + orders = relationship( + "Order", back_populates="vendor" + ) # Relationship with Order model for orders placed by this vendor + marketplace_import_jobs = relationship( + "MarketplaceImportJob", back_populates="vendor" + ) # Relationship with MarketplaceImportJob model for import jobs related to this vendor domains = relationship( "VendorDomain", back_populates="vendor", cascade="all, delete-orphan", - order_by="VendorDomain.is_primary.desc()" + order_by="VendorDomain.is_primary.desc()", ) # Relationship with VendorDomain model for custom domains of the vendor # Single theme relationship (ONE vendor = ONE theme) @@ -77,14 +97,12 @@ class Vendor(Base, TimestampMixin): "VendorTheme", back_populates="vendor", uselist=False, - cascade="all, delete-orphan" + cascade="all, delete-orphan", ) # Relationship with VendorTheme model for the active theme of the vendor # Content pages relationship (vendor can override platform default pages) content_pages = relationship( - "ContentPage", - back_populates="vendor", - cascade="all, delete-orphan" + "ContentPage", back_populates="vendor", cascade="all, delete-orphan" ) # Relationship with ContentPage model for vendor-specific content pages def __repr__(self): @@ -121,23 +139,16 @@ class Vendor(Base, TimestampMixin): "accent": "#ec4899", "background": "#ffffff", "text": "#1f2937", - "border": "#e5e7eb" - }, - "fonts": { - "heading": "Inter, sans-serif", - "body": "Inter, sans-serif" + "border": "#e5e7eb", }, + "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"}, "branding": { "logo": None, "logo_dark": None, "favicon": None, - "banner": None - }, - "layout": { - "style": "grid", - "header": "fixed", - "product_card": "modern" + "banner": None, }, + "layout": {"style": "grid", "header": "fixed", "product_card": "modern"}, "social_links": {}, "custom_css": None, "css_variables": { @@ -149,18 +160,22 @@ class Vendor(Base, TimestampMixin): "--color-border": "#e5e7eb", "--font-heading": "Inter, sans-serif", "--font-body": "Inter, sans-serif", - } + }, } def get_primary_color(self) -> str: """Get primary color from active theme.""" theme = self.get_effective_theme() - return theme.get("colors", {}).get("primary", "#6366f1") # Default to default theme if not found + return theme.get("colors", {}).get( + "primary", "#6366f1" + ) # Default to default theme if not found def get_logo_url(self) -> str: """Get logo URL from active theme.""" theme = self.get_effective_theme() - return theme.get("branding", {}).get("logo") # Return None or the logo URL if found + return theme.get("branding", {}).get( + "logo" + ) # Return None or the logo URL if found # ======================================================================== # Domain Helper Methods @@ -177,7 +192,9 @@ class Vendor(Base, TimestampMixin): @property def all_domains(self): """Get all active domains (subdomain + custom domains).""" - domains = [f"{self.subdomain}.{settings.platform_domain}"] # Start with the main subdomain + domains = [ + f"{self.subdomain}.{settings.platform_domain}" + ] # Start with the main subdomain for domain in self.domains: if domain.is_active: domains.append(domain.domain) # Add other active custom domains @@ -186,6 +203,7 @@ class Vendor(Base, TimestampMixin): class VendorUserType(str, enum.Enum): """Types of vendor users.""" + OWNER = "owner" # Vendor owner (full access to vendor area) TEAM_MEMBER = "member" # Team member (role-based access to vendor area) @@ -222,14 +240,18 @@ class VendorUser(Base, TimestampMixin): invitation_sent_at = Column(DateTime, nullable=True) invitation_accepted_at = Column(DateTime, nullable=True) - is_active = Column(Boolean, default=False, nullable=False) # False until invitation accepted + is_active = Column( + Boolean, default=False, nullable=False + ) # False until invitation accepted """Indicates whether the VendorUser role is active.""" # Relationships vendor = relationship("Vendor", back_populates="vendor_users") """Relationship to the Vendor model, representing the associated vendor.""" - user = relationship("User", foreign_keys=[user_id], back_populates="vendor_memberships") + user = relationship( + "User", foreign_keys=[user_id], back_populates="vendor_memberships" + ) """Relationship to the User model, representing the user who holds this role within the vendor.""" inviter = relationship("User", foreign_keys=[invited_by]) @@ -287,6 +309,7 @@ class VendorUser(Base, TimestampMixin): if self.is_owner: # Return all possible permissions from app.core.permissions import VendorPermissions + return list(VendorPermissions.__members__.values()) if self.role and self.role.permissions: @@ -294,6 +317,7 @@ class VendorUser(Base, TimestampMixin): return [] + class Role(Base, TimestampMixin): """Represents a role within a vendor's system.""" diff --git a/models/database/vendor_domain.py b/models/database/vendor_domain.py index 5bafea68..e22512e1 100644 --- a/models/database/vendor_domain.py +++ b/models/database/vendor_domain.py @@ -3,11 +3,11 @@ Vendor Domain Model - Maps custom domains to vendors """ from datetime import datetime, timezone -from sqlalchemy import ( - Column, Integer, String, Boolean, DateTime, - ForeignKey, UniqueConstraint, Index -) + +from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, + String, UniqueConstraint) from sqlalchemy.orm import relationship + from app.core.database import Base from models.database.base import TimestampMixin @@ -21,10 +21,13 @@ class VendorDomain(Base, TimestampMixin): - shop.mybusiness.com → Vendor 2 - www.customdomain1.com → Vendor 1 (www is stripped) """ + __tablename__ = "vendor_domains" id = Column(Integer, primary_key=True, index=True) - vendor_id = Column(Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=False) + vendor_id = Column( + Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=False + ) # Domain configuration domain = Column(String(255), nullable=False, unique=True, index=True) @@ -32,7 +35,9 @@ class VendorDomain(Base, TimestampMixin): is_active = Column(Boolean, default=True, nullable=False) # SSL/TLS status (for monitoring) - ssl_status = Column(String(50), default="pending") # pending, active, expired, error + ssl_status = Column( + String(50), default="pending" + ) # pending, active, expired, error ssl_verified_at = Column(DateTime(timezone=True), nullable=True) # DNS verification (to confirm domain ownership) @@ -45,9 +50,9 @@ class VendorDomain(Base, TimestampMixin): # Constraints __table_args__ = ( - UniqueConstraint('vendor_id', 'domain', name='uq_vendor_domain'), - Index('idx_domain_active', 'domain', 'is_active'), - Index('idx_vendor_primary', 'vendor_id', 'is_primary'), + UniqueConstraint("vendor_id", "domain", name="uq_vendor_domain"), + Index("idx_domain_active", "domain", "is_active"), + Index("idx_vendor_primary", "vendor_id", "is_primary"), ) def __repr__(self): diff --git a/models/database/vendor_theme.py b/models/database/vendor_theme.py index 60db45e5..fcb83cb4 100644 --- a/models/database/vendor_theme.py +++ b/models/database/vendor_theme.py @@ -3,8 +3,9 @@ Vendor Theme Configuration Model Allows each vendor to customize their shop's appearance """ -from sqlalchemy import Column, Integer, String, Boolean, Text, JSON, ForeignKey +from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String, Text from sqlalchemy.orm import relationship + from app.core.database import Base from models.database.base import TimestampMixin @@ -22,6 +23,7 @@ class VendorTheme(Base, TimestampMixin): Theme presets available: default, modern, classic, minimal, vibrant """ + __tablename__ = "vendor_themes" id = Column(Integer, primary_key=True, index=True) @@ -29,22 +31,27 @@ class VendorTheme(Base, TimestampMixin): Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=False, - unique=True # ONE vendor = ONE theme + unique=True, # ONE vendor = ONE theme ) # Basic Theme Settings - theme_name = Column(String(100), default="default") # default, modern, classic, minimal, vibrant + theme_name = Column( + String(100), default="default" + ) # default, modern, classic, minimal, vibrant is_active = Column(Boolean, default=True) # Color Scheme (JSON for flexibility) - colors = Column(JSON, default={ - "primary": "#6366f1", # Indigo - "secondary": "#8b5cf6", # Purple - "accent": "#ec4899", # Pink - "background": "#ffffff", # White - "text": "#1f2937", # Gray-800 - "border": "#e5e7eb" # Gray-200 - }) + colors = Column( + JSON, + default={ + "primary": "#6366f1", # Indigo + "secondary": "#8b5cf6", # Purple + "accent": "#ec4899", # Pink + "background": "#ffffff", # White + "text": "#1f2937", # Gray-800 + "border": "#e5e7eb", # Gray-200 + }, + ) # Typography font_family_heading = Column(String(100), default="Inter, sans-serif") @@ -59,7 +66,9 @@ class VendorTheme(Base, TimestampMixin): # Layout Preferences layout_style = Column(String(50), default="grid") # grid, list, masonry header_style = Column(String(50), default="fixed") # fixed, static, transparent - product_card_style = Column(String(50), default="modern") # modern, classic, minimal + product_card_style = Column( + String(50), default="modern" + ) # modern, classic, minimal # Custom CSS (for advanced customization) custom_css = Column(Text, nullable=True) @@ -68,14 +77,18 @@ class VendorTheme(Base, TimestampMixin): social_links = Column(JSON, default={}) # {facebook: "url", instagram: "url", etc.} # SEO & Meta - meta_title_template = Column(String(200), nullable=True) # e.g., "{product_name} - {shop_name}" + meta_title_template = Column( + String(200), nullable=True + ) # e.g., "{product_name} - {shop_name}" meta_description = Column(Text, nullable=True) # Relationships - FIXED: back_populates must match the relationship name in Vendor model vendor = relationship("Vendor", back_populates="vendor_theme") def __repr__(self): - return f"" + return ( + f"" + ) @property def primary_color(self): diff --git a/models/schema/__init__.py b/models/schema/__init__.py index aff6c19d..9358039d 100644 --- a/models/schema/__init__.py +++ b/models/schema/__init__.py @@ -1,14 +1,9 @@ # models/schema/__init__.py """API models package - Pydantic models for request/response validation.""" -from . import auth # Import API model modules -from . import base -from . import marketplace_import_job -from . import marketplace_product -from . import stats -from . import inventory -from . import vendor +from . import (auth, base, inventory, marketplace_import_job, + marketplace_product, stats, vendor) # Common imports for convenience from .base import * # Base Pydantic models diff --git a/models/schema/admin.py b/models/schema/admin.py index 22318417..03a533b2 100644 --- a/models/schema/admin.py +++ b/models/schema/admin.py @@ -12,16 +12,18 @@ This module provides schemas for: """ from datetime import datetime -from typing import Optional, List, Dict, Any -from pydantic import BaseModel, Field, field_validator +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field, field_validator # ============================================================================ # ADMIN AUDIT LOG SCHEMAS # ============================================================================ + class AdminAuditLogResponse(BaseModel): """Response model for admin audit logs.""" + id: int admin_user_id: int admin_username: Optional[str] = None @@ -39,6 +41,7 @@ class AdminAuditLogResponse(BaseModel): class AdminAuditLogFilters(BaseModel): """Filters for querying audit logs.""" + admin_user_id: Optional[int] = None action: Optional[str] = None target_type: Optional[str] = None @@ -50,6 +53,7 @@ class AdminAuditLogFilters(BaseModel): class AdminAuditLogListResponse(BaseModel): """Paginated list of audit logs.""" + logs: List[AdminAuditLogResponse] total: int skip: int @@ -60,8 +64,10 @@ class AdminAuditLogListResponse(BaseModel): # ADMIN NOTIFICATION SCHEMAS # ============================================================================ + class AdminNotificationCreate(BaseModel): """Create admin notification.""" + type: str = Field(..., max_length=50, description="Notification type") priority: str = Field(default="normal", description="Priority level") title: str = Field(..., max_length=200) @@ -70,10 +76,10 @@ class AdminNotificationCreate(BaseModel): action_url: Optional[str] = Field(None, max_length=500) metadata: Optional[Dict[str, Any]] = None - @field_validator('priority') + @field_validator("priority") @classmethod def validate_priority(cls, v): - allowed = ['low', 'normal', 'high', 'critical'] + allowed = ["low", "normal", "high", "critical"] if v not in allowed: raise ValueError(f"Priority must be one of: {', '.join(allowed)}") return v @@ -81,6 +87,7 @@ class AdminNotificationCreate(BaseModel): class AdminNotificationResponse(BaseModel): """Admin notification response.""" + id: int type: str priority: str @@ -99,11 +106,13 @@ class AdminNotificationResponse(BaseModel): class AdminNotificationUpdate(BaseModel): """Mark notification as read.""" + is_read: bool = True class AdminNotificationListResponse(BaseModel): """Paginated list of notifications.""" + notifications: List[AdminNotificationResponse] total: int unread_count: int @@ -115,8 +124,10 @@ class AdminNotificationListResponse(BaseModel): # ADMIN SETTINGS SCHEMAS # ============================================================================ + class AdminSettingCreate(BaseModel): """Create or update admin setting.""" + key: str = Field(..., max_length=100, description="Unique setting key") value: str = Field(..., description="Setting value") value_type: str = Field(default="string", description="Data type") @@ -125,25 +136,28 @@ class AdminSettingCreate(BaseModel): is_encrypted: bool = Field(default=False) is_public: bool = Field(default=False, description="Can be exposed to frontend") - @field_validator('value_type') + @field_validator("value_type") @classmethod def validate_value_type(cls, v): - allowed = ['string', 'integer', 'boolean', 'json', 'float'] + allowed = ["string", "integer", "boolean", "json", "float"] if v not in allowed: raise ValueError(f"Value type must be one of: {', '.join(allowed)}") return v - @field_validator('key') + @field_validator("key") @classmethod def validate_key_format(cls, v): # Setting keys should be lowercase with underscores - if not v.replace('_', '').isalnum(): - raise ValueError("Setting key must contain only letters, numbers, and underscores") + if not v.replace("_", "").isalnum(): + raise ValueError( + "Setting key must contain only letters, numbers, and underscores" + ) return v.lower() class AdminSettingResponse(BaseModel): """Admin setting response.""" + id: int key: str value: str @@ -160,12 +174,14 @@ class AdminSettingResponse(BaseModel): class AdminSettingUpdate(BaseModel): """Update admin setting value.""" + value: str description: Optional[str] = None class AdminSettingListResponse(BaseModel): """List of settings by category.""" + settings: List[AdminSettingResponse] total: int category: Optional[str] = None @@ -175,8 +191,10 @@ class AdminSettingListResponse(BaseModel): # PLATFORM ALERT SCHEMAS # ============================================================================ + class PlatformAlertCreate(BaseModel): """Create platform alert.""" + alert_type: str = Field(..., max_length=50) severity: str = Field(..., description="Alert severity") title: str = Field(..., max_length=200) @@ -185,18 +203,25 @@ class PlatformAlertCreate(BaseModel): affected_systems: Optional[List[str]] = None auto_generated: bool = Field(default=True) - @field_validator('severity') + @field_validator("severity") @classmethod def validate_severity(cls, v): - allowed = ['info', 'warning', 'error', 'critical'] + allowed = ["info", "warning", "error", "critical"] if v not in allowed: raise ValueError(f"Severity must be one of: {', '.join(allowed)}") return v - @field_validator('alert_type') + @field_validator("alert_type") @classmethod def validate_alert_type(cls, v): - allowed = ['security', 'performance', 'capacity', 'integration', 'database', 'system'] + allowed = [ + "security", + "performance", + "capacity", + "integration", + "database", + "system", + ] if v not in allowed: raise ValueError(f"Alert type must be one of: {', '.join(allowed)}") return v @@ -204,6 +229,7 @@ class PlatformAlertCreate(BaseModel): class PlatformAlertResponse(BaseModel): """Platform alert response.""" + id: int alert_type: str severity: str @@ -226,12 +252,14 @@ class PlatformAlertResponse(BaseModel): class PlatformAlertResolve(BaseModel): """Resolve platform alert.""" + is_resolved: bool = True resolution_notes: Optional[str] = None class PlatformAlertListResponse(BaseModel): """Paginated list of platform alerts.""" + alerts: List[PlatformAlertResponse] total: int active_count: int @@ -244,17 +272,19 @@ class PlatformAlertListResponse(BaseModel): # BULK OPERATION SCHEMAS # ============================================================================ + class BulkVendorAction(BaseModel): """Bulk actions on vendors.""" + vendor_ids: List[int] = Field(..., min_length=1, max_length=100) action: str = Field(..., description="Action to perform") confirm: bool = Field(default=False, description="Required for destructive actions") reason: Optional[str] = Field(None, description="Reason for bulk action") - @field_validator('action') + @field_validator("action") @classmethod def validate_action(cls, v): - allowed = ['activate', 'deactivate', 'verify', 'unverify', 'delete'] + allowed = ["activate", "deactivate", "verify", "unverify", "delete"] if v not in allowed: raise ValueError(f"Action must be one of: {', '.join(allowed)}") return v @@ -262,6 +292,7 @@ class BulkVendorAction(BaseModel): class BulkVendorActionResponse(BaseModel): """Response for bulk vendor actions.""" + successful: List[int] failed: Dict[int, str] # vendor_id -> error_message total_processed: int @@ -271,15 +302,16 @@ class BulkVendorActionResponse(BaseModel): class BulkUserAction(BaseModel): """Bulk actions on users.""" + user_ids: List[int] = Field(..., min_length=1, max_length=100) action: str = Field(..., description="Action to perform") confirm: bool = Field(default=False) reason: Optional[str] = None - @field_validator('action') + @field_validator("action") @classmethod def validate_action(cls, v): - allowed = ['activate', 'deactivate', 'delete'] + allowed = ["activate", "deactivate", "delete"] if v not in allowed: raise ValueError(f"Action must be one of: {', '.join(allowed)}") return v @@ -287,6 +319,7 @@ class BulkUserAction(BaseModel): class BulkUserActionResponse(BaseModel): """Response for bulk user actions.""" + successful: List[int] failed: Dict[int, str] total_processed: int @@ -298,8 +331,10 @@ class BulkUserActionResponse(BaseModel): # ADMIN DASHBOARD SCHEMAS # ============================================================================ + class AdminDashboardStats(BaseModel): """Comprehensive admin dashboard statistics.""" + platform: Dict[str, Any] users: Dict[str, Any] vendors: Dict[str, Any] @@ -317,8 +352,10 @@ class AdminDashboardStats(BaseModel): # SYSTEM HEALTH SCHEMAS # ============================================================================ + class ComponentHealthStatus(BaseModel): """Health status for a system component.""" + status: str # healthy, degraded, unhealthy response_time_ms: Optional[float] = None error_message: Optional[str] = None @@ -328,6 +365,7 @@ class ComponentHealthStatus(BaseModel): class SystemHealthResponse(BaseModel): """System health check response.""" + overall_status: str # healthy, degraded, critical database: ComponentHealthStatus redis: ComponentHealthStatus @@ -342,8 +380,10 @@ class SystemHealthResponse(BaseModel): # ADMIN SESSION SCHEMAS # ============================================================================ + class AdminSessionResponse(BaseModel): """Admin session information.""" + id: int admin_user_id: int admin_username: Optional[str] = None @@ -360,6 +400,7 @@ class AdminSessionResponse(BaseModel): class AdminSessionListResponse(BaseModel): """List of admin sessions.""" + sessions: List[AdminSessionResponse] total: int active_count: int diff --git a/models/schema/auth.py b/models/schema/auth.py index 95ced43c..5d57aedc 100644 --- a/models/schema/auth.py +++ b/models/schema/auth.py @@ -17,7 +17,9 @@ class UserRegister(BaseModel): @classmethod def validate_username(cls, v): if not re.match(r"^[a-zA-Z0-9_]+$", v): - raise ValueError("Username must contain only letters, numbers, or underscores") + raise ValueError( + "Username must contain only letters, numbers, or underscores" + ) return v.lower().strip() @field_validator("password") @@ -31,7 +33,9 @@ class UserRegister(BaseModel): class UserLogin(BaseModel): email_or_username: str = Field(..., description="Username or email address") password: str = Field(..., description="Password") - vendor_code: Optional[str] = Field(None, description="Optional vendor code for context") + vendor_code: Optional[str] = Field( + None, description="Optional vendor code for context" + ) @field_validator("email_or_username") @classmethod diff --git a/models/schema/cart.py b/models/schema/cart.py index be4511c0..a5c705b7 100644 --- a/models/schema/cart.py +++ b/models/schema/cart.py @@ -5,21 +5,24 @@ Pydantic schemas for shopping cart operations. from datetime import datetime from typing import List, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field # ============================================================================ # Request Schemas # ============================================================================ + class AddToCartRequest(BaseModel): """Request model for adding items to cart.""" + product_id: int = Field(..., description="Product ID to add", gt=0) quantity: int = Field(1, ge=1, description="Quantity to add") class UpdateCartItemRequest(BaseModel): """Request model for updating cart item quantity.""" + quantity: int = Field(..., ge=1, description="New quantity (must be >= 1)") @@ -27,23 +30,30 @@ class UpdateCartItemRequest(BaseModel): # Response Schemas # ============================================================================ + class CartItemResponse(BaseModel): """Response model for a single cart item.""" + model_config = ConfigDict(from_attributes=True) product_id: int = Field(..., description="Product ID") product_name: str = Field(..., description="Product name") quantity: int = Field(..., description="Quantity in cart") price: float = Field(..., description="Price per unit when added to cart") - line_total: float = Field(..., description="Total price for this line (price * quantity)") + line_total: float = Field( + ..., description="Total price for this line (price * quantity)" + ) image_url: Optional[str] = Field(None, description="Product image URL") class CartResponse(BaseModel): """Response model for shopping cart.""" + vendor_id: int = Field(..., description="Vendor ID") session_id: str = Field(..., description="Shopping session ID") - items: List[CartItemResponse] = Field(default_factory=list, description="Cart items") + items: List[CartItemResponse] = Field( + default_factory=list, description="Cart items" + ) subtotal: float = Field(..., description="Subtotal of all items") total: float = Field(..., description="Total amount (currently same as subtotal)") item_count: int = Field(..., description="Total number of items in cart") @@ -63,18 +73,22 @@ class CartResponse(BaseModel): items=items, subtotal=cart_dict["subtotal"], total=cart_dict["total"], - item_count=len(items) + item_count=len(items), ) class CartOperationResponse(BaseModel): """Response model for cart operations (add, update, remove).""" + message: str = Field(..., description="Operation result message") product_id: int = Field(..., description="Product ID affected") - quantity: Optional[int] = Field(None, description="New quantity (for add/update operations)") + quantity: Optional[int] = Field( + None, description="New quantity (for add/update operations)" + ) class ClearCartResponse(BaseModel): """Response model for clearing cart.""" + message: str = Field(..., description="Operation result message") items_removed: int = Field(..., description="Number of items removed from cart") diff --git a/models/schema/customer.py b/models/schema/customer.py index 3be4c89a..9c25ac28 100644 --- a/models/schema/customer.py +++ b/models/schema/customer.py @@ -5,35 +5,34 @@ Pydantic schema for customer-related operations. from datetime import datetime from decimal import Decimal -from typing import Optional, Dict, Any, List -from pydantic import BaseModel, EmailStr, Field, field_validator +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, EmailStr, Field, field_validator # ============================================================================ # Customer Registration & Authentication # ============================================================================ + class CustomerRegister(BaseModel): """Schema for customer registration.""" email: EmailStr = Field(..., description="Customer email address") password: str = Field( - ..., - min_length=8, - description="Password (minimum 8 characters)" + ..., min_length=8, description="Password (minimum 8 characters)" ) first_name: str = Field(..., min_length=1, max_length=100) last_name: str = Field(..., min_length=1, max_length=100) phone: Optional[str] = Field(None, max_length=50) marketing_consent: bool = Field(default=False) - @field_validator('email') + @field_validator("email") @classmethod def email_lowercase(cls, v: str) -> str: """Convert email to lowercase.""" return v.lower() - @field_validator('password') + @field_validator("password") @classmethod def password_strength(cls, v: str) -> str: """Validate password strength.""" @@ -55,7 +54,7 @@ class CustomerUpdate(BaseModel): phone: Optional[str] = Field(None, max_length=50) marketing_consent: Optional[bool] = None - @field_validator('email') + @field_validator("email") @classmethod def email_lowercase(cls, v: Optional[str]) -> Optional[str]: """Convert email to lowercase.""" @@ -66,6 +65,7 @@ class CustomerUpdate(BaseModel): # Customer Response # ============================================================================ + class CustomerResponse(BaseModel): """Schema for customer response (excludes password).""" @@ -84,9 +84,7 @@ class CustomerResponse(BaseModel): created_at: datetime updated_at: datetime - model_config = { - "from_attributes": True - } + model_config = {"from_attributes": True} @property def full_name(self) -> str: @@ -110,6 +108,7 @@ class CustomerListResponse(BaseModel): # Customer Address # ============================================================================ + class CustomerAddressCreate(BaseModel): """Schema for creating customer address.""" @@ -159,14 +158,14 @@ class CustomerAddressResponse(BaseModel): created_at: datetime updated_at: datetime - model_config = { - "from_attributes": True - } + model_config = {"from_attributes": True} + # ============================================================================ # Customer Preferences # ============================================================================ + class CustomerPreferencesUpdate(BaseModel): """Schema for updating customer preferences.""" diff --git a/models/schema/inventory.py b/models/schema/inventory.py index 759383e2..89b3b584 100644 --- a/models/schema/inventory.py +++ b/models/schema/inventory.py @@ -1,6 +1,7 @@ # models/schema/inventory.py from datetime import datetime from typing import List, Optional + from pydantic import BaseModel, ConfigDict, Field @@ -11,16 +12,21 @@ class InventoryBase(BaseModel): class InventoryCreate(InventoryBase): """Set exact inventory quantity (replaces existing).""" + quantity: int = Field(..., description="Exact inventory quantity", ge=0) class InventoryAdjust(InventoryBase): """Add or remove inventory quantity.""" - quantity: int = Field(..., description="Quantity to add (positive) or remove (negative)") + + quantity: int = Field( + ..., description="Quantity to add (positive) or remove (negative)" + ) class InventoryUpdate(BaseModel): """Update inventory fields.""" + quantity: Optional[int] = Field(None, ge=0) reserved_quantity: Optional[int] = Field(None, ge=0) location: Optional[str] = None @@ -28,6 +34,7 @@ class InventoryUpdate(BaseModel): class InventoryReserve(BaseModel): """Reserve inventory for orders.""" + product_id: int location: str quantity: int = Field(..., gt=0) @@ -60,6 +67,7 @@ class InventoryLocationResponse(BaseModel): class ProductInventorySummary(BaseModel): """Inventory summary for a product.""" + product_id: int vendor_id: int product_sku: Optional[str] diff --git a/models/schema/marketplace.py b/models/schema/marketplace.py index 22a01dd7..4097cbf4 100644 --- a/models/schema/marketplace.py +++ b/models/schema/marketplace.py @@ -1 +1 @@ -# Marketplace import job models +# Marketplace import job models diff --git a/models/schema/marketplace_import_job.py b/models/schema/marketplace_import_job.py index 5a72c244..ece440be 100644 --- a/models/schema/marketplace_import_job.py +++ b/models/schema/marketplace_import_job.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Optional + from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -8,9 +9,12 @@ class MarketplaceImportJobRequest(BaseModel): Note: vendor_id is injected by middleware, not from request body. """ + source_url: str = Field(..., description="URL to CSV file from marketplace") marketplace: str = Field(default="Letzshop", description="Marketplace name") - batch_size: Optional[int] = Field(1000, description="Processing batch size", ge=100, le=10000) + batch_size: Optional[int] = Field( + 1000, description="Processing batch size", ge=100, le=10000 + ) @field_validator("source_url") @classmethod @@ -28,6 +32,7 @@ class MarketplaceImportJobRequest(BaseModel): class MarketplaceImportJobResponse(BaseModel): """Response schema for marketplace import job.""" + model_config = ConfigDict(from_attributes=True) job_id: int @@ -55,6 +60,7 @@ class MarketplaceImportJobResponse(BaseModel): class MarketplaceImportJobListResponse(BaseModel): """Response schema for list of import jobs.""" + jobs: list[MarketplaceImportJobResponse] total: int skip: int @@ -63,6 +69,7 @@ class MarketplaceImportJobListResponse(BaseModel): class MarketplaceImportJobStatusUpdate(BaseModel): """Schema for updating import job status (internal use).""" + status: str imported_count: Optional[int] = None updated_count: Optional[int] = None diff --git a/models/schema/marketplace_product.py b/models/schema/marketplace_product.py index 56e5406a..e459f1a3 100644 --- a/models/schema/marketplace_product.py +++ b/models/schema/marketplace_product.py @@ -1,9 +1,12 @@ # models/schema/marketplace_products.py - Simplified validation from datetime import datetime from typing import List, Optional + from pydantic import BaseModel, ConfigDict, Field + from models.schema.inventory import ProductInventorySummary + class MarketplaceProductBase(BaseModel): marketplace_product_id: Optional[str] = None title: Optional[str] = None @@ -45,27 +48,34 @@ class MarketplaceProductBase(BaseModel): marketplace: Optional[str] = None vendor_name: Optional[str] = None + class MarketplaceProductCreate(MarketplaceProductBase): - marketplace_product_id: str = Field(..., description="MarketplaceProduct identifier") + marketplace_product_id: str = Field( + ..., description="MarketplaceProduct identifier" + ) title: str = Field(..., description="MarketplaceProduct title") # Removed: min_length constraints and custom validators # Service will handle empty string validation with proper domain exceptions + class MarketplaceProductUpdate(MarketplaceProductBase): pass + class MarketplaceProductResponse(MarketplaceProductBase): model_config = ConfigDict(from_attributes=True) id: int created_at: datetime updated_at: datetime + class MarketplaceProductListResponse(BaseModel): products: List[MarketplaceProductResponse] total: int skip: int limit: int + class MarketplaceProductDetailResponse(BaseModel): product: MarketplaceProductResponse inventory_info: Optional[ProductInventorySummary] = None diff --git a/models/schema/media.py b/models/schema/media.py index cd0a71f0..cb2935b4 100644 --- a/models/schema/media.py +++ b/models/schema/media.py @@ -1 +1 @@ -# Media/file management models +# Media/file management models diff --git a/models/schema/monitoring.py b/models/schema/monitoring.py index 84f2d5a3..c7ab11c1 100644 --- a/models/schema/monitoring.py +++ b/models/schema/monitoring.py @@ -1 +1 @@ -# Monitoring models +# Monitoring models diff --git a/models/schema/notification.py b/models/schema/notification.py index cb27dfba..36f54ac4 100644 --- a/models/schema/notification.py +++ b/models/schema/notification.py @@ -1 +1 @@ -# Notification models +# Notification models diff --git a/models/schema/order.py b/models/schema/order.py index 6d34e9ef..a5a1d371 100644 --- a/models/schema/order.py +++ b/models/schema/order.py @@ -4,23 +4,26 @@ Pydantic schema for order operations. """ from datetime import datetime -from typing import List, Optional from decimal import Decimal -from pydantic import BaseModel, Field, ConfigDict +from typing import List, Optional +from pydantic import BaseModel, ConfigDict, Field # ============================================================================ # Order Item Schemas # ============================================================================ + class OrderItemCreate(BaseModel): """Schema for creating an order item.""" + product_id: int quantity: int = Field(..., ge=1) class OrderItemResponse(BaseModel): """Schema for order item response.""" + model_config = ConfigDict(from_attributes=True) id: int @@ -41,8 +44,10 @@ class OrderItemResponse(BaseModel): # Order Address Schemas # ============================================================================ + class OrderAddressCreate(BaseModel): """Schema for order address (shipping/billing).""" + first_name: str = Field(..., min_length=1, max_length=100) last_name: str = Field(..., min_length=1, max_length=100) company: Optional[str] = Field(None, max_length=200) @@ -55,6 +60,7 @@ class OrderAddressCreate(BaseModel): class OrderAddressResponse(BaseModel): """Schema for order address response.""" + model_config = ConfigDict(from_attributes=True) id: int @@ -73,8 +79,10 @@ class OrderAddressResponse(BaseModel): # Order Create/Update Schemas # ============================================================================ + class OrderCreate(BaseModel): """Schema for creating an order.""" + customer_id: Optional[int] = None # Optional for guest checkout items: List[OrderItemCreate] = Field(..., min_length=1) @@ -92,9 +100,9 @@ class OrderCreate(BaseModel): class OrderUpdate(BaseModel): """Schema for updating order status.""" + status: Optional[str] = Field( - None, - pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$" + None, pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$" ) tracking_number: Optional[str] = None internal_notes: Optional[str] = None @@ -104,8 +112,10 @@ class OrderUpdate(BaseModel): # Order Response Schemas # ============================================================================ + class OrderResponse(BaseModel): """Schema for order response.""" + model_config = ConfigDict(from_attributes=True) id: int @@ -141,6 +151,7 @@ class OrderResponse(BaseModel): class OrderDetailResponse(OrderResponse): """Schema for detailed order response with items and addresses.""" + items: List[OrderItemResponse] shipping_address: OrderAddressResponse billing_address: OrderAddressResponse @@ -148,6 +159,7 @@ class OrderDetailResponse(OrderResponse): class OrderListResponse(BaseModel): """Schema for paginated order list.""" + orders: List[OrderResponse] total: int skip: int diff --git a/models/schema/payment.py b/models/schema/payment.py index 1442207e..1ad2d971 100644 --- a/models/schema/payment.py +++ b/models/schema/payment.py @@ -1 +1 @@ -# Payment models +# Payment models diff --git a/models/schema/product.py b/models/schema/product.py index 7af82de7..1844e955 100644 --- a/models/schema/product.py +++ b/models/schema/product.py @@ -1,14 +1,20 @@ # models/schema/product.py from datetime import datetime from typing import List, Optional + from pydantic import BaseModel, ConfigDict, Field -from models.schema.marketplace_product import MarketplaceProductResponse + from models.schema.inventory import InventoryLocationResponse +from models.schema.marketplace_product import MarketplaceProductResponse class ProductCreate(BaseModel): - marketplace_product_id: int = Field(..., description="MarketplaceProduct ID to add to vendor catalog") - product_id: Optional[str] = Field(None, description="Vendor's internal SKU/product ID") + marketplace_product_id: int = Field( + ..., description="MarketplaceProduct ID to add to vendor catalog" + ) + product_id: Optional[str] = Field( + None, description="Vendor's internal SKU/product ID" + ) price: Optional[float] = Field(None, ge=0) sale_price: Optional[float] = Field(None, ge=0) currency: Optional[str] = None @@ -59,6 +65,7 @@ class ProductResponse(BaseModel): class ProductDetailResponse(ProductResponse): """Product with full inventory details.""" + inventory_locations: List[InventoryLocationResponse] = [] diff --git a/models/schema/search.py b/models/schema/search.py index 69abe85d..2e8b1772 100644 --- a/models/schema/search.py +++ b/models/schema/search.py @@ -1 +1 @@ -# Search models +# Search models diff --git a/models/schema/stats.py b/models/schema/stats.py index 0bdf9b65..4ef278fd 100644 --- a/models/schema/stats.py +++ b/models/schema/stats.py @@ -1,6 +1,6 @@ import re -from decimal import Decimal from datetime import datetime +from decimal import Decimal from typing import List, Optional from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator @@ -22,10 +22,12 @@ class MarketplaceStatsResponse(BaseModel): unique_vendors: int unique_brands: int + # ============================================================================ # Customer Statistics # ============================================================================ + class CustomerStatsResponse(BaseModel): """Schema for customer statistics.""" @@ -42,8 +44,10 @@ class CustomerStatsResponse(BaseModel): # Order Statistics # ============================================================================ + class OrderStatsResponse(BaseModel): """Schema for order statistics.""" + total_orders: int pending_orders: int processing_orders: int @@ -53,13 +57,16 @@ class OrderStatsResponse(BaseModel): total_revenue: Decimal average_order_value: Decimal + # ============================================================================ # Vendor Statistics # ============================================================================ + class VendorStatsResponse(BaseModel): """Vendor statistics response schema.""" + total: int = Field(..., description="Total number of vendors") verified: int = Field(..., description="Number of verified vendors") pending: int = Field(..., description="Number of pending verification vendors") - inactive: int = Field(..., description="Number of inactive vendors") \ No newline at end of file + inactive: int = Field(..., description="Number of inactive vendors") diff --git a/models/schema/team.py b/models/schema/team.py index ec716979..5d3f5ca0 100644 --- a/models/schema/team.py +++ b/models/schema/team.py @@ -10,33 +10,40 @@ This module defines request/response schemas for: """ from datetime import datetime -from typing import Optional, List -from pydantic import BaseModel, EmailStr, Field, field_validator +from typing import List, Optional +from pydantic import BaseModel, EmailStr, Field, field_validator # ============================================================================ # Role Schemas # ============================================================================ + class RoleBase(BaseModel): """Base role schema.""" + name: str = Field(..., min_length=1, max_length=100, description="Role name") - permissions: List[str] = Field(default_factory=list, description="List of permission strings") + permissions: List[str] = Field( + default_factory=list, description="List of permission strings" + ) class RoleCreate(RoleBase): """Schema for creating a role.""" + pass class RoleUpdate(BaseModel): """Schema for updating a role.""" + name: Optional[str] = Field(None, min_length=1, max_length=100) permissions: Optional[List[str]] = None class RoleResponse(RoleBase): """Schema for role response.""" + id: int vendor_id: int created_at: datetime @@ -48,6 +55,7 @@ class RoleResponse(RoleBase): class RoleListResponse(BaseModel): """Schema for role list response.""" + roles: List[RoleResponse] total: int @@ -56,8 +64,10 @@ class RoleListResponse(BaseModel): # Team Member Schemas # ============================================================================ + class TeamMemberBase(BaseModel): """Base team member schema.""" + email: EmailStr = Field(..., description="Team member email address") first_name: Optional[str] = Field(None, max_length=100) last_name: Optional[str] = Field(None, max_length=100) @@ -65,30 +75,34 @@ class TeamMemberBase(BaseModel): class TeamMemberInvite(TeamMemberBase): """Schema for inviting a team member.""" - role_id: Optional[int] = Field(None, description="Role ID to assign (for preset roles)") - role_name: Optional[str] = Field(None, description="Role name (manager, staff, support, etc.)") + + role_id: Optional[int] = Field( + None, description="Role ID to assign (for preset roles)" + ) + role_name: Optional[str] = Field( + None, description="Role name (manager, staff, support, etc.)" + ) custom_permissions: Optional[List[str]] = Field( - None, - description="Custom permissions (overrides role preset)" + None, description="Custom permissions (overrides role preset)" ) - @field_validator('role_name') + @field_validator("role_name") def validate_role_name(cls, v): """Validate role name is in allowed presets.""" if v is not None: - allowed_roles = ['manager', 'staff', 'support', 'viewer', 'marketing'] + allowed_roles = ["manager", "staff", "support", "viewer", "marketing"] if v.lower() not in allowed_roles: raise ValueError( f"Role name must be one of: {', '.join(allowed_roles)}" ) return v.lower() if v else v - @field_validator('custom_permissions') + @field_validator("custom_permissions") def validate_custom_permissions(cls, v, values): """Ensure either role_id/role_name OR custom_permissions is provided.""" if v is not None and len(v) > 0: # If custom permissions provided, role_name should be provided too - if 'role_name' not in values or not values['role_name']: + if "role_name" not in values or not values["role_name"]: raise ValueError( "role_name is required when providing custom_permissions" ) @@ -97,12 +111,14 @@ class TeamMemberInvite(TeamMemberBase): class TeamMemberUpdate(BaseModel): """Schema for updating a team member.""" + role_id: Optional[int] = Field(None, description="New role ID") is_active: Optional[bool] = Field(None, description="Active status") class TeamMemberResponse(BaseModel): """Schema for team member response.""" + id: int = Field(..., description="User ID") email: EmailStr username: str @@ -112,15 +128,18 @@ class TeamMemberResponse(BaseModel): user_type: str = Field(..., description="'owner' or 'member'") role_name: str = Field(..., description="Role name") role_id: Optional[int] - permissions: List[str] = Field(default_factory=list, description="User's permissions") + permissions: List[str] = Field( + default_factory=list, description="User's permissions" + ) is_active: bool is_owner: bool invitation_pending: bool = Field( - default=False, - description="True if invitation not yet accepted" + default=False, description="True if invitation not yet accepted" ) invited_at: Optional[datetime] = Field(None, description="When invitation was sent") - accepted_at: Optional[datetime] = Field(None, description="When invitation was accepted") + accepted_at: Optional[datetime] = Field( + None, description="When invitation was accepted" + ) joined_at: datetime = Field(..., description="When user joined vendor") class Config: @@ -129,6 +148,7 @@ class TeamMemberResponse(BaseModel): class TeamMemberListResponse(BaseModel): """Schema for team member list response.""" + members: List[TeamMemberResponse] total: int active_count: int @@ -139,19 +159,20 @@ class TeamMemberListResponse(BaseModel): # Invitation Schemas # ============================================================================ + class InvitationAccept(BaseModel): """Schema for accepting a team invitation.""" - invitation_token: str = Field(..., min_length=32, description="Invitation token from email") + + invitation_token: str = Field( + ..., min_length=32, description="Invitation token from email" + ) password: str = Field( - ..., - min_length=8, - max_length=128, - description="Password for new account" + ..., min_length=8, max_length=128, description="Password for new account" ) first_name: str = Field(..., min_length=1, max_length=100) last_name: str = Field(..., min_length=1, max_length=100) - @field_validator('password') + @field_validator("password") def validate_password_strength(cls, v): """Validate password meets minimum requirements.""" if len(v) < 8: @@ -172,18 +193,19 @@ class InvitationAccept(BaseModel): class InvitationResponse(BaseModel): """Schema for invitation response.""" + message: str email: EmailStr role: str invitation_token: Optional[str] = Field( - None, - description="Token (only returned in dev/test environments)" + None, description="Token (only returned in dev/test environments)" ) invitation_sent: bool = Field(default=True) class InvitationAcceptResponse(BaseModel): """Schema for invitation acceptance response.""" + message: str vendor: dict = Field(..., description="Vendor information") user: dict = Field(..., description="User information") @@ -194,8 +216,10 @@ class InvitationAcceptResponse(BaseModel): # Team Statistics Schema # ============================================================================ + class TeamStatistics(BaseModel): """Schema for team statistics.""" + total_members: int active_members: int inactive_members: int @@ -203,8 +227,7 @@ class TeamStatistics(BaseModel): owners: int team_members: int roles_breakdown: dict = Field( - default_factory=dict, - description="Count of members per role" + default_factory=dict, description="Count of members per role" ) @@ -212,13 +235,18 @@ class TeamStatistics(BaseModel): # Bulk Operations Schemas # ============================================================================ + class BulkRemoveRequest(BaseModel): """Schema for bulk removing team members.""" - user_ids: List[int] = Field(..., min_items=1, description="List of user IDs to remove") + + user_ids: List[int] = Field( + ..., min_items=1, description="List of user IDs to remove" + ) class BulkRemoveResponse(BaseModel): """Schema for bulk remove response.""" + success_count: int failed_count: int errors: List[dict] = Field(default_factory=list) @@ -228,21 +256,27 @@ class BulkRemoveResponse(BaseModel): # Permission Check Schemas # ============================================================================ + class PermissionCheckRequest(BaseModel): """Schema for checking permissions.""" + permissions: List[str] = Field(..., min_items=1, description="Permissions to check") class PermissionCheckResponse(BaseModel): """Schema for permission check response.""" + has_all: bool = Field(..., description="True if user has all permissions") has_any: bool = Field(..., description="True if user has any permission") granted: List[str] = Field(default_factory=list, description="Permissions user has") - denied: List[str] = Field(default_factory=list, description="Permissions user lacks") + denied: List[str] = Field( + default_factory=list, description="Permissions user lacks" + ) class UserPermissionsResponse(BaseModel): """Schema for user's permissions response.""" + permissions: List[str] = Field(default_factory=list) permission_count: int is_owner: bool @@ -253,8 +287,10 @@ class UserPermissionsResponse(BaseModel): # Error Response Schema # ============================================================================ + class TeamErrorResponse(BaseModel): """Schema for team operation errors.""" + error_code: str message: str details: Optional[dict] = None diff --git a/models/schema/vendor.py b/models/schema/vendor.py index f6080a69..e622f407 100644 --- a/models/schema/vendor.py +++ b/models/schema/vendor.py @@ -15,7 +15,8 @@ Schemas include: import re from datetime import datetime -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional + from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -27,32 +28,26 @@ class VendorCreate(BaseModel): ..., description="Unique vendor identifier (e.g., TECHSTORE)", min_length=2, - max_length=50 + max_length=50, ) subdomain: str = Field( - ..., - description="Unique subdomain for the vendor", - min_length=2, - max_length=100 + ..., description="Unique subdomain for the vendor", min_length=2, max_length=100 ) name: str = Field( - ..., - description="Display name of the vendor", - min_length=2, - max_length=255 + ..., description="Display name of the vendor", min_length=2, max_length=255 ) description: Optional[str] = Field(None, description="Vendor description") # Owner Information (Creates User Account) owner_email: str = Field( ..., - description="Email for the vendor owner (used for login and authentication)" + description="Email for the vendor owner (used for login and authentication)", ) # Business Contact Information (Vendor Fields) contact_email: Optional[str] = Field( None, - description="Public business contact email (defaults to owner_email if not provided)" + description="Public business contact email (defaults to owner_email if not provided)", ) contact_phone: Optional[str] = Field(None, description="Contact phone number") website: Optional[str] = Field(None, description="Website URL") @@ -78,8 +73,10 @@ class VendorCreate(BaseModel): @classmethod def validate_subdomain(cls, v): """Validate subdomain format: lowercase alphanumeric with hyphens.""" - if v and not re.match(r'^[a-z0-9][a-z0-9-]*[a-z0-9]$', v): - raise ValueError("Subdomain must contain only lowercase letters, numbers, and hyphens") + if v and not re.match(r"^[a-z0-9][a-z0-9-]*[a-z0-9]$", v): + raise ValueError( + "Subdomain must contain only lowercase letters, numbers, and hyphens" + ) return v.lower() if v else v @field_validator("vendor_code") @@ -104,8 +101,7 @@ class VendorUpdate(BaseModel): # Business Contact Information (Vendor Fields) contact_email: Optional[str] = Field( - None, - description="Public business contact email" + None, description="Public business contact email" ) contact_phone: Optional[str] = None website: Optional[str] = None @@ -142,6 +138,7 @@ class VendorUpdate(BaseModel): class VendorResponse(BaseModel): """Standard schema for vendor response data.""" + model_config = ConfigDict(from_attributes=True) id: int @@ -184,13 +181,9 @@ class VendorDetailResponse(VendorResponse): """ owner_email: str = Field( - ..., - description="Email of the vendor owner (for login/authentication)" - ) - owner_username: str = Field( - ..., - description="Username of the vendor owner" + ..., description="Email of the vendor owner (for login/authentication)" ) + owner_username: str = Field(..., description="Username of the vendor owner") class VendorCreateResponse(VendorDetailResponse): @@ -201,17 +194,14 @@ class VendorCreateResponse(VendorDetailResponse): """ temporary_password: str = Field( - ..., - description="Temporary password for owner (SHOWN ONLY ONCE)" - ) - login_url: Optional[str] = Field( - None, - description="URL for vendor owner to login" + ..., description="Temporary password for owner (SHOWN ONLY ONCE)" ) + login_url: Optional[str] = Field(None, description="URL for vendor owner to login") class VendorListResponse(BaseModel): """Schema for paginated vendor list.""" + vendors: List[VendorResponse] total: int skip: int @@ -220,6 +210,7 @@ class VendorListResponse(BaseModel): class VendorSummary(BaseModel): """Lightweight vendor summary for dropdowns and quick references.""" + model_config = ConfigDict(from_attributes=True) id: int @@ -239,20 +230,17 @@ class VendorTransferOwnership(BaseModel): """ new_owner_user_id: int = Field( - ..., - description="ID of the user who will become the new owner", - gt=0 + ..., description="ID of the user who will become the new owner", gt=0 ) confirm_transfer: bool = Field( - ..., - description="Must be true to confirm ownership transfer" + ..., description="Must be true to confirm ownership transfer" ) transfer_reason: Optional[str] = Field( None, max_length=500, - description="Reason for ownership transfer (for audit logs)" + description="Reason for ownership transfer (for audit logs)", ) @field_validator("confirm_transfer") @@ -273,12 +261,10 @@ class VendorTransferOwnershipResponse(BaseModel): vendor_name: str old_owner: Dict[str, Any] = Field( - ..., - description="Information about the previous owner" + ..., description="Information about the previous owner" ) new_owner: Dict[str, Any] = Field( - ..., - description="Information about the new owner" + ..., description="Information about the new owner" ) transferred_at: datetime diff --git a/models/schema/vendor_domain.py b/models/schema/vendor_domain.py index aca0eab3..a99b47bb 100644 --- a/models/schema/vendor_domain.py +++ b/models/schema/vendor_domain.py @@ -12,7 +12,8 @@ Schemas include: import re from datetime import datetime -from typing import List, Optional, Dict +from typing import Dict, List, Optional + from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -23,14 +24,13 @@ class VendorDomainCreate(BaseModel): ..., description="Custom domain (e.g., myshop.com or shop.mybrand.com)", min_length=3, - max_length=255 + max_length=255, ) is_primary: bool = Field( - default=False, - description="Set as primary domain for the vendor" + default=False, description="Set as primary domain for the vendor" ) - @field_validator('domain') + @field_validator("domain") @classmethod def validate_domain(cls, v: str) -> str: """Validate and normalize domain.""" @@ -44,20 +44,22 @@ class VendorDomainCreate(BaseModel): domain = domain.lower().strip() # Basic validation - if not domain or '/' in domain: + if not domain or "/" in domain: raise ValueError("Invalid domain format") - if '.' not in domain: + if "." not in domain: raise ValueError("Domain must have at least one dot") # Check for reserved subdomains - reserved = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail'] - first_part = domain.split('.')[0] + reserved = ["www", "admin", "api", "mail", "smtp", "ftp", "cpanel", "webmail"] + first_part = domain.split(".")[0] if first_part in reserved: - raise ValueError(f"Domain cannot start with reserved subdomain: {first_part}") + raise ValueError( + f"Domain cannot start with reserved subdomain: {first_part}" + ) # Validate domain format (basic regex) - domain_pattern = r'^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$' + domain_pattern = r"^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$" if not re.match(domain_pattern, domain): raise ValueError("Invalid domain format") @@ -75,6 +77,7 @@ class VendorDomainUpdate(BaseModel): class VendorDomainResponse(BaseModel): """Standard schema for vendor domain response.""" + model_config = ConfigDict(from_attributes=True) id: int diff --git a/models/schema/vendor_theme.py b/models/schema/vendor_theme.py index 5697ed18..9ba482fc 100644 --- a/models/schema/vendor_theme.py +++ b/models/schema/vendor_theme.py @@ -3,12 +3,14 @@ Pydantic schemas for vendor theme operations. """ -from typing import Dict, Optional, List +from typing import Dict, List, Optional + from pydantic import BaseModel, Field class VendorThemeColors(BaseModel): """Color scheme for vendor theme.""" + primary: Optional[str] = Field(None, description="Primary brand color") secondary: Optional[str] = Field(None, description="Secondary color") accent: Optional[str] = Field(None, description="Accent/CTA color") @@ -19,12 +21,14 @@ class VendorThemeColors(BaseModel): class VendorThemeFonts(BaseModel): """Typography settings for vendor theme.""" + heading: Optional[str] = Field(None, description="Font for headings") body: Optional[str] = Field(None, description="Font for body text") class VendorThemeBranding(BaseModel): """Branding assets for vendor theme.""" + logo: Optional[str] = Field(None, description="Logo URL") logo_dark: Optional[str] = Field(None, description="Dark mode logo URL") favicon: Optional[str] = Field(None, description="Favicon URL") @@ -33,36 +37,54 @@ class VendorThemeBranding(BaseModel): class VendorThemeLayout(BaseModel): """Layout settings for vendor theme.""" - style: Optional[str] = Field(None, description="Product layout style (grid, list, masonry)") - header: Optional[str] = Field(None, description="Header style (fixed, static, transparent)") - product_card: Optional[str] = Field(None, description="Product card style (modern, classic, minimal)") + + style: Optional[str] = Field( + None, description="Product layout style (grid, list, masonry)" + ) + header: Optional[str] = Field( + None, description="Header style (fixed, static, transparent)" + ) + product_card: Optional[str] = Field( + None, description="Product card style (modern, classic, minimal)" + ) class VendorThemeUpdate(BaseModel): """Schema for updating vendor theme (partial updates allowed).""" + theme_name: Optional[str] = Field(None, description="Theme preset name") colors: Optional[Dict[str, str]] = Field(None, description="Color scheme") fonts: Optional[Dict[str, str]] = Field(None, description="Font settings") - branding: Optional[Dict[str, Optional[str]]] = Field(None, description="Branding assets") + branding: Optional[Dict[str, Optional[str]]] = Field( + None, description="Branding assets" + ) layout: Optional[Dict[str, str]] = Field(None, description="Layout settings") custom_css: Optional[str] = Field(None, description="Custom CSS rules") - social_links: Optional[Dict[str, str]] = Field(None, description="Social media links") + social_links: Optional[Dict[str, str]] = Field( + None, description="Social media links" + ) class VendorThemeResponse(BaseModel): """Schema for vendor theme response.""" + theme_name: str = Field(..., description="Theme name") colors: Dict[str, str] = Field(..., description="Color scheme") fonts: Dict[str, str] = Field(..., description="Font settings") branding: Dict[str, Optional[str]] = Field(..., description="Branding assets") layout: Dict[str, str] = Field(..., description="Layout settings") - social_links: Optional[Dict[str, str]] = Field(default_factory=dict, description="Social links") + social_links: Optional[Dict[str, str]] = Field( + default_factory=dict, description="Social links" + ) custom_css: Optional[str] = Field(None, description="Custom CSS") - css_variables: Optional[Dict[str, str]] = Field(None, description="CSS custom properties") + css_variables: Optional[Dict[str, str]] = Field( + None, description="CSS custom properties" + ) class ThemePresetPreview(BaseModel): """Preview information for a theme preset.""" + name: str = Field(..., description="Preset name") description: str = Field(..., description="Preset description") primary_color: str = Field(..., description="Primary color") @@ -75,10 +97,12 @@ class ThemePresetPreview(BaseModel): class ThemePresetResponse(BaseModel): """Response after applying a preset.""" + message: str = Field(..., description="Success message") theme: VendorThemeResponse = Field(..., description="Applied theme") class ThemePresetListResponse(BaseModel): """List of available theme presets.""" + presets: List[ThemePresetPreview] = Field(..., description="Available presets") diff --git a/scripts/backup_database.py b/scripts/backup_database.py index 52fbfcd1..41385cd5 100644 --- a/scripts/backup_database.py +++ b/scripts/backup_database.py @@ -2,9 +2,9 @@ """Database backup utility that uses project configuration.""" import os -import sys import shutil import sqlite3 +import sys from datetime import datetime from pathlib import Path from urllib.parse import urlparse @@ -19,10 +19,10 @@ def get_database_path(): """Extract database path from DATABASE_URL.""" db_url = settings.database_url - if db_url.startswith('sqlite:///'): + if db_url.startswith("sqlite:///"): # Remove sqlite:/// prefix and handle relative paths - db_path = db_url.replace('sqlite:///', '') - if db_path.startswith('./'): + db_path = db_url.replace("sqlite:///", "") + if db_path.startswith("./"): db_path = db_path[2:] # Remove ./ prefix return db_path else: @@ -76,7 +76,7 @@ def backup_database(): print("[BACKUP] Starting database backup...") print(f"[INFO] Database URL: {settings.database_url}") - if settings.database_url.startswith('sqlite'): + if settings.database_url.startswith("sqlite"): return backup_sqlite_database() else: print("[INFO] For PostgreSQL databases, use pg_dump:") diff --git a/scripts/create_default_content_pages.py b/scripts/create_default_content_pages.py index 927559c3..14fc665d 100755 --- a/scripts/create_default_content_pages.py +++ b/scripts/create_default_content_pages.py @@ -26,20 +26,19 @@ Usage: """ import sys -from pathlib import Path from datetime import datetime, timezone +from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from sqlalchemy.orm import Session from sqlalchemy import select +from sqlalchemy.orm import Session from app.core.database import SessionLocal from models.database.content_page import ContentPage - # ============================================================================ # DEFAULT PAGE CONTENT # ============================================================================ @@ -458,6 +457,7 @@ DEFAULT_PAGES = [ # SCRIPT FUNCTIONS # ============================================================================ + def create_default_pages(db: Session) -> None: """ Create default platform content pages. @@ -475,13 +475,14 @@ def create_default_pages(db: Session) -> None: # Check if page already exists (platform default with this slug) existing = db.execute( select(ContentPage).where( - ContentPage.vendor_id == None, - ContentPage.slug == page_data["slug"] + ContentPage.vendor_id == None, ContentPage.slug == page_data["slug"] ) ).scalar_one_or_none() if existing: - print(f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists") + print( + f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists" + ) skipped_count += 1 continue @@ -519,7 +520,9 @@ def create_default_pages(db: Session) -> None: if created_count > 0: print("✅ Default platform content pages created successfully!\n") print("Next steps:") - print(" 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms") + print( + " 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms" + ) print(" 2. Vendors can override these pages through the vendor dashboard") print(" 3. Edit platform defaults through the admin panel\n") else: @@ -530,6 +533,7 @@ def create_default_pages(db: Session) -> None: # MAIN EXECUTION # ============================================================================ + def main(): """Main execution function.""" print("\n🚀 Starting Default Content Pages Creation Script...\n") diff --git a/scripts/create_inventory.py b/scripts/create_inventory.py index a69a341e..d1bc55a0 100755 --- a/scripts/create_inventory.py +++ b/scripts/create_inventory.py @@ -12,17 +12,18 @@ from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) -import sqlite3 -from datetime import datetime, UTC import os +import sqlite3 +from datetime import UTC, datetime + from dotenv import load_dotenv # Load environment variables load_dotenv() # Get database URL -database_url = os.getenv('DATABASE_URL', 'wizamart.db') -db_path = database_url.replace('sqlite:///', '') +database_url = os.getenv("DATABASE_URL", "wizamart.db") +db_path = database_url.replace("sqlite:///", "") print(f"📦 Creating inventory entries in {db_path}...") @@ -30,12 +31,14 @@ conn = sqlite3.connect(db_path) cursor = conn.cursor() # Get products without inventory -cursor.execute(""" +cursor.execute( + """ SELECT p.id, p.vendor_id, p.product_id FROM products p LEFT JOIN inventory i ON p.id = i.product_id WHERE i.id IS NULL -""") +""" +) products_without_inventory = cursor.fetchall() if not products_without_inventory: @@ -47,7 +50,8 @@ print(f"📦 Creating inventory for {len(products_without_inventory)} products.. # Create inventory entries for product_id, vendor_id, sku in products_without_inventory: - cursor.execute(""" + cursor.execute( + """ INSERT INTO inventory ( vendor_id, product_id, @@ -57,15 +61,17 @@ for product_id, vendor_id, sku in products_without_inventory: created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - vendor_id, - product_id, - 'Main Warehouse', - 100, # Total quantity - 0, # Reserved quantity - datetime.now(UTC), - datetime.now(UTC) - )) + """, + ( + vendor_id, + product_id, + "Main Warehouse", + 100, # Total quantity + 0, # Reserved quantity + datetime.now(UTC), + datetime.now(UTC), + ), + ) conn.commit() @@ -76,7 +82,9 @@ cursor.execute("SELECT COUNT(*) FROM inventory") total_count = cursor.fetchone()[0] print(f"\n📊 Total inventory entries: {total_count}") -cursor.execute("SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5") +cursor.execute( + "SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5" +) print("\n📦 Sample inventory:") for row in cursor.fetchall(): available = row[2] - row[3] diff --git a/scripts/create_landing_page.py b/scripts/create_landing_page.py index e5d463fc..203c33f3 100755 --- a/scripts/create_landing_page.py +++ b/scripts/create_landing_page.py @@ -12,18 +12,20 @@ from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) -from sqlalchemy.orm import Session -from app.core.database import SessionLocal -from models.database.vendor import Vendor -from models.database.content_page import ContentPage from datetime import datetime, timezone +from sqlalchemy.orm import Session + +from app.core.database import SessionLocal +from models.database.content_page import ContentPage +from models.database.vendor import Vendor + def create_landing_page( vendor_subdomain: str, template: str = "default", title: str = None, - content: str = None + content: str = None, ): """ Create a landing page for a vendor. @@ -38,9 +40,7 @@ def create_landing_page( try: # Find vendor - vendor = db.query(Vendor).filter( - Vendor.subdomain == vendor_subdomain - ).first() + vendor = db.query(Vendor).filter(Vendor.subdomain == vendor_subdomain).first() if not vendor: print(f"❌ Vendor '{vendor_subdomain}' not found!") @@ -49,10 +49,11 @@ def create_landing_page( print(f"✅ Found vendor: {vendor.name} (ID: {vendor.id})") # Check if landing page already exists - existing = db.query(ContentPage).filter( - ContentPage.vendor_id == vendor.id, - ContentPage.slug == "landing" - ).first() + existing = ( + db.query(ContentPage) + .filter(ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing") + .first() + ) if existing: print(f"⚠️ Landing page already exists (ID: {existing.id})") @@ -74,7 +75,8 @@ def create_landing_page( vendor_id=vendor.id, slug="landing", title=title or f"Welcome to {vendor.name}", - content=content or f""" + content=content + or f"""

About {vendor.name}

{vendor.description or 'Your trusted shopping destination for quality products.'}

@@ -96,7 +98,7 @@ def create_landing_page( published_at=datetime.now(timezone.utc), show_in_footer=False, show_in_header=False, - display_order=0 + display_order=0, ) db.add(landing_page) @@ -141,10 +143,13 @@ def list_vendors(): print(f" Code: {vendor.vendor_code}") # Check if has landing page - landing = db.query(ContentPage).filter( - ContentPage.vendor_id == vendor.id, - ContentPage.slug == "landing" - ).first() + landing = ( + db.query(ContentPage) + .filter( + ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing" + ) + .first() + ) if landing: print(f" Landing Page: ✅ ({landing.template})") @@ -165,7 +170,7 @@ def show_templates(): ("default", "Clean professional layout with 3-column quick links"), ("minimal", "Ultra-simple centered design with single CTA"), ("modern", "Full-screen hero with animations and features"), - ("full", "Maximum features with split-screen hero and stats") + ("full", "Maximum features with split-screen hero and stats"), ] for name, desc in templates: @@ -209,7 +214,7 @@ if __name__ == "__main__": success = create_landing_page( vendor_subdomain=vendor_subdomain, template=template, - title=title if title else None + title=title if title else None, ) if success: diff --git a/scripts/create_platform_pages.py b/scripts/create_platform_pages.py index e757beb4..cd694bad 100755 --- a/scripts/create_platform_pages.py +++ b/scripts/create_platform_pages.py @@ -46,16 +46,22 @@ def create_platform_pages(): print("1. Creating Platform Homepage...") # Check if already exists - existing = db.query(ContentPage).filter_by(vendor_id=None, slug="platform_homepage").first() + existing = ( + db.query(ContentPage) + .filter_by(vendor_id=None, slug="platform_homepage") + .first() + ) if existing: - print(f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})") + print( + f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})" + ) else: try: homepage = content_page_service.create_page( - db, - slug="platform_homepage", - title="Welcome to Our Multi-Vendor Marketplace", - content=""" + db, + slug="platform_homepage", + title="Welcome to Our Multi-Vendor Marketplace", + content="""

Connect vendors with customers worldwide. Build your online store and reach millions of shoppers.

@@ -64,14 +70,14 @@ def create_platform_pages(): with minimal effort and maximum impact.

""", - template="modern", # Uses platform/homepage-modern.html - vendor_id=None, # Platform-level page - is_published=True, - show_in_header=False, # Homepage is not in menu (it's the root) - show_in_footer=False, - display_order=0, + template="modern", # Uses platform/homepage-modern.html + vendor_id=None, # Platform-level page + is_published=True, + show_in_header=False, # Homepage is not in menu (it's the root) + show_in_footer=False, + display_order=0, meta_description="Leading multi-vendor marketplace platform. Connect with thousands of vendors and discover millions of products.", - meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform" + meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform", ) print(f" ✅ Created: {homepage.title} (/{homepage.slug})") except Exception as e: @@ -88,10 +94,10 @@ def create_platform_pages(): else: try: about = content_page_service.create_page( - db, - slug="about", - title="About Us", - content=""" + db, + slug="about", + title="About Us", + content="""

Our Mission

We're on a mission to democratize e-commerce by providing powerful, @@ -121,13 +127,13 @@ def create_platform_pages():

  • Excellence: We strive for the highest quality in everything we do
  • """, - vendor_id=None, - is_published=True, - show_in_header=True, # Show in header navigation - show_in_footer=True, # Show in footer - display_order=1, + vendor_id=None, + is_published=True, + show_in_header=True, # Show in header navigation + show_in_footer=True, # Show in footer + display_order=1, meta_description="Learn about our mission to democratize e-commerce and empower entrepreneurs worldwide.", - meta_keywords="about us, mission, vision, values, company" + meta_keywords="about us, mission, vision, values, company", ) print(f" ✅ Created: {about.title} (/{about.slug})") except Exception as e: @@ -144,10 +150,10 @@ def create_platform_pages(): else: try: faq = content_page_service.create_page( - db, - slug="faq", - title="Frequently Asked Questions", - content=""" + db, + slug="faq", + title="Frequently Asked Questions", + content="""

    Getting Started

    How do I create a vendor account?

    @@ -204,13 +210,13 @@ def create_platform_pages(): and marketing tools.

    """, - vendor_id=None, - is_published=True, - show_in_header=True, # Show in header navigation - show_in_footer=True, - display_order=2, + vendor_id=None, + is_published=True, + show_in_header=True, # Show in header navigation + show_in_footer=True, + display_order=2, meta_description="Find answers to common questions about our marketplace platform.", - meta_keywords="faq, frequently asked questions, help, support" + meta_keywords="faq, frequently asked questions, help, support", ) print(f" ✅ Created: {faq.title} (/{faq.slug})") except Exception as e: @@ -221,16 +227,18 @@ def create_platform_pages(): # ======================================================================== print("4. Creating Contact Us page...") - existing = db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first() + existing = ( + db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first() + ) if existing: print(f" ⚠️ Skipped: Contact Us - already exists (ID: {existing.id})") else: try: contact = content_page_service.create_page( - db, - slug="contact", - title="Contact Us", - content=""" + db, + slug="contact", + title="Contact Us", + content="""

    Get in Touch

    We'd love to hear from you! Whether you have questions about our platform, @@ -271,13 +279,13 @@ def create_platform_pages(): Email: press@marketplace.com

    """, - vendor_id=None, - is_published=True, - show_in_header=True, # Show in header navigation - show_in_footer=True, - display_order=3, + vendor_id=None, + is_published=True, + show_in_header=True, # Show in header navigation + show_in_footer=True, + display_order=3, meta_description="Get in touch with our team. We're here to help you succeed.", - meta_keywords="contact, support, email, phone, address" + meta_keywords="contact, support, email, phone, address", ) print(f" ✅ Created: {contact.title} (/{contact.slug})") except Exception as e: @@ -290,14 +298,16 @@ def create_platform_pages(): existing = db.query(ContentPage).filter_by(vendor_id=None, slug="terms").first() if existing: - print(f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})") + print( + f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})" + ) else: try: terms = content_page_service.create_page( - db, - slug="terms", - title="Terms of Service", - content=""" + db, + slug="terms", + title="Terms of Service", + content="""

    Last updated: January 1, 2024

    1. Acceptance of Terms

    @@ -361,13 +371,13 @@ def create_platform_pages(): legal@marketplace.com.

    """, - vendor_id=None, - is_published=True, - show_in_header=False, # Too legal for header - show_in_footer=True, # Show in footer - display_order=10, + vendor_id=None, + is_published=True, + show_in_header=False, # Too legal for header + show_in_footer=True, # Show in footer + display_order=10, meta_description="Read our terms of service and platform usage policies.", - meta_keywords="terms of service, terms, legal, policy, agreement" + meta_keywords="terms of service, terms, legal, policy, agreement", ) print(f" ✅ Created: {terms.title} (/{terms.slug})") except Exception as e: @@ -378,16 +388,18 @@ def create_platform_pages(): # ======================================================================== print("6. Creating Privacy Policy page...") - existing = db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first() + existing = ( + db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first() + ) if existing: print(f" ⚠️ Skipped: Privacy Policy - already exists (ID: {existing.id})") else: try: privacy = content_page_service.create_page( - db, - slug="privacy", - title="Privacy Policy", - content=""" + db, + slug="privacy", + title="Privacy Policy", + content="""

    Last updated: January 1, 2024

    1. Information We Collect

    @@ -453,13 +465,13 @@ def create_platform_pages(): privacy@marketplace.com.

    """, - vendor_id=None, - is_published=True, - show_in_header=False, # Too legal for header - show_in_footer=True, # Show in footer - display_order=11, + vendor_id=None, + is_published=True, + show_in_header=False, # Too legal for header + show_in_footer=True, # Show in footer + display_order=11, meta_description="Learn how we collect, use, and protect your personal information.", - meta_keywords="privacy policy, privacy, data protection, gdpr, personal information" + meta_keywords="privacy policy, privacy, data protection, gdpr, personal information", ) print(f" ✅ Created: {privacy.title} (/{privacy.slug})") except Exception as e: diff --git a/scripts/init_production.py b/scripts/init_production.py index 6a86f412..882d8f70 100644 --- a/scripts/init_production.py +++ b/scripts/init_production.py @@ -17,29 +17,30 @@ This script is idempotent - safe to run multiple times. """ import sys -from pathlib import Path from datetime import datetime, timezone +from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from sqlalchemy.orm import Session from sqlalchemy import select +from sqlalchemy.orm import Session +from app.core.config import (print_environment_info, settings, + validate_production_settings) from app.core.database import SessionLocal -from app.core.config import settings, print_environment_info, validate_production_settings -from app.core.environment import is_production, get_environment -from models.database.user import User -from models.database.admin import AdminSetting -from middleware.auth import AuthManager +from app.core.environment import get_environment, is_production from app.core.permissions import PermissionGroups - +from middleware.auth import AuthManager +from models.database.admin import AdminSetting +from models.database.user import User # ============================================================================= # HELPER FUNCTIONS # ============================================================================= + def print_header(text: str): """Print formatted header.""" print("\n" + "=" * 70) @@ -71,6 +72,7 @@ def print_error(text: str): # INITIALIZATION FUNCTIONS # ============================================================================= + def create_admin_user(db: Session, auth_manager: AuthManager) -> User: """Create or get the platform admin user.""" @@ -206,9 +208,7 @@ def create_admin_settings(db: Session) -> int: for setting_data in default_settings: # Check if setting already exists existing = db.execute( - select(AdminSetting).where( - AdminSetting.key == setting_data["key"] - ) + select(AdminSetting).where(AdminSetting.key == setting_data["key"]) ).scalar_one_or_none() if not existing: @@ -281,6 +281,7 @@ def verify_rbac_schema(db: Session) -> bool: # MAIN INITIALIZATION # ============================================================================= + def initialize_production(db: Session, auth_manager: AuthManager): """Initialize production database with essential data.""" @@ -362,6 +363,7 @@ def print_summary(db: Session): # MAIN ENTRY POINT # ============================================================================= + def main(): """Main entry point.""" @@ -405,6 +407,7 @@ def main(): print_header("❌ INITIALIZATION FAILED") print(f"\nError: {e}\n") import traceback + traceback.print_exc() sys.exit(1) diff --git a/scripts/route_diagnostics.py b/scripts/route_diagnostics.py index 0616e221..febda516 100644 --- a/scripts/route_diagnostics.py +++ b/scripts/route_diagnostics.py @@ -7,7 +7,7 @@ Usage: python route_diagnostics.py """ import sys -from typing import List, Dict +from typing import Dict, List def check_route_order(): @@ -29,23 +29,23 @@ def check_route_order(): html_routes = [] for route in routes: - if hasattr(route, 'path'): + if hasattr(route, "path"): path = route.path - methods = getattr(route, 'methods', set()) + methods = getattr(route, "methods", set()) # Determine if JSON or HTML based on common patterns - if 'login' in path or 'dashboard' in path or 'products' in path: + if "login" in path or "dashboard" in path or "products" in path: # Check response class if available - if hasattr(route, 'response_class'): + if hasattr(route, "response_class"): response_class = str(route.response_class) - if 'HTML' in response_class: + if "HTML" in response_class: html_routes.append((path, methods)) else: json_routes.append((path, methods)) else: # Check endpoint name/tags - endpoint = getattr(route, 'endpoint', None) - if endpoint and 'page' in endpoint.__name__.lower(): + endpoint = getattr(route, "endpoint", None) + if endpoint and "page" in endpoint.__name__.lower(): html_routes.append((path, methods)) else: json_routes.append((path, methods)) @@ -57,8 +57,10 @@ def check_route_order(): # Check for specific vendor info route vendor_info_found = False for route in routes: - if hasattr(route, 'path'): - if '/{vendor_code}' == route.path and 'GET' in getattr(route, 'methods', set()): + if hasattr(route, "path"): + if "/{vendor_code}" == route.path and "GET" in getattr( + route, "methods", set() + ): vendor_info_found = True print("✅ Found vendor info endpoint: GET /{vendor_code}") break @@ -100,14 +102,14 @@ def test_vendor_endpoint(): print(f" Content-Type: {response.headers.get('content-type', 'N/A')}") # Check if response is JSON - content_type = response.headers.get('content-type', '') - if 'application/json' in content_type: + content_type = response.headers.get("content-type", "") + if "application/json" in content_type: print("✅ Response is JSON") data = response.json() print(f" Vendor: {data.get('name', 'N/A')}") print(f" Code: {data.get('vendor_code', 'N/A')}") return True - elif 'text/html' in content_type: + elif "text/html" in content_type: print("❌ ERROR: Response is HTML, not JSON!") print(" This confirms the route ordering issue") print(" The HTML page route is catching the API request") diff --git a/scripts/seed_demo.py b/scripts/seed_demo.py index 6e224ea7..13cade60 100644 --- a/scripts/seed_demo.py +++ b/scripts/seed_demo.py @@ -26,39 +26,39 @@ This script is idempotent when run normally. """ import sys +from datetime import datetime, timedelta, timezone +from decimal import Decimal from pathlib import Path from typing import Dict, List -from datetime import datetime, timezone, timedelta -from decimal import Decimal # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from sqlalchemy.orm import Session -from sqlalchemy import select, delete - -from app.core.database import SessionLocal -from app.core.config import settings -from app.core.environment import is_production, get_environment -from models.database.user import User -from models.database.vendor import Vendor, VendorUser, Role -from models.database.vendor_domain import VendorDomain -from models.database.vendor_theme import VendorTheme -from models.database.customer import Customer, CustomerAddress -from models.database.product import Product -from models.database.marketplace_product import MarketplaceProduct -from models.database.marketplace_import_job import MarketplaceImportJob -from models.database.order import Order, OrderItem -from models.database.admin import PlatformAlert -from middleware.auth import AuthManager - # ============================================================================= # MODE DETECTION (from environment variable set by Makefile) # ============================================================================= import os -SEED_MODE = os.getenv('SEED_MODE', 'normal') # normal, minimal, reset +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.core.database import SessionLocal +from app.core.environment import get_environment, is_production +from middleware.auth import AuthManager +from models.database.admin import PlatformAlert +from models.database.customer import Customer, CustomerAddress +from models.database.marketplace_import_job import MarketplaceImportJob +from models.database.marketplace_product import MarketplaceProduct +from models.database.order import Order, OrderItem +from models.database.product import Product +from models.database.user import User +from models.database.vendor import Role, Vendor, VendorUser +from models.database.vendor_domain import VendorDomain +from models.database.vendor_theme import VendorTheme + +SEED_MODE = os.getenv("SEED_MODE", "normal") # normal, minimal, reset # ============================================================================= # DEMO DATA CONFIGURATION @@ -147,6 +147,7 @@ THEME_PRESETS = { # HELPER FUNCTIONS # ============================================================================= + def print_header(text: str): """Print formatted header.""" print("\n" + "=" * 70) @@ -178,6 +179,7 @@ def print_error(text: str): # SAFETY CHECKS # ============================================================================= + def check_environment(): """Prevent running demo seed in production.""" @@ -196,9 +198,7 @@ def check_environment(): def check_admin_exists(db: Session) -> bool: """Check if admin user exists.""" - admin = db.execute( - select(User).where(User.role == "admin") - ).scalar_one_or_none() + admin = db.execute(select(User).where(User.role == "admin")).scalar_one_or_none() if not admin: print_error("No admin user found!") @@ -214,6 +214,7 @@ def check_admin_exists(db: Session) -> bool: # DATA DELETION (for reset mode) # ============================================================================= + def reset_all_data(db: Session): """Delete ALL data from database (except admin user).""" @@ -258,13 +259,14 @@ def reset_all_data(db: Session): # SEEDING FUNCTIONS # ============================================================================= + def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]: """Create demo vendors with users.""" vendors = [] # Determine how many vendors to create based on mode - vendor_count = 1 if SEED_MODE == 'minimal' else settings.seed_demo_vendors + vendor_count = 1 if SEED_MODE == "minimal" else settings.seed_demo_vendors vendors_to_create = DEMO_VENDORS[:vendor_count] users_to_create = DEMO_VENDOR_USERS[:vendor_count] @@ -320,7 +322,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]: db.add(vendor_user_link) # Create vendor theme - theme_colors = THEME_PRESETS.get(vendor_data["theme_preset"], THEME_PRESETS["modern"]) + theme_colors = THEME_PRESETS.get( + vendor_data["theme_preset"], THEME_PRESETS["modern"] + ) theme = VendorTheme( vendor_id=vendor.id, theme_name=vendor_data["theme_preset"], @@ -355,7 +359,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]: return vendors -def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager, count: int) -> List[Customer]: +def create_demo_customers( + db: Session, vendor: Vendor, auth_manager: AuthManager, count: int +) -> List[Customer]: """Create demo customers for a vendor.""" customers = [] @@ -367,10 +373,11 @@ def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager customer_number = f"CUST-{vendor.vendor_code}-{i:04d}" # Check if customer already exists - existing_customer = db.query(Customer).filter( - Customer.vendor_id == vendor.id, - Customer.email == email - ).first() + existing_customer = ( + db.query(Customer) + .filter(Customer.vendor_id == vendor.id, Customer.email == email) + .first() + ) if existing_customer: customers.append(existing_customer) @@ -412,19 +419,22 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc product_id = f"{vendor.vendor_code}-PROD-{i:03d}" # Check if this product already exists - existing_product = db.query(Product).filter( - Product.vendor_id == vendor.id, - Product.product_id == product_id - ).first() + existing_product = ( + db.query(Product) + .filter(Product.vendor_id == vendor.id, Product.product_id == product_id) + .first() + ) if existing_product: products.append(existing_product) continue # Skip creation, product already exists # Check if marketplace product already exists - existing_mp = db.query(MarketplaceProduct).filter( - MarketplaceProduct.marketplace_product_id == marketplace_product_id - ).first() + existing_mp = ( + db.query(MarketplaceProduct) + .filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id) + .first() + ) if existing_mp: marketplace_product = existing_mp @@ -485,6 +495,7 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc # MAIN SEEDING # ============================================================================= + def seed_demo_data(db: Session, auth_manager: AuthManager): """Seed demo data for development.""" @@ -501,7 +512,7 @@ def seed_demo_data(db: Session, auth_manager: AuthManager): sys.exit(1) # Step 3: Reset data if in reset mode - if SEED_MODE == 'reset': + if SEED_MODE == "reset": print_step(3, "Resetting data...") reset_all_data(db) @@ -513,20 +524,13 @@ def seed_demo_data(db: Session, auth_manager: AuthManager): print_step(5, "Creating demo customers...") for vendor in vendors: create_demo_customers( - db, - vendor, - auth_manager, - count=settings.seed_customers_per_vendor + db, vendor, auth_manager, count=settings.seed_customers_per_vendor ) # Step 6: Create products print_step(6, "Creating demo products...") for vendor in vendors: - create_demo_products( - db, - vendor, - count=settings.seed_products_per_vendor - ) + create_demo_products(db, vendor, count=settings.seed_products_per_vendor) # Commit all changes db.commit() @@ -558,17 +562,18 @@ def print_summary(db: Session): print(f" Subdomain: {vendor.subdomain}.{settings.platform_domain}") # Query custom domains separately - custom_domain = db.query(VendorDomain).filter( - VendorDomain.vendor_id == vendor.id, - VendorDomain.is_active == True - ).first() + custom_domain = ( + db.query(VendorDomain) + .filter(VendorDomain.vendor_id == vendor.id, VendorDomain.is_active == True) + .first() + ) if custom_domain: # Try different possible field names (model field might vary) domain_value = ( - getattr(custom_domain, 'domain', None) or - getattr(custom_domain, 'domain_name', None) or - getattr(custom_domain, 'name', None) + getattr(custom_domain, "domain", None) + or getattr(custom_domain, "domain_name", None) + or getattr(custom_domain, "name", None) ) if domain_value: print(f" Custom: {domain_value}") @@ -584,8 +589,12 @@ def print_summary(db: Session): print(f" Email: {vendor_data['email']}") print(f" Password: {vendor_data['password']}") if vendor: - print(f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login") - print(f" or http://{vendor.subdomain}.localhost:8000/vendor/login") + print( + f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login" + ) + print( + f" or http://{vendor.subdomain}.localhost:8000/vendor/login" + ) print() print(f"\n🛒 Demo Customer Credentials:") @@ -600,7 +609,9 @@ def print_summary(db: Session): print("─" * 70) for vendor in vendors: print(f" {vendor.name}:") - print(f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/") + print( + f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/" + ) print(f" Subdomain: http://{vendor.subdomain}.localhost:8000/") print() @@ -621,6 +632,7 @@ def print_summary(db: Session): # MAIN ENTRY POINT # ============================================================================= + def main(): """Main entry point.""" @@ -647,6 +659,7 @@ def main(): print_header("❌ SEEDING FAILED") print(f"\nError: {e}\n") import traceback + traceback.print_exc() sys.exit(1) diff --git a/scripts/show_structure.py b/scripts/show_structure.py index 6d3eeb43..83d382ae 100644 --- a/scripts/show_structure.py +++ b/scripts/show_structure.py @@ -11,11 +11,11 @@ Usage: """ import os -import sys import subprocess +import sys from datetime import datetime from pathlib import Path -from typing import List, Dict +from typing import Dict, List def count_files(directory: str, pattern: str) -> int: @@ -26,10 +26,14 @@ def count_files(directory: str, pattern: str) -> int: count = 0 for root, dirs, files in os.walk(directory): # Skip __pycache__ and other cache directories - dirs[:] = [d for d in dirs if d not in ['__pycache__', '.pytest_cache', '.git', 'node_modules']] + dirs[:] = [ + d + for d in dirs + if d not in ["__pycache__", ".pytest_cache", ".git", "node_modules"] + ] for file in files: - if pattern == '*' or file.endswith(pattern): + if pattern == "*" or file.endswith(pattern): count += 1 return count @@ -41,26 +45,26 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st # Try to use system tree command first try: - if sys.platform == 'win32': + if sys.platform == "win32": # Windows tree command result = subprocess.run( - ['tree', '/F', '/A', directory], + ["tree", "/F", "/A", directory], capture_output=True, text=True, - encoding='utf-8', - errors='replace' + encoding="utf-8", + errors="replace", ) return result.stdout else: # Linux/Mac tree command with exclusions exclude_args = [] if exclude_patterns: - exclude_args = ['-I', '|'.join(exclude_patterns)] + exclude_args = ["-I", "|".join(exclude_patterns)] result = subprocess.run( - ['tree', '-F', '-a'] + exclude_args + [directory], + ["tree", "-F", "-a"] + exclude_args + [directory], capture_output=True, - text=True + text=True, ) if result.returncode == 0: return result.stdout @@ -71,10 +75,19 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st return generate_manual_tree(directory, exclude_patterns) -def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, prefix: str = "") -> str: +def generate_manual_tree( + directory: str, exclude_patterns: List[str] = None, prefix: str = "" +) -> str: """Generate tree structure manually when tree command is not available.""" if exclude_patterns is None: - exclude_patterns = ['__pycache__', '.pytest_cache', '.git', 'node_modules', '*.pyc', '*.pyo'] + exclude_patterns = [ + "__pycache__", + ".pytest_cache", + ".git", + "node_modules", + "*.pyc", + "*.pyo", + ] output = [] path = Path(directory) @@ -86,7 +99,7 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre # Skip excluded patterns skip = False for pattern in exclude_patterns: - if pattern.startswith('*'): + if pattern.startswith("*"): # File extension pattern if item.name.endswith(pattern[1:]): skip = True @@ -106,7 +119,9 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre if item.is_dir(): output.append(f"{prefix}{current_prefix}{item.name}/") extension = " " if is_last else "│ " - subtree = generate_manual_tree(str(item), exclude_patterns, prefix + extension) + subtree = generate_manual_tree( + str(item), exclude_patterns, prefix + extension + ) if subtree: output.append(subtree) else: @@ -127,20 +142,36 @@ def generate_frontend_structure() -> str: # Templates section output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ JINJA2 TEMPLATES ║") - output.append("║ Location: app/templates ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ JINJA2 TEMPLATES ║" + ) + output.append( + "║ Location: app/templates ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append(get_tree_structure("app/templates")) # Static assets section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ STATIC ASSETS ║") - output.append("║ Location: static ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ STATIC ASSETS ║" + ) + output.append( + "║ Location: static ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append(get_tree_structure("static")) @@ -148,11 +179,21 @@ def generate_frontend_structure() -> str: if os.path.exists("docs"): output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ DOCUMENTATION ║") - output.append("║ Location: docs ║") - output.append("║ (also listed in tools structure) ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ DOCUMENTATION ║" + ) + output.append( + "║ Location: docs ║" + ) + output.append( + "║ (also listed in tools structure) ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append("Note: Documentation is also included in tools structure") output.append(" for infrastructure/DevOps context.") @@ -160,9 +201,15 @@ def generate_frontend_structure() -> str: # Statistics section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ STATISTICS ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ STATISTICS ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append("Templates:") @@ -174,8 +221,8 @@ def generate_frontend_structure() -> str: output.append(f" - JavaScript files: {count_files('static', '.js')}") output.append(f" - CSS files: {count_files('static', '.css')}") output.append(" - Image files:") - for ext in ['png', 'jpg', 'jpeg', 'gif', 'svg', 'webp', 'ico']: - count = count_files('static', f'.{ext}') + for ext in ["png", "jpg", "jpeg", "gif", "svg", "webp", "ico"]: + count = count_files("static", f".{ext}") if count > 0: output.append(f" - .{ext}: {count}") @@ -200,41 +247,58 @@ def generate_backend_structure() -> str: output.append("=" * 78) output.append("") - exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info', 'templates'] + exclude = [ + "__pycache__", + "*.pyc", + "*.pyo", + ".pytest_cache", + "*.egg-info", + "templates", + ] # Backend directories to include backend_dirs = [ - ('app', 'Application Code'), - ('middleware', 'Middleware Components'), - ('models', 'Database Models'), - ('storage', 'File Storage'), - ('tasks', 'Background Tasks'), - ('logs', 'Application Logs'), + ("app", "Application Code"), + ("middleware", "Middleware Components"), + ("models", "Database Models"), + ("storage", "File Storage"), + ("tasks", "Background Tasks"), + ("logs", "Application Logs"), ] for directory, title in backend_dirs: if os.path.exists(directory): output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) output.append(f"║ {title.upper().center(62)} ║") output.append(f"║ Location: {directory + '/'.ljust(51)} ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append(get_tree_structure(directory, exclude)) # Statistics section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ STATISTICS ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ STATISTICS ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append("Python Files by Directory:") total_py_files = 0 for directory, title in backend_dirs: if os.path.exists(directory): - count = count_files(directory, '.py') + count = count_files(directory, ".py") total_py_files += count output.append(f" - {directory}/: {count} files") @@ -242,11 +306,11 @@ def generate_backend_structure() -> str: output.append("") output.append("Application Components (if in app/):") - components = ['routes', 'services', 'schemas', 'exceptions', 'utils'] + components = ["routes", "services", "schemas", "exceptions", "utils"] for component in components: component_path = f"app/{component}" if os.path.exists(component_path): - count = count_files(component_path, '.py') + count = count_files(component_path, ".py") output.append(f" - app/{component}: {count} files") output.append("") @@ -264,48 +328,58 @@ def generate_tools_structure() -> str: output.append("=" * 78) output.append("") - exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info'] + exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"] # Tools directories to include tools_dirs = [ - ('alembic', 'Database Migrations'), - ('scripts', 'Utility Scripts'), - ('docker', 'Docker Configuration'), - ('docs', 'Documentation'), + ("alembic", "Database Migrations"), + ("scripts", "Utility Scripts"), + ("docker", "Docker Configuration"), + ("docs", "Documentation"), ] for directory, title in tools_dirs: if os.path.exists(directory): output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) output.append(f"║ {title.upper().center(62)} ║") output.append(f"║ Location: {directory + '/'.ljust(51)} ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append(get_tree_structure(directory, exclude)) # Configuration files section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ CONFIGURATION FILES ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ CONFIGURATION FILES ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append("Root configuration files:") config_files = [ - ('Makefile', 'Build automation'), - ('requirements.txt', 'Python dependencies'), - ('pyproject.toml', 'Python project config'), - ('setup.py', 'Python setup script'), - ('setup.cfg', 'Setup configuration'), - ('alembic.ini', 'Alembic migrations config'), - ('mkdocs.yml', 'MkDocs documentation config'), - ('Dockerfile', 'Docker image definition'), - ('docker-compose.yml', 'Docker services'), - ('.dockerignore', 'Docker ignore patterns'), - ('.gitignore', 'Git ignore patterns'), - ('.env.example', 'Environment variables template'), + ("Makefile", "Build automation"), + ("requirements.txt", "Python dependencies"), + ("pyproject.toml", "Python project config"), + ("setup.py", "Python setup script"), + ("setup.cfg", "Setup configuration"), + ("alembic.ini", "Alembic migrations config"), + ("mkdocs.yml", "MkDocs documentation config"), + ("Dockerfile", "Docker image definition"), + ("docker-compose.yml", "Docker services"), + (".dockerignore", "Docker ignore patterns"), + (".gitignore", "Git ignore patterns"), + (".env.example", "Environment variables template"), ] for file, description in config_files: @@ -315,9 +389,15 @@ def generate_tools_structure() -> str: # Statistics section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ STATISTICS ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ STATISTICS ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append("Database Migrations:") @@ -327,7 +407,9 @@ def generate_tools_structure() -> str: if migration_count > 0: # Get first and last migration try: - migrations = sorted([f for f in os.listdir("alembic/versions") if f.endswith('.py')]) + migrations = sorted( + [f for f in os.listdir("alembic/versions") if f.endswith(".py")] + ) if migrations: output.append(f" - First: {migrations[0][:40]}...") if len(migrations) > 1: @@ -341,9 +423,9 @@ def generate_tools_structure() -> str: output.append("Scripts:") if os.path.exists("scripts"): script_types = { - '.py': 'Python scripts', - '.sh': 'Shell scripts', - '.bat': 'Batch scripts', + ".py": "Python scripts", + ".sh": "Shell scripts", + ".bat": "Batch scripts", } for ext, desc in script_types.items(): count = count_files("scripts", ext) @@ -356,8 +438,8 @@ def generate_tools_structure() -> str: output.append("Documentation:") if os.path.exists("docs"): doc_types = { - '.md': 'Markdown files', - '.rst': 'reStructuredText files', + ".md": "Markdown files", + ".rst": "reStructuredText files", } for ext, desc in doc_types.items(): count = count_files("docs", ext) @@ -368,7 +450,7 @@ def generate_tools_structure() -> str: output.append("") output.append("Docker:") - docker_files = ['Dockerfile', 'docker-compose.yml', '.dockerignore'] + docker_files = ["Dockerfile", "docker-compose.yml", ".dockerignore"] docker_exists = any(os.path.exists(f) for f in docker_files) if docker_exists: output.append(" ✓ Docker configuration present") @@ -394,25 +476,44 @@ def generate_test_structure() -> str: # Test files section output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ TEST FILES ║") - output.append("║ Location: tests/ ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ TEST FILES ║" + ) + output.append( + "║ Location: tests/ ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") - exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info'] + exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"] output.append(get_tree_structure("tests", exclude)) # Configuration section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ TEST CONFIGURATION ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ TEST CONFIGURATION ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") output.append("Test configuration files:") - test_config_files = ['pytest.ini', 'conftest.py', 'tests/conftest.py', '.coveragerc'] + test_config_files = [ + "pytest.ini", + "conftest.py", + "tests/conftest.py", + ".coveragerc", + ] for file in test_config_files: if os.path.exists(file): output.append(f" ✓ {file}") @@ -420,16 +521,22 @@ def generate_test_structure() -> str: # Statistics section output.append("") output.append("") - output.append("╔══════════════════════════════════════════════════════════════════╗") - output.append("║ STATISTICS ║") - output.append("╚══════════════════════════════════════════════════════════════════╝") + output.append( + "╔══════════════════════════════════════════════════════════════════╗" + ) + output.append( + "║ STATISTICS ║" + ) + output.append( + "╚══════════════════════════════════════════════════════════════════╝" + ) output.append("") # Count test files test_file_count = 0 if os.path.exists("tests"): for root, dirs, files in os.walk("tests"): - dirs[:] = [d for d in dirs if d != '__pycache__'] + dirs[:] = [d for d in dirs if d != "__pycache__"] for file in files: if file.startswith("test_") and file.endswith(".py"): test_file_count += 1 @@ -439,13 +546,13 @@ def generate_test_structure() -> str: output.append("") output.append("By Category:") - categories = ['unit', 'integration', 'system', 'e2e', 'performance'] + categories = ["unit", "integration", "system", "e2e", "performance"] for category in categories: category_path = f"tests/{category}" if os.path.exists(category_path): count = 0 for root, dirs, files in os.walk(category_path): - dirs[:] = [d for d in dirs if d != '__pycache__'] + dirs[:] = [d for d in dirs if d != "__pycache__"] for file in files: if file.startswith("test_") and file.endswith(".py"): count += 1 @@ -455,12 +562,12 @@ def generate_test_structure() -> str: test_function_count = 0 if os.path.exists("tests"): for root, dirs, files in os.walk("tests"): - dirs[:] = [d for d in dirs if d != '__pycache__'] + dirs[:] = [d for d in dirs if d != "__pycache__"] for file in files: if file.startswith("test_") and file.endswith(".py"): filepath = os.path.join(root, file) try: - with open(filepath, 'r', encoding='utf-8') as f: + with open(filepath, "r", encoding="utf-8") as f: for line in f: if line.strip().startswith("def test_"): test_function_count += 1 @@ -494,19 +601,19 @@ def main(): structure_type = sys.argv[1].lower() generators = { - 'frontend': ('frontend-structure.txt', generate_frontend_structure), - 'backend': ('backend-structure.txt', generate_backend_structure), - 'tests': ('test-structure.txt', generate_test_structure), - 'tools': ('tools-structure.txt', generate_tools_structure), + "frontend": ("frontend-structure.txt", generate_frontend_structure), + "backend": ("backend-structure.txt", generate_backend_structure), + "tests": ("test-structure.txt", generate_test_structure), + "tools": ("tools-structure.txt", generate_tools_structure), } - if structure_type == 'all': + if structure_type == "all": for name, (filename, generator) in generators.items(): print(f"\n{'=' * 60}") print(f"Generating {name} structure...") - print('=' * 60) + print("=" * 60) content = generator() - with open(filename, 'w', encoding='utf-8') as f: + with open(filename, "w", encoding="utf-8") as f: f.write(content) print(f"✅ {name.capitalize()} structure saved to {filename}") print(f"\n{content}\n") @@ -514,7 +621,7 @@ def main(): filename, generator = generators[structure_type] print(f"Generating {structure_type} structure...") content = generator() - with open(filename, 'w', encoding='utf-8') as f: + with open(filename, "w", encoding="utf-8") as f: f.write(content) print(f"\n✅ Structure saved to {filename}\n") print(content) diff --git a/scripts/test_auth_complete.py b/scripts/test_auth_complete.py index 82e78a14..abc5c18a 100644 --- a/scripts/test_auth_complete.py +++ b/scripts/test_auth_complete.py @@ -20,22 +20,26 @@ Requirements: * Customer: username=customer, password=customer123, vendor_id=1 """ -import requests import json from typing import Dict, Optional +import requests + BASE_URL = "http://localhost:8000" + class Color: """Terminal colors for pretty output""" - GREEN = '\033[92m' - RED = '\033[91m' - YELLOW = '\033[93m' - BLUE = '\033[94m' - MAGENTA = '\033[95m' - CYAN = '\033[96m' - BOLD = '\033[1m' - END = '\033[0m' + + GREEN = "\033[92m" + RED = "\033[91m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + MAGENTA = "\033[95m" + CYAN = "\033[96m" + BOLD = "\033[1m" + END = "\033[0m" + def print_section(name: str): """Print section header""" @@ -43,66 +47,70 @@ def print_section(name: str): print(f" {name}") print(f"{'═' * 60}{Color.END}") + def print_test(name: str): """Print test name""" print(f"\n{Color.BOLD}{Color.BLUE}🧪 Test: {name}{Color.END}") + def print_success(message: str): """Print success message""" print(f"{Color.GREEN}✅ {message}{Color.END}") + def print_error(message: str): """Print error message""" print(f"{Color.RED}❌ {message}{Color.END}") + def print_info(message: str): """Print info message""" print(f"{Color.YELLOW}ℹ️ {message}{Color.END}") + def print_warning(message: str): """Print warning message""" print(f"{Color.MAGENTA}⚠️ {message}{Color.END}") + # ============================================================================ # ADMIN AUTHENTICATION TESTS # ============================================================================ + def test_admin_login() -> Optional[Dict]: """Test admin login and cookie configuration""" print_test("Admin Login") - + try: response = requests.post( f"{BASE_URL}/api/v1/admin/auth/login", - json={"username": "admin", "password": "admin123"} + json={"username": "admin", "password": "admin123"}, ) - + if response.status_code == 200: data = response.json() cookies = response.cookies - + if "access_token" in data: print_success("Admin login successful") print_success(f"Received access token: {data['access_token'][:20]}...") else: print_error("No access token in response") return None - + if "admin_token" in cookies: print_success("admin_token cookie set") print_info("Cookie path should be /admin (verify in browser)") else: print_error("admin_token cookie NOT set") - - return { - "token": data["access_token"], - "user": data.get("user", {}) - } + + return {"token": data["access_token"], "user": data.get("user", {})} else: print_error(f"Login failed: {response.status_code}") print_error(f"Response: {response.text}") return None - + except Exception as e: print_error(f"Exception during admin login: {str(e)}") return None @@ -111,13 +119,13 @@ def test_admin_login() -> Optional[Dict]: def test_admin_cannot_access_vendor_api(admin_token: str): """Test that admin token cannot access vendor API""" print_test("Admin Token on Vendor API (Should Block)") - + try: response = requests.get( f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}"}, ) - + if response.status_code in [401, 403]: data = response.json() print_success("Admin correctly blocked from vendor API") @@ -129,7 +137,7 @@ def test_admin_cannot_access_vendor_api(admin_token: str): else: print_warning(f"Unexpected status code: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -138,13 +146,13 @@ def test_admin_cannot_access_vendor_api(admin_token: str): def test_admin_cannot_access_customer_api(admin_token: str): """Test that admin token cannot access customer account pages""" print_test("Admin Token on Customer API (Should Block)") - + try: response = requests.get( f"{BASE_URL}/shop/account/dashboard", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}"}, ) - + # Customer pages may return HTML or JSON error if response.status_code in [401, 403]: print_success("Admin correctly blocked from customer pages") @@ -155,7 +163,7 @@ def test_admin_cannot_access_customer_api(admin_token: str): else: print_warning(f"Unexpected status code: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -165,46 +173,47 @@ def test_admin_cannot_access_customer_api(admin_token: str): # VENDOR AUTHENTICATION TESTS # ============================================================================ + def test_vendor_login() -> Optional[Dict]: """Test vendor login and cookie configuration""" print_test("Vendor Login") - + try: response = requests.post( f"{BASE_URL}/api/v1/vendor/auth/login", - json={"username": "vendor", "password": "vendor123"} + json={"username": "vendor", "password": "vendor123"}, ) - + if response.status_code == 200: data = response.json() cookies = response.cookies - + if "access_token" in data: print_success("Vendor login successful") print_success(f"Received access token: {data['access_token'][:20]}...") else: print_error("No access token in response") return None - + if "vendor_token" in cookies: print_success("vendor_token cookie set") print_info("Cookie path should be /vendor (verify in browser)") else: print_error("vendor_token cookie NOT set") - + if "vendor" in data: print_success(f"Vendor: {data['vendor'].get('vendor_code', 'N/A')}") - + return { "token": data["access_token"], "user": data.get("user", {}), - "vendor": data.get("vendor", {}) + "vendor": data.get("vendor", {}), } else: print_error(f"Login failed: {response.status_code}") print_error(f"Response: {response.text}") return None - + except Exception as e: print_error(f"Exception during vendor login: {str(e)}") return None @@ -213,13 +222,13 @@ def test_vendor_login() -> Optional[Dict]: def test_vendor_cannot_access_admin_api(vendor_token: str): """Test that vendor token cannot access admin API""" print_test("Vendor Token on Admin API (Should Block)") - + try: response = requests.get( f"{BASE_URL}/api/v1/admin/vendors", - headers={"Authorization": f"Bearer {vendor_token}"} + headers={"Authorization": f"Bearer {vendor_token}"}, ) - + if response.status_code in [401, 403]: data = response.json() print_success("Vendor correctly blocked from admin API") @@ -231,7 +240,7 @@ def test_vendor_cannot_access_admin_api(vendor_token: str): else: print_warning(f"Unexpected status code: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -240,13 +249,13 @@ def test_vendor_cannot_access_admin_api(vendor_token: str): def test_vendor_cannot_access_customer_api(vendor_token: str): """Test that vendor token cannot access customer account pages""" print_test("Vendor Token on Customer API (Should Block)") - + try: response = requests.get( f"{BASE_URL}/shop/account/dashboard", - headers={"Authorization": f"Bearer {vendor_token}"} + headers={"Authorization": f"Bearer {vendor_token}"}, ) - + if response.status_code in [401, 403]: print_success("Vendor correctly blocked from customer pages") return True @@ -256,7 +265,7 @@ def test_vendor_cannot_access_customer_api(vendor_token: str): else: print_warning(f"Unexpected status code: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -266,42 +275,40 @@ def test_vendor_cannot_access_customer_api(vendor_token: str): # CUSTOMER AUTHENTICATION TESTS # ============================================================================ + def test_customer_login() -> Optional[Dict]: """Test customer login and cookie configuration""" print_test("Customer Login") - + try: response = requests.post( f"{BASE_URL}/api/v1/public/vendors/1/customers/login", - json={"username": "customer", "password": "customer123"} + json={"username": "customer", "password": "customer123"}, ) - + if response.status_code == 200: data = response.json() cookies = response.cookies - + if "access_token" in data: print_success("Customer login successful") print_success(f"Received access token: {data['access_token'][:20]}...") else: print_error("No access token in response") return None - + if "customer_token" in cookies: print_success("customer_token cookie set") print_info("Cookie path should be /shop (verify in browser)") else: print_error("customer_token cookie NOT set") - - return { - "token": data["access_token"], - "user": data.get("user", {}) - } + + return {"token": data["access_token"], "user": data.get("user", {})} else: print_error(f"Login failed: {response.status_code}") print_error(f"Response: {response.text}") return None - + except Exception as e: print_error(f"Exception during customer login: {str(e)}") return None @@ -310,13 +317,13 @@ def test_customer_login() -> Optional[Dict]: def test_customer_cannot_access_admin_api(customer_token: str): """Test that customer token cannot access admin API""" print_test("Customer Token on Admin API (Should Block)") - + try: response = requests.get( f"{BASE_URL}/api/v1/admin/vendors", - headers={"Authorization": f"Bearer {customer_token}"} + headers={"Authorization": f"Bearer {customer_token}"}, ) - + if response.status_code in [401, 403]: data = response.json() print_success("Customer correctly blocked from admin API") @@ -328,7 +335,7 @@ def test_customer_cannot_access_admin_api(customer_token: str): else: print_warning(f"Unexpected status code: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -337,13 +344,13 @@ def test_customer_cannot_access_admin_api(customer_token: str): def test_customer_cannot_access_vendor_api(customer_token: str): """Test that customer token cannot access vendor API""" print_test("Customer Token on Vendor API (Should Block)") - + try: response = requests.get( f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products", - headers={"Authorization": f"Bearer {customer_token}"} + headers={"Authorization": f"Bearer {customer_token}"}, ) - + if response.status_code in [401, 403]: data = response.json() print_success("Customer correctly blocked from vendor API") @@ -355,7 +362,7 @@ def test_customer_cannot_access_vendor_api(customer_token: str): else: print_warning(f"Unexpected status code: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -364,17 +371,17 @@ def test_customer_cannot_access_vendor_api(customer_token: str): def test_public_shop_access(): """Test that public shop pages are accessible without authentication""" print_test("Public Shop Access (No Auth Required)") - + try: response = requests.get(f"{BASE_URL}/shop/products") - + if response.status_code == 200: print_success("Public shop pages accessible without auth") return True else: print_error(f"Failed to access public shop: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -383,10 +390,10 @@ def test_public_shop_access(): def test_health_check(): """Test health check endpoint""" print_test("Health Check") - + try: response = requests.get(f"{BASE_URL}/health") - + if response.status_code == 200: data = response.json() print_success("Health check passed") @@ -395,7 +402,7 @@ def test_health_check(): else: print_error(f"Health check failed: {response.status_code}") return False - + except Exception as e: print_error(f"Exception: {str(e)}") return False @@ -405,19 +412,16 @@ def test_health_check(): # MAIN TEST RUNNER # ============================================================================ + def main(): """Run all tests""" print(f"\n{Color.BOLD}{Color.CYAN}{'═' * 60}") print(f" 🔒 COMPLETE AUTHENTICATION SYSTEM TEST SUITE") print(f"{'═' * 60}{Color.END}") print(f"Testing server at: {BASE_URL}") - - results = { - "passed": 0, - "failed": 0, - "total": 0 - } - + + results = {"passed": 0, "failed": 0, "total": 0} + # Health check first print_section("🏥 System Health") results["total"] += 1 @@ -427,12 +431,12 @@ def main(): results["failed"] += 1 print_error("Server not responding. Is it running?") return - + # ======================================================================== # ADMIN TESTS # ======================================================================== print_section("👤 Admin Authentication Tests") - + # Admin login results["total"] += 1 admin_auth = test_admin_login() @@ -441,7 +445,7 @@ def main(): else: results["failed"] += 1 admin_auth = None - + # Admin cross-context tests if admin_auth: results["total"] += 1 @@ -449,18 +453,18 @@ def main(): results["passed"] += 1 else: results["failed"] += 1 - + results["total"] += 1 if test_admin_cannot_access_customer_api(admin_auth["token"]): results["passed"] += 1 else: results["failed"] += 1 - + # ======================================================================== # VENDOR TESTS # ======================================================================== print_section("🏪 Vendor Authentication Tests") - + # Vendor login results["total"] += 1 vendor_auth = test_vendor_login() @@ -469,7 +473,7 @@ def main(): else: results["failed"] += 1 vendor_auth = None - + # Vendor cross-context tests if vendor_auth: results["total"] += 1 @@ -477,25 +481,25 @@ def main(): results["passed"] += 1 else: results["failed"] += 1 - + results["total"] += 1 if test_vendor_cannot_access_customer_api(vendor_auth["token"]): results["passed"] += 1 else: results["failed"] += 1 - + # ======================================================================== # CUSTOMER TESTS # ======================================================================== print_section("🛒 Customer Authentication Tests") - + # Public shop access results["total"] += 1 if test_public_shop_access(): results["passed"] += 1 else: results["failed"] += 1 - + # Customer login results["total"] += 1 customer_auth = test_customer_login() @@ -504,7 +508,7 @@ def main(): else: results["failed"] += 1 customer_auth = None - + # Customer cross-context tests if customer_auth: results["total"] += 1 @@ -512,29 +516,31 @@ def main(): results["passed"] += 1 else: results["failed"] += 1 - + results["total"] += 1 if test_customer_cannot_access_vendor_api(customer_auth["token"]): results["passed"] += 1 else: results["failed"] += 1 - + # ======================================================================== # RESULTS # ======================================================================== print_section("📊 Test Results") print(f"\n{Color.BOLD}Total Tests: {results['total']}{Color.END}") print_success(f"Passed: {results['passed']}") - if results['failed'] > 0: + if results["failed"] > 0: print_error(f"Failed: {results['failed']}") - - if results['failed'] == 0: + + if results["failed"] == 0: print(f"\n{Color.GREEN}{Color.BOLD}🎉 ALL TESTS PASSED!{Color.END}") - print(f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}") + print( + f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}" + ) else: print(f"\n{Color.RED}{Color.BOLD}⚠️ SOME TESTS FAILED{Color.END}") print(f"{Color.RED}Please review the output above.{Color.END}") - + # Browser tests reminder print_section("🌐 Manual Browser Tests") print("Please also verify in browser:") diff --git a/scripts/test_vendor_management.py b/scripts/test_vendor_management.py index 751f59c2..71292c8b 100644 --- a/scripts/test_vendor_management.py +++ b/scripts/test_vendor_management.py @@ -9,10 +9,11 @@ Tests: - Transfer ownership """ -import requests import json from pprint import pprint +import requests + BASE_URL = "http://localhost:8000" ADMIN_TOKEN = None # Will be set after login @@ -25,8 +26,8 @@ def login_admin(): f"{BASE_URL}/api/v1/admin/auth/login", # ✅ Changed: added /admin/ json={ "username": "admin", # Replace with your admin username - "password": "admin123" # Replace with your admin password - } + "password": "admin123", # Replace with your admin password + }, ) if response.status_code == 200: @@ -44,7 +45,7 @@ def get_headers(): """Get authorization headers.""" return { "Authorization": f"Bearer {ADMIN_TOKEN}", - "Content-Type": "application/json" + "Content-Type": "application/json", } @@ -60,13 +61,11 @@ def test_create_vendor_with_both_emails(): "subdomain": "testdual", "owner_email": "owner@testdual.com", "contact_email": "contact@testdual.com", - "description": "Test vendor with separate emails" + "description": "Test vendor with separate emails", } response = requests.post( - f"{BASE_URL}/api/v1/admin/vendors", - headers=get_headers(), - json=vendor_data + f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data ) if response.status_code == 200: @@ -79,7 +78,7 @@ def test_create_vendor_with_both_emails(): print(f" Username: {data['owner_username']}") print(f" Password: {data['temporary_password']}") print(f"\n🔗 Login URL: {data['login_url']}") - return data['id'] + return data["id"] else: print(f"❌ Failed: {response.status_code}") print(response.text) @@ -97,13 +96,11 @@ def test_create_vendor_single_email(): "name": "Test Single Email Vendor", "subdomain": "testsingle", "owner_email": "owner@testsingle.com", - "description": "Test vendor with single email" + "description": "Test vendor with single email", } response = requests.post( - f"{BASE_URL}/api/v1/admin/vendors", - headers=get_headers(), - json=vendor_data + f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data ) if response.status_code == 200: @@ -113,12 +110,12 @@ def test_create_vendor_single_email(): print(f" Owner Email: {data['owner_email']}") print(f" Contact Email: {data['contact_email']}") - if data['owner_email'] == data['contact_email']: + if data["owner_email"] == data["contact_email"]: print("✅ Contact email correctly defaulted to owner email") else: print("❌ Contact email should have defaulted to owner email") - return data['id'] + return data["id"] else: print(f"❌ Failed: {response.status_code}") print(response.text) @@ -133,13 +130,13 @@ def test_update_vendor_contact_email(vendor_id): update_data = { "contact_email": "newcontact@business.com", - "name": "Updated Vendor Name" + "name": "Updated Vendor Name", } response = requests.put( f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers(), - json=update_data + json=update_data, ) if response.status_code == 200: @@ -163,8 +160,7 @@ def test_get_vendor_details(vendor_id): print("=" * 60) response = requests.get( - f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", - headers=get_headers() + f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers() ) if response.status_code == 200: @@ -196,8 +192,7 @@ def test_transfer_ownership(vendor_id, new_owner_user_id): # First, get current owner info response = requests.get( - f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", - headers=get_headers() + f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers() ) if response.status_code == 200: @@ -211,13 +206,13 @@ def test_transfer_ownership(vendor_id, new_owner_user_id): transfer_data = { "new_owner_user_id": new_owner_user_id, "confirm_transfer": True, - "transfer_reason": "Testing ownership transfer feature" + "transfer_reason": "Testing ownership transfer feature", } response = requests.post( f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}/transfer-ownership", headers=get_headers(), - json=transfer_data + json=transfer_data, ) if response.status_code == 200: diff --git a/scripts/validate_architecture.py b/scripts/validate_architecture.py index dc186de8..6e654957 100755 --- a/scripts/validate_architecture.py +++ b/scripts/validate_architecture.py @@ -22,15 +22,17 @@ import argparse import ast import re import sys -from pathlib import Path -from typing import List, Dict, Any, Tuple from dataclasses import dataclass, field from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Tuple + import yaml class Severity(Enum): """Validation severity levels""" + ERROR = "error" WARNING = "warning" INFO = "info" @@ -39,6 +41,7 @@ class Severity(Enum): @dataclass class Violation: """Represents an architectural rule violation""" + rule_id: str rule_name: str severity: Severity @@ -52,6 +55,7 @@ class Violation: @dataclass class ValidationResult: """Results of architecture validation""" + violations: List[Violation] = field(default_factory=list) files_checked: int = 0 rules_applied: int = 0 @@ -82,7 +86,7 @@ class ArchitectureValidator: print(f"❌ Configuration file not found: {self.config_path}") sys.exit(1) - with open(self.config_path, 'r') as f: + with open(self.config_path, "r") as f: config = yaml.safe_load(f) print(f"📋 Loaded architecture rules: {config.get('project', 'unknown')}") @@ -126,7 +130,7 @@ class ArchitectureValidator: continue content = file_path.read_text() - lines = content.split('\n') + lines = content.split("\n") # API-001: Check for Pydantic model usage self._check_pydantic_usage(file_path, content, lines) @@ -147,8 +151,8 @@ class ArchitectureValidator: return # Check for response_model in route decorators - route_pattern = r'@router\.(get|post|put|delete|patch)' - dict_return_pattern = r'return\s+\{.*\}' + route_pattern = r"@router\.(get|post|put|delete|patch)" + dict_return_pattern = r"return\s+\{.*\}" for i, line in enumerate(lines, 1): # Check for dict returns in endpoints @@ -166,47 +170,51 @@ class ArchitectureValidator: if re.search(dict_return_pattern, func_line): self._add_violation( rule_id="API-001", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.ERROR, file_path=file_path, line_number=j + 1, message="Endpoint returns raw dict instead of Pydantic model", context=func_line.strip(), - suggestion="Define a Pydantic response model and use response_model parameter" + suggestion="Define a Pydantic response model and use response_model parameter", ) - def _check_no_business_logic_in_endpoints(self, file_path: Path, content: str, lines: List[str]): + def _check_no_business_logic_in_endpoints( + self, file_path: Path, content: str, lines: List[str] + ): """API-002: Ensure no business logic in endpoints""" rule = self._get_rule("API-002") if not rule: return anti_patterns = [ - (r'db\.add\(', "Database operations should be in service layer"), - (r'db\.commit\(\)', "Database commits should be in service layer"), - (r'db\.query\(', "Database queries should be in service layer"), - (r'db\.execute\(', "Database operations should be in service layer"), + (r"db\.add\(", "Database operations should be in service layer"), + (r"db\.commit\(\)", "Database commits should be in service layer"), + (r"db\.query\(", "Database queries should be in service layer"), + (r"db\.execute\(", "Database operations should be in service layer"), ] for i, line in enumerate(lines, 1): # Skip service method calls (allowed) - if '_service.' in line or 'service.' in line: + if "_service." in line or "service." in line: continue for pattern, message in anti_patterns: if re.search(pattern, line): self._add_violation( rule_id="API-002", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.ERROR, file_path=file_path, line_number=i, message=message, context=line.strip(), - suggestion="Move database operations to service layer" + suggestion="Move database operations to service layer", ) - def _check_endpoint_exception_handling(self, file_path: Path, content: str, lines: List[str]): + def _check_endpoint_exception_handling( + self, file_path: Path, content: str, lines: List[str] + ): """API-003: Check proper exception handling in endpoints""" rule = self._get_rule("API-003") if not rule: @@ -222,40 +230,41 @@ class ArchitectureValidator: if isinstance(node, ast.FunctionDef): # Check if it's a route handler has_router_decorator = any( - isinstance(d, ast.Call) and - isinstance(d.func, ast.Attribute) and - getattr(d.func.value, 'id', None) == 'router' + isinstance(d, ast.Call) + and isinstance(d.func, ast.Attribute) + and getattr(d.func.value, "id", None) == "router" for d in node.decorator_list ) if has_router_decorator: # Check if function body has try/except has_try_except = any( - isinstance(child, ast.Try) - for child in ast.walk(node) + isinstance(child, ast.Try) for child in ast.walk(node) ) # Check if function calls service methods has_service_call = any( - isinstance(child, ast.Call) and - isinstance(child.func, ast.Attribute) and - 'service' in getattr(child.func.value, 'id', '').lower() + isinstance(child, ast.Call) + and isinstance(child.func, ast.Attribute) + and "service" in getattr(child.func.value, "id", "").lower() for child in ast.walk(node) ) if has_service_call and not has_try_except: self._add_violation( rule_id="API-003", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.WARNING, file_path=file_path, line_number=node.lineno, message=f"Endpoint '{node.name}' calls service but lacks exception handling", context=f"def {node.name}(...)", - suggestion="Wrap service calls in try/except and convert to HTTPException" + suggestion="Wrap service calls in try/except and convert to HTTPException", ) - def _check_endpoint_authentication(self, file_path: Path, content: str, lines: List[str]): + def _check_endpoint_authentication( + self, file_path: Path, content: str, lines: List[str] + ): """API-004: Check authentication on endpoints""" rule = self._get_rule("API-004") if not rule: @@ -264,24 +273,28 @@ class ArchitectureValidator: # This is a warning-level check # Look for endpoints without Depends(get_current_*) for i, line in enumerate(lines, 1): - if '@router.' in line and ('post' in line or 'put' in line or 'delete' in line): + if "@router." in line and ( + "post" in line or "put" in line or "delete" in line + ): # Check next 5 lines for auth has_auth = False for j in range(i, min(i + 5, len(lines))): - if 'Depends(get_current_' in lines[j]: + if "Depends(get_current_" in lines[j]: has_auth = True break - if not has_auth and 'include_in_schema=False' not in ' '.join(lines[i:i+5]): + if not has_auth and "include_in_schema=False" not in " ".join( + lines[i : i + 5] + ): self._add_violation( rule_id="API-004", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.WARNING, file_path=file_path, line_number=i, message="Endpoint may be missing authentication", context=line.strip(), - suggestion="Add Depends(get_current_user) or similar if endpoint should be protected" + suggestion="Add Depends(get_current_user) or similar if endpoint should be protected", ) def _validate_service_layer(self, target_path: Path): @@ -296,7 +309,7 @@ class ArchitectureValidator: continue content = file_path.read_text() - lines = content.split('\n') + lines = content.split("\n") # SVC-001: No HTTPException in services self._check_no_http_exception_in_services(file_path, content, lines) @@ -307,38 +320,45 @@ class ArchitectureValidator: # SVC-003: DB session as parameter self._check_db_session_parameter(file_path, content, lines) - def _check_no_http_exception_in_services(self, file_path: Path, content: str, lines: List[str]): + def _check_no_http_exception_in_services( + self, file_path: Path, content: str, lines: List[str] + ): """SVC-001: Services must not raise HTTPException""" rule = self._get_rule("SVC-001") if not rule: return for i, line in enumerate(lines, 1): - if 'raise HTTPException' in line: + if "raise HTTPException" in line: self._add_violation( rule_id="SVC-001", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.ERROR, file_path=file_path, line_number=i, message="Service raises HTTPException - use domain exceptions instead", context=line.strip(), - suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that" + suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that", ) - if 'from fastapi import HTTPException' in line or 'from fastapi.exceptions import HTTPException' in line: + if ( + "from fastapi import HTTPException" in line + or "from fastapi.exceptions import HTTPException" in line + ): self._add_violation( rule_id="SVC-001", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.ERROR, file_path=file_path, line_number=i, message="Service imports HTTPException - services should not know about HTTP", context=line.strip(), - suggestion="Remove HTTPException import and use domain exceptions" + suggestion="Remove HTTPException import and use domain exceptions", ) - def _check_service_exceptions(self, file_path: Path, content: str, lines: List[str]): + def _check_service_exceptions( + self, file_path: Path, content: str, lines: List[str] + ): """SVC-002: Check for proper exception handling""" rule = self._get_rule("SVC-002") if not rule: @@ -346,19 +366,21 @@ class ArchitectureValidator: for i, line in enumerate(lines, 1): # Check for generic Exception raises - if re.match(r'\s*raise Exception\(', line): + if re.match(r"\s*raise Exception\(", line): self._add_violation( rule_id="SVC-002", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.WARNING, file_path=file_path, line_number=i, message="Service raises generic Exception - use specific domain exception", context=line.strip(), - suggestion="Create custom exception class for this error case" + suggestion="Create custom exception class for this error case", ) - def _check_db_session_parameter(self, file_path: Path, content: str, lines: List[str]): + def _check_db_session_parameter( + self, file_path: Path, content: str, lines: List[str] + ): """SVC-003: Service methods should accept db session as parameter""" rule = self._get_rule("SVC-003") if not rule: @@ -366,16 +388,16 @@ class ArchitectureValidator: # Check for SessionLocal() creation in service files for i, line in enumerate(lines, 1): - if 'SessionLocal()' in line and 'class' not in line: + if "SessionLocal()" in line and "class" not in line: self._add_violation( rule_id="SVC-003", - rule_name=rule['name'], + rule_name=rule["name"], severity=Severity.ERROR, file_path=file_path, line_number=i, message="Service creates database session internally", context=line.strip(), - suggestion="Accept db: Session as method parameter instead" + suggestion="Accept db: Session as method parameter instead", ) def _validate_models(self, target_path: Path): @@ -391,11 +413,11 @@ class ArchitectureValidator: continue content = file_path.read_text() - lines = content.split('\n') + lines = content.split("\n") # Check for mixing SQLAlchemy and Pydantic for i, line in enumerate(lines, 1): - if re.search(r'class.*\(Base.*,.*BaseModel.*\)', line): + if re.search(r"class.*\(Base.*,.*BaseModel.*\)", line): self._add_violation( rule_id="MDL-002", rule_name="Separate SQLAlchemy and Pydantic models", @@ -404,7 +426,7 @@ class ArchitectureValidator: line_number=i, message="Model mixes SQLAlchemy Base and Pydantic BaseModel", context=line.strip(), - suggestion="Keep SQLAlchemy models and Pydantic models separate" + suggestion="Keep SQLAlchemy models and Pydantic models separate", ) def _validate_exceptions(self, target_path: Path): @@ -418,11 +440,11 @@ class ArchitectureValidator: continue content = file_path.read_text() - lines = content.split('\n') + lines = content.split("\n") # EXC-002: Check for bare except for i, line in enumerate(lines, 1): - if re.match(r'\s*except\s*:', line): + if re.match(r"\s*except\s*:", line): self._add_violation( rule_id="EXC-002", rule_name="No bare except clauses", @@ -431,7 +453,7 @@ class ArchitectureValidator: line_number=i, message="Bare except clause catches all exceptions including system exits", context=line.strip(), - suggestion="Specify exception type: except ValueError: or except Exception:" + suggestion="Specify exception type: except ValueError: or except Exception:", ) def _validate_javascript(self, target_path: Path): @@ -443,11 +465,16 @@ class ArchitectureValidator: for file_path in js_files: content = file_path.read_text() - lines = content.split('\n') + lines = content.split("\n") # JS-001: Check for window.apiClient for i, line in enumerate(lines, 1): - if 'window.apiClient' in line and '//' not in line[:line.find('window.apiClient')] if 'window.apiClient' in line else True: + if ( + "window.apiClient" in line + and "//" not in line[: line.find("window.apiClient")] + if "window.apiClient" in line + else True + ): self._add_violation( rule_id="JS-001", rule_name="Use apiClient directly", @@ -456,14 +483,14 @@ class ArchitectureValidator: line_number=i, message="Use apiClient directly instead of window.apiClient", context=line.strip(), - suggestion="Replace window.apiClient with apiClient" + suggestion="Replace window.apiClient with apiClient", ) # JS-002: Check for console usage for i, line in enumerate(lines, 1): - if re.search(r'console\.(log|warn|error)', line): + if re.search(r"console\.(log|warn|error)", line): # Skip if it's a comment or bootstrap message - if '//' in line or '✅' in line or 'eslint-disable' in line: + if "//" in line or "✅" in line or "eslint-disable" in line: continue self._add_violation( @@ -474,7 +501,7 @@ class ArchitectureValidator: line_number=i, message="Use centralized logger instead of console", context=line.strip()[:80], - suggestion="Use window.LogConfig.createLogger('moduleName')" + suggestion="Use window.LogConfig.createLogger('moduleName')", ) def _validate_templates(self, target_path: Path): @@ -486,14 +513,16 @@ class ArchitectureValidator: for file_path in template_files: # Skip base template and partials - if 'base.html' in file_path.name or 'partials' in str(file_path): + if "base.html" in file_path.name or "partials" in str(file_path): continue content = file_path.read_text() - lines = content.split('\n') + lines = content.split("\n") # TPL-001: Check for extends - has_extends = any('{% extends' in line and 'admin/base.html' in line for line in lines) + has_extends = any( + "{% extends" in line and "admin/base.html" in line for line in lines + ) if not has_extends: self._add_violation( @@ -504,23 +533,29 @@ class ArchitectureValidator: line_number=1, message="Admin template does not extend admin/base.html", context=file_path.name, - suggestion="Add {% extends 'admin/base.html' %} at the top" + suggestion="Add {% extends 'admin/base.html' %} at the top", ) def _get_rule(self, rule_id: str) -> Dict[str, Any]: """Get rule configuration by ID""" # Look in different rule categories - for category in ['api_endpoint_rules', 'service_layer_rules', 'model_rules', - 'exception_rules', 'javascript_rules', 'template_rules']: + for category in [ + "api_endpoint_rules", + "service_layer_rules", + "model_rules", + "exception_rules", + "javascript_rules", + "template_rules", + ]: rules = self.config.get(category, []) for rule in rules: - if rule.get('id') == rule_id: + if rule.get("id") == rule_id: return rule return None def _should_ignore_file(self, file_path: Path) -> bool: """Check if file should be ignored""" - ignore_patterns = self.config.get('ignore', {}).get('files', []) + ignore_patterns = self.config.get("ignore", {}).get("files", []) for pattern in ignore_patterns: if file_path.match(pattern): @@ -528,9 +563,17 @@ class ArchitectureValidator: return False - def _add_violation(self, rule_id: str, rule_name: str, severity: Severity, - file_path: Path, line_number: int, message: str, - context: str = "", suggestion: str = ""): + def _add_violation( + self, + rule_id: str, + rule_name: str, + severity: Severity, + file_path: Path, + line_number: int, + message: str, + context: str = "", + suggestion: str = "", + ): """Add a violation to results""" violation = Violation( rule_id=rule_id, @@ -540,7 +583,7 @@ class ArchitectureValidator: line_number=line_number, message=message, context=context, - suggestion=suggestion + suggestion=suggestion, ) self.result.violations.append(violation) @@ -590,24 +633,34 @@ class ArchitectureValidator: violations_json = [] for v in self.result.violations: - rel_path = str(v.file_path.relative_to(self.project_root)) if self.project_root in v.file_path.parents else str(v.file_path) - violations_json.append({ - 'rule_id': v.rule_id, - 'rule_name': v.rule_name, - 'severity': v.severity.value, - 'file_path': rel_path, - 'line_number': v.line_number, - 'message': v.message, - 'context': v.context or '', - 'suggestion': v.suggestion or '' - }) + rel_path = ( + str(v.file_path.relative_to(self.project_root)) + if self.project_root in v.file_path.parents + else str(v.file_path) + ) + violations_json.append( + { + "rule_id": v.rule_id, + "rule_name": v.rule_name, + "severity": v.severity.value, + "file_path": rel_path, + "line_number": v.line_number, + "message": v.message, + "context": v.context or "", + "suggestion": v.suggestion or "", + } + ) output = { - 'files_checked': self.result.files_checked, - 'total_violations': len(self.result.violations), - 'errors': len([v for v in self.result.violations if v.severity == Severity.ERROR]), - 'warnings': len([v for v in self.result.violations if v.severity == Severity.WARNING]), - 'violations': violations_json + "files_checked": self.result.files_checked, + "total_violations": len(self.result.violations), + "errors": len( + [v for v in self.result.violations if v.severity == Severity.ERROR] + ), + "warnings": len( + [v for v in self.result.violations if v.severity == Severity.WARNING] + ), + "violations": violations_json, } print(json.dumps(output, indent=2)) @@ -616,7 +669,11 @@ class ArchitectureValidator: def _print_violation(self, v: Violation): """Print a single violation""" - rel_path = v.file_path.relative_to(self.project_root) if self.project_root in v.file_path.parents else v.file_path + rel_path = ( + v.file_path.relative_to(self.project_root) + if self.project_root in v.file_path.parents + else v.file_path + ) print(f"\n [{v.rule_id}] {v.rule_name}") print(f" File: {rel_path}:{v.line_number}") @@ -634,40 +691,40 @@ def main(): parser = argparse.ArgumentParser( description="Validate architecture patterns in codebase", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__ + epilog=__doc__, ) parser.add_argument( - 'path', - nargs='?', + "path", + nargs="?", type=Path, default=Path.cwd(), - help="Path to validate (default: current directory)" + help="Path to validate (default: current directory)", ) parser.add_argument( - '-c', '--config', + "-c", + "--config", type=Path, - default=Path.cwd() / '.architecture-rules.yaml', - help="Path to architecture rules config (default: .architecture-rules.yaml)" + default=Path.cwd() / ".architecture-rules.yaml", + help="Path to architecture rules config (default: .architecture-rules.yaml)", ) parser.add_argument( - '-v', '--verbose', - action='store_true', - help="Show detailed output including context" + "-v", + "--verbose", + action="store_true", + help="Show detailed output including context", ) parser.add_argument( - '--errors-only', - action='store_true', - help="Only show errors, suppress warnings" + "--errors-only", action="store_true", help="Only show errors, suppress warnings" ) parser.add_argument( - '--json', - action='store_true', - help="Output results as JSON (for programmatic use)" + "--json", + action="store_true", + help="Output results as JSON (for programmatic use)", ) args = parser.parse_args() @@ -687,5 +744,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/verify_setup.py b/scripts/verify_setup.py index f7b5edf6..432550bd 100644 --- a/scripts/verify_setup.py +++ b/scripts/verify_setup.py @@ -3,10 +3,12 @@ import os import sys +from urllib.parse import urlparse + from sqlalchemy import create_engine, text + from alembic import command from alembic.config import Config -from urllib.parse import urlparse # Add project root to Python path sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) @@ -20,10 +22,10 @@ def get_database_info(): parsed = urlparse(db_url) db_type = parsed.scheme - if db_type == 'sqlite': + if db_type == "sqlite": # Extract path from sqlite:///./path or sqlite:///path - db_path = db_url.replace('sqlite:///', '') - if db_path.startswith('./'): + db_path = db_url.replace("sqlite:///", "") + if db_path.startswith("./"): db_path = db_path[2:] return db_type, db_path else: @@ -40,7 +42,7 @@ def verify_database_setup(): db_type, db_path = get_database_info() print(f"[INFO] Database type: {db_type}") - if db_type == 'sqlite': + if db_type == "sqlite": if not os.path.exists(db_path): print(f"[ERROR] Database file not found: {db_path}") return False @@ -54,10 +56,14 @@ def verify_database_setup(): with engine.connect() as conn: # Get table list (works for both SQLite and PostgreSQL) - if db_type == 'sqlite': - result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'")) + if db_type == "sqlite": + result = conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ) else: - result = conn.execute(text("SELECT tablename FROM pg_tables WHERE schemaname='public'")) + result = conn.execute( + text("SELECT tablename FROM pg_tables WHERE schemaname='public'") + ) tables = [row[0] for row in result.fetchall()] @@ -65,12 +71,17 @@ def verify_database_setup(): # Expected tables from your models expected_tables = [ - 'users', 'products', 'inventory', 'vendors', 'products', - 'marketplace_import_jobs', 'alembic_version' + "users", + "products", + "inventory", + "vendors", + "products", + "marketplace_import_jobs", + "alembic_version", ] for table in sorted(tables): - if table == 'alembic_version': + if table == "alembic_version": print(f" * {table} (migration tracking)") elif table in expected_tables: print(f" * {table}") @@ -78,12 +89,12 @@ def verify_database_setup(): print(f" ? {table} (unexpected)") # Check for missing expected tables - missing_tables = set(expected_tables) - set(tables) - {'alembic_version'} + missing_tables = set(expected_tables) - set(tables) - {"alembic_version"} if missing_tables: print(f"[WARNING] Missing expected tables: {missing_tables}") # Check if alembic_version table exists - if 'alembic_version' in tables: + if "alembic_version" in tables: result = conn.execute(text("SELECT version_num FROM alembic_version")) version = result.fetchone() if version: @@ -126,16 +137,19 @@ def verify_model_structure(): # Test database models try: from models.database.base import Base + print(f"[OK] Database Base imported") - print(f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}") + print( + f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}" + ) # Import specific models - from models.database.user import User - from models.database.marketplace_product import MarketplaceProduct from models.database.inventory import Inventory - from models.database.vendor import Vendor - from models.database.product import Product from models.database.marketplace_import_job import MarketplaceImportJob + from models.database.marketplace_product import MarketplaceProduct + from models.database.product import Product + from models.database.user import User + from models.database.vendor import Vendor print("[OK] All database models imported successfully") @@ -146,13 +160,23 @@ def verify_model_structure(): # Test API models try: import models.schema + print("[OK] API models package imported") # Test specific API model imports - api_modules = ['base', 'auth', 'product', 'inventory', 'vendor ', 'marketplace', 'admin', 'stats'] + api_modules = [ + "base", + "auth", + "product", + "inventory", + "vendor ", + "marketplace", + "admin", + "stats", + ] for module in api_modules: try: - __import__(f'models.api.{module}') + __import__(f"models.api.{module}") print(f" * models.api.{module}") except ImportError: print(f" ? models.api.{module} (not found, optional)") @@ -176,7 +200,7 @@ def check_project_structure(): "models/database/inventory.py", "app/core/config.py", "alembic/env.py", - "alembic.ini" + "alembic.ini", ] for path in critical_paths: @@ -189,7 +213,7 @@ def check_project_structure(): init_files = [ "models/__init__.py", "models/database/__init__.py", - "models/api/__init__.py" + "models/api/__init__.py", ] print(f"\n[INIT] Checking __init__.py files...") @@ -221,7 +245,9 @@ if __name__ == "__main__": print("Next steps:") print(" 1. Run 'make dev' to start your FastAPI server") print(" 2. Visit http://localhost:8000/docs for interactive API docs") - print(" 3. Use your API endpoints for authentication, products, inventory, etc.") + print( + " 3. Use your API endpoints for authentication, products, inventory, etc." + ) sys.exit(0) else: print() diff --git a/storage/backends.py b/storage/backends.py index 4eb1b1ea..999a775e 100644 --- a/storage/backends.py +++ b/storage/backends.py @@ -1 +1 @@ -# Storage backend implementations +# Storage backend implementations diff --git a/storage/utils.py b/storage/utils.py index 99c27a1d..7d8066b1 100644 --- a/storage/utils.py +++ b/storage/utils.py @@ -1 +1 @@ -# Storage utilities +# Storage utilities diff --git a/tasks/analytics_tasks.py b/tasks/analytics_tasks.py index 6adffe9b..7882ffad 100644 --- a/tasks/analytics_tasks.py +++ b/tasks/analytics_tasks.py @@ -1 +1 @@ -# Analytics and reporting tasks +# Analytics and reporting tasks diff --git a/tasks/backup_tasks.py b/tasks/backup_tasks.py index ce44c3ec..13c5b701 100644 --- a/tasks/backup_tasks.py +++ b/tasks/backup_tasks.py @@ -1 +1 @@ -# Backup and recovery tasks +# Backup and recovery tasks diff --git a/tasks/cleanup_tasks.py b/tasks/cleanup_tasks.py index 822a2e92..da7c9643 100644 --- a/tasks/cleanup_tasks.py +++ b/tasks/cleanup_tasks.py @@ -1 +1 @@ -# Data cleanup and maintenance tasks +# Data cleanup and maintenance tasks diff --git a/tasks/email_tasks.py b/tasks/email_tasks.py index 81bf1fd3..3576fcb6 100644 --- a/tasks/email_tasks.py +++ b/tasks/email_tasks.py @@ -1 +1 @@ -# Email sending tasks +# Email sending tasks diff --git a/tasks/marketplace_import.py b/tasks/marketplace_import.py index 6c318097..4af01878 100644 --- a/tasks/marketplace_import.py +++ b/tasks/marketplace_import.py @@ -1,12 +1,12 @@ -# Marketplace CSV import tasks +# Marketplace CSV import tasks # app/tasks/background_tasks.py import logging from datetime import datetime, timezone from app.core.database import SessionLocal +from app.utils.csv_processor import CSVProcessor from models.database.marketplace_import_job import MarketplaceImportJob from models.database.vendor import Vendor -from app.utils.csv_processor import CSVProcessor logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ async def process_marketplace_import( url: str, marketplace: str, vendor_id: int, # FIXED: Changed from vendor_name to vendor_id - batch_size: int = 1000 + batch_size: int = 1000, ): """Background task to process marketplace CSV import.""" db = SessionLocal() @@ -60,7 +60,7 @@ async def process_marketplace_import( marketplace, vendor_id, # FIXED: Pass vendor_id instead of vendor_name batch_size, - db + db, ) # Update job with results diff --git a/tasks/media_processing.py b/tasks/media_processing.py index 5a57bed2..f2e05cc6 100644 --- a/tasks/media_processing.py +++ b/tasks/media_processing.py @@ -1 +1 @@ -# Image processing and optimization tasks +# Image processing and optimization tasks diff --git a/tasks/search_indexing.py b/tasks/search_indexing.py index 64f776af..c73a2b86 100644 --- a/tasks/search_indexing.py +++ b/tasks/search_indexing.py @@ -1 +1 @@ -# Search index maintenance tasks +# Search index maintenance tasks diff --git a/tasks/task_manager.py b/tasks/task_manager.py index d7fe33a0..655b273c 100644 --- a/tasks/task_manager.py +++ b/tasks/task_manager.py @@ -1 +1 @@ -# Celery configuration and task management +# Celery configuration and task management diff --git a/tests/conftest.py b/tests/conftest.py index 0a800a28..6f922731 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,13 +7,13 @@ from sqlalchemy.pool import StaticPool from app.core.database import Base, get_db from main import app +from models.database.inventory import Inventory # Import all models to ensure they're registered with Base metadata from models.database.marketplace_import_job import MarketplaceImportJob from models.database.marketplace_product import MarketplaceProduct -from models.database.vendor import Vendor from models.database.product import Product -from models.database.inventory import Inventory from models.database.user import User +from models.database.vendor import Vendor # Use in-memory SQLite database for tests SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:" diff --git a/tests/fixtures/auth_fixtures.py b/tests/fixtures/auth_fixtures.py index fe712109..c2618c66 100644 --- a/tests/fixtures/auth_fixtures.py +++ b/tests/fixtures/auth_fixtures.py @@ -53,6 +53,7 @@ def test_admin(db, auth_manager): db.expunge(admin) return admin + @pytest.fixture def another_admin(db, auth_manager): """Create another test admin user for testing admin-to-admin interactions""" @@ -72,6 +73,7 @@ def another_admin(db, auth_manager): db.expunge(admin) return admin + @pytest.fixture def other_user(db, auth_manager): """Create a different user for testing access controls""" diff --git a/tests/fixtures/customer_fixtures.py b/tests/fixtures/customer_fixtures.py index b9235a88..5666834b 100644 --- a/tests/fixtures/customer_fixtures.py +++ b/tests/fixtures/customer_fixtures.py @@ -1,5 +1,6 @@ # tests/fixtures/customer_fixtures.py import pytest + from models.database.customer import Customer, CustomerAddress diff --git a/tests/fixtures/marketplace_import_job_fixtures.py b/tests/fixtures/marketplace_import_job_fixtures.py index d61e2441..c79decd7 100644 --- a/tests/fixtures/marketplace_import_job_fixtures.py +++ b/tests/fixtures/marketplace_import_job_fixtures.py @@ -1,5 +1,6 @@ # tests/fixtures/marketplace_import_job_fixtures.py import pytest + from models.database.marketplace_import_job import MarketplaceImportJob diff --git a/tests/fixtures/marketplace_product_fixtures.py b/tests/fixtures/marketplace_product_fixtures.py index cb2b148d..a27f9b87 100644 --- a/tests/fixtures/marketplace_product_fixtures.py +++ b/tests/fixtures/marketplace_product_fixtures.py @@ -117,7 +117,9 @@ def marketplace_product_factory(): @pytest.fixture -def test_marketplace_product_with_inventory(db, test_marketplace_product, test_inventory): +def test_marketplace_product_with_inventory( + db, test_marketplace_product, test_inventory +): """MarketplaceProduct with associated inventory record.""" # Ensure they're linked by GTIN if test_marketplace_product.gtin != test_inventory.gtin: @@ -126,6 +128,6 @@ def test_marketplace_product_with_inventory(db, test_marketplace_product, test_i db.refresh(test_inventory) return { - 'marketplace_product': test_marketplace_product, - 'inventory': test_inventory + "marketplace_product": test_marketplace_product, + "inventory": test_inventory, } diff --git a/tests/fixtures/testing_fixtures.py b/tests/fixtures/testing_fixtures.py index ffbcb335..5e3f0cc0 100644 --- a/tests/fixtures/testing_fixtures.py +++ b/tests/fixtures/testing_fixtures.py @@ -8,8 +8,9 @@ This module provides fixtures for: - Additional testing utilities """ -import pytest from unittest.mock import Mock + +import pytest from sqlalchemy.exc import SQLAlchemyError @@ -26,7 +27,7 @@ def empty_db(db): "inventory", # Fixed: singular not plural "products", # Referenced by products "vendors", # Has foreign key to users - "users" # Base table + "users", # Base table ] for table in tables_to_clear: diff --git a/tests/fixtures/vendor_fixtures.py b/tests/fixtures/vendor_fixtures.py index 04cb902c..4b2443e4 100644 --- a/tests/fixtures/vendor_fixtures.py +++ b/tests/fixtures/vendor_fixtures.py @@ -1,10 +1,11 @@ # tests/fixtures/vendor_fixtures.py import uuid + import pytest -from models.database.vendor import Vendor -from models.database.product import Product from models.database.inventory import Inventory +from models.database.product import Product +from models.database.vendor import Vendor @pytest.fixture @@ -185,6 +186,7 @@ def multiple_inventory_entries(db, multiple_products, test_vendor): def create_unique_vendor_factory(): """Factory function to create unique vendors in tests""" + def _create_vendor(db, owner_user_id, **kwargs): unique_id = str(uuid.uuid4())[:8] defaults = { diff --git a/tests/integration/api/v1/test_admin_endpoints.py b/tests/integration/api/v1/test_admin_endpoints.py index 61b940a1..0ede38e9 100644 --- a/tests/integration/api/v1/test_admin_endpoints.py +++ b/tests/integration/api/v1/test_admin_endpoints.py @@ -59,7 +59,9 @@ class TestAdminAPI: assert response.status_code == 400 # Business logic error data = response.json() assert data["error_code"] == "CANNOT_MODIFY_SELF" - assert "Cannot perform 'deactivate account' on your own account" in data["message"] + assert ( + "Cannot perform 'deactivate account' on your own account" in data["message"] + ) def test_toggle_user_status_cannot_modify_admin( self, client, admin_headers, test_admin, another_admin @@ -85,7 +87,9 @@ class TestAdminAPI: # Check that test_vendor is in the response vendor_codes = [ - vendor ["vendor_code"] for vendor in data["vendors"] if "vendor_code" in vendor + vendor["vendor_code"] + for vendor in data["vendors"] + if "vendor_code" in vendor ] assert test_vendor.vendor_code in vendor_codes @@ -98,7 +102,7 @@ class TestAdminAPI: assert data["error_code"] == "ADMIN_REQUIRED" def test_verify_vendor_admin(self, client, admin_headers, test_vendor): - """Test admin verifying/unverifying vendor """ + """Test admin verifying/unverifying vendor""" response = client.put( f"/api/v1/admin/vendors/{test_vendor.id}/verify", headers=admin_headers ) @@ -109,8 +113,10 @@ class TestAdminAPI: assert test_vendor.vendor_code in message def test_verify_vendor_not_found(self, client, admin_headers): - """Test admin verifying non-existent vendor """ - response = client.put("/api/v1/admin/vendors/99999/verify", headers=admin_headers) + """Test admin verifying non-existent vendor""" + response = client.put( + "/api/v1/admin/vendors/99999/verify", headers=admin_headers + ) assert response.status_code == 404 data = response.json() @@ -129,8 +135,10 @@ class TestAdminAPI: assert test_vendor.vendor_code in message def test_toggle_vendor_status_not_found(self, client, admin_headers): - """Test admin toggling status for non-existent vendor """ - response = client.put("/api/v1/admin/vendors/99999/status", headers=admin_headers) + """Test admin toggling status for non-existent vendor""" + response = client.put( + "/api/v1/admin/vendors/99999/status", headers=admin_headers + ) assert response.status_code == 404 data = response.json() @@ -166,7 +174,8 @@ class TestAdminAPI: data = response.json() assert len(data) >= 1 assert all( - job["marketplace"] == test_marketplace_import_job.marketplace for job in data + job["marketplace"] == test_marketplace_import_job.marketplace + for job in data ) def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers): diff --git a/tests/integration/api/v1/test_auth_endpoints.py b/tests/integration/api/v1/test_auth_endpoints.py index dd41d40d..83e0ca07 100644 --- a/tests/integration/api/v1/test_auth_endpoints.py +++ b/tests/integration/api/v1/test_auth_endpoints.py @@ -1,7 +1,8 @@ # tests/integration/api/v1/test_auth_endpoints.py +from datetime import datetime, timedelta, timezone + import pytest from jose import jwt -from datetime import datetime, timedelta, timezone @pytest.mark.integration @@ -178,8 +179,7 @@ class TestAuthenticationAPI: def test_get_current_user_invalid_token(self, client): """Test getting current user with invalid token""" response = client.get( - "/api/v1/auth/me", - headers={"Authorization": "Bearer invalid_token_here"} + "/api/v1/auth/me", headers={"Authorization": "Bearer invalid_token_here"} ) assert response.status_code == 401 @@ -201,14 +201,11 @@ class TestAuthenticationAPI: } expired_token = jwt.encode( - expired_payload, - auth_manager.secret_key, - algorithm=auth_manager.algorithm + expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm ) response = client.get( - "/api/v1/auth/me", - headers={"Authorization": f"Bearer {expired_token}"} + "/api/v1/auth/me", headers={"Authorization": f"Bearer {expired_token}"} ) assert response.status_code == 401 @@ -321,4 +318,3 @@ class TestAuthManager: user = auth_manager.authenticate_user(db, "nonexistent", "password") assert user is None - diff --git a/tests/integration/api/v1/test_filtering.py b/tests/integration/api/v1/test_filtering.py index b28d3862..8e89fe2e 100644 --- a/tests/integration/api/v1/test_filtering.py +++ b/tests/integration/api/v1/test_filtering.py @@ -14,19 +14,34 @@ class TestFiltering: """Test filtering products by brand successfully""" # Create products with different brands using unique IDs import uuid + unique_suffix = str(uuid.uuid4())[:8] products = [ - MarketplaceProduct(marketplace_product_id=f"BRAND1_{unique_suffix}", title="MarketplaceProduct 1", brand="BrandA"), - MarketplaceProduct(marketplace_product_id=f"BRAND2_{unique_suffix}", title="MarketplaceProduct 2", brand="BrandB"), - MarketplaceProduct(marketplace_product_id=f"BRAND3_{unique_suffix}", title="MarketplaceProduct 3", brand="BrandA"), + MarketplaceProduct( + marketplace_product_id=f"BRAND1_{unique_suffix}", + title="MarketplaceProduct 1", + brand="BrandA", + ), + MarketplaceProduct( + marketplace_product_id=f"BRAND2_{unique_suffix}", + title="MarketplaceProduct 2", + brand="BrandB", + ), + MarketplaceProduct( + marketplace_product_id=f"BRAND3_{unique_suffix}", + title="MarketplaceProduct 3", + brand="BrandA", + ), ] db.add_all(products) db.commit() # Filter by BrandA - response = client.get("/api/v1/marketplace/product?brand=BrandA", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?brand=BrandA", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 2 # At least our test products @@ -37,7 +52,9 @@ class TestFiltering: assert product["brand"] == "BrandA" # Filter by BrandB - response = client.get("/api/v1/marketplace/product?brand=BrandB", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?brand=BrandB", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 # At least our test product @@ -45,37 +62,57 @@ class TestFiltering: def test_product_marketplace_filter_success(self, client, auth_headers, db): """Test filtering products by marketplace successfully""" import uuid + unique_suffix = str(uuid.uuid4())[:8] products = [ - MarketplaceProduct(marketplace_product_id=f"MKT1_{unique_suffix}", title="MarketplaceProduct 1", marketplace="Amazon"), - MarketplaceProduct(marketplace_product_id=f"MKT2_{unique_suffix}", title="MarketplaceProduct 2", marketplace="eBay"), - MarketplaceProduct(marketplace_product_id=f"MKT3_{unique_suffix}", title="MarketplaceProduct 3", marketplace="Amazon"), + MarketplaceProduct( + marketplace_product_id=f"MKT1_{unique_suffix}", + title="MarketplaceProduct 1", + marketplace="Amazon", + ), + MarketplaceProduct( + marketplace_product_id=f"MKT2_{unique_suffix}", + title="MarketplaceProduct 2", + marketplace="eBay", + ), + MarketplaceProduct( + marketplace_product_id=f"MKT3_{unique_suffix}", + title="MarketplaceProduct 3", + marketplace="Amazon", + ), ] db.add_all(products) db.commit() - response = client.get("/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 2 # At least our test products # Verify all returned products have Amazon marketplace - amazon_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)] + amazon_products = [ + p + for p in data["products"] + if p["marketplace_product_id"].endswith(unique_suffix) + ] for product in amazon_products: assert product["marketplace"] == "Amazon" def test_product_search_filter_success(self, client, auth_headers, db): """Test searching products by text successfully""" import uuid + unique_suffix = str(uuid.uuid4())[:8] products = [ MarketplaceProduct( marketplace_product_id=f"SEARCH1_{unique_suffix}", title=f"Apple iPhone {unique_suffix}", - description="Smartphone" + description="Smartphone", ), MarketplaceProduct( marketplace_product_id=f"SEARCH2_{unique_suffix}", @@ -85,7 +122,7 @@ class TestFiltering: MarketplaceProduct( marketplace_product_id=f"SEARCH3_{unique_suffix}", title=f"iPad Tablet {unique_suffix}", - description="Apple tablet" + description="Apple tablet", ), ] @@ -93,13 +130,17 @@ class TestFiltering: db.commit() # Search for "Apple" - response = client.get(f"/api/v1/marketplace/product?search=Apple", headers=auth_headers) + response = client.get( + f"/api/v1/marketplace/product?search=Apple", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 2 # iPhone and iPad # Search for "phone" - response = client.get(f"/api/v1/marketplace/product?search=phone", headers=auth_headers) + response = client.get( + f"/api/v1/marketplace/product?search=phone", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 2 # iPhone and Galaxy @@ -107,6 +148,7 @@ class TestFiltering: def test_combined_filters_success(self, client, auth_headers, db): """Test combining multiple filters successfully""" import uuid + unique_suffix = str(uuid.uuid4())[:8] products = [ @@ -135,14 +177,19 @@ class TestFiltering: # Filter by brand AND marketplace response = client.get( - "/api/v1/marketplace/product?brand=Apple&marketplace=Amazon", headers=auth_headers + "/api/v1/marketplace/product?brand=Apple&marketplace=Amazon", + headers=auth_headers, ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 # At least iPhone matches both # Find our specific test product - matching_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)] + matching_products = [ + p + for p in data["products"] + if p["marketplace_product_id"].endswith(unique_suffix) + ] for product in matching_products: assert product["brand"] == "Apple" assert product["marketplace"] == "Amazon" @@ -150,7 +197,8 @@ class TestFiltering: def test_filter_with_no_results(self, client, auth_headers): """Test filtering with criteria that returns no results""" response = client.get( - "/api/v1/marketplace/product?brand=NonexistentBrand123456", headers=auth_headers + "/api/v1/marketplace/product?brand=NonexistentBrand123456", + headers=auth_headers, ) assert response.status_code == 200 @@ -161,6 +209,7 @@ class TestFiltering: def test_filter_case_insensitive(self, client, auth_headers, db): """Test that filters are case-insensitive""" import uuid + unique_suffix = str(uuid.uuid4())[:8] product = MarketplaceProduct( @@ -174,7 +223,10 @@ class TestFiltering: # Test different case variations for brand_filter in ["TestBrand", "testbrand", "TESTBRAND"]: - response = client.get(f"/api/v1/marketplace/product?brand={brand_filter}", headers=auth_headers) + response = client.get( + f"/api/v1/marketplace/product?brand={brand_filter}", + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 @@ -183,9 +235,14 @@ class TestFiltering: """Test behavior with invalid filter parameters""" # Test with very long filter values long_brand = "A" * 1000 - response = client.get(f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers) + response = client.get( + f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers + ) assert response.status_code == 200 # Should handle gracefully # Test with special characters - response = client.get("/api/v1/marketplace/product?brand=", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?brand=", + headers=auth_headers, + ) assert response.status_code == 200 # Should handle gracefully diff --git a/tests/integration/api/v1/test_inventory_endpoints.py b/tests/integration/api/v1/test_inventory_endpoints.py index bf121bba..1e00dc34 100644 --- a/tests/integration/api/v1/test_inventory_endpoints.py +++ b/tests/integration/api/v1/test_inventory_endpoints.py @@ -17,7 +17,9 @@ class TestInventoryAPI: "quantity": 100, } - response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) + response = client.post( + "/api/v1/inventory", headers=auth_headers, json=inventory_data + ) assert response.status_code == 200 data = response.json() @@ -38,7 +40,9 @@ class TestInventoryAPI: "quantity": 75, } - response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) + response = client.post( + "/api/v1/inventory", headers=auth_headers, json=inventory_data + ) assert response.status_code == 200 data = response.json() @@ -52,7 +56,9 @@ class TestInventoryAPI: "quantity": 100, } - response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) + response = client.post( + "/api/v1/inventory", headers=auth_headers, json=inventory_data + ) assert response.status_code == 422 data = response.json() @@ -60,7 +66,9 @@ class TestInventoryAPI: assert data["status_code"] == 422 assert "GTIN is required" in data["message"] - def test_set_inventory_invalid_quantity_validation_error(self, client, auth_headers): + def test_set_inventory_invalid_quantity_validation_error( + self, client, auth_headers + ): """Test setting inventory with invalid quantity returns InvalidQuantityException""" inventory_data = { "gtin": "1234567890123", @@ -68,7 +76,9 @@ class TestInventoryAPI: "quantity": -10, # Negative quantity } - response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) + response = client.post( + "/api/v1/inventory", headers=auth_headers, json=inventory_data + ) assert response.status_code in [400, 422] data = response.json() @@ -138,7 +148,9 @@ class TestInventoryAPI: data = response.json() assert data["quantity"] == 35 # 50 - 15 - def test_remove_inventory_insufficient_returns_business_logic_error(self, client, auth_headers, db): + def test_remove_inventory_insufficient_returns_business_logic_error( + self, client, auth_headers, db + ): """Test removing more inventory than available returns InsufficientInventoryException""" # Create initial inventory inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=10) @@ -184,7 +196,9 @@ class TestInventoryAPI: assert data["error_code"] == "INVENTORY_NOT_FOUND" assert data["status_code"] == 404 - def test_negative_inventory_not_allowed_business_logic_error(self, client, auth_headers, db): + def test_negative_inventory_not_allowed_business_logic_error( + self, client, auth_headers, db + ): """Test operations resulting in negative inventory returns NegativeInventoryException""" # Create initial inventory inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=5) @@ -204,14 +218,21 @@ class TestInventoryAPI: assert response.status_code == 400 data = response.json() # This might be caught as INSUFFICIENT_INVENTORY or NEGATIVE_INVENTORY_NOT_ALLOWED - assert data["error_code"] in ["INSUFFICIENT_INVENTORY", "NEGATIVE_INVENTORY_NOT_ALLOWED"] + assert data["error_code"] in [ + "INSUFFICIENT_INVENTORY", + "NEGATIVE_INVENTORY_NOT_ALLOWED", + ] assert data["status_code"] == 400 def test_get_inventory_by_gtin_success(self, client, auth_headers, db): """Test getting inventory summary for GTIN successfully""" # Create inventory in multiple locations - inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50) - inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25) + inventory1 = Inventory( + gtin="1234567890123", location="WAREHOUSE_A", quantity=50 + ) + inventory2 = Inventory( + gtin="1234567890123", location="WAREHOUSE_B", quantity=25 + ) db.add_all([inventory1, inventory2]) db.commit() @@ -238,12 +259,18 @@ class TestInventoryAPI: def test_get_total_inventory_success(self, client, auth_headers, db): """Test getting total inventory for GTIN successfully""" # Create inventory in multiple locations - inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50) - inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25) + inventory1 = Inventory( + gtin="1234567890123", location="WAREHOUSE_A", quantity=50 + ) + inventory2 = Inventory( + gtin="1234567890123", location="WAREHOUSE_B", quantity=25 + ) db.add_all([inventory1, inventory2]) db.commit() - response = client.get("/api/v1/inventory/1234567890123/total", headers=auth_headers) + response = client.get( + "/api/v1/inventory/1234567890123/total", headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -253,7 +280,9 @@ class TestInventoryAPI: def test_get_total_inventory_not_found(self, client, auth_headers): """Test getting total inventory for nonexistent GTIN returns InventoryNotFoundException""" - response = client.get("/api/v1/inventory/9999999999999/total", headers=auth_headers) + response = client.get( + "/api/v1/inventory/9999999999999/total", headers=auth_headers + ) assert response.status_code == 404 data = response.json() @@ -263,8 +292,12 @@ class TestInventoryAPI: def test_get_all_inventory_success(self, client, auth_headers, db): """Test getting all inventory entries successfully""" # Create some inventory entries - inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50) - inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25) + inventory1 = Inventory( + gtin="1234567890123", location="WAREHOUSE_A", quantity=50 + ) + inventory2 = Inventory( + gtin="9876543210987", location="WAREHOUSE_B", quantity=25 + ) db.add_all([inventory1, inventory2]) db.commit() @@ -277,20 +310,28 @@ class TestInventoryAPI: def test_get_all_inventory_with_filters(self, client, auth_headers, db): """Test getting inventory entries with filtering""" # Create inventory entries - inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50) - inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25) + inventory1 = Inventory( + gtin="1234567890123", location="WAREHOUSE_A", quantity=50 + ) + inventory2 = Inventory( + gtin="9876543210987", location="WAREHOUSE_B", quantity=25 + ) db.add_all([inventory1, inventory2]) db.commit() # Filter by location - response = client.get("/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers) + response = client.get( + "/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers + ) assert response.status_code == 200 data = response.json() for inventory in data: assert inventory["location"] == "WAREHOUSE_A" # Filter by GTIN - response = client.get("/api/v1/inventory?gtin=1234567890123", headers=auth_headers) + response = client.get( + "/api/v1/inventory?gtin=1234567890123", headers=auth_headers + ) assert response.status_code == 200 data = response.json() for inventory in data: @@ -390,7 +431,9 @@ class TestInventoryAPI: "quantity": 100, } - response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) + response = client.post( + "/api/v1/inventory", headers=auth_headers, json=inventory_data + ) # This depends on whether your service validates locations if response.status_code == 404: diff --git a/tests/integration/api/v1/test_marketplace_import_job_endpoints.py b/tests/integration/api/v1/test_marketplace_import_job_endpoints.py index 9e1988cc..37b617f1 100644 --- a/tests/integration/api/v1/test_marketplace_import_job_endpoints.py +++ b/tests/integration/api/v1/test_marketplace_import_job_endpoints.py @@ -8,9 +8,11 @@ import pytest @pytest.mark.api @pytest.mark.marketplace class TestMarketplaceImportJobAPI: - def test_import_from_marketplace(self, client, auth_headers, test_vendor, test_user): + def test_import_from_marketplace( + self, client, auth_headers, test_vendor, test_user + ): """Test marketplace import endpoint - just test job creation""" - # Ensure user owns the vendor + # Ensure user owns the vendor test_vendor.owner_user_id = test_user.id import_data = { @@ -32,7 +34,7 @@ class TestMarketplaceImportJobAPI: assert data["vendor_id"] == test_vendor.id def test_import_from_marketplace_invalid_vendor(self, client, auth_headers): - """Test marketplace import with invalid vendor """ + """Test marketplace import with invalid vendor""" import_data = { "url": "https://example.com/products.csv", "marketplace": "TestMarket", @@ -48,7 +50,9 @@ class TestMarketplaceImportJobAPI: assert data["error_code"] == "VENDOR_NOT_FOUND" assert "NONEXISTENT" in data["message"] - def test_import_from_marketplace_unauthorized_vendor(self, client, auth_headers, test_vendor, other_user): + def test_import_from_marketplace_unauthorized_vendor( + self, client, auth_headers, test_vendor, other_user + ): """Test marketplace import with unauthorized vendor access""" # Set vendor owner to different user test_vendor.owner_user_id = other_user.id @@ -85,8 +89,10 @@ class TestMarketplaceImportJobAPI: assert data["error_code"] == "VALIDATION_ERROR" assert "Request validation failed" in data["message"] - def test_import_from_marketplace_admin_access(self, client, admin_headers, test_vendor): - """Test that admin can import for any vendor """ + def test_import_from_marketplace_admin_access( + self, client, admin_headers, test_vendor + ): + """Test that admin can import for any vendor""" import_data = { "url": "https://example.com/products.csv", "marketplace": "AdminMarket", @@ -94,7 +100,9 @@ class TestMarketplaceImportJobAPI: } response = client.post( - "/api/v1/marketplace/import-product", headers=admin_headers, json=import_data + "/api/v1/marketplace/import-product", + headers=admin_headers, + json=import_data, ) assert response.status_code == 200 @@ -102,11 +110,13 @@ class TestMarketplaceImportJobAPI: assert data["marketplace"] == "AdminMarket" assert data["vendor_code"] == test_vendor.vendor_code - def test_get_marketplace_import_status(self, client, auth_headers, test_marketplace_import_job): + def test_get_marketplace_import_status( + self, client, auth_headers, test_marketplace_import_job + ): """Test getting marketplace import status""" response = client.get( f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 @@ -118,8 +128,7 @@ class TestMarketplaceImportJobAPI: def test_get_marketplace_import_status_not_found(self, client, auth_headers): """Test getting status of non-existent import job""" response = client.get( - "/api/v1/marketplace/import-status/99999", - headers=auth_headers + "/api/v1/marketplace/import-status/99999", headers=auth_headers ) assert response.status_code == 404 @@ -127,21 +136,25 @@ class TestMarketplaceImportJobAPI: assert data["error_code"] == "IMPORT_JOB_NOT_FOUND" assert "99999" in data["message"] - def test_get_marketplace_import_status_unauthorized(self, client, auth_headers, test_marketplace_import_job, other_user): + def test_get_marketplace_import_status_unauthorized( + self, client, auth_headers, test_marketplace_import_job, other_user + ): """Test getting status of unauthorized import job""" # Change job owner to other user test_marketplace_import_job.user_id = other_user.id response = client.get( f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 403 data = response.json() assert data["error_code"] == "IMPORT_JOB_NOT_OWNED" - def test_get_marketplace_import_jobs(self, client, auth_headers, test_marketplace_import_job): + def test_get_marketplace_import_jobs( + self, client, auth_headers, test_marketplace_import_job + ): """Test getting marketplace import jobs""" response = client.get("/api/v1/marketplace/import-jobs", headers=auth_headers) @@ -154,11 +167,13 @@ class TestMarketplaceImportJobAPI: job_ids = [job["job_id"] for job in data] assert test_marketplace_import_job.id in job_ids - def test_get_marketplace_import_jobs_with_filters(self, client, auth_headers, test_marketplace_import_job): + def test_get_marketplace_import_jobs_with_filters( + self, client, auth_headers, test_marketplace_import_job + ): """Test getting import jobs with filters""" response = client.get( f"/api/v1/marketplace/import-jobs?marketplace={test_marketplace_import_job.marketplace}", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 @@ -167,13 +182,15 @@ class TestMarketplaceImportJobAPI: assert len(data) >= 1 for job in data: - assert test_marketplace_import_job.marketplace.lower() in job["marketplace"].lower() + assert ( + test_marketplace_import_job.marketplace.lower() + in job["marketplace"].lower() + ) def test_get_marketplace_import_jobs_pagination(self, client, auth_headers): """Test import jobs pagination""" response = client.get( - "/api/v1/marketplace/import-jobs?skip=0&limit=5", - headers=auth_headers + "/api/v1/marketplace/import-jobs?skip=0&limit=5", headers=auth_headers ) assert response.status_code == 200 @@ -181,9 +198,13 @@ class TestMarketplaceImportJobAPI: assert isinstance(data, list) assert len(data) <= 5 - def test_get_marketplace_import_stats(self, client, auth_headers, test_marketplace_import_job): + def test_get_marketplace_import_stats( + self, client, auth_headers, test_marketplace_import_job + ): """Test getting marketplace import statistics""" - response = client.get("/api/v1/marketplace/marketplace-import-stats", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/marketplace-import-stats", headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -195,12 +216,15 @@ class TestMarketplaceImportJobAPI: assert isinstance(data["total_jobs"], int) assert data["total_jobs"] >= 1 - def test_cancel_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db): + def test_cancel_marketplace_import_job( + self, client, auth_headers, test_user, test_vendor, db + ): """Test cancelling a marketplace import job""" # Create a pending job that can be cancelled - from models.database.marketplace_import_job import MarketplaceImportJob import uuid + from models.database.marketplace_import_job import MarketplaceImportJob + unique_id = str(uuid.uuid4())[:8] job = MarketplaceImportJob( status="pending", @@ -219,8 +243,7 @@ class TestMarketplaceImportJobAPI: db.refresh(job) response = client.put( - f"/api/v1/marketplace/import-jobs/{job.id}/cancel", - headers=auth_headers + f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=auth_headers ) assert response.status_code == 200 @@ -232,15 +255,16 @@ class TestMarketplaceImportJobAPI: def test_cancel_marketplace_import_job_not_found(self, client, auth_headers): """Test cancelling non-existent import job""" response = client.put( - "/api/v1/marketplace/import-jobs/99999/cancel", - headers=auth_headers + "/api/v1/marketplace/import-jobs/99999/cancel", headers=auth_headers ) assert response.status_code == 404 data = response.json() assert data["error_code"] == "IMPORT_JOB_NOT_FOUND" - def test_cancel_marketplace_import_job_cannot_cancel(self, client, auth_headers, test_marketplace_import_job, db): + def test_cancel_marketplace_import_job_cannot_cancel( + self, client, auth_headers, test_marketplace_import_job, db + ): """Test cancelling a job that cannot be cancelled""" # Set job to completed status test_marketplace_import_job.status = "completed" @@ -248,7 +272,7 @@ class TestMarketplaceImportJobAPI: response = client.put( f"/api/v1/marketplace/import-jobs/{test_marketplace_import_job.id}/cancel", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 400 @@ -256,12 +280,15 @@ class TestMarketplaceImportJobAPI: assert data["error_code"] == "IMPORT_JOB_CANNOT_BE_CANCELLED" assert "completed" in data["message"] - def test_delete_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db): + def test_delete_marketplace_import_job( + self, client, auth_headers, test_user, test_vendor, db + ): """Test deleting a marketplace import job""" # Create a completed job that can be deleted - from models.database.marketplace_import_job import MarketplaceImportJob import uuid + from models.database.marketplace_import_job import MarketplaceImportJob + unique_id = str(uuid.uuid4())[:8] job = MarketplaceImportJob( status="completed", @@ -280,8 +307,7 @@ class TestMarketplaceImportJobAPI: db.refresh(job) response = client.delete( - f"/api/v1/marketplace/import-jobs/{job.id}", - headers=auth_headers + f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers ) assert response.status_code == 200 @@ -291,20 +317,22 @@ class TestMarketplaceImportJobAPI: def test_delete_marketplace_import_job_not_found(self, client, auth_headers): """Test deleting non-existent import job""" response = client.delete( - "/api/v1/marketplace/import-jobs/99999", - headers=auth_headers + "/api/v1/marketplace/import-jobs/99999", headers=auth_headers ) assert response.status_code == 404 data = response.json() assert data["error_code"] == "IMPORT_JOB_NOT_FOUND" - def test_delete_marketplace_import_job_cannot_delete(self, client, auth_headers, test_user, test_vendor, db): + def test_delete_marketplace_import_job_cannot_delete( + self, client, auth_headers, test_user, test_vendor, db + ): """Test deleting a job that cannot be deleted""" # Create a pending job that cannot be deleted - from models.database.marketplace_import_job import MarketplaceImportJob import uuid + from models.database.marketplace_import_job import MarketplaceImportJob + unique_id = str(uuid.uuid4())[:8] job = MarketplaceImportJob( status="pending", @@ -323,8 +351,7 @@ class TestMarketplaceImportJobAPI: db.refresh(job) response = client.delete( - f"/api/v1/marketplace/import-jobs/{job.id}", - headers=auth_headers + f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers ) assert response.status_code == 400 @@ -352,7 +379,9 @@ class TestMarketplaceImportJobAPI: data = response.json() assert data["error_code"] == "INVALID_TOKEN" - def test_admin_can_access_all_jobs(self, client, admin_headers, test_marketplace_import_job): + def test_admin_can_access_all_jobs( + self, client, admin_headers, test_marketplace_import_job + ): """Test that admin can access all import jobs""" response = client.get("/api/v1/marketplace/import-jobs", headers=admin_headers) @@ -363,23 +392,28 @@ class TestMarketplaceImportJobAPI: job_ids = [job["job_id"] for job in data] assert test_marketplace_import_job.id in job_ids - def test_admin_can_view_any_job_status(self, client, admin_headers, test_marketplace_import_job): + def test_admin_can_view_any_job_status( + self, client, admin_headers, test_marketplace_import_job + ): """Test that admin can view any job status""" response = client.get( f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}", - headers=admin_headers + headers=admin_headers, ) assert response.status_code == 200 data = response.json() assert data["job_id"] == test_marketplace_import_job.id - def test_admin_can_cancel_any_job(self, client, admin_headers, test_user, test_vendor, db): + def test_admin_can_cancel_any_job( + self, client, admin_headers, test_user, test_vendor, db + ): """Test that admin can cancel any job""" # Create a pending job owned by different user - from models.database.marketplace_import_job import MarketplaceImportJob import uuid + from models.database.marketplace_import_job import MarketplaceImportJob + unique_id = str(uuid.uuid4())[:8] job = MarketplaceImportJob( status="pending", @@ -398,8 +432,7 @@ class TestMarketplaceImportJobAPI: db.refresh(job) response = client.put( - f"/api/v1/marketplace/import-jobs/{job.id}/cancel", - headers=admin_headers + f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=admin_headers ) assert response.status_code == 200 diff --git a/tests/integration/api/v1/test_marketplace_product_export.py b/tests/integration/api/v1/test_marketplace_product_export.py index 02438470..272d1a6e 100644 --- a/tests/integration/api/v1/test_marketplace_product_export.py +++ b/tests/integration/api/v1/test_marketplace_product_export.py @@ -1,7 +1,7 @@ # tests/integration/api/v1/test_export.py import csv -from io import StringIO import uuid +from io import StringIO import pytest @@ -13,9 +13,13 @@ from models.database.marketplace_product import MarketplaceProduct @pytest.mark.performance # for the performance test class TestExportFunctionality: - def test_csv_export_basic_success(self, client, auth_headers, test_marketplace_product): + def test_csv_export_basic_success( + self, client, auth_headers, test_marketplace_product + ): """Test basic CSV export functionality successfully""" - response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/export-csv", headers=auth_headers + ) assert response.status_code == 200 assert response.headers["content-type"] == "text/csv; charset=utf-8" @@ -26,13 +30,22 @@ class TestExportFunctionality: # Check header row header = next(csv_reader) - expected_fields = ["marketplace_product_id", "title", "description", "price", "marketplace"] + expected_fields = [ + "marketplace_product_id", + "title", + "description", + "price", + "marketplace", + ] for field in expected_fields: assert field in header # Verify test product appears in export - csv_lines = csv_content.split('\n') - test_product_found = any(test_marketplace_product.marketplace_product_id in line for line in csv_lines) + csv_lines = csv_content.split("\n") + test_product_found = any( + test_marketplace_product.marketplace_product_id in line + for line in csv_lines + ) assert test_product_found, "Test product should appear in CSV export" def test_csv_export_with_marketplace_filter_success(self, client, auth_headers, db): @@ -43,12 +56,12 @@ class TestExportFunctionality: MarketplaceProduct( marketplace_product_id=f"EXP1_{unique_suffix}", title=f"Amazon MarketplaceProduct {unique_suffix}", - marketplace="Amazon" + marketplace="Amazon", ), MarketplaceProduct( marketplace_product_id=f"EXP2_{unique_suffix}", title=f"eBay MarketplaceProduct {unique_suffix}", - marketplace="eBay" + marketplace="eBay", ), ] @@ -56,7 +69,8 @@ class TestExportFunctionality: db.commit() response = client.get( - "/api/v1/marketplace/product/export-csv?marketplace=Amazon", headers=auth_headers + "/api/v1/marketplace/product/export-csv?marketplace=Amazon", + headers=auth_headers, ) assert response.status_code == 200 assert response.headers["content-type"] == "text/csv; charset=utf-8" @@ -72,12 +86,12 @@ class TestExportFunctionality: MarketplaceProduct( marketplace_product_id=f"VENDOR1_{unique_suffix}", title=f"Vendor1 MarketplaceProduct {unique_suffix}", - vendor_name="TestVendor1" + vendor_name="TestVendor1", ), MarketplaceProduct( marketplace_product_id=f"VENDOR2_{unique_suffix}", title=f"Vendor2 MarketplaceProduct {unique_suffix}", - vendor_name="TestVendor2" + vendor_name="TestVendor2", ), ] @@ -101,19 +115,19 @@ class TestExportFunctionality: marketplace_product_id=f"COMBO1_{unique_suffix}", title=f"Combo MarketplaceProduct 1 {unique_suffix}", marketplace="Amazon", - vendor_name="TestVendor" + vendor_name="TestVendor", ), MarketplaceProduct( marketplace_product_id=f"COMBO2_{unique_suffix}", title=f"Combo MarketplaceProduct 2 {unique_suffix}", marketplace="eBay", - vendor_name="TestVendor" + vendor_name="TestVendor", ), MarketplaceProduct( marketplace_product_id=f"COMBO3_{unique_suffix}", title=f"Combo MarketplaceProduct 3 {unique_suffix}", marketplace="Amazon", - vendor_name="OtherVendor" + vendor_name="OtherVendor", ), ] @@ -122,27 +136,27 @@ class TestExportFunctionality: response = client.get( "/api/v1/marketplace/product?marketplace=Amazon&name=TestVendor", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 csv_content = response.content.decode("utf-8") assert f"COMBO1_{unique_suffix}" in csv_content # Matches both filters assert f"COMBO2_{unique_suffix}" not in csv_content # Wrong marketplace - assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor + assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor def test_csv_export_no_results(self, client, auth_headers): """Test CSV export with filters that return no results""" response = client.get( "/api/v1/marketplace/product/export-csv?marketplace=NonexistentMarketplace12345", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 assert response.headers["content-type"] == "text/csv; charset=utf-8" csv_content = response.content.decode("utf-8") - csv_lines = csv_content.strip().split('\n') + csv_lines = csv_content.strip().split("\n") # Should have header row even with no data assert len(csv_lines) >= 1 # First line should be headers @@ -161,27 +175,32 @@ class TestExportFunctionality: title=f"Performance MarketplaceProduct {i}", marketplace="Performance", description=f"Performance test product {i}", - price="10.99" + price="10.99", ) products.append(product) # Add in batches to avoid memory issues for i in range(0, len(products), 50): - batch = products[i:i + 50] + batch = products[i : i + 50] db.add_all(batch) db.commit() import time + start_time = time.time() - response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/export-csv", headers=auth_headers + ) end_time = time.time() execution_time = end_time - start_time assert response.status_code == 200 assert response.headers["content-type"] == "text/csv; charset=utf-8" - assert execution_time < 10.0, f"Export took {execution_time:.2f} seconds, should be under 10s" + assert ( + execution_time < 10.0 + ), f"Export took {execution_time:.2f} seconds, should be under 10s" # Verify content contains our test data csv_content = response.content.decode("utf-8") @@ -197,9 +216,13 @@ class TestExportFunctionality: assert data["error_code"] == "INVALID_TOKEN" assert data["status_code"] == 401 - def test_csv_export_streaming_response_format(self, client, auth_headers, test_marketplace_product): + def test_csv_export_streaming_response_format( + self, client, auth_headers, test_marketplace_product + ): """Test that CSV export returns proper streaming response with correct headers""" - response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/export-csv", headers=auth_headers + ) assert response.status_code == 200 assert response.headers["content-type"] == "text/csv; charset=utf-8" @@ -217,23 +240,27 @@ class TestExportFunctionality: # Create product with special characters that might break CSV product = MarketplaceProduct( marketplace_product_id=f"SPECIAL_{unique_suffix}", - title=f'MarketplaceProduct with quotes and commas {unique_suffix}', # Simplified to avoid CSV escaping issues + title=f"MarketplaceProduct with quotes and commas {unique_suffix}", # Simplified to avoid CSV escaping issues description=f"Description with special chars {unique_suffix}", marketplace="Test Market", - price="19.99" + price="19.99", ) db.add(product) db.commit() - response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/export-csv", headers=auth_headers + ) assert response.status_code == 200 csv_content = response.content.decode("utf-8") # Verify our test product appears in the CSV content assert f"SPECIAL_{unique_suffix}" in csv_content - assert f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content + assert ( + f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content + ) assert "Test Market" in csv_content assert "19.99" in csv_content @@ -243,7 +270,12 @@ class TestExportFunctionality: header = next(csv_reader) # Verify header contains expected fields - expected_fields = ["marketplace_product_id", "title", "marketplace", "price"] + expected_fields = [ + "marketplace_product_id", + "title", + "marketplace", + "price", + ] for field in expected_fields: assert field in header @@ -264,12 +296,15 @@ class TestExportFunctionality: assert parsed_successfully, "CSV should be parseable despite special characters" - def test_csv_export_error_handling_service_failure(self, client, auth_headers, monkeypatch): + def test_csv_export_error_handling_service_failure( + self, client, auth_headers, monkeypatch + ): """Test CSV export handles service failures gracefully""" # Mock the service to raise an exception def mock_generate_csv_export(*args, **kwargs): from app.exceptions import ValidationException + raise ValidationException("Mocked service failure") # This would require access to your service instance to mock properly @@ -293,7 +328,9 @@ class TestExportFunctionality: def test_csv_export_filename_generation(self, client, auth_headers): """Test CSV export generates appropriate filenames based on filters""" # Test basic export filename - response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/export-csv", headers=auth_headers + ) assert response.status_code == 200 content_disposition = response.headers.get("content-disposition", "") @@ -302,7 +339,7 @@ class TestExportFunctionality: # Test with marketplace filter response = client.get( "/api/v1/marketplace/product/export-csv?marketplace=Amazon", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 diff --git a/tests/integration/api/v1/test_marketplace_products_endpoints.py b/tests/integration/api/v1/test_marketplace_products_endpoints.py index 3e01c0e8..dbda5e8d 100644 --- a/tests/integration/api/v1/test_marketplace_products_endpoints.py +++ b/tests/integration/api/v1/test_marketplace_products_endpoints.py @@ -16,7 +16,9 @@ class TestMarketplaceProductsAPI: assert data["products"] == [] assert data["total"] == 0 - def test_get_products_with_data(self, client, auth_headers, test_marketplace_product): + def test_get_products_with_data( + self, client, auth_headers, test_marketplace_product + ): """Test getting products with data""" response = client.get("/api/v1/marketplace/product", headers=auth_headers) @@ -25,25 +27,39 @@ class TestMarketplaceProductsAPI: assert len(data["products"]) >= 1 assert data["total"] >= 1 # Find our test product - test_product_found = any(p["marketplace_product_id"] == test_marketplace_product.marketplace_product_id for p in data["products"]) + test_product_found = any( + p["marketplace_product_id"] + == test_marketplace_product.marketplace_product_id + for p in data["products"] + ) assert test_product_found - def test_get_products_with_filters(self, client, auth_headers, test_marketplace_product): + def test_get_products_with_filters( + self, client, auth_headers, test_marketplace_product + ): """Test filtering products""" # Test brand filter - response = client.get(f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}", headers=auth_headers) + response = client.get( + f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}", + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 # Test marketplace filter - response = client.get(f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}", headers=auth_headers) + response = client.get( + f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}", + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 # Test search - response = client.get("/api/v1/marketplace/product?search=Test", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?search=Test", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 @@ -71,7 +87,9 @@ class TestMarketplaceProductsAPI: assert data["title"] == "New MarketplaceProduct" assert data["marketplace"] == "Amazon" - def test_create_product_duplicate_id_returns_conflict(self, client, auth_headers, test_marketplace_product): + def test_create_product_duplicate_id_returns_conflict( + self, client, auth_headers, test_marketplace_product + ): """Test creating product with duplicate ID returns MarketplaceProductAlreadyExistsException""" product_data = { "marketplace_product_id": test_marketplace_product.marketplace_product_id, @@ -93,7 +111,10 @@ class TestMarketplaceProductsAPI: assert data["error_code"] == "PRODUCT_ALREADY_EXISTS" assert data["status_code"] == 409 assert test_marketplace_product.marketplace_product_id in data["message"] - assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id + assert ( + data["details"]["marketplace_product_id"] + == test_marketplace_product.marketplace_product_id + ) def test_create_product_missing_title_validation_error(self, client, auth_headers): """Test creating product without title returns ValidationException""" @@ -114,7 +135,9 @@ class TestMarketplaceProductsAPI: assert "MarketplaceProduct title is required" in data["message"] assert data["details"]["field"] == "title" - def test_create_product_missing_product_id_validation_error(self, client, auth_headers): + def test_create_product_missing_product_id_validation_error( + self, client, auth_headers + ): """Test creating product without marketplace_product_id returns ValidationException""" product_data = { "marketplace_product_id": "", # Empty product ID @@ -191,20 +214,28 @@ class TestMarketplaceProductsAPI: assert "Request validation failed" in data["message"] assert "validation_errors" in data["details"] - def test_get_product_by_id_success(self, client, auth_headers, test_marketplace_product): + def test_get_product_by_id_success( + self, client, auth_headers, test_marketplace_product + ): """Test getting specific product successfully""" response = client.get( - f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers + f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", + headers=auth_headers, ) assert response.status_code == 200 data = response.json() - assert data["product"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id + assert ( + data["product"]["marketplace_product_id"] + == test_marketplace_product.marketplace_product_id + ) assert data["product"]["title"] == test_marketplace_product.title def test_get_nonexistent_product_returns_not_found(self, client, auth_headers): """Test getting nonexistent product returns MarketplaceProductNotFoundException""" - response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers + ) assert response.status_code == 404 data = response.json() @@ -214,7 +245,9 @@ class TestMarketplaceProductsAPI: assert data["details"]["resource_type"] == "MarketplaceProduct" assert data["details"]["identifier"] == "NONEXISTENT" - def test_update_product_success(self, client, auth_headers, test_marketplace_product): + def test_update_product_success( + self, client, auth_headers, test_marketplace_product + ): """Test updating product successfully""" update_data = {"title": "Updated MarketplaceProduct Title", "price": "25.99"} @@ -247,7 +280,9 @@ class TestMarketplaceProductsAPI: assert data["details"]["resource_type"] == "MarketplaceProduct" assert data["details"]["identifier"] == "NONEXISTENT" - def test_update_product_empty_title_validation_error(self, client, auth_headers, test_marketplace_product): + def test_update_product_empty_title_validation_error( + self, client, auth_headers, test_marketplace_product + ): """Test updating product with empty title returns MarketplaceProductValidationException""" update_data = {"title": ""} @@ -264,7 +299,9 @@ class TestMarketplaceProductsAPI: assert "MarketplaceProduct title cannot be empty" in data["message"] assert data["details"]["field"] == "title" - def test_update_product_invalid_gtin_data_error(self, client, auth_headers, test_marketplace_product): + def test_update_product_invalid_gtin_data_error( + self, client, auth_headers, test_marketplace_product + ): """Test updating product with invalid GTIN returns InvalidMarketplaceProductDataException""" update_data = {"gtin": "invalid_gtin"} @@ -281,7 +318,9 @@ class TestMarketplaceProductsAPI: assert "Invalid GTIN format" in data["message"] assert data["details"]["field"] == "gtin" - def test_update_product_invalid_price_data_error(self, client, auth_headers, test_marketplace_product): + def test_update_product_invalid_price_data_error( + self, client, auth_headers, test_marketplace_product + ): """Test updating product with invalid price returns InvalidMarketplaceProductDataException""" update_data = {"price": "invalid_price"} @@ -298,10 +337,13 @@ class TestMarketplaceProductsAPI: assert "Invalid price format" in data["message"] assert data["details"]["field"] == "price" - def test_delete_product_success(self, client, auth_headers, test_marketplace_product): + def test_delete_product_success( + self, client, auth_headers, test_marketplace_product + ): """Test deleting product successfully""" response = client.delete( - f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers + f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", + headers=auth_headers, ) assert response.status_code == 200 @@ -309,7 +351,9 @@ class TestMarketplaceProductsAPI: def test_delete_nonexistent_product_returns_not_found(self, client, auth_headers): """Test deleting nonexistent product returns MarketplaceProductNotFoundException""" - response = client.delete("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers) + response = client.delete( + "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers + ) assert response.status_code == 404 data = response.json() @@ -331,7 +375,9 @@ class TestMarketplaceProductsAPI: def test_exception_structure_consistency(self, client, auth_headers): """Test that all exceptions follow the consistent WizamartException structure""" # Test with a known error case - response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers + ) assert response.status_code == 404 data = response.json() diff --git a/tests/integration/api/v1/test_pagination.py b/tests/integration/api/v1/test_pagination.py index c2440c18..22ead4b0 100644 --- a/tests/integration/api/v1/test_pagination.py +++ b/tests/integration/api/v1/test_pagination.py @@ -4,6 +4,7 @@ import pytest from models.database.marketplace_product import MarketplaceProduct from models.database.vendor import Vendor + @pytest.mark.integration @pytest.mark.api @pytest.mark.database @@ -13,6 +14,7 @@ class TestPagination: def test_product_pagination_success(self, client, auth_headers, db): """Test pagination for product listing successfully""" import uuid + unique_suffix = str(uuid.uuid4())[:8] # Create multiple products @@ -29,7 +31,9 @@ class TestPagination: db.commit() # Test first page - response = client.get("/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert len(data["products"]) == 10 @@ -38,21 +42,29 @@ class TestPagination: assert data["limit"] == 10 # Test second page - response = client.get("/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert len(data["products"]) == 10 assert data["skip"] == 10 # Test last page (should have remaining products) - response = client.get("/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert len(data["products"]) >= 5 # At least 5 remaining from our test set - def test_pagination_boundary_negative_skip_validation_error(self, client, auth_headers): + def test_pagination_boundary_negative_skip_validation_error( + self, client, auth_headers + ): """Test negative skip parameter returns ValidationException""" - response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=-1", headers=auth_headers + ) assert response.status_code == 422 data = response.json() @@ -61,9 +73,13 @@ class TestPagination: assert "Request validation failed" in data["message"] assert "validation_errors" in data["details"] - def test_pagination_boundary_zero_limit_validation_error(self, client, auth_headers): + def test_pagination_boundary_zero_limit_validation_error( + self, client, auth_headers + ): """Test zero limit parameter returns ValidationException""" - response = client.get("/api/v1/marketplace/product?limit=0", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=0", headers=auth_headers + ) assert response.status_code == 422 data = response.json() @@ -71,9 +87,13 @@ class TestPagination: assert data["status_code"] == 422 assert "Request validation failed" in data["message"] - def test_pagination_boundary_excessive_limit_validation_error(self, client, auth_headers): + def test_pagination_boundary_excessive_limit_validation_error( + self, client, auth_headers + ): """Test excessive limit parameter returns ValidationException""" - response = client.get("/api/v1/marketplace/product?limit=10000", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=10000", headers=auth_headers + ) assert response.status_code == 422 data = response.json() @@ -84,7 +104,9 @@ class TestPagination: def test_pagination_beyond_available_records(self, client, auth_headers, db): """Test pagination beyond available records returns empty results""" # Test skip beyond available records - response = client.get("/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -96,6 +118,7 @@ class TestPagination: def test_pagination_with_filters(self, client, auth_headers, db): """Test pagination combined with filtering""" import uuid + unique_suffix = str(uuid.uuid4())[:8] # Create products with same brand for filtering @@ -115,7 +138,7 @@ class TestPagination: # Test first page with filter response = client.get( "/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=0", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 @@ -124,14 +147,18 @@ class TestPagination: assert data["total"] >= 15 # At least our test products # Verify all products have the filtered brand - test_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)] + test_products = [ + p + for p in data["products"] + if p["marketplace_product_id"].endswith(unique_suffix) + ] for product in test_products: assert product["brand"] == "FilterBrand" # Test second page with same filter response = client.get( "/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=5", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 @@ -152,6 +179,7 @@ class TestPagination: def test_pagination_consistency(self, client, auth_headers, db): """Test pagination consistency across multiple requests""" import uuid + unique_suffix = str(uuid.uuid4())[:8] # Create products with predictable ordering @@ -168,14 +196,22 @@ class TestPagination: db.commit() # Get first page - response1 = client.get("/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers) + response1 = client.get( + "/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers + ) assert response1.status_code == 200 - first_page_ids = [p["marketplace_product_id"] for p in response1.json()["products"]] + first_page_ids = [ + p["marketplace_product_id"] for p in response1.json()["products"] + ] # Get second page - response2 = client.get("/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers) + response2 = client.get( + "/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers + ) assert response2.status_code == 200 - second_page_ids = [p["marketplace_product_id"] for p in response2.json()["products"]] + second_page_ids = [ + p["marketplace_product_id"] for p in response2.json()["products"] + ] # Verify no overlap between pages overlap = set(first_page_ids) & set(second_page_ids) @@ -184,11 +220,13 @@ class TestPagination: def test_vendor_pagination_success(self, client, admin_headers, db, test_user): """Test pagination for vendor listing successfully""" import uuid + unique_suffix = str(uuid.uuid4())[:8] # Create multiple vendors for pagination testing from models.database.vendor import Vendor - vendors =[] + + vendors = [] for i in range(15): vendor = Vendor( vendor_code=f"PAGEVENDOR{i:03d}_{unique_suffix}", @@ -202,9 +240,7 @@ class TestPagination: db.commit() # Test first page (assuming admin endpoint exists) - response = client.get( - "/api/v1/vendor?limit=5&skip=0", headers=admin_headers - ) + response = client.get("/api/v1/vendor?limit=5&skip=0", headers=admin_headers) assert response.status_code == 200 data = response.json() assert len(data["vendors"]) == 5 @@ -215,10 +251,12 @@ class TestPagination: def test_inventory_pagination_success(self, client, auth_headers, db): """Test pagination for inventory listing successfully""" import uuid + unique_suffix = str(uuid.uuid4())[:8] # Create multiple inventory entries from models.database.inventory import Inventory + inventory_entries = [] for i in range(20): inventory = Inventory( @@ -249,7 +287,9 @@ class TestPagination: import time start_time = time.time() - response = client.get("/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers + ) end_time = time.time() assert response.status_code == 200 @@ -262,19 +302,25 @@ class TestPagination: def test_pagination_with_invalid_parameters_types(self, client, auth_headers): """Test pagination with invalid parameter types returns ValidationException""" # Test non-numeric skip - response = client.get("/api/v1/marketplace/product?skip=invalid", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=invalid", headers=auth_headers + ) assert response.status_code == 422 data = response.json() assert data["error_code"] == "VALIDATION_ERROR" # Test non-numeric limit - response = client.get("/api/v1/marketplace/product?limit=invalid", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=invalid", headers=auth_headers + ) assert response.status_code == 422 data = response.json() assert data["error_code"] == "VALIDATION_ERROR" # Test float values (should be converted or rejected) - response = client.get("/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers + ) assert response.status_code in [200, 422] # Depends on implementation def test_empty_dataset_pagination(self, client, auth_headers): @@ -282,7 +328,7 @@ class TestPagination: # Use a filter that should return no results response = client.get( "/api/v1/marketplace/product?brand=NonexistentBrand999&limit=10&skip=0", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 @@ -294,7 +340,9 @@ class TestPagination: def test_exception_structure_in_pagination_errors(self, client, auth_headers): """Test that pagination validation errors follow consistent exception structure""" - response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=-1", headers=auth_headers + ) assert response.status_code == 422 data = response.json() diff --git a/tests/integration/api/v1/test_stats_endpoints.py b/tests/integration/api/v1/test_stats_endpoints.py index eade22f2..253615c9 100644 --- a/tests/integration/api/v1/test_stats_endpoints.py +++ b/tests/integration/api/v1/test_stats_endpoints.py @@ -1,6 +1,7 @@ # tests/integration/api/v1/test_stats_endpoints.py import pytest + @pytest.mark.integration @pytest.mark.api @pytest.mark.stats @@ -18,7 +19,9 @@ class TestStatsAPI: assert "unique_vendors" in data assert data["total_products"] >= 1 - def test_get_marketplace_stats(self, client, auth_headers, test_marketplace_product): + def test_get_marketplace_stats( + self, client, auth_headers, test_marketplace_product + ): """Test getting marketplace statistics""" response = client.get("/api/v1/stats/marketplace", headers=auth_headers) diff --git a/tests/integration/api/v1/test_vendor_endpoints.py b/tests/integration/api/v1/test_vendor_endpoints.py index 663cca06..6b689d8c 100644 --- a/tests/integration/api/v1/test_vendor_endpoints.py +++ b/tests/integration/api/v1/test_vendor_endpoints.py @@ -23,7 +23,9 @@ class TestVendorsAPI: assert data["name"] == "New Vendor" assert data["is_active"] is True - def test_create_vendor_duplicate_code_returns_conflict(self, client, auth_headers, test_vendor): + def test_create_vendor_duplicate_code_returns_conflict( + self, client, auth_headers, test_vendor + ): """Test creating vendor with duplicate code returns VendorAlreadyExistsException""" vendor_data = { "vendor_code": test_vendor.vendor_code, @@ -40,7 +42,9 @@ class TestVendorsAPI: assert test_vendor.vendor_code in data["message"] assert data["details"]["vendor_code"] == test_vendor.vendor_code - def test_create_vendor_missing_vendor_code_validation_error(self, client, auth_headers): + def test_create_vendor_missing_vendor_code_validation_error( + self, client, auth_headers + ): """Test creating vendor without vendor_code returns ValidationException""" vendor_data = { "name": "Vendor without Code", @@ -56,7 +60,9 @@ class TestVendorsAPI: assert "Request validation failed" in data["message"] assert "validation_errors" in data["details"] - def test_create_vendor_empty_vendor_name_validation_error(self, client, auth_headers): + def test_create_vendor_empty_vendor_name_validation_error( + self, client, auth_headers + ): """Test creating vendor with empty name returns VendorValidationException""" vendor_data = { "vendor_code": "EMPTYNAME", @@ -73,7 +79,9 @@ class TestVendorsAPI: assert "Vendor name is required" in data["message"] assert data["details"]["field"] == "name" - def test_create_vendor_max_vendors_reached_business_logic_error(self, client, auth_headers, db, test_user): + def test_create_vendor_max_vendors_reached_business_logic_error( + self, client, auth_headers, db, test_user + ): """Test creating vendor when max vendors reached returns MaxVendorsReachedException""" # This test would require creating the maximum allowed vendors first # The exact implementation depends on your business rules @@ -94,8 +102,10 @@ class TestVendorsAPI: assert data["total"] >= 1 assert len(data["vendors"]) >= 1 - # Find our test vendor - test_vendor_found = any(s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"]) + # Find our test vendor + test_vendor_found = any( + s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"] + ) assert test_vendor_found def test_get_vendors_with_filters(self, client, auth_headers, test_vendor): @@ -105,7 +115,7 @@ class TestVendorsAPI: assert response.status_code == 200 data = response.json() for vendor in data["vendors"]: - assert vendor ["is_active"] is True + assert vendor["is_active"] is True # Test verified_only filter response = client.get("/api/v1/vendor?verified_only=true", headers=auth_headers) @@ -135,7 +145,9 @@ class TestVendorsAPI: assert data["details"]["resource_type"] == "Vendor" assert data["details"]["identifier"] == "NONEXISTENT" - def test_get_vendor_unauthorized_access(self, client, auth_headers, test_vendor, other_user, db): + def test_get_vendor_unauthorized_access( + self, client, auth_headers, test_vendor, other_user, db + ): """Test accessing vendor owned by another user returns UnauthorizedVendorAccessException""" # Change vendor owner to other user AND make it unverified/inactive # so that non-owner users cannot access it @@ -154,7 +166,9 @@ class TestVendorsAPI: assert test_vendor.vendor_code in data["message"] assert data["details"]["vendor_code"] == test_vendor.vendor_code - def test_get_vendor_unauthorized_access_with_inactive_vendor(self, client, auth_headers, inactive_vendor): + def test_get_vendor_unauthorized_access_with_inactive_vendor( + self, client, auth_headers, inactive_vendor + ): """Test accessing inactive vendor owned by another user returns UnauthorizedVendorAccessException""" # inactive_vendor fixture already creates an unverified, inactive vendor owned by other_user response = client.get( @@ -168,7 +182,9 @@ class TestVendorsAPI: assert inactive_vendor.vendor_code in data["message"] assert data["details"]["vendor_code"] == inactive_vendor.vendor_code - def test_get_vendor_public_access_allowed(self, client, auth_headers, verified_vendor): + def test_get_vendor_public_access_allowed( + self, client, auth_headers, verified_vendor + ): """Test accessing verified vendor owned by another user is allowed (public access)""" # verified_vendor fixture creates a verified, active vendor owned by other_user # This should allow public access per your business logic @@ -181,7 +197,9 @@ class TestVendorsAPI: assert data["vendor_code"] == verified_vendor.vendor_code assert data["name"] == verified_vendor.name - def test_add_product_to_vendor_success(self, client, auth_headers, test_vendor, unique_product): + def test_add_product_to_vendor_success( + self, client, auth_headers, test_vendor, unique_product + ): """Test adding product to vendor successfully""" product_data = { "marketplace_product_id": unique_product.marketplace_product_id, # Use string marketplace_product_id, not database id @@ -193,7 +211,7 @@ class TestVendorsAPI: response = client.post( f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers, - json=product_data + json=product_data, ) assert response.status_code == 200 @@ -207,10 +225,15 @@ class TestVendorsAPI: # MarketplaceProduct details are nested in the 'marketplace_product' field assert "marketplace_product" in data - assert data["marketplace_product"]["marketplace_product_id"] == unique_product.marketplace_product_id + assert ( + data["marketplace_product"]["marketplace_product_id"] + == unique_product.marketplace_product_id + ) assert data["marketplace_product"]["id"] == unique_product.id - def test_add_product_to_vendor_already_exists_conflict(self, client, auth_headers, test_vendor, test_product): + def test_add_product_to_vendor_already_exists_conflict( + self, client, auth_headers, test_vendor, test_product + ): """Test adding product that already exists in vendor returns ProductAlreadyExistsException""" # test_product fixture already creates a relationship, get the marketplace_product_id string existing_product = test_product.marketplace_product @@ -223,7 +246,7 @@ class TestVendorsAPI: response = client.post( f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers, - json=product_data + json=product_data, ) assert response.status_code == 409 @@ -233,7 +256,9 @@ class TestVendorsAPI: assert test_vendor.vendor_code in data["message"] assert existing_product.marketplace_product_id in data["message"] - def test_add_nonexistent_product_to_vendor_not_found(self, client, auth_headers, test_vendor): + def test_add_nonexistent_product_to_vendor_not_found( + self, client, auth_headers, test_vendor + ): """Test adding nonexistent product to vendor returns MarketplaceProductNotFoundException""" product_data = { "marketplace_product_id": "NONEXISTENT_PRODUCT", # Use string marketplace_product_id that doesn't exist @@ -243,7 +268,7 @@ class TestVendorsAPI: response = client.post( f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers, - json=product_data + json=product_data, ) assert response.status_code == 404 @@ -252,11 +277,12 @@ class TestVendorsAPI: assert data["status_code"] == 404 assert "NONEXISTENT_PRODUCT" in data["message"] - def test_get_products_success(self, client, auth_headers, test_vendor, test_product): + def test_get_products_success( + self, client, auth_headers, test_vendor, test_product + ): """Test getting vendor products successfully""" response = client.get( - f"/api/v1/vendor/{test_vendor.vendor_code}/products", - headers=auth_headers + f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers ) assert response.status_code == 200 @@ -271,22 +297,21 @@ class TestVendorsAPI: # Test active_only filter response = client.get( f"/api/v1/vendor/{test_vendor.vendor_code}/products?active_only=true", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 # Test featured_only filter response = client.get( f"/api/v1/vendor/{test_vendor.vendor_code}/products?featured_only=true", - headers=auth_headers + headers=auth_headers, ) assert response.status_code == 200 def test_get_products_from_nonexistent_vendor_not_found(self, client, auth_headers): """Test getting products from nonexistent vendor returns VendorNotFoundException""" response = client.get( - "/api/v1/vendor/NONEXISTENT/products", - headers=auth_headers + "/api/v1/vendor/NONEXISTENT/products", headers=auth_headers ) assert response.status_code == 404 @@ -295,7 +320,9 @@ class TestVendorsAPI: assert data["status_code"] == 404 assert "NONEXISTENT" in data["message"] - def test_vendor_not_active_business_logic_error(self, client, auth_headers, test_vendor, db): + def test_vendor_not_active_business_logic_error( + self, client, auth_headers, test_vendor, db + ): """Test accessing inactive vendor returns VendorNotActiveException (if enforced)""" # Set vendor to inactive test_vendor.is_active = False @@ -313,7 +340,9 @@ class TestVendorsAPI: assert data["status_code"] == 400 assert test_vendor.vendor_code in data["message"] - def test_vendor_not_verified_business_logic_error(self, client, auth_headers, test_vendor, db): + def test_vendor_not_verified_business_logic_error( + self, client, auth_headers, test_vendor, db + ): """Test operations requiring verification returns VendorNotVerifiedException (if enforced)""" # Set vendor to unverified test_vendor.is_verified = False @@ -328,7 +357,7 @@ class TestVendorsAPI: response = client.post( f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers, - json=product_data + json=product_data, ) # If your service requires verification for adding products diff --git a/tests/integration/api/v1/vendor/test_authentication.py b/tests/integration/api/v1/vendor/test_authentication.py index 8a76417b..a5cf271c 100644 --- a/tests/integration/api/v1/vendor/test_authentication.py +++ b/tests/integration/api/v1/vendor/test_authentication.py @@ -10,8 +10,9 @@ These tests verify that: 5. Vendor context middleware works correctly with API authentication """ -import pytest from datetime import datetime, timedelta, timezone + +import pytest from jose import jwt @@ -26,7 +27,9 @@ class TestVendorAPIAuthentication: # Authentication Tests - /api/v1/vendor/auth/me # ======================================================================== - def test_vendor_auth_me_success(self, client, vendor_user_headers, test_vendor_user): + def test_vendor_auth_me_success( + self, client, vendor_user_headers, test_vendor_user + ): """Test /auth/me endpoint with valid vendor user token""" response = client.get("/api/v1/vendor/auth/me", headers=vendor_user_headers) @@ -50,7 +53,7 @@ class TestVendorAPIAuthentication: """Test /auth/me endpoint with invalid token format""" response = client.get( "/api/v1/vendor/auth/me", - headers={"Authorization": "Bearer invalid_token_here"} + headers={"Authorization": "Bearer invalid_token_here"}, ) assert response.status_code == 401 @@ -66,7 +69,9 @@ class TestVendorAPIAuthentication: assert data["error_code"] == "FORBIDDEN" assert "Admin users cannot access vendor API" in data["message"] - def test_vendor_auth_me_with_regular_user_token(self, client, auth_headers, test_user): + def test_vendor_auth_me_with_regular_user_token( + self, client, auth_headers, test_user + ): """Test /auth/me endpoint rejects regular users""" response = client.get("/api/v1/vendor/auth/me", headers=auth_headers) @@ -88,14 +93,12 @@ class TestVendorAPIAuthentication: } expired_token = jwt.encode( - expired_payload, - auth_manager.secret_key, - algorithm=auth_manager.algorithm + expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm ) response = client.get( "/api/v1/vendor/auth/me", - headers={"Authorization": f"Bearer {expired_token}"} + headers={"Authorization": f"Bearer {expired_token}"}, ) assert response.status_code == 401 @@ -111,8 +114,7 @@ class TestVendorAPIAuthentication: ): """Test dashboard stats with valid vendor authentication""" response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor_user_headers + "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers ) assert response.status_code == 200 @@ -131,10 +133,7 @@ class TestVendorAPIAuthentication: def test_vendor_dashboard_stats_with_admin(self, client, admin_headers): """Test dashboard stats rejects admin users""" - response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=admin_headers - ) + response = client.get("/api/v1/vendor/dashboard/stats", headers=admin_headers) assert response.status_code == 403 data = response.json() @@ -145,10 +144,7 @@ class TestVendorAPIAuthentication: # Login to get session cookie login_response = client.post( "/api/v1/vendor/auth/login", - json={ - "username": test_vendor_user.username, - "password": "vendorpass123" - } + json={"username": test_vendor_user.username, "password": "vendorpass123"}, ) assert login_response.status_code == 200 @@ -169,10 +165,7 @@ class TestVendorAPIAuthentication: # Get a valid session by logging in login_response = client.post( "/api/v1/vendor/auth/login", - json={ - "username": test_vendor_user.username, - "password": "vendorpass123" - } + json={"username": test_vendor_user.username, "password": "vendorpass123"}, ) assert login_response.status_code == 200 @@ -191,8 +184,9 @@ class TestVendorAPIAuthentication: response = client.get(endpoint) # All should fail with 401 (header required) - assert response.status_code == 401, \ - f"Endpoint {endpoint} should reject cookie-only auth" + assert ( + response.status_code == 401 + ), f"Endpoint {endpoint} should reject cookie-only auth" # ======================================================================== # Role-Based Access Control Tests @@ -211,13 +205,15 @@ class TestVendorAPIAuthentication: for endpoint in endpoints: # Test with regular user token response = client.get(endpoint, headers=auth_headers) - assert response.status_code == 403, \ - f"Endpoint {endpoint} should reject regular users" + assert ( + response.status_code == 403 + ), f"Endpoint {endpoint} should reject regular users" # Test with admin token response = client.get(endpoint, headers=admin_headers) - assert response.status_code == 403, \ - f"Endpoint {endpoint} should reject admin users" + assert ( + response.status_code == 403 + ), f"Endpoint {endpoint} should reject admin users" def test_vendor_api_accepts_only_vendor_role( self, client, vendor_user_headers, test_vendor_user @@ -229,8 +225,10 @@ class TestVendorAPIAuthentication: for endpoint in endpoints: response = client.get(endpoint, headers=vendor_user_headers) - assert response.status_code in [200, 404], \ - f"Endpoint {endpoint} should accept vendor users (got {response.status_code})" + assert response.status_code in [ + 200, + 404, + ], f"Endpoint {endpoint} should accept vendor users (got {response.status_code})" # ======================================================================== # Token Validation Tests @@ -248,8 +246,9 @@ class TestVendorAPIAuthentication: for headers in malformed_headers: response = client.get("/api/v1/vendor/auth/me", headers=headers) - assert response.status_code == 401, \ - f"Should reject malformed header: {headers}" + assert ( + response.status_code == 401 + ), f"Should reject malformed header: {headers}" def test_token_with_missing_claims(self, client, auth_manager): """Test token missing required claims""" @@ -261,14 +260,12 @@ class TestVendorAPIAuthentication: } invalid_token = jwt.encode( - invalid_payload, - auth_manager.secret_key, - algorithm=auth_manager.algorithm + invalid_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm ) response = client.get( "/api/v1/vendor/auth/me", - headers={"Authorization": f"Bearer {invalid_token}"} + headers={"Authorization": f"Bearer {invalid_token}"}, ) assert response.status_code == 401 @@ -298,9 +295,7 @@ class TestVendorAPIAuthentication: db.add(test_vendor_user) db.commit() - def test_concurrent_requests_with_same_token( - self, client, vendor_user_headers - ): + def test_concurrent_requests_with_same_token(self, client, vendor_user_headers): """Test that the same token can be used for multiple concurrent requests""" # Make multiple requests with the same token responses = [] @@ -314,10 +309,7 @@ class TestVendorAPIAuthentication: def test_vendor_api_with_empty_authorization_header(self, client): """Test vendor API with empty Authorization header value""" - response = client.get( - "/api/v1/vendor/auth/me", - headers={"Authorization": ""} - ) + response = client.get("/api/v1/vendor/auth/me", headers={"Authorization": ""}) assert response.status_code == 401 @@ -328,17 +320,12 @@ class TestVendorAPIAuthentication: class TestVendorAPIConsistency: """Test that all vendor API endpoints use consistent authentication""" - def test_all_vendor_endpoints_require_header_auth( - self, client, test_vendor_user - ): + def test_all_vendor_endpoints_require_header_auth(self, client, test_vendor_user): """Verify all vendor API endpoints require Authorization header""" # Login to establish session client.post( "/api/v1/vendor/auth/login", - json={ - "username": test_vendor_user.username, - "password": "vendorpass123" - } + json={"username": test_vendor_user.username, "password": "vendorpass123"}, ) # All vendor API endpoints (excluding public endpoints like /info) @@ -361,8 +348,9 @@ class TestVendorAPIConsistency: response = client.post(endpoint, json={}) # All should reject cookie-only auth with 401 - assert response.status_code == 401, \ - f"Endpoint {endpoint} should require Authorization header (got {response.status_code})" + assert ( + response.status_code == 401 + ), f"Endpoint {endpoint} should require Authorization header (got {response.status_code})" def test_vendor_endpoints_accept_vendor_token( self, client, vendor_user_headers, test_vendor_with_vendor_user @@ -380,5 +368,7 @@ class TestVendorAPIConsistency: response = client.get(endpoint, headers=vendor_user_headers) # Should not be authentication/authorization errors - assert response.status_code not in [401, 403], \ - f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})" + assert response.status_code not in [ + 401, + 403, + ], f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})" diff --git a/tests/integration/api/v1/vendor/test_dashboard.py b/tests/integration/api/v1/vendor/test_dashboard.py index f2e4d67f..5fe90cc3 100644 --- a/tests/integration/api/v1/vendor/test_dashboard.py +++ b/tests/integration/api/v1/vendor/test_dashboard.py @@ -23,8 +23,7 @@ class TestVendorDashboardAPI: ): """Test dashboard stats returns correct data structure""" response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor_user_headers + "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers ) assert response.status_code == 200 @@ -66,9 +65,9 @@ class TestVendorDashboardAPI: self, client, db, test_vendor_user, auth_manager ): """Test that dashboard stats only show data for the authenticated vendor""" - from models.database.vendor import Vendor, VendorUser - from models.database.product import Product from models.database.marketplace_product import MarketplaceProduct + from models.database.product import Product + from models.database.vendor import Vendor, VendorUser # Create two separate vendors with different data vendor1 = Vendor( @@ -118,10 +117,7 @@ class TestVendorDashboardAPI: vendor1_headers = {"Authorization": f"Bearer {token_data['access_token']}"} # Get stats for vendor1 - response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor1_headers - ) + response = client.get("/api/v1/vendor/dashboard/stats", headers=vendor1_headers) assert response.status_code == 200 data = response.json() @@ -130,9 +126,7 @@ class TestVendorDashboardAPI: assert data["vendor"]["id"] == vendor1.id assert data["products"]["total"] == 3 - def test_dashboard_stats_without_vendor_association( - self, client, db, auth_manager - ): + def test_dashboard_stats_without_vendor_association(self, client, db, auth_manager): """Test dashboard stats for user not associated with any vendor""" from models.database.user import User @@ -206,8 +200,7 @@ class TestVendorDashboardAPI: ): """Test dashboard stats for vendor with no data""" response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor_user_headers + "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers ) assert response.status_code == 200 @@ -224,8 +217,8 @@ class TestVendorDashboardAPI: self, client, db, vendor_user_headers, test_vendor_with_vendor_user ): """Test dashboard stats accuracy with actual products""" - from models.database.product import Product from models.database.marketplace_product import MarketplaceProduct + from models.database.product import Product # Create marketplace products mp = MarketplaceProduct( @@ -249,8 +242,7 @@ class TestVendorDashboardAPI: # Get stats response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor_user_headers + "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers ) assert response.status_code == 200 @@ -267,8 +259,7 @@ class TestVendorDashboardAPI: start_time = time.time() response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor_user_headers + "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers ) end_time = time.time() @@ -284,8 +275,7 @@ class TestVendorDashboardAPI: responses = [] for _ in range(3): response = client.get( - "/api/v1/vendor/dashboard/stats", - headers=vendor_user_headers + "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers ) responses.append(response) diff --git a/tests/integration/middleware/conftest.py b/tests/integration/middleware/conftest.py index a5e253b4..663dd723 100644 --- a/tests/integration/middleware/conftest.py +++ b/tests/integration/middleware/conftest.py @@ -3,6 +3,7 @@ Fixtures specific to middleware integration tests. """ import pytest + from models.database.vendor import Vendor from models.database.vendor_domain import VendorDomain from models.database.vendor_theme import VendorTheme @@ -12,10 +13,7 @@ from models.database.vendor_theme import VendorTheme def vendor_with_subdomain(db): """Create a vendor with subdomain for testing.""" vendor = Vendor( - name="Test Vendor", - code="testvendor", - subdomain="testvendor", - is_active=True + name="Test Vendor", code="testvendor", subdomain="testvendor", is_active=True ) db.add(vendor) db.commit() @@ -30,7 +28,7 @@ def vendor_with_custom_domain(db): name="Custom Domain Vendor", code="customvendor", subdomain="customvendor", - is_active=True + is_active=True, ) db.add(vendor) db.commit() @@ -38,10 +36,7 @@ def vendor_with_custom_domain(db): # Add custom domain domain = VendorDomain( - vendor_id=vendor.id, - domain="customdomain.com", - is_active=True, - is_primary=True + vendor_id=vendor.id, domain="customdomain.com", is_active=True, is_primary=True ) db.add(domain) db.commit() @@ -56,7 +51,7 @@ def vendor_with_theme(db): name="Themed Vendor", code="themedvendor", subdomain="themedvendor", - is_active=True + is_active=True, ) db.add(vendor) db.commit() @@ -69,7 +64,7 @@ def vendor_with_theme(db): secondary_color="#33FF57", logo_url="/static/vendors/themedvendor/logo.png", favicon_url="/static/vendors/themedvendor/favicon.ico", - custom_css="body { background: #FF5733; }" + custom_css="body { background: #FF5733; }", ) db.add(theme) db.commit() @@ -81,10 +76,7 @@ def vendor_with_theme(db): def inactive_vendor(db): """Create an inactive vendor for testing.""" vendor = Vendor( - name="Inactive Vendor", - code="inactive", - subdomain="inactive", - is_active=False + name="Inactive Vendor", code="inactive", subdomain="inactive", is_active=False ) db.add(vendor) db.commit() diff --git a/tests/integration/middleware/test_context_detection_flow.py b/tests/integration/middleware/test_context_detection_flow.py index 301fbd29..8c06f3f7 100644 --- a/tests/integration/middleware/test_context_detection_flow.py +++ b/tests/integration/middleware/test_context_detection_flow.py @@ -5,8 +5,10 @@ Integration tests for request context detection end-to-end flow. These tests verify that context type (API, ADMIN, VENDOR_DASHBOARD, SHOP, FALLBACK) is correctly detected through real HTTP requests. """ -import pytest from unittest.mock import patch + +import pytest + from middleware.context import RequestContext @@ -23,13 +25,22 @@ class TestContextDetectionFlow: def test_api_path_detected_as_api_context(self, client): """Test that /api/* paths are detected as API context.""" from fastapi import Request + from main import app @app.get("/api/test-api-context") async def test_api(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "context_enum": ( + request.state.context_type.name + if hasattr(request.state, "context_type") + else None + ), } response = client.get("/api/test-api-context") @@ -42,12 +53,17 @@ class TestContextDetectionFlow: def test_nested_api_path_detected_as_api_context(self, client): """Test that nested /api/ paths are detected as API context.""" from fastapi import Request + from main import app @app.get("/api/v1/vendor/products") async def test_nested_api(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } response = client.get("/api/v1/vendor/products") @@ -63,13 +79,22 @@ class TestContextDetectionFlow: def test_admin_path_detected_as_admin_context(self, client): """Test that /admin/* paths are detected as ADMIN context.""" from fastapi import Request + from main import app @app.get("/admin/test-admin-context") async def test_admin(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "context_enum": ( + request.state.context_type.name + if hasattr(request.state, "context_type") + else None + ), } response = client.get("/admin/test-admin-context") @@ -82,19 +107,23 @@ class TestContextDetectionFlow: def test_admin_subdomain_detected_as_admin_context(self, client): """Test that admin.* subdomain is detected as ADMIN context.""" from fastapi import Request + from main import app @app.get("/test-admin-subdomain-context") async def test_admin_subdomain(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-admin-subdomain-context", - headers={"host": "admin.platform.com"} + "/test-admin-subdomain-context", headers={"host": "admin.platform.com"} ) assert response.status_code == 200 @@ -104,12 +133,17 @@ class TestContextDetectionFlow: def test_nested_admin_path_detected_as_admin_context(self, client): """Test that nested /admin/ paths are detected as ADMIN context.""" from fastapi import Request + from main import app @app.get("/admin/vendors/123/edit") async def test_nested_admin(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } response = client.get("/admin/vendors/123/edit") @@ -125,21 +159,31 @@ class TestContextDetectionFlow: def test_vendor_dashboard_path_detected(self, client, vendor_with_subdomain): """Test that /vendor/* paths are detected as VENDOR_DASHBOARD context.""" from fastapi import Request + from main import app @app.get("/vendor/test-vendor-dashboard") async def test_vendor_dashboard(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "context_enum": ( + request.state.context_type.name + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/vendor/test-vendor-dashboard", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -151,19 +195,24 @@ class TestContextDetectionFlow: def test_nested_vendor_dashboard_path_detected(self, client, vendor_with_subdomain): """Test that nested /vendor/ paths are detected as VENDOR_DASHBOARD context.""" from fastapi import Request + from main import app @app.get("/vendor/products/123/edit") async def test_nested_vendor(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/vendor/products/123/edit", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -174,24 +223,36 @@ class TestContextDetectionFlow: # Shop Context Detection Tests # ======================================================================== - def test_shop_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain): + def test_shop_path_with_vendor_detected_as_shop( + self, client, vendor_with_subdomain + ): """Test that /shop/* paths with vendor are detected as SHOP context.""" from fastapi import Request + from main import app @app.get("/shop/test-shop-context") async def test_shop(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "context_enum": ( + request.state.context_type.name + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/shop/test-shop-context", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -200,23 +261,31 @@ class TestContextDetectionFlow: assert data["context_enum"] == "SHOP" assert data["has_vendor"] is True - def test_root_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain): + def test_root_path_with_vendor_detected_as_shop( + self, client, vendor_with_subdomain + ): """Test that root path with vendor is detected as SHOP context.""" from fastapi import Request + from main import app @app.get("/test-root-shop") async def test_root_shop(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-root-shop", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -228,21 +297,27 @@ class TestContextDetectionFlow: def test_custom_domain_shop_detected(self, client, vendor_with_custom_domain): """Test that custom domain shop is detected as SHOP context.""" from fastapi import Request + from main import app @app.get("/products") async def test_custom_domain_shop(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" - response = client.get( - "/products", - headers={"host": "customdomain.com"} - ) + response = client.get("/products", headers={"host": "customdomain.com"}) assert response.status_code == 200 data = response.json() @@ -256,21 +331,30 @@ class TestContextDetectionFlow: def test_unknown_path_without_vendor_fallback_context(self, client): """Test that unknown paths without vendor get FALLBACK context.""" from fastapi import Request + from main import app @app.get("/test-fallback-context") async def test_fallback(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "context_enum": ( + request.state.context_type.name + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-fallback-context", - headers={"host": "platform.com"} + "/test-fallback-context", headers={"host": "platform.com"} ) assert response.status_code == 200 @@ -286,20 +370,26 @@ class TestContextDetectionFlow: def test_api_path_overrides_vendor_context(self, client, vendor_with_subdomain): """Test that /api/* path sets API context even with vendor.""" from fastapi import Request + from main import app @app.get("/api/test-api-priority") async def test_api_priority(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/api/test-api-priority", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -312,20 +402,26 @@ class TestContextDetectionFlow: def test_admin_path_overrides_vendor_context(self, client, vendor_with_subdomain): """Test that /admin/* path sets ADMIN context even with vendor.""" from fastapi import Request + from main import app @app.get("/admin/test-admin-priority") async def test_admin_priority(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/admin/test-admin-priority", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -333,22 +429,29 @@ class TestContextDetectionFlow: # Admin path should override vendor context assert data["context_type"] == "admin" - def test_vendor_dashboard_overrides_shop_context(self, client, vendor_with_subdomain): + def test_vendor_dashboard_overrides_shop_context( + self, client, vendor_with_subdomain + ): """Test that /vendor/* path sets VENDOR_DASHBOARD, not SHOP.""" from fastapi import Request + from main import app @app.get("/vendor/test-priority") async def test_vendor_priority(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/vendor/test-priority", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -363,21 +466,30 @@ class TestContextDetectionFlow: def test_context_uses_clean_path_for_detection(self, client, vendor_with_subdomain): """Test that context detection uses clean_path, not original path.""" from fastapi import Request + from main import app @app.get("/vendors/{vendor_code}/shop/products") async def test_clean_path_context(vendor_code: str, request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None, - "original_path": request.url.path + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "clean_path": ( + request.state.clean_path + if hasattr(request.state, "clean_path") + else None + ), + "original_path": request.url.path, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( f"/vendors/{vendor_with_subdomain.code}/shop/products", - headers={"host": "localhost:8000"} + headers={"host": "localhost:8000"}, ) assert response.status_code == 200 @@ -394,15 +506,20 @@ class TestContextDetectionFlow: def test_context_type_is_enum_instance(self, client): """Test that context_type is a RequestContext enum instance.""" from fastapi import Request + from main import app @app.get("/api/test-enum") async def test_enum(request: Request): - context = request.state.context_type if hasattr(request.state, 'context_type') else None + context = ( + request.state.context_type + if hasattr(request.state, "context_type") + else None + ) return { "is_enum": isinstance(context, RequestContext) if context else False, "enum_name": context.name if context else None, - "enum_value": context.value if context else None + "enum_value": context.value if context else None, } response = client.get("/api/test-enum") @@ -417,23 +534,30 @@ class TestContextDetectionFlow: # Edge Cases # ======================================================================== - def test_empty_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain): + def test_empty_path_with_vendor_detected_as_shop( + self, client, vendor_with_subdomain + ): """Test that empty/root path with vendor is detected as SHOP.""" from fastapi import Request + from main import app @app.get("/") async def test_root(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + "/", headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} ) assert response.status_code in [200, 404] # Might not have root handler @@ -445,13 +569,18 @@ class TestContextDetectionFlow: def test_case_insensitive_context_detection(self, client): """Test that context detection is case insensitive for paths.""" from fastapi import Request + from main import app @app.get("/API/test-case") @app.get("/api/test-case") async def test_case(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } # Test uppercase diff --git a/tests/integration/middleware/test_middleware_stack.py b/tests/integration/middleware/test_middleware_stack.py index f61efea3..4578159f 100644 --- a/tests/integration/middleware/test_middleware_stack.py +++ b/tests/integration/middleware/test_middleware_stack.py @@ -5,8 +5,10 @@ Integration tests for the complete middleware stack. These tests verify that all middleware components work together correctly through real HTTP requests, ensuring proper execution order and state injection. """ -import pytest from unittest.mock import patch + +import pytest + from middleware.context import RequestContext @@ -23,14 +25,20 @@ class TestMiddlewareStackIntegration: """Test that /admin/* paths set ADMIN context type.""" # Create a simple endpoint to inspect request state from fastapi import Request + from main import app @app.get("/admin/test-context") async def test_admin_context(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "has_theme": hasattr(request.state, 'theme') + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "has_theme": hasattr(request.state, "theme"), } response = client.get("/admin/test-context") @@ -44,20 +52,24 @@ class TestMiddlewareStackIntegration: def test_admin_subdomain_sets_admin_context(self, client): """Test that admin.* subdomain sets ADMIN context type.""" from fastapi import Request + from main import app @app.get("/test-admin-subdomain") async def test_admin_subdomain(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } # Simulate request with admin subdomain - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-admin-subdomain", - headers={"host": "admin.platform.com"} + "/test-admin-subdomain", headers={"host": "admin.platform.com"} ) assert response.status_code == 200 @@ -71,12 +83,17 @@ class TestMiddlewareStackIntegration: def test_api_path_sets_api_context(self, client): """Test that /api/* paths set API context type.""" from fastapi import Request + from main import app @app.get("/api/test-context") async def test_api_context(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ) } response = client.get("/api/test-context") @@ -89,25 +106,40 @@ class TestMiddlewareStackIntegration: # Vendor Dashboard Context Tests # ======================================================================== - def test_vendor_dashboard_path_sets_vendor_context(self, client, vendor_with_subdomain): + def test_vendor_dashboard_path_sets_vendor_context( + self, client, vendor_with_subdomain + ): """Test that /vendor/* paths with vendor set VENDOR_DASHBOARD context.""" from fastapi import Request + from main import app @app.get("/vendor/test-context") async def test_vendor_context(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") + else None + ), } # Request with vendor subdomain - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/vendor/test-context", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -120,24 +152,35 @@ class TestMiddlewareStackIntegration: # Shop Context Tests # ======================================================================== - def test_shop_path_with_subdomain_sets_shop_context(self, client, vendor_with_subdomain): + def test_shop_path_with_subdomain_sets_shop_context( + self, client, vendor_with_subdomain + ): """Test that /shop/* paths with vendor subdomain set SHOP context.""" from fastapi import Request + from main import app @app.get("/shop/test-context") async def test_shop_context(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None, - "has_theme": hasattr(request.state, 'theme') + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), + "has_theme": hasattr(request.state, "theme"), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/shop/test-context", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -146,23 +189,33 @@ class TestMiddlewareStackIntegration: assert data["vendor_id"] == vendor_with_subdomain.id assert data["has_theme"] is True - def test_shop_path_with_custom_domain_sets_shop_context(self, client, vendor_with_custom_domain): + def test_shop_path_with_custom_domain_sets_shop_context( + self, client, vendor_with_custom_domain + ): """Test that /shop/* paths with custom domain set SHOP context.""" from fastapi import Request + from main import app @app.get("/shop/test-custom-domain") async def test_shop_custom_domain(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/shop/test-custom-domain", - headers={"host": "customdomain.com"} + "/shop/test-custom-domain", headers={"host": "customdomain.com"} ) assert response.status_code == 200 @@ -174,9 +227,12 @@ class TestMiddlewareStackIntegration: # Middleware Execution Order Tests # ======================================================================== - def test_vendor_context_runs_before_context_detection(self, client, vendor_with_subdomain): + def test_vendor_context_runs_before_context_detection( + self, client, vendor_with_subdomain + ): """Test that VendorContextMiddleware runs before ContextDetectionMiddleware.""" from fastapi import Request + from main import app @app.get("/test-execution-order") @@ -184,16 +240,16 @@ class TestMiddlewareStackIntegration: # If vendor context runs first, clean_path should be available # before context detection uses it return { - "has_vendor": hasattr(request.state, 'vendor'), - "has_clean_path": hasattr(request.state, 'clean_path'), - "has_context_type": hasattr(request.state, 'context_type') + "has_vendor": hasattr(request.state, "vendor"), + "has_clean_path": hasattr(request.state, "clean_path"), + "has_context_type": hasattr(request.state, "context_type"), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-execution-order", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -206,21 +262,26 @@ class TestMiddlewareStackIntegration: def test_theme_context_runs_after_vendor_context(self, client, vendor_with_theme): """Test that ThemeContextMiddleware runs after VendorContextMiddleware.""" from fastapi import Request + from main import app @app.get("/test-theme-loading") async def test_theme_loading(request: Request): return { - "has_vendor": hasattr(request.state, 'vendor'), - "has_theme": hasattr(request.state, 'theme'), - "theme_primary_color": request.state.theme.get('primary_color') if hasattr(request.state, 'theme') else None + "has_vendor": hasattr(request.state, "vendor"), + "has_theme": hasattr(request.state, "theme"), + "theme_primary_color": ( + request.state.theme.get("primary_color") + if hasattr(request.state, "theme") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-theme-loading", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -249,21 +310,27 @@ class TestMiddlewareStackIntegration: def test_missing_vendor_graceful_handling(self, client): """Test that missing vendor is handled gracefully.""" from fastapi import Request + from main import app @app.get("/test-missing-vendor") async def test_missing_vendor(request: Request): return { - "has_vendor": hasattr(request.state, 'vendor'), - "vendor": request.state.vendor if hasattr(request.state, 'vendor') else None, - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None + "has_vendor": hasattr(request.state, "vendor"), + "vendor": ( + request.state.vendor if hasattr(request.state, "vendor") else None + ), + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-missing-vendor", - headers={"host": "nonexistent.platform.com"} + "/test-missing-vendor", headers={"host": "nonexistent.platform.com"} ) assert response.status_code == 200 @@ -276,20 +343,23 @@ class TestMiddlewareStackIntegration: def test_inactive_vendor_not_loaded(self, client, inactive_vendor): """Test that inactive vendors are not loaded.""" from fastapi import Request + from main import app @app.get("/test-inactive-vendor") async def test_inactive_vendor_endpoint(request: Request): return { - "has_vendor": hasattr(request.state, 'vendor'), - "vendor": request.state.vendor if hasattr(request.state, 'vendor') else None + "has_vendor": hasattr(request.state, "vendor"), + "vendor": ( + request.state.vendor if hasattr(request.state, "vendor") else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-inactive-vendor", - headers={"host": f"{inactive_vendor.subdomain}.platform.com"} + headers={"host": f"{inactive_vendor.subdomain}.platform.com"}, ) assert response.status_code == 200 diff --git a/tests/integration/middleware/test_theme_loading_flow.py b/tests/integration/middleware/test_theme_loading_flow.py index 9b19bcfc..00bb55cc 100644 --- a/tests/integration/middleware/test_theme_loading_flow.py +++ b/tests/integration/middleware/test_theme_loading_flow.py @@ -5,9 +5,10 @@ Integration tests for theme loading end-to-end flow. These tests verify that vendor themes are correctly loaded and injected into request.state through real HTTP requests. """ -import pytest from unittest.mock import patch +import pytest + @pytest.mark.integration @pytest.mark.middleware @@ -22,21 +23,22 @@ class TestThemeLoadingFlow: def test_theme_loaded_for_vendor_with_custom_theme(self, client, vendor_with_theme): """Test that custom theme is loaded for vendor with theme.""" from fastapi import Request + from main import app @app.get("/test-theme-loading") async def test_theme(request: Request): - theme = request.state.theme if hasattr(request.state, 'theme') else None + theme = request.state.theme if hasattr(request.state, "theme") else None return { "has_theme": theme is not None, - "theme_data": theme if theme else None + "theme_data": theme if theme else None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-theme-loading", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -46,24 +48,27 @@ class TestThemeLoadingFlow: assert data["theme_data"]["primary_color"] == "#FF5733" assert data["theme_data"]["secondary_color"] == "#33FF57" - def test_default_theme_loaded_for_vendor_without_theme(self, client, vendor_with_subdomain): + def test_default_theme_loaded_for_vendor_without_theme( + self, client, vendor_with_subdomain + ): """Test that default theme is loaded for vendor without custom theme.""" from fastapi import Request + from main import app @app.get("/test-default-theme") async def test_default_theme(request: Request): - theme = request.state.theme if hasattr(request.state, 'theme') else None + theme = request.state.theme if hasattr(request.state, "theme") else None return { "has_theme": theme is not None, - "theme_data": theme if theme else None + "theme_data": theme if theme else None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-default-theme", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -76,21 +81,20 @@ class TestThemeLoadingFlow: def test_no_theme_loaded_without_vendor(self, client): """Test that no theme is loaded when there's no vendor.""" from fastapi import Request + from main import app @app.get("/test-no-theme") async def test_no_theme(request: Request): return { - "has_theme": hasattr(request.state, 'theme'), - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None + "has_theme": hasattr(request.state, "theme"), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" - response = client.get( - "/test-no-theme", - headers={"host": "platform.com"} - ) + response = client.get("/test-no-theme", headers={"host": "platform.com"}) assert response.status_code == 200 data = response.json() @@ -105,24 +109,25 @@ class TestThemeLoadingFlow: def test_custom_theme_contains_all_fields(self, client, vendor_with_theme): """Test that custom theme contains all expected fields.""" from fastapi import Request + from main import app @app.get("/test-theme-fields") async def test_theme_fields(request: Request): - theme = request.state.theme if hasattr(request.state, 'theme') else {} + theme = request.state.theme if hasattr(request.state, "theme") else {} return { "primary_color": theme.get("primary_color"), "secondary_color": theme.get("secondary_color"), "logo_url": theme.get("logo_url"), "favicon_url": theme.get("favicon_url"), - "custom_css": theme.get("custom_css") + "custom_css": theme.get("custom_css"), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-theme-fields", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -136,24 +141,25 @@ class TestThemeLoadingFlow: def test_default_theme_structure(self, client, vendor_with_subdomain): """Test that default theme has expected structure.""" from fastapi import Request + from main import app @app.get("/test-default-theme-structure") async def test_default_structure(request: Request): - theme = request.state.theme if hasattr(request.state, 'theme') else {} + theme = request.state.theme if hasattr(request.state, "theme") else {} return { "has_primary_color": "primary_color" in theme, "has_secondary_color": "secondary_color" in theme, "has_logo_url": "logo_url" in theme, "has_favicon_url": "favicon_url" in theme, - "has_custom_css": "custom_css" in theme + "has_custom_css": "custom_css" in theme, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-default-theme-structure", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -169,21 +175,30 @@ class TestThemeLoadingFlow: def test_theme_loaded_in_shop_context(self, client, vendor_with_theme): """Test that theme is loaded in SHOP context.""" from fastapi import Request + from main import app @app.get("/shop/test-shop-theme") async def test_shop_theme(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_theme": hasattr(request.state, 'theme'), - "theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_theme": hasattr(request.state, "theme"), + "theme_primary": ( + request.state.theme.get("primary_color") + if hasattr(request.state, "theme") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/shop/test-shop-theme", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -195,21 +210,30 @@ class TestThemeLoadingFlow: def test_theme_loaded_in_vendor_dashboard_context(self, client, vendor_with_theme): """Test that theme is loaded in VENDOR_DASHBOARD context.""" from fastapi import Request + from main import app @app.get("/vendor/test-dashboard-theme") async def test_dashboard_theme(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_theme": hasattr(request.state, 'theme'), - "theme_secondary": request.state.theme.get("secondary_color") if hasattr(request.state, 'theme') else None + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_theme": hasattr(request.state, "theme"), + "theme_secondary": ( + request.state.theme.get("secondary_color") + if hasattr(request.state, "theme") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/vendor/test-dashboard-theme", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -221,21 +245,27 @@ class TestThemeLoadingFlow: def test_theme_loaded_in_api_context_with_vendor(self, client, vendor_with_theme): """Test that theme is loaded in API context when vendor is present.""" from fastapi import Request + from main import app @app.get("/api/test-api-theme") async def test_api_theme(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "has_theme": hasattr(request.state, 'theme') + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "has_theme": hasattr(request.state, "theme"), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/api/test-api-theme", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -248,13 +278,18 @@ class TestThemeLoadingFlow: def test_no_theme_in_admin_context(self, client): """Test that theme is not loaded in ADMIN context (no vendor).""" from fastapi import Request + from main import app @app.get("/admin/test-admin-no-theme") async def test_admin_no_theme(request: Request): return { - "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None, - "has_theme": hasattr(request.state, 'theme') + "context_type": ( + request.state.context_type.value + if hasattr(request.state, "context_type") + else None + ), + "has_theme": hasattr(request.state, "theme"), } response = client.get("/admin/test-admin-no-theme") @@ -272,20 +307,29 @@ class TestThemeLoadingFlow: def test_theme_loaded_with_subdomain_routing(self, client, vendor_with_theme): """Test theme loading with subdomain routing.""" from fastapi import Request + from main import app @app.get("/test-subdomain-theme") async def test_subdomain_theme(request: Request): return { - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None, - "theme_logo": request.state.theme.get("logo_url") if hasattr(request.state, 'theme') else None + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), + "theme_logo": ( + request.state.theme.get("logo_url") + if hasattr(request.state, "theme") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-subdomain-theme", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -293,33 +337,40 @@ class TestThemeLoadingFlow: assert data["vendor_code"] == vendor_with_theme.code assert data["theme_logo"] == "/static/vendors/themedvendor/logo.png" - def test_theme_loaded_with_custom_domain_routing(self, client, vendor_with_custom_domain, db): + def test_theme_loaded_with_custom_domain_routing( + self, client, vendor_with_custom_domain, db + ): """Test theme loading with custom domain routing.""" # Add theme to custom domain vendor from models.database.vendor_theme import VendorTheme + theme = VendorTheme( vendor_id=vendor_with_custom_domain.id, primary_color="#123456", - secondary_color="#654321" + secondary_color="#654321", ) db.add(theme) db.commit() from fastapi import Request + from main import app @app.get("/test-custom-domain-theme") async def test_custom_domain_theme(request: Request): return { - "has_theme": hasattr(request.state, 'theme'), - "theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None + "has_theme": hasattr(request.state, "theme"), + "theme_primary": ( + request.state.theme.get("primary_color") + if hasattr(request.state, "theme") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-custom-domain-theme", - headers={"host": "customdomain.com"} + "/test-custom-domain-theme", headers={"host": "customdomain.com"} ) assert response.status_code == 200 @@ -331,28 +382,37 @@ class TestThemeLoadingFlow: # Theme Dependency on Vendor Context Tests # ======================================================================== - def test_theme_middleware_depends_on_vendor_middleware(self, client, vendor_with_theme): + def test_theme_middleware_depends_on_vendor_middleware( + self, client, vendor_with_theme + ): """Test that theme loading depends on vendor being detected first.""" from fastapi import Request + from main import app @app.get("/test-theme-vendor-dependency") async def test_dependency(request: Request): return { - "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None, - "has_theme": hasattr(request.state, 'theme'), + "has_vendor": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), + "has_theme": hasattr(request.state, "theme"), "vendor_matches_theme": ( request.state.vendor_id == vendor_with_theme.id - if hasattr(request.state, 'vendor_id') else False - ) + if hasattr(request.state, "vendor_id") + else False + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-theme-vendor-dependency", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -369,15 +429,20 @@ class TestThemeLoadingFlow: def test_theme_loaded_consistently_across_requests(self, client, vendor_with_theme): """Test that theme is loaded consistently across multiple requests.""" from fastapi import Request + from main import app @app.get("/test-theme-consistency") async def test_consistency(request: Request): return { - "theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None + "theme_primary": ( + request.state.theme.get("primary_color") + if hasattr(request.state, "theme") + else None + ) } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" # Make multiple requests @@ -385,7 +450,7 @@ class TestThemeLoadingFlow: for _ in range(3): response = client.get( "/test-theme-consistency", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) responses.append(response.json()) @@ -396,25 +461,28 @@ class TestThemeLoadingFlow: # Edge Cases and Error Handling Tests # ======================================================================== - def test_theme_gracefully_handles_missing_theme_fields(self, client, vendor_with_subdomain): + def test_theme_gracefully_handles_missing_theme_fields( + self, client, vendor_with_subdomain + ): """Test that missing theme fields are handled gracefully.""" from fastapi import Request + from main import app @app.get("/test-partial-theme") async def test_partial_theme(request: Request): - theme = request.state.theme if hasattr(request.state, 'theme') else {} + theme = request.state.theme if hasattr(request.state, "theme") else {} return { "has_theme": bool(theme), "primary_color": theme.get("primary_color", "default"), - "logo_url": theme.get("logo_url", "default") + "logo_url": theme.get("logo_url", "default"), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-partial-theme", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -427,23 +495,21 @@ class TestThemeLoadingFlow: def test_theme_dict_is_mutable(self, client, vendor_with_theme): """Test that theme dict can be accessed and read from.""" from fastapi import Request + from main import app @app.get("/test-theme-mutable") async def test_mutable(request: Request): - theme = request.state.theme if hasattr(request.state, 'theme') else {} + theme = request.state.theme if hasattr(request.state, "theme") else {} # Try to access theme values primary = theme.get("primary_color") - return { - "can_read": primary is not None, - "value": primary - } + return {"can_read": primary is not None, "value": primary} - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-theme-mutable", - headers={"host": f"{vendor_with_theme.subdomain}.platform.com"} + headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}, ) assert response.status_code == 200 diff --git a/tests/integration/middleware/test_vendor_context_flow.py b/tests/integration/middleware/test_vendor_context_flow.py index 25a5bf94..657ab068 100644 --- a/tests/integration/middleware/test_vendor_context_flow.py +++ b/tests/integration/middleware/test_vendor_context_flow.py @@ -5,9 +5,10 @@ Integration tests for vendor context detection end-to-end flow. These tests verify that vendor detection works correctly through real HTTP requests for all routing modes: subdomain, custom domain, and path-based. """ -import pytest from unittest.mock import patch +import pytest + @pytest.mark.integration @pytest.mark.middleware @@ -22,23 +23,37 @@ class TestVendorContextFlow: def test_subdomain_vendor_detection(self, client, vendor_with_subdomain): """Test vendor detection via subdomain routing.""" from fastapi import Request + from main import app @app.get("/test-subdomain-detection") async def test_subdomain(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None, - "vendor_name": request.state.vendor.name if hasattr(request.state, 'vendor') and request.state.vendor else None, - "detection_method": "subdomain" + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), + "vendor_name": ( + request.state.vendor.name + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), + "detection_method": "subdomain", } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-subdomain-detection", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -51,20 +66,28 @@ class TestVendorContextFlow: def test_subdomain_with_port_detection(self, client, vendor_with_subdomain): """Test vendor detection via subdomain with port number.""" from fastapi import Request + from main import app @app.get("/test-subdomain-port") async def test_subdomain_port(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-subdomain-port", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"} + headers={ + "host": f"{vendor_with_subdomain.subdomain}.platform.com:8000" + }, ) assert response.status_code == 200 @@ -75,20 +98,24 @@ class TestVendorContextFlow: def test_nonexistent_subdomain_returns_no_vendor(self, client): """Test that nonexistent subdomain doesn't crash and returns no vendor.""" from fastapi import Request + from main import app @app.get("/test-nonexistent-subdomain") async def test_nonexistent(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor": request.state.vendor if hasattr(request.state, 'vendor') else None + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor": ( + request.state.vendor if hasattr(request.state, "vendor") else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-nonexistent-subdomain", - headers={"host": "nonexistent.platform.com"} + headers={"host": "nonexistent.platform.com"}, ) assert response.status_code == 200 @@ -102,22 +129,31 @@ class TestVendorContextFlow: def test_custom_domain_vendor_detection(self, client, vendor_with_custom_domain): """Test vendor detection via custom domain.""" from fastapi import Request + from main import app @app.get("/test-custom-domain") async def test_custom_domain(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None, - "detection_method": "custom_domain" + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), + "detection_method": "custom_domain", } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-custom-domain", - headers={"host": "customdomain.com"} + "/test-custom-domain", headers={"host": "customdomain.com"} ) assert response.status_code == 200 @@ -129,21 +165,26 @@ class TestVendorContextFlow: def test_custom_domain_with_www_detection(self, client, vendor_with_custom_domain): """Test vendor detection via custom domain with www prefix.""" from fastapi import Request + from main import app @app.get("/test-custom-domain-www") async def test_custom_domain_www(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" # Test with www prefix - should still detect vendor response = client.get( - "/test-custom-domain-www", - headers={"host": "www.customdomain.com"} + "/test-custom-domain-www", headers={"host": "www.customdomain.com"} ) # This might fail if your implementation doesn't strip www @@ -154,25 +195,37 @@ class TestVendorContextFlow: # Path-Based Detection Tests (Development Mode) # ======================================================================== - def test_path_based_vendor_detection_vendors_prefix(self, client, vendor_with_subdomain): + def test_path_based_vendor_detection_vendors_prefix( + self, client, vendor_with_subdomain + ): """Test vendor detection via path-based routing with /vendors/ prefix.""" from fastapi import Request + from main import app @app.get("/vendors/{vendor_code}/test-path") async def test_path_based(vendor_code: str, request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, "vendor_code_param": vendor_code, - "vendor_code_state": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None, - "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None + "vendor_code_state": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), + "clean_path": ( + request.state.clean_path + if hasattr(request.state, "clean_path") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( f"/vendors/{vendor_with_subdomain.code}/test-path", - headers={"host": "localhost:8000"} + headers={"host": "localhost:8000"}, ) assert response.status_code == 200 @@ -181,23 +234,31 @@ class TestVendorContextFlow: assert data["vendor_code_param"] == vendor_with_subdomain.code assert data["vendor_code_state"] == vendor_with_subdomain.code - def test_path_based_vendor_detection_vendor_prefix(self, client, vendor_with_subdomain): + def test_path_based_vendor_detection_vendor_prefix( + self, client, vendor_with_subdomain + ): """Test vendor detection via path-based routing with /vendor/ prefix.""" from fastapi import Request + from main import app @app.get("/vendor/{vendor_code}/test") async def test_vendor_path(vendor_code: str, request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None, - "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None, + "vendor_code": ( + request.state.vendor.code + if hasattr(request.state, "vendor") and request.state.vendor + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( f"/vendor/{vendor_with_subdomain.code}/test", - headers={"host": "localhost:8000"} + headers={"host": "localhost:8000"}, ) assert response.status_code == 200 @@ -209,48 +270,63 @@ class TestVendorContextFlow: # Clean Path Extraction Tests # ======================================================================== - def test_clean_path_extracted_from_vendor_prefix(self, client, vendor_with_subdomain): + def test_clean_path_extracted_from_vendor_prefix( + self, client, vendor_with_subdomain + ): """Test that clean_path is correctly extracted from path-based routing.""" from fastapi import Request + from main import app @app.get("/vendors/{vendor_code}/shop/products") async def test_clean_path(vendor_code: str, request: Request): return { - "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None, - "original_path": request.url.path + "clean_path": ( + request.state.clean_path + if hasattr(request.state, "clean_path") + else None + ), + "original_path": request.url.path, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( f"/vendors/{vendor_with_subdomain.code}/shop/products", - headers={"host": "localhost:8000"} + headers={"host": "localhost:8000"}, ) assert response.status_code == 200 data = response.json() # Clean path should have vendor prefix removed assert data["clean_path"] == "/shop/products" - assert f"/vendors/{vendor_with_subdomain.code}/shop/products" in data["original_path"] + assert ( + f"/vendors/{vendor_with_subdomain.code}/shop/products" + in data["original_path"] + ) def test_clean_path_unchanged_for_subdomain(self, client, vendor_with_subdomain): """Test that clean_path equals original path for subdomain routing.""" from fastapi import Request + from main import app @app.get("/shop/test-clean-path") async def test_subdomain_clean_path(request: Request): return { - "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None, - "original_path": request.url.path + "clean_path": ( + request.state.clean_path + if hasattr(request.state, "clean_path") + else None + ), + "original_path": request.url.path, } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/shop/test-clean-path", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -266,21 +342,30 @@ class TestVendorContextFlow: def test_vendor_id_injected_into_request_state(self, client, vendor_with_subdomain): """Test that vendor_id is correctly injected into request.state.""" from fastapi import Request + from main import app @app.get("/test-vendor-id-injection") async def test_vendor_id(request: Request): return { - "has_vendor_id": hasattr(request.state, 'vendor_id'), - "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None, - "vendor_id_type": type(request.state.vendor_id).__name__ if hasattr(request.state, 'vendor_id') else None + "has_vendor_id": hasattr(request.state, "vendor_id"), + "vendor_id": ( + request.state.vendor_id + if hasattr(request.state, "vendor_id") + else None + ), + "vendor_id_type": ( + type(request.state.vendor_id).__name__ + if hasattr(request.state, "vendor_id") + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-vendor-id-injection", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -289,30 +374,41 @@ class TestVendorContextFlow: assert data["vendor_id"] == vendor_with_subdomain.id assert data["vendor_id_type"] == "int" - def test_vendor_object_injected_into_request_state(self, client, vendor_with_subdomain): + def test_vendor_object_injected_into_request_state( + self, client, vendor_with_subdomain + ): """Test that full vendor object is injected into request.state.""" from fastapi import Request + from main import app @app.get("/test-vendor-object-injection") async def test_vendor_object(request: Request): - vendor = request.state.vendor if hasattr(request.state, 'vendor') and request.state.vendor else None + vendor = ( + request.state.vendor + if hasattr(request.state, "vendor") and request.state.vendor + else None + ) return { "has_vendor": vendor is not None, - "vendor_attributes": { - "id": vendor.id if vendor else None, - "name": vendor.name if vendor else None, - "code": vendor.code if vendor else None, - "subdomain": vendor.subdomain if vendor else None, - "is_active": vendor.is_active if vendor else None - } if vendor else None + "vendor_attributes": ( + { + "id": vendor.id if vendor else None, + "name": vendor.name if vendor else None, + "code": vendor.code if vendor else None, + "subdomain": vendor.subdomain if vendor else None, + "is_active": vendor.is_active if vendor else None, + } + if vendor + else None + ), } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-vendor-object-injection", - headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"} + headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -330,19 +426,21 @@ class TestVendorContextFlow: def test_inactive_vendor_not_detected(self, client, inactive_vendor): """Test that inactive vendors are not detected.""" from fastapi import Request + from main import app @app.get("/test-inactive-vendor-detection") async def test_inactive(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( "/test-inactive-vendor-detection", - headers={"host": f"{inactive_vendor.subdomain}.platform.com"} + headers={"host": f"{inactive_vendor.subdomain}.platform.com"}, ) assert response.status_code == 200 @@ -352,19 +450,20 @@ class TestVendorContextFlow: def test_platform_domain_without_subdomain_no_vendor(self, client): """Test that platform domain without subdomain doesn't detect vendor.""" from fastapi import Request + from main import app @app.get("/test-platform-domain") async def test_platform(request: Request): return { - "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None + "vendor_detected": hasattr(request.state, "vendor") + and request.state.vendor is not None } - with patch('app.core.config.settings') as mock_settings: + with patch("app.core.config.settings") as mock_settings: mock_settings.platform_domain = "platform.com" response = client.get( - "/test-platform-domain", - headers={"host": "platform.com"} + "/test-platform-domain", headers={"host": "platform.com"} ) assert response.status_code == 200 diff --git a/tests/integration/security/test_input_validation.py b/tests/integration/security/test_input_validation.py index 1b44136f..17cb6433 100644 --- a/tests/integration/security/test_input_validation.py +++ b/tests/integration/security/test_input_validation.py @@ -11,7 +11,8 @@ class TestInputValidation: malicious_search = "'; DROP TABLE products; --" response = client.get( - f"/api/v1/marketplace/product?search={malicious_search}", headers=auth_headers + f"/api/v1/marketplace/product?search={malicious_search}", + headers=auth_headers, ) # Should not crash and should return normal response @@ -40,17 +41,23 @@ class TestInputValidation: def test_parameter_validation(self, client, auth_headers): """Test parameter validation for API endpoints""" # Test invalid pagination parameters - response = client.get("/api/v1/marketplace/product?limit=-1", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=-1", headers=auth_headers + ) assert response.status_code == 422 # Validation error - response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?skip=-1", headers=auth_headers + ) assert response.status_code == 422 # Validation error def test_json_validation(self, client, auth_headers): """Test JSON validation for POST requests""" # Test invalid JSON structure response = client.post( - "/api/v1/marketplace/product", headers=auth_headers, content="invalid json content" + "/api/v1/marketplace/product", + headers=auth_headers, + content="invalid json content", ) assert response.status_code == 422 # JSON decode error @@ -58,6 +65,8 @@ class TestInputValidation: response = client.post( "/api/v1/marketplace/product", headers=auth_headers, - json={"title": "Test MarketplaceProduct"}, # Missing required marketplace_product_id + json={ + "title": "Test MarketplaceProduct" + }, # Missing required marketplace_product_id ) assert response.status_code == 422 # Validation error diff --git a/tests/integration/workflows/test_integration.py b/tests/integration/workflows/test_integration.py index a10cfe01..ba50d374 100644 --- a/tests/integration/workflows/test_integration.py +++ b/tests/integration/workflows/test_integration.py @@ -33,12 +33,15 @@ class TestIntegrationFlows: "quantity": 50, } - response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) + response = client.post( + "/api/v1/inventory", headers=auth_headers, json=inventory_data + ) assert response.status_code == 200 # 3. Get product with inventory info response = client.get( - f"/api/v1/marketplace/product/{product['marketplace_product_id']}", headers=auth_headers + f"/api/v1/marketplace/product/{product['marketplace_product_id']}", + headers=auth_headers, ) assert response.status_code == 200 product_detail = response.json() @@ -55,14 +58,15 @@ class TestIntegrationFlows: # 5. Search for product response = client.get( - "/api/v1/marketplace/product?search=Updated Integration", headers=auth_headers + "/api/v1/marketplace/product?search=Updated Integration", + headers=auth_headers, ) assert response.status_code == 200 assert response.json()["total"] == 1 def test_product_workflow(self, client, auth_headers): """Test vendor creation and product management workflow""" - # 1. Create a vendor + # 1. Create a vendor vendor_data = { "vendor_code": "FLOWVENDOR", "name": "Integration Flow Vendor", @@ -91,7 +95,9 @@ class TestIntegrationFlows: # This would test the vendor -product association # 4. Get vendor details - response = client.get(f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers) + response = client.get( + f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers + ) assert response.status_code == 200 def test_inventory_operations_workflow(self, client, auth_headers): diff --git a/tests/performance/test_api_performance.py b/tests/performance/test_api_performance.py index e60f7888..5d442052 100644 --- a/tests/performance/test_api_performance.py +++ b/tests/performance/test_api_performance.py @@ -28,7 +28,9 @@ class TestPerformance: # Time the request start_time = time.time() - response = client.get("/api/v1/marketplace/product?limit=100", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?limit=100", headers=auth_headers + ) end_time = time.time() assert response.status_code == 200 @@ -54,7 +56,9 @@ class TestPerformance: # Time search request start_time = time.time() - response = client.get("/api/v1/marketplace/product?search=Searchable", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product?search=Searchable", headers=auth_headers + ) end_time = time.time() assert response.status_code == 200 @@ -115,7 +119,8 @@ class TestPerformance: for offset in offsets: start_time = time.time() response = client.get( - f"/api/v1/marketplace/product?skip={offset}&limit=20", headers=auth_headers + f"/api/v1/marketplace/product?skip={offset}&limit=20", + headers=auth_headers, ) end_time = time.time() diff --git a/tests/system/test_error_handling.py b/tests/system/test_error_handling.py index 9247035f..1cde868b 100644 --- a/tests/system/test_error_handling.py +++ b/tests/system/test_error_handling.py @@ -5,9 +5,10 @@ System tests for error handling across the LetzVendor API. Tests the complete error handling flow from FastAPI through custom exception handlers to ensure proper HTTP status codes, error structures, and client-friendly responses. """ -import pytest import json +import pytest + @pytest.mark.system class TestErrorHandling: @@ -16,9 +17,7 @@ class TestErrorHandling: def test_invalid_json_request(self, client, auth_headers): """Test handling of malformed JSON requests""" response = client.post( - "/api/v1/vendor", - headers=auth_headers, - content="{ invalid json syntax" + "/api/v1/vendor", headers=auth_headers, content="{ invalid json syntax" ) assert response.status_code == 422 @@ -31,9 +30,7 @@ class TestErrorHandling: """Test validation errors for missing required fields""" # Missing name response = client.post( - "/api/v1/vendor", - headers=auth_headers, - json={"vendor_code": "TESTVENDOR"} + "/api/v1/vendor", headers=auth_headers, json={"vendor_code": "TESTVENDOR"} ) assert response.status_code == 422 @@ -48,10 +45,7 @@ class TestErrorHandling: response = client.post( "/api/v1/vendor", headers=auth_headers, - json={ - "vendor_code": "INVALID@VENDOR!", - "name": "Test Vendor" - } + json={"vendor_code": "INVALID@VENDOR!", "name": "Test Vendor"}, ) assert response.status_code == 422 @@ -92,7 +86,7 @@ class TestErrorHandling: assert data["status_code"] == 401 def test_vendor_not_found(self, client, auth_headers): - """Test accessing non-existent vendor """ + """Test accessing non-existent vendor""" response = client.get("/api/v1/vendor/NONEXISTENT", headers=auth_headers) assert response.status_code == 404 @@ -104,7 +98,9 @@ class TestErrorHandling: def test_product_not_found(self, client, auth_headers): """Test accessing non-existent product""" - response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers) + response = client.get( + "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers + ) assert response.status_code == 404 data = response.json() @@ -117,7 +113,7 @@ class TestErrorHandling: """Test creating vendor with duplicate vendor code""" vendor_data = { "vendor_code": test_vendor.vendor_code, - "name": "Duplicate Vendor" + "name": "Duplicate Vendor", } response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data) @@ -128,25 +124,34 @@ class TestErrorHandling: assert data["status_code"] == 409 assert data["details"]["vendor_code"] == test_vendor.vendor_code.upper() - def test_duplicate_product_creation(self, client, auth_headers, test_marketplace_product): + def test_duplicate_product_creation( + self, client, auth_headers, test_marketplace_product + ): """Test creating product with duplicate product ID""" product_data = { "marketplace_product_id": test_marketplace_product.marketplace_product_id, "title": "Duplicate MarketplaceProduct", - "gtin": "1234567890123" + "gtin": "1234567890123", } - response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/marketplace/product", headers=auth_headers, json=product_data + ) assert response.status_code == 409 data = response.json() assert data["error_code"] == "PRODUCT_ALREADY_EXISTS" assert data["status_code"] == 409 - assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id + assert ( + data["details"]["marketplace_product_id"] + == test_marketplace_product.marketplace_product_id + ) def test_unauthorized_vendor_access(self, client, auth_headers, inactive_vendor): """Test accessing vendor without proper permissions""" - response = client.get(f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers) + response = client.get( + f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers + ) assert response.status_code == 403 data = response.json() @@ -154,27 +159,33 @@ class TestErrorHandling: assert data["status_code"] == 403 assert data["details"]["vendor_code"] == inactive_vendor.vendor_code - def test_insufficient_permissions(self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"): + def test_insufficient_permissions( + self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users" + ): """Test accessing admin endpoints with regular user""" response = client.get(admin_only_endpoint, headers=auth_headers) - assert response.status_code in [403, 404] # 403 for permission denied, 404 if endpoint doesn't exist + assert response.status_code in [ + 403, + 404, + ] # 403 for permission denied, 404 if endpoint doesn't exist if response.status_code == 403: data = response.json() assert data["error_code"] in ["ADMIN_REQUIRED", "INSUFFICIENT_PERMISSIONS"] assert data["status_code"] == 403 - def test_business_logic_violation_max_vendors(self, client, auth_headers, monkeypatch): + def test_business_logic_violation_max_vendors( + self, client, auth_headers, monkeypatch + ): """Test business logic violation - creating too many vendors""" # This test would require mocking the vendor limit check # For now, test the error structure when creating multiple vendors vendors_created = [] for i in range(6): # Assume limit is 5 - vendor_data = { - "vendor_code": f"VENDOR{i:03d}", - "name": f"Test Vendor {i}" - } - response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data) + vendor_data = {"vendor_code": f"VENDOR{i:03d}", "name": f"Test Vendor {i}"} + response = client.post( + "/api/v1/vendor", headers=auth_headers, json=vendor_data + ) vendors_created.append(response) # At least one should succeed, and if limit is enforced, later ones should fail @@ -193,10 +204,12 @@ class TestErrorHandling: product_data = { "marketplace_product_id": "TESTPROD001", "title": "Test MarketplaceProduct", - "gtin": "invalid_gtin_format" + "gtin": "invalid_gtin_format", } - response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/marketplace/product", headers=auth_headers, json=product_data + ) assert response.status_code == 422 data = response.json() @@ -204,13 +217,15 @@ class TestErrorHandling: assert data["status_code"] == 422 assert data["details"]["field"] == "gtin" - def test_inventory_insufficient_quantity(self, client, auth_headers, test_vendor, test_marketplace_product): + def test_inventory_insufficient_quantity( + self, client, auth_headers, test_vendor, test_marketplace_product + ): """Test business logic error for insufficient inventory""" # First create some inventory inventory_data = { "gtin": test_marketplace_product.gtin, "location": "WAREHOUSE_A", - "quantity": 5 + "quantity": 5, } client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data) @@ -218,9 +233,11 @@ class TestErrorHandling: remove_data = { "gtin": test_marketplace_product.gtin, "location": "WAREHOUSE_A", - "quantity": 10 # More than the 5 we added + "quantity": 10, # More than the 5 we added } - response = client.post("/api/v1/inventory/remove", headers=auth_headers, json=remove_data) + response = client.post( + "/api/v1/inventory/remove", headers=auth_headers, json=remove_data + ) # This should ALWAYS fail with insufficient inventory error assert response.status_code == 400 @@ -257,7 +274,7 @@ class TestErrorHandling: response = client.post( "/api/v1/vendor", headers=headers, - content="TEST" + content="TEST", ) assert response.status_code in [400, 415, 422] @@ -268,7 +285,7 @@ class TestErrorHandling: vendor_data = { "vendor_code": "LARGEVENDOR", "name": "Large Vendor", - "description": large_description + "description": large_description, } response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data) @@ -310,10 +327,12 @@ class TestErrorHandling: # Test invalid marketplace import_data = { "marketplace": "INVALID_MARKETPLACE", - "vendor_code": test_vendor.vendor_code + "vendor_code": test_vendor.vendor_code, } - response = client.post("/api/v1/imports", headers=auth_headers, json=import_data) + response = client.post( + "/api/v1/imports", headers=auth_headers, json=import_data + ) if response.status_code == 422: data = response.json() @@ -329,10 +348,12 @@ class TestErrorHandling: # Test with potentially problematic external data import_data = { "marketplace": "LETZSHOP", - "external_url": "https://nonexistent-marketplace.com/api" + "external_url": "https://nonexistent-marketplace.com/api", } - response = client.post("/api/v1/imports", headers=auth_headers, json=import_data) + response = client.post( + "/api/v1/imports", headers=auth_headers, json=import_data + ) # If it's a real external service error, check structure if response.status_code == 502: @@ -356,7 +377,9 @@ class TestErrorHandling: # All error responses should have these fields required_fields = ["error_code", "message", "status_code"] for field in required_fields: - assert field in data, f"Missing {field} in error response for {endpoint}" + assert ( + field in data + ), f"Missing {field} in error response for {endpoint}" # Details field should be present (can be empty dict) assert "details" in data @@ -420,5 +443,7 @@ class TestErrorRecovery: # Check that error was logged (if your app logs 404s as errors) # Adjust based on your logging configuration - error_logs = [record for record in caplog.records if record.levelno >= logging.ERROR] + error_logs = [ + record for record in caplog.records if record.levelno >= logging.ERROR + ] # May or may not have logs depending on whether 404s are logged as errors diff --git a/tests/unit/middleware/test_auth.py b/tests/unit/middleware/test_auth.py index d998b4cb..d78e5b6d 100644 --- a/tests/unit/middleware/test_auth.py +++ b/tests/unit/middleware/test_auth.py @@ -12,21 +12,18 @@ Tests cover: - Error handling and edge cases """ -import pytest -from unittest.mock import Mock, MagicMock, patch from datetime import datetime, timedelta, timezone -from jose import jwt -from fastapi import HTTPException +from unittest.mock import MagicMock, Mock, patch +import pytest +from fastapi import HTTPException +from jose import jwt + +from app.exceptions import (AdminRequiredException, + InsufficientPermissionsException, + InvalidCredentialsException, InvalidTokenException, + TokenExpiredException, UserNotActiveException) from middleware.auth import AuthManager -from app.exceptions import ( - InvalidTokenException, - TokenExpiredException, - UserNotActiveException, - InvalidCredentialsException, - AdminRequiredException, - InsufficientPermissionsException, -) from models.database.user import User @@ -124,7 +121,9 @@ class TestUserAuthentication: mock_db.query.return_value.filter.return_value.first.return_value = mock_user - result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123") + result = auth_manager.authenticate_user( + mock_db, "test@example.com", "password123" + ) assert result is mock_user @@ -192,7 +191,9 @@ class TestJWTTokenCreation: token = token_data["access_token"] # Decode without verification to check payload - payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]) + payload = jwt.decode( + token, auth_manager.secret_key, algorithms=[auth_manager.algorithm] + ) assert payload["sub"] == "42" assert payload["username"] == "testuser" @@ -205,8 +206,12 @@ class TestJWTTokenCreation: """Test tokens are different for different users.""" auth_manager = AuthManager() - user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer") - user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor") + user1 = Mock( + spec=User, id=1, username="user1", email="user1@test.com", role="customer" + ) + user2 = Mock( + spec=User, id=2, username="user2", email="user2@test.com", role="vendor" + ) token1 = auth_manager.create_access_token(user1)["access_token"] token2 = auth_manager.create_access_token(user2)["access_token"] @@ -227,7 +232,7 @@ class TestJWTTokenCreation: payload = jwt.decode( token_data["access_token"], auth_manager.secret_key, - algorithms=[auth_manager.algorithm] + algorithms=[auth_manager.algorithm], ) assert payload["role"] == "admin" @@ -311,9 +316,11 @@ class TestJWTTokenVerification: # Create token without 'sub' field payload = { "username": "testuser", - "exp": datetime.now(timezone.utc) + timedelta(minutes=30) + "exp": datetime.now(timezone.utc) + timedelta(minutes=30), } - token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm) + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) with pytest.raises(InvalidTokenException) as exc_info: auth_manager.verify_token(token) @@ -325,11 +332,10 @@ class TestJWTTokenVerification: auth_manager = AuthManager() # Create token without 'exp' field - payload = { - "sub": "1", - "username": "testuser" - } - token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm) + payload = {"sub": "1", "username": "testuser"} + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) with pytest.raises(InvalidTokenException) as exc_info: auth_manager.verify_token(token) @@ -343,7 +349,7 @@ class TestJWTTokenVerification: payload = { "sub": "1", "username": "testuser", - "exp": datetime.now(timezone.utc) + timedelta(minutes=30) + "exp": datetime.now(timezone.utc) + timedelta(minutes=30), } # Create token with different algorithm token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512") @@ -357,15 +363,13 @@ class TestJWTTokenVerification: # Create a token with expiration in the past past_time = datetime.now(timezone.utc) - timedelta(minutes=1) - payload = { - "sub": "1", - "username": "testuser", - "exp": past_time.timestamp() - } - token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm) + payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()} + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) # Mock jwt.decode to bypass its expiration check and test line 205 - with patch('middleware.auth.jwt.decode') as mock_decode: + with patch("middleware.auth.jwt.decode") as mock_decode: mock_decode.return_value = payload with pytest.raises(TokenExpiredException): @@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser: # Existing admin user existing_admin = Mock(spec=User) - mock_db.query.return_value.filter.return_value.first.return_value = existing_admin + mock_db.query.return_value.filter.return_value.first.return_value = ( + existing_admin + ) result = auth_manager.create_default_admin_user(mock_db) @@ -599,19 +605,21 @@ class TestAuthManagerConfiguration: def test_default_configuration(self): """Test AuthManager uses default configuration.""" - with patch.dict('os.environ', {}, clear=True): + with patch.dict("os.environ", {}, clear=True): auth_manager = AuthManager() assert auth_manager.algorithm == "HS256" assert auth_manager.token_expire_minutes == 30 - assert auth_manager.secret_key == "your-secret-key-change-in-production-please" + assert ( + auth_manager.secret_key == "your-secret-key-change-in-production-please" + ) def test_custom_configuration(self): """Test AuthManager uses environment variables.""" - with patch.dict('os.environ', { - 'JWT_SECRET_KEY': 'custom-secret-key', - 'JWT_EXPIRE_MINUTES': '60' - }): + with patch.dict( + "os.environ", + {"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"}, + ): auth_manager = AuthManager() assert auth_manager.secret_key == "custom-secret-key" @@ -619,9 +627,7 @@ class TestAuthManagerConfiguration: def test_partial_custom_configuration(self): """Test AuthManager with partial environment configuration.""" - with patch.dict('os.environ', { - 'JWT_EXPIRE_MINUTES': '120' - }, clear=False): + with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False): auth_manager = AuthManager() assert auth_manager.token_expire_minutes == 120 @@ -656,9 +662,11 @@ class TestEdgeCases: "sub": "1", "username": "testuser", "iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time - "exp": datetime.now(timezone.utc) + timedelta(hours=2) + "exp": datetime.now(timezone.utc) + timedelta(hours=2), } - token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm) + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) # Should still verify successfully (JWT doesn't validate iat by default) result = auth_manager.verify_token(token) @@ -698,7 +706,9 @@ class TestEdgeCases: token = token_data["access_token"] # Mock jose.jwt.decode to raise an unexpected exception - with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")): + with patch( + "middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error") + ): with pytest.raises(InvalidTokenException) as exc_info: auth_manager.verify_token(token) diff --git a/tests/unit/middleware/test_context.py b/tests/unit/middleware/test_context.py index 95786c33..d2bf6f1b 100644 --- a/tests/unit/middleware/test_context.py +++ b/tests/unit/middleware/test_context.py @@ -10,16 +10,13 @@ Tests cover: - Edge cases and error handling """ +from unittest.mock import AsyncMock, Mock, patch + import pytest -from unittest.mock import Mock, AsyncMock, patch from fastapi import Request -from middleware.context import ( - ContextManager, - ContextMiddleware, - RequestContext, - get_request_context, -) +from middleware.context import (ContextManager, ContextMiddleware, + RequestContext, get_request_context) @pytest.mark.unit @@ -321,22 +318,38 @@ class TestContextManagerHelpers: def test_is_admin_context_from_subdomain(self): """Test _is_admin_context with admin subdomain.""" request = Mock() - assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True + assert ( + ContextManager._is_admin_context( + request, "admin.platform.com", "/dashboard" + ) + is True + ) def test_is_admin_context_from_path(self): """Test _is_admin_context with admin path.""" request = Mock() - assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True + assert ( + ContextManager._is_admin_context(request, "localhost", "/admin/users") + is True + ) def test_is_admin_context_both(self): """Test _is_admin_context with both subdomain and path.""" request = Mock() - assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True + assert ( + ContextManager._is_admin_context( + request, "admin.platform.com", "/admin/users" + ) + is True + ) def test_is_not_admin_context(self): """Test _is_admin_context returns False for non-admin.""" request = Mock() - assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False + assert ( + ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") + is False + ) def test_is_vendor_dashboard_context(self): """Test _is_vendor_dashboard_context with /vendor/ path.""" @@ -344,11 +357,16 @@ class TestContextManagerHelpers: def test_is_vendor_dashboard_context_nested(self): """Test _is_vendor_dashboard_context with nested vendor path.""" - assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True + assert ( + ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True + ) def test_is_not_vendor_dashboard_context_vendors_plural(self): """Test _is_vendor_dashboard_context excludes /vendors/ path.""" - assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False + assert ( + ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") + is False + ) def test_is_not_vendor_dashboard_context(self): """Test _is_vendor_dashboard_context returns False for non-vendor paths.""" @@ -373,7 +391,7 @@ class TestContextMiddleware: await middleware.dispatch(request, call_next) - assert hasattr(request.state, 'context_type') + assert hasattr(request.state, "context_type") assert request.state.context_type == RequestContext.API call_next.assert_called_once_with(request) @@ -565,7 +583,7 @@ class TestEdgeCases: request.url = Mock(path="/api/vendors") request.headers = {"host": "localhost"} # No state attribute at all - delattr(request, 'state') + delattr(request, "state") # Should still work, falling back to url.path with pytest.raises(AttributeError): diff --git a/tests/unit/middleware/test_decorators.py b/tests/unit/middleware/test_decorators.py index 1b957a90..19d4ed25 100644 --- a/tests/unit/middleware/test_decorators.py +++ b/tests/unit/middleware/test_decorators.py @@ -12,16 +12,18 @@ Tests cover: - Edge cases and isolation """ -import pytest from unittest.mock import Mock -from middleware.decorators import rate_limit, rate_limiter -from app.exceptions.base import RateLimitException +import pytest + +from app.exceptions.base import RateLimitException +from middleware.decorators import rate_limit, rate_limiter # ============================================================================= # Fixtures # ============================================================================= + @pytest.fixture(autouse=True) def reset_rate_limiter(): """Reset rate limiter state before each test to ensure isolation.""" @@ -34,6 +36,7 @@ def reset_rate_limiter(): # Rate Limit Decorator Tests # ============================================================================= + @pytest.mark.unit @pytest.mark.auth class TestRateLimitDecorator: @@ -42,6 +45,7 @@ class TestRateLimitDecorator: @pytest.mark.asyncio async def test_decorator_allows_within_limit(self): """Test decorator allows requests within rate limit.""" + @rate_limit(max_requests=10, window_seconds=3600) async def test_endpoint(): return {"status": "ok"} @@ -53,6 +57,7 @@ class TestRateLimitDecorator: @pytest.mark.asyncio async def test_decorator_blocks_exceeding_limit(self): """Test decorator blocks requests exceeding rate limit.""" + @rate_limit(max_requests=2, window_seconds=3600) async def test_endpoint_blocked(): return {"status": "ok"} @@ -71,6 +76,7 @@ class TestRateLimitDecorator: @pytest.mark.asyncio async def test_decorator_preserves_function_metadata(self): """Test decorator preserves original function metadata.""" + @rate_limit(max_requests=10, window_seconds=3600) async def test_endpoint(): """Test endpoint docstring.""" @@ -82,21 +88,19 @@ class TestRateLimitDecorator: @pytest.mark.asyncio async def test_decorator_with_args_and_kwargs(self): """Test decorator works with function arguments.""" + @rate_limit(max_requests=10, window_seconds=3600) async def test_endpoint(arg1, arg2, kwarg1=None): return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1} result = await test_endpoint("value1", "value2", kwarg1="value3") - assert result == { - "arg1": "value1", - "arg2": "value2", - "kwarg1": "value3" - } + assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"} @pytest.mark.asyncio async def test_decorator_default_parameters(self): """Test decorator uses default parameters.""" + @rate_limit() # Use defaults async def test_endpoint(): return {"status": "ok"} @@ -108,6 +112,7 @@ class TestRateLimitDecorator: @pytest.mark.asyncio async def test_decorator_exception_includes_retry_after(self): """Test rate limit exception includes retry_after.""" + @rate_limit(max_requests=1, window_seconds=60) async def test_endpoint_retry(): return {"status": "ok"} @@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases: @pytest.mark.asyncio async def test_decorator_with_zero_max_requests(self): """Test decorator with max_requests=0 blocks all requests.""" + @rate_limit(max_requests=0, window_seconds=3600) async def test_endpoint_zero(): return {"status": "ok"} @@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases: @pytest.mark.asyncio async def test_decorator_with_very_short_window(self): """Test decorator with very short time window.""" + @rate_limit(max_requests=1, window_seconds=1) async def test_endpoint_short(): return {"status": "ok"} @@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases: @pytest.mark.asyncio async def test_decorator_multiple_functions_separate_limits(self): """Test that different functions have separate rate limits.""" + @rate_limit(max_requests=1, window_seconds=3600) async def endpoint1(): return {"endpoint": "1"} @@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases: @pytest.mark.asyncio async def test_decorator_with_exception_in_function(self): """Test decorator handles exceptions from wrapped function.""" + @rate_limit(max_requests=10, window_seconds=3600) async def test_endpoint_error(): raise ValueError("Function error") @@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases: @pytest.mark.asyncio async def test_decorator_isolation_between_tests(self): """Test that rate limiter state is properly isolated between tests.""" + @rate_limit(max_requests=2, window_seconds=3600) async def test_endpoint_isolation(): return {"status": "ok"} @@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues: @pytest.mark.asyncio async def test_decorator_returns_dict(self): """Test decorator correctly returns dictionary.""" + @rate_limit(max_requests=10, window_seconds=3600) async def return_dict(): return {"key": "value", "number": 42} @@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues: @pytest.mark.asyncio async def test_decorator_returns_list(self): """Test decorator correctly returns list.""" + @rate_limit(max_requests=10, window_seconds=3600) async def return_list(): return [1, 2, 3, 4, 5] @@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues: @pytest.mark.asyncio async def test_decorator_returns_none(self): """Test decorator correctly returns None.""" + @rate_limit(max_requests=10, window_seconds=3600) async def return_none(): return None @@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues: @pytest.mark.asyncio async def test_decorator_returns_object(self): """Test decorator correctly returns custom objects.""" + class TestObject: def __init__(self): self.name = "test_object" diff --git a/tests/unit/middleware/test_logging.py b/tests/unit/middleware/test_logging.py index 5bb88434..6fda9238 100644 --- a/tests/unit/middleware/test_logging.py +++ b/tests/unit/middleware/test_logging.py @@ -11,9 +11,10 @@ Tests cover: - Edge cases (missing client info, etc.) """ -import pytest import asyncio -from unittest.mock import Mock, AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch + +import pytest from fastapi import Request from middleware.logging import LoggingMiddleware @@ -40,7 +41,7 @@ class TestLoggingMiddleware: call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger') as mock_logger: + with patch("middleware.logging.logger") as mock_logger: await middleware.dispatch(request, call_next) # Verify request was logged @@ -65,7 +66,7 @@ class TestLoggingMiddleware: call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger') as mock_logger: + with patch("middleware.logging.logger") as mock_logger: result = await middleware.dispatch(request, call_next) # Verify response was logged @@ -89,7 +90,7 @@ class TestLoggingMiddleware: call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger'): + with patch("middleware.logging.logger"): result = await middleware.dispatch(request, call_next) assert "X-Process-Time" in response.headers @@ -113,11 +114,13 @@ class TestLoggingMiddleware: call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger') as mock_logger: + with patch("middleware.logging.logger") as mock_logger: await middleware.dispatch(request, call_next) # Should log "unknown" for client - assert any("unknown" in str(call) for call in mock_logger.info.call_args_list) + assert any( + "unknown" in str(call) for call in mock_logger.info.call_args_list + ) @pytest.mark.asyncio async def test_middleware_logs_exceptions(self): @@ -131,8 +134,9 @@ class TestLoggingMiddleware: call_next = AsyncMock(side_effect=Exception("Test error")) - with patch('middleware.logging.logger') as mock_logger, \ - pytest.raises(Exception): + with patch("middleware.logging.logger") as mock_logger, pytest.raises( + Exception + ): await middleware.dispatch(request, call_next) # Verify error was logged @@ -156,7 +160,7 @@ class TestLoggingMiddleware: call_next = slow_call_next - with patch('middleware.logging.logger'): + with patch("middleware.logging.logger"): result = await middleware.dispatch(request, call_next) process_time = float(result.headers["X-Process-Time"]) @@ -181,7 +185,7 @@ class TestLoggingEdgeCases: response = Mock(status_code=200, headers={}) call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger'): + with patch("middleware.logging.logger"): result = await middleware.dispatch(request, call_next) # Should still have process time, even if very small @@ -205,11 +209,13 @@ class TestLoggingEdgeCases: response = Mock(status_code=200, headers={}) call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger') as mock_logger: + with patch("middleware.logging.logger") as mock_logger: await middleware.dispatch(request, call_next) # Verify method was logged - assert any(method in str(call) for call in mock_logger.info.call_args_list) + assert any( + method in str(call) for call in mock_logger.info.call_args_list + ) @pytest.mark.asyncio async def test_middleware_logs_different_status_codes(self): @@ -227,8 +233,11 @@ class TestLoggingEdgeCases: response = Mock(status_code=status_code, headers={}) call_next = AsyncMock(return_value=response) - with patch('middleware.logging.logger') as mock_logger: + with patch("middleware.logging.logger") as mock_logger: await middleware.dispatch(request, call_next) # Verify status code was logged - assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list) + assert any( + str(status_code) in str(call) + for call in mock_logger.info.call_args_list + ) diff --git a/tests/unit/middleware/test_rate_limiter.py b/tests/unit/middleware/test_rate_limiter.py index 94c5418d..c16e2495 100644 --- a/tests/unit/middleware/test_rate_limiter.py +++ b/tests/unit/middleware/test_rate_limiter.py @@ -11,10 +11,11 @@ Tests cover: - Edge cases and concurrency scenarios """ -import pytest -from unittest.mock import Mock, patch -from datetime import datetime, timedelta, timezone from collections import deque +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, patch + +import pytest from middleware.rate_limiter import RateLimiter @@ -306,8 +307,8 @@ class TestRateLimiterStatistics: # Add requests at different times now = datetime.now(timezone.utc) limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour - limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day - limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day + limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day + limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day stats = limiter.get_client_stats(client_id) @@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases: limiter = RateLimiter() client_id = "long_window_client" - result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365) + result = limiter.allow_request( + client_id, max_requests=10, window_seconds=86400 * 365 + ) assert result is True @@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases: client_id = "same_client" # Allow with one limit - assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True + assert ( + limiter.allow_request(client_id, max_requests=10, window_seconds=3600) + is True + ) # Check with stricter limit - assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False + assert ( + limiter.allow_request(client_id, max_requests=1, window_seconds=3600) + is False + ) def test_rate_limiter_unicode_client_id(self): """Test rate limiter with unicode client ID.""" diff --git a/tests/unit/middleware/test_theme_context.py b/tests/unit/middleware/test_theme_context.py index 885d902a..2cb9887b 100644 --- a/tests/unit/middleware/test_theme_context.py +++ b/tests/unit/middleware/test_theme_context.py @@ -11,15 +11,14 @@ Tests cover: - Edge cases and error handling """ +from unittest.mock import AsyncMock, MagicMock, Mock, patch + import pytest -from unittest.mock import Mock, AsyncMock, MagicMock, patch from fastapi import Request -from middleware.theme_context import ( - ThemeContextManager, - ThemeContextMiddleware, - get_current_theme, -) +from middleware.theme_context import (ThemeContextManager, + ThemeContextMiddleware, + get_current_theme) @pytest.mark.unit @@ -42,7 +41,14 @@ class TestThemeContextManager: """Test default theme has all required colors.""" theme = ThemeContextManager.get_default_theme() - required_colors = ["primary", "secondary", "accent", "background", "text", "border"] + required_colors = [ + "primary", + "secondary", + "accent", + "background", + "text", + "border", + ] for color in required_colors: assert color in theme["colors"] assert theme["colors"][color].startswith("#") @@ -79,10 +85,7 @@ class TestThemeContextManager: mock_theme = Mock() # Mock to_dict to return actual dictionary - custom_theme_dict = { - "theme_name": "custom", - "colors": {"primary": "#ff0000"} - } + custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}} mock_theme.to_dict.return_value = custom_theme_dict # Correct filter chain: query().filter().first() @@ -141,8 +144,11 @@ class TestThemeContextMiddleware: mock_db = MagicMock() mock_theme = {"theme_name": "test_theme"} - with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \ - patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme): + with patch( + "middleware.theme_context.get_db", return_value=iter([mock_db]) + ), patch.object( + ThemeContextManager, "get_vendor_theme", return_value=mock_theme + ): await middleware.dispatch(request, call_next) @@ -161,7 +167,7 @@ class TestThemeContextMiddleware: await middleware.dispatch(request, call_next) - assert hasattr(request.state, 'theme') + assert hasattr(request.state, "theme") assert request.state.theme["theme_name"] == "default" call_next.assert_called_once() @@ -178,8 +184,11 @@ class TestThemeContextMiddleware: mock_db = MagicMock() - with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \ - patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")): + with patch( + "middleware.theme_context.get_db", return_value=iter([mock_db]) + ), patch.object( + ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error") + ): await middleware.dispatch(request, call_next) @@ -224,7 +233,7 @@ class TestThemeEdgeCases: mock_db = MagicMock() - with patch('middleware.theme_context.get_db', return_value=iter([mock_db])): + with patch("middleware.theme_context.get_db", return_value=iter([mock_db])): await middleware.dispatch(request, call_next) # Verify database was closed diff --git a/tests/unit/middleware/test_vendor_context.py b/tests/unit/middleware/test_vendor_context.py index 3071b092..8a3c44b2 100644 --- a/tests/unit/middleware/test_vendor_context.py +++ b/tests/unit/middleware/test_vendor_context.py @@ -11,17 +11,16 @@ Tests cover: - Edge cases and error handling """ +from unittest.mock import AsyncMock, MagicMock, Mock, patch + import pytest -from unittest.mock import Mock, MagicMock, patch, AsyncMock -from fastapi import Request, HTTPException +from fastapi import HTTPException, Request from sqlalchemy.orm import Session -from middleware.vendor_context import ( - VendorContextManager, - VendorContextMiddleware, - get_current_vendor, - require_vendor_context, -) +from middleware.vendor_context import (VendorContextManager, + VendorContextMiddleware, + get_current_vendor, + require_vendor_context) @pytest.mark.unit @@ -39,7 +38,7 @@ class TestVendorContextManager: request.headers = {"host": "customdomain1.com"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -55,7 +54,7 @@ class TestVendorContextManager: request.headers = {"host": "customdomain1.com:8000"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -71,7 +70,7 @@ class TestVendorContextManager: request.headers = {"host": "vendor1.platform.com"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -87,7 +86,7 @@ class TestVendorContextManager: request.headers = {"host": "vendor1.platform.com:8000"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -140,7 +139,7 @@ class TestVendorContextManager: request.headers = {"host": "admin.platform.com"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -153,7 +152,7 @@ class TestVendorContextManager: request.headers = {"host": "www.platform.com"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -166,7 +165,7 @@ class TestVendorContextManager: request.headers = {"host": "api.platform.com"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -179,7 +178,7 @@ class TestVendorContextManager: request.headers = {"host": "localhost"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -198,12 +197,11 @@ class TestVendorContextManager: mock_vendor.is_active = True mock_vendor_domain.vendor = mock_vendor - mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain + mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = ( + mock_vendor_domain + ) - context = { - "detection_method": "custom_domain", - "domain": "customdomain1.com" - } + context = {"detection_method": "custom_domain", "domain": "customdomain1.com"} vendor = VendorContextManager.get_vendor_from_context(mock_db, context) @@ -218,12 +216,11 @@ class TestVendorContextManager: mock_vendor.is_active = False mock_vendor_domain.vendor = mock_vendor - mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain + mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = ( + mock_vendor_domain + ) - context = { - "detection_method": "custom_domain", - "domain": "customdomain1.com" - } + context = {"detection_method": "custom_domain", "domain": "customdomain1.com"} vendor = VendorContextManager.get_vendor_from_context(mock_db, context) @@ -232,12 +229,11 @@ class TestVendorContextManager: def test_get_vendor_from_custom_domain_not_found(self): """Test custom domain not found in database.""" mock_db = Mock(spec=Session) - mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None + mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = ( + None + ) - context = { - "detection_method": "custom_domain", - "domain": "nonexistent.com" - } + context = {"detection_method": "custom_domain", "domain": "nonexistent.com"} vendor = VendorContextManager.get_vendor_from_context(mock_db, context) @@ -249,12 +245,11 @@ class TestVendorContextManager: mock_vendor = Mock() mock_vendor.is_active = True - mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor + mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = ( + mock_vendor + ) - context = { - "detection_method": "subdomain", - "subdomain": "vendor1" - } + context = {"detection_method": "subdomain", "subdomain": "vendor1"} vendor = VendorContextManager.get_vendor_from_context(mock_db, context) @@ -266,12 +261,11 @@ class TestVendorContextManager: mock_vendor = Mock() mock_vendor.is_active = True - mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor + mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = ( + mock_vendor + ) - context = { - "detection_method": "path", - "subdomain": "vendor1" - } + context = {"detection_method": "path", "subdomain": "vendor1"} vendor = VendorContextManager.get_vendor_from_context(mock_db, context) @@ -291,12 +285,11 @@ class TestVendorContextManager: mock_vendor = Mock() mock_vendor.is_active = True - mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor + mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = ( + mock_vendor + ) - context = { - "detection_method": "subdomain", - "subdomain": "VENDOR1" # Uppercase - } + context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase vendor = VendorContextManager.get_vendor_from_context(mock_db, context) @@ -311,10 +304,7 @@ class TestVendorContextManager: request = Mock(spec=Request) request.url = Mock(path="/vendor/vendor1/shop/products") - vendor_context = { - "detection_method": "path", - "path_prefix": "/vendor/vendor1" - } + vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"} clean_path = VendorContextManager.extract_clean_path(request, vendor_context) @@ -325,10 +315,7 @@ class TestVendorContextManager: request = Mock(spec=Request) request.url = Mock(path="/vendors/vendor1/shop/products") - vendor_context = { - "detection_method": "path", - "path_prefix": "/vendors/vendor1" - } + vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"} clean_path = VendorContextManager.extract_clean_path(request, vendor_context) @@ -339,10 +326,7 @@ class TestVendorContextManager: request = Mock(spec=Request) request.url = Mock(path="/vendor/vendor1") - vendor_context = { - "detection_method": "path", - "path_prefix": "/vendor/vendor1" - } + vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"} clean_path = VendorContextManager.extract_clean_path(request, vendor_context) @@ -353,10 +337,7 @@ class TestVendorContextManager: request = Mock(spec=Request) request.url = Mock(path="/shop/products") - vendor_context = { - "detection_method": "subdomain", - "subdomain": "vendor1" - } + vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"} clean_path = VendorContextManager.extract_clean_path(request, vendor_context) @@ -425,21 +406,24 @@ class TestVendorContextManager: # Static File Detection Tests # ======================================================================== - @pytest.mark.parametrize("path", [ - "/static/css/style.css", - "/static/js/app.js", - "/media/images/product.png", - "/assets/logo.svg", - "/.well-known/security.txt", - "/favicon.ico", - "/image.jpg", - "/style.css", - "/app.webmanifest", - "/static/", # Path starting with /static/ but no extension - "/media/uploads", # Path starting with /media/ but no extension - "/subfolder/favicon.ico", # favicon.ico in subfolder - "/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226) - ]) + @pytest.mark.parametrize( + "path", + [ + "/static/css/style.css", + "/static/js/app.js", + "/media/images/product.png", + "/assets/logo.svg", + "/.well-known/security.txt", + "/favicon.ico", + "/image.jpg", + "/style.css", + "/app.webmanifest", + "/static/", # Path starting with /static/ but no extension + "/media/uploads", # Path starting with /media/ but no extension + "/subfolder/favicon.ico", # favicon.ico in subfolder + "/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226) + ], + ) def test_is_static_file_request(self, path): """Test static file detection for various paths and extensions.""" request = Mock(spec=Request) @@ -447,12 +431,15 @@ class TestVendorContextManager: assert VendorContextManager.is_static_file_request(request) is True - @pytest.mark.parametrize("path", [ - "/shop/products", - "/admin/dashboard", - "/api/vendors", - "/about", - ]) + @pytest.mark.parametrize( + "path", + [ + "/shop/products", + "/admin/dashboard", + "/api/vendors", + "/about", + ], + ) def test_is_not_static_file_request(self, path): """Test non-static file paths.""" request = Mock(spec=Request) @@ -478,7 +465,7 @@ class TestVendorContextMiddleware: call_next = AsyncMock(return_value=Mock()) - with patch.object(VendorContextManager, 'is_admin_request', return_value=True): + with patch.object(VendorContextManager, "is_admin_request", return_value=True): await middleware.dispatch(request, call_next) assert request.state.vendor is None @@ -498,7 +485,7 @@ class TestVendorContextMiddleware: call_next = AsyncMock(return_value=Mock()) - with patch.object(VendorContextManager, 'is_api_request', return_value=True): + with patch.object(VendorContextManager, "is_api_request", return_value=True): await middleware.dispatch(request, call_next) assert request.state.vendor is None @@ -517,7 +504,9 @@ class TestVendorContextMiddleware: call_next = AsyncMock(return_value=Mock()) - with patch.object(VendorContextManager, 'is_static_file_request', return_value=True): + with patch.object( + VendorContextManager, "is_static_file_request", return_value=True + ): await middleware.dispatch(request, call_next) assert request.state.vendor is None @@ -540,17 +529,19 @@ class TestVendorContextMiddleware: mock_vendor.name = "Test Vendor" mock_vendor.subdomain = "vendor1" - vendor_context = { - "detection_method": "subdomain", - "subdomain": "vendor1" - } + vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"} mock_db = MagicMock() - with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \ - patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \ - patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \ - patch('middleware.vendor_context.get_db', return_value=iter([mock_db])): + with patch.object( + VendorContextManager, "detect_vendor_context", return_value=vendor_context + ), patch.object( + VendorContextManager, "get_vendor_from_context", return_value=mock_vendor + ), patch.object( + VendorContextManager, "extract_clean_path", return_value="/shop/products" + ), patch( + "middleware.vendor_context.get_db", return_value=iter([mock_db]) + ): await middleware.dispatch(request, call_next) @@ -571,16 +562,17 @@ class TestVendorContextMiddleware: call_next = AsyncMock(return_value=Mock()) - vendor_context = { - "detection_method": "subdomain", - "subdomain": "nonexistent" - } + vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"} mock_db = MagicMock() - with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \ - patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \ - patch('middleware.vendor_context.get_db', return_value=iter([mock_db])): + with patch.object( + VendorContextManager, "detect_vendor_context", return_value=vendor_context + ), patch.object( + VendorContextManager, "get_vendor_from_context", return_value=None + ), patch( + "middleware.vendor_context.get_db", return_value=iter([mock_db]) + ): await middleware.dispatch(request, call_next) @@ -601,7 +593,9 @@ class TestVendorContextMiddleware: call_next = AsyncMock(return_value=Mock()) - with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None): + with patch.object( + VendorContextManager, "detect_vendor_context", return_value=None + ): await middleware.dispatch(request, call_next) assert request.state.vendor is None @@ -714,7 +708,7 @@ class TestEdgeCases: request.headers = {"host": "shop.vendor1.platform.com"} request.url = Mock(path="/") - with patch('middleware.vendor_context.settings') as mock_settings: + with patch("middleware.vendor_context.settings") as mock_settings: mock_settings.platform_domain = "platform.com" context = VendorContextManager.detect_vendor_context(request) @@ -735,11 +729,14 @@ class TestEdgeCases: context = {"subdomain": "nonexistent", "detection_method": "subdomain"} - with patch('middleware.vendor_context.logger') as mock_logger: + with patch("middleware.vendor_context.logger") as mock_logger: vendor = VendorContextManager.get_vendor_from_context(mock_db, context) assert vendor is None # Verify warning was logged mock_logger.warning.assert_called() warning_message = str(mock_logger.warning.call_args) - assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message + assert ( + "No active vendor found for subdomain" in warning_message + and "nonexistent" in warning_message + ) diff --git a/tests/unit/models/test_database_models.py b/tests/unit/models/test_database_models.py index 8cae8dcf..a25997b4 100644 --- a/tests/unit/models/test_database_models.py +++ b/tests/unit/models/test_database_models.py @@ -1,16 +1,17 @@ # tests/unit/models/test_database_models.py -import pytest from datetime import datetime, timezone + +import pytest from sqlalchemy.exc import IntegrityError -from models.database.marketplace_product import MarketplaceProduct -from models.database.vendor import Vendor, VendorUser, Role -from models.database.inventory import Inventory -from models.database.user import User -from models.database.marketplace_import_job import MarketplaceImportJob -from models.database.product import Product from models.database.customer import Customer, CustomerAddress +from models.database.inventory import Inventory +from models.database.marketplace_import_job import MarketplaceImportJob +from models.database.marketplace_product import MarketplaceProduct from models.database.order import Order, OrderItem +from models.database.product import Product +from models.database.user import User +from models.database.vendor import Role, Vendor, VendorUser @pytest.mark.unit @@ -277,7 +278,7 @@ class TestMarketplaceProductModel: vendor_id=test_vendor.id, marketplace_product_id="UNIQUE_001", title="Product 1", - marketplace="Letzshop" + marketplace="Letzshop", ) db.add(product1) db.commit() @@ -288,7 +289,7 @@ class TestMarketplaceProductModel: vendor_id=test_vendor.id, marketplace_product_id="UNIQUE_001", title="Product 2", - marketplace="Letzshop" + marketplace="Letzshop", ) db.add(product2) db.commit() @@ -515,7 +516,9 @@ class TestCustomerModel: class TestOrderModel: """Test Order model""" - def test_order_creation(self, db, test_vendor, test_customer, test_customer_address): + def test_order_creation( + self, db, test_vendor, test_customer, test_customer_address + ): """Test Order model with customer relationship""" order = Order( vendor_id=test_vendor.id, @@ -563,7 +566,9 @@ class TestOrderModel: assert float(order_item.unit_price) == 49.99 assert float(order_item.total_price) == 99.98 - def test_order_number_uniqueness(self, db, test_vendor, test_customer, test_customer_address): + def test_order_number_uniqueness( + self, db, test_vendor, test_customer, test_customer_address + ): """Test order_number unique constraint""" order1 = Order( vendor_id=test_vendor.id, diff --git a/tests/unit/services/test_admin_service.py b/tests/unit/services/test_admin_service.py index 01b9374a..000ec495 100644 --- a/tests/unit/services/test_admin_service.py +++ b/tests/unit/services/test_admin_service.py @@ -1,14 +1,10 @@ # tests/unit/services/test_admin_service.py import pytest -from app.exceptions import ( - UserNotFoundException, - UserStatusChangeException, - CannotModifySelfException, - VendorNotFoundException, - VendorVerificationException, - AdminOperationException, -) +from app.exceptions import (AdminOperationException, CannotModifySelfException, + UserNotFoundException, UserStatusChangeException, + VendorNotFoundException, + VendorVerificationException) from app.services.admin_service import AdminService from app.services.stats_service import stats_service from models.database.marketplace_import_job import MarketplaceImportJob @@ -85,7 +81,9 @@ class TestAdminService: assert exception.error_code == "CANNOT_MODIFY_SELF" assert "deactivate account" in exception.message - def test_toggle_user_status_cannot_modify_admin(self, db, test_admin, another_admin): + def test_toggle_user_status_cannot_modify_admin( + self, db, test_admin, another_admin + ): """Test that admin cannot modify another admin""" with pytest.raises(UserStatusChangeException) as exc_info: self.service.toggle_user_status(db, another_admin.id, test_admin.id) @@ -148,7 +146,7 @@ class TestAdminService: assert "99999" in exception.message def test_toggle_vendor_status_deactivate(self, db, test_vendor): - """Test deactivating a vendor """ + """Test deactivating a vendor""" original_status = test_vendor.is_active vendor, message = self.service.toggle_vendor_status(db, test_vendor.id) @@ -170,21 +168,26 @@ class TestAdminService: assert exception.error_code == "VENDOR_NOT_FOUND" # Marketplace Import Jobs Tests - def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_import_job): + def test_get_marketplace_import_jobs_no_filters( + self, db, test_marketplace_import_job + ): """Test getting marketplace import jobs without filters""" result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10) assert len(result) >= 1 # Find our test job in the results test_job = next( - (job for job in result if job.job_id == test_marketplace_import_job.id), None + (job for job in result if job.job_id == test_marketplace_import_job.id), + None, ) assert test_job is not None assert test_job.marketplace == test_marketplace_import_job.marketplace assert test_job.vendor_name == test_marketplace_import_job.name assert test_job.status == test_marketplace_import_job.status - def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_import_job): + def test_get_marketplace_import_jobs_with_marketplace_filter( + self, db, test_marketplace_import_job + ): """Test filtering marketplace import jobs by marketplace""" result = self.service.get_marketplace_import_jobs( db, marketplace=test_marketplace_import_job.marketplace, skip=0, limit=10 @@ -192,9 +195,14 @@ class TestAdminService: assert len(result) >= 1 for job in result: - assert test_marketplace_import_job.marketplace.lower() in job.marketplace.lower() + assert ( + test_marketplace_import_job.marketplace.lower() + in job.marketplace.lower() + ) - def test_get_marketplace_import_jobs_with_vendor_filter(self, db, test_marketplace_import_job): + def test_get_marketplace_import_jobs_with_vendor_filter( + self, db, test_marketplace_import_job + ): """Test filtering marketplace import jobs by vendor name""" result = self.service.get_marketplace_import_jobs( db, vendor_name=test_marketplace_import_job.name, skip=0, limit=10 @@ -204,7 +212,9 @@ class TestAdminService: for job in result: assert test_marketplace_import_job.name.lower() in job.vendor_name.lower() - def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_import_job): + def test_get_marketplace_import_jobs_with_status_filter( + self, db, test_marketplace_import_job + ): """Test filtering marketplace import jobs by status""" result = self.service.get_marketplace_import_jobs( db, status=test_marketplace_import_job.status, skip=0, limit=10 @@ -214,7 +224,9 @@ class TestAdminService: for job in result: assert job.status == test_marketplace_import_job.status - def test_get_marketplace_import_jobs_pagination(self, db, test_marketplace_import_job): + def test_get_marketplace_import_jobs_pagination( + self, db, test_marketplace_import_job + ): """Test marketplace import jobs pagination""" result_page1 = self.service.get_marketplace_import_jobs(db, skip=0, limit=1) result_page2 = self.service.get_marketplace_import_jobs(db, skip=1, limit=1) diff --git a/tests/unit/services/test_auth_service.py b/tests/unit/services/test_auth_service.py index e46809c9..5198e31a 100644 --- a/tests/unit/services/test_auth_service.py +++ b/tests/unit/services/test_auth_service.py @@ -1,11 +1,9 @@ # tests/test_auth_service.py import pytest -from app.exceptions.auth import ( - UserAlreadyExistsException, - InvalidCredentialsException, - UserNotActiveException, -) +from app.exceptions.auth import (InvalidCredentialsException, + UserAlreadyExistsException, + UserNotActiveException) from app.exceptions.base import ValidationException from app.services.auth_service import AuthService from models.schema.auth import UserLogin, UserRegister @@ -218,11 +216,14 @@ class TestAuthService: def test_create_access_token_failure(self, test_user, monkeypatch): """Test creating access token handles failures""" + # Mock the auth_manager to raise an exception def mock_create_token(*args, **kwargs): raise Exception("Token creation failed") - monkeypatch.setattr(self.service.auth_manager, "create_access_token", mock_create_token) + monkeypatch.setattr( + self.service.auth_manager, "create_access_token", mock_create_token + ) with pytest.raises(ValidationException) as exc_info: self.service.create_access_token(test_user) @@ -250,11 +251,14 @@ class TestAuthService: def test_hash_password_failure(self, monkeypatch): """Test password hashing handles failures""" + # Mock the auth_manager to raise an exception def mock_hash_password(*args, **kwargs): raise Exception("Hashing failed") - monkeypatch.setattr(self.service.auth_manager, "hash_password", mock_hash_password) + monkeypatch.setattr( + self.service.auth_manager, "hash_password", mock_hash_password + ) with pytest.raises(ValidationException) as exc_info: self.service.hash_password("testpassword") @@ -267,9 +271,7 @@ class TestAuthService: def test_register_user_database_error(self, db_with_error): """Test user registration handles database errors""" user_data = UserRegister( - email="test@example.com", - username="testuser", - password="password123" + email="test@example.com", username="testuser", password="password123" ) with pytest.raises(ValidationException) as exc_info: diff --git a/tests/unit/services/test_inventory_service.py b/tests/unit/services/test_inventory_service.py index 78942474..d1b38364 100644 --- a/tests/unit/services/test_inventory_service.py +++ b/tests/unit/services/test_inventory_service.py @@ -3,19 +3,17 @@ import uuid import pytest +from app.exceptions import (InsufficientInventoryException, + InvalidInventoryOperationException, + InvalidQuantityException, + InventoryNotFoundException, + InventoryValidationException, + NegativeInventoryException, ValidationException) from app.services.inventory_service import InventoryService -from app.exceptions import ( - InventoryNotFoundException, - InsufficientInventoryException, - InvalidInventoryOperationException, - InventoryValidationException, - NegativeInventoryException, - InvalidQuantityException, - ValidationException, -) -from models.schema.inventory import InventoryAdd, InventoryCreate, InventoryUpdate -from models.database.marketplace_product import MarketplaceProduct from models.database.inventory import Inventory +from models.database.marketplace_product import MarketplaceProduct +from models.schema.inventory import (InventoryAdd, InventoryCreate, + InventoryUpdate) @pytest.mark.unit @@ -40,10 +38,14 @@ class TestInventoryService: def test_normalize_gtin_valid(self): """Test GTIN normalization with valid GTINs.""" # Test various valid GTIN formats - these should remain unchanged - assert self.service._normalize_gtin("1234567890123") == "1234567890123" # EAN-13 + assert ( + self.service._normalize_gtin("1234567890123") == "1234567890123" + ) # EAN-13 assert self.service._normalize_gtin("123456789012") == "123456789012" # UPC-A assert self.service._normalize_gtin("12345678") == "12345678" # EAN-8 - assert self.service._normalize_gtin("12345678901234") == "12345678901234" # GTIN-14 + assert ( + self.service._normalize_gtin("12345678901234") == "12345678901234" + ) # GTIN-14 # Test with decimal points (should be removed) assert self.service._normalize_gtin("1234567890123.0") == "1234567890123" @@ -52,11 +54,17 @@ class TestInventoryService: assert self.service._normalize_gtin(" 1234567890123 ") == "1234567890123" # Test short GTINs being padded - assert self.service._normalize_gtin("123") == "0000000000123" # Padded to EAN-13 - assert self.service._normalize_gtin("12345") == "0000000012345" # Padded to EAN-13 + assert ( + self.service._normalize_gtin("123") == "0000000000123" + ) # Padded to EAN-13 + assert ( + self.service._normalize_gtin("12345") == "0000000012345" + ) # Padded to EAN-13 # Test long GTINs being truncated - assert self.service._normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13 + assert ( + self.service._normalize_gtin("123456789012345") == "3456789012345" + ) # Truncated to 13 def test_normalize_gtin_edge_cases(self): """Test GTIN normalization edge cases.""" @@ -65,9 +73,15 @@ class TestInventoryService: assert self.service._normalize_gtin(123) == "0000000000123" # Test mixed valid/invalid characters - assert self.service._normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed - assert self.service._normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed - assert self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed + assert ( + self.service._normalize_gtin("123-456-789-012") == "123456789012" + ) # Dashes removed + assert ( + self.service._normalize_gtin("123 456 789 012") == "123456789012" + ) # Spaces removed + assert ( + self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" + ) # Letters removed def test_set_inventory_new_entry_success(self, db): """Test setting inventory for a new GTIN/location combination successfully.""" @@ -162,7 +176,9 @@ class TestInventoryService: def test_add_inventory_invalid_gtin_validation_error(self, db): """Test adding inventory with invalid GTIN returns InventoryValidationException.""" - inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50) + inventory_data = InventoryAdd( + gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50 + ) with pytest.raises(InventoryValidationException) as exc_info: self.service.add_inventory(db, inventory_data) @@ -180,11 +196,12 @@ class TestInventoryService: assert exc_info.value.error_code == "INVALID_QUANTITY" assert "Quantity must be positive" in str(exc_info.value) - def test_remove_inventory_success(self, db, test_inventory): """Test removing inventory successfully.""" original_quantity = test_inventory.quantity - remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available + remove_quantity = min( + 10, original_quantity + ) # Ensure we don't remove more than available inventory_data = InventoryAdd( gtin=test_inventory.gtin, @@ -212,7 +229,9 @@ class TestInventoryService: assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY" assert exc_info.value.details["gtin"] == test_inventory.gtin assert exc_info.value.details["location"] == test_inventory.location - assert exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10 + assert ( + exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10 + ) assert exc_info.value.details["available_quantity"] == test_inventory.quantity def test_remove_inventory_nonexistent_entry_not_found(self, db): @@ -231,7 +250,9 @@ class TestInventoryService: def test_remove_inventory_invalid_gtin_validation_error(self, db): """Test removing inventory with invalid GTIN returns InventoryValidationException.""" - inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10) + inventory_data = InventoryAdd( + gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10 + ) with pytest.raises(InventoryValidationException) as exc_info: self.service.remove_inventory(db, inventory_data) @@ -254,7 +275,9 @@ class TestInventoryService: # The service prevents negative inventory through InsufficientInventoryException assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY" - def test_get_inventory_by_gtin_success(self, db, test_inventory, test_marketplace_product): + def test_get_inventory_by_gtin_success( + self, db, test_inventory, test_marketplace_product + ): """Test getting inventory summary by GTIN successfully.""" result = self.service.get_inventory_by_gtin(db, test_inventory.gtin) @@ -265,14 +288,20 @@ class TestInventoryService: assert result.locations[0].quantity == test_inventory.quantity assert result.product_title == test_marketplace_product.title - def test_get_inventory_by_gtin_multiple_locations_success(self, db, test_marketplace_product): + def test_get_inventory_by_gtin_multiple_locations_success( + self, db, test_marketplace_product + ): """Test getting inventory summary with multiple locations successfully.""" unique_gtin = test_marketplace_product.gtin unique_id = str(uuid.uuid4())[:8] # Create multiple inventory entries for the same GTIN with unique locations - inventory1 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50) - inventory2 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30) + inventory1 = Inventory( + gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50 + ) + inventory2 = Inventory( + gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30 + ) db.add(inventory1) db.add(inventory2) @@ -301,7 +330,9 @@ class TestInventoryService: assert exc_info.value.error_code == "INVENTORY_VALIDATION_FAILED" assert "Invalid GTIN format" in str(exc_info.value) - def test_get_total_inventory_success(self, db, test_inventory, test_marketplace_product): + def test_get_total_inventory_success( + self, db, test_inventory, test_marketplace_product + ): """Test getting total inventory for a GTIN successfully.""" result = self.service.get_total_inventory(db, test_inventory.gtin) @@ -364,7 +395,9 @@ class TestInventoryService: result = self.service.get_all_inventory(db, skip=2, limit=2) - assert len(result) <= 2 # Should be at most 2, might be less if other records exist + assert ( + len(result) <= 2 + ) # Should be at most 2, might be less if other records exist def test_update_inventory_success(self, db, test_inventory): """Test updating inventory quantity successfully.""" @@ -404,7 +437,9 @@ class TestInventoryService: assert result is True # Verify the inventory is actually deleted - deleted_inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first() + deleted_inventory = ( + db.query(Inventory).filter(Inventory.id == inventory_id).first() + ) assert deleted_inventory is None def test_delete_inventory_not_found_error(self, db): @@ -415,7 +450,9 @@ class TestInventoryService: assert exc_info.value.error_code == "INVENTORY_NOT_FOUND" assert "99999" in str(exc_info.value) - def test_get_low_inventory_items_success(self, db, test_inventory, test_marketplace_product): + def test_get_low_inventory_items_success( + self, db, test_inventory, test_marketplace_product + ): """Test getting low inventory items successfully.""" # Set inventory to a low value test_inventory.quantity = 5 @@ -424,7 +461,9 @@ class TestInventoryService: result = self.service.get_low_inventory_items(db, threshold=10) assert len(result) >= 1 - low_inventory_item = next((item for item in result if item["gtin"] == test_inventory.gtin), None) + low_inventory_item = next( + (item for item in result if item["gtin"] == test_inventory.gtin), None + ) assert low_inventory_item is not None assert low_inventory_item["current_quantity"] == 5 assert low_inventory_item["location"] == test_inventory.location @@ -440,9 +479,13 @@ class TestInventoryService: def test_get_inventory_summary_by_location_success(self, db, test_inventory): """Test getting inventory summary by location successfully.""" - result = self.service.get_inventory_summary_by_location(db, test_inventory.location) + result = self.service.get_inventory_summary_by_location( + db, test_inventory.location + ) - assert result["location"] == test_inventory.location.upper() # Service normalizes to uppercase + assert ( + result["location"] == test_inventory.location.upper() + ) # Service normalizes to uppercase assert result["total_items"] >= 1 assert result["total_quantity"] >= test_inventory.quantity assert result["unique_gtins"] >= 1 @@ -450,7 +493,9 @@ class TestInventoryService: def test_get_inventory_summary_by_location_empty_result(self, db): """Test getting inventory summary for location with no inventory.""" unique_id = str(uuid.uuid4())[:8] - result = self.service.get_inventory_summary_by_location(db, f"EMPTY_LOCATION_{unique_id}") + result = self.service.get_inventory_summary_by_location( + db, f"EMPTY_LOCATION_{unique_id}" + ) assert result["total_items"] == 0 assert result["total_quantity"] == 0 @@ -459,12 +504,16 @@ class TestInventoryService: def test_validate_quantity_edge_cases(self, db): """Test quantity validation with edge cases.""" # Test zero quantity with allow_zero=True (should succeed) - inventory_data = InventoryCreate(gtin="1234567890123", location="WAREHOUSE_A", quantity=0) + inventory_data = InventoryCreate( + gtin="1234567890123", location="WAREHOUSE_A", quantity=0 + ) result = self.service.set_inventory(db, inventory_data) assert result.quantity == 0 # Test zero quantity with add_inventory (should fail - doesn't allow zero) - inventory_data_add = InventoryAdd(gtin="1234567890123", location="WAREHOUSE_B", quantity=0) + inventory_data_add = InventoryAdd( + gtin="1234567890123", location="WAREHOUSE_B", quantity=0 + ) with pytest.raises(InvalidQuantityException): self.service.add_inventory(db, inventory_data_add) @@ -477,10 +526,10 @@ class TestInventoryService: exception = exc_info.value # Verify exception structure matches WizamartException.to_dict() - assert hasattr(exception, 'error_code') - assert hasattr(exception, 'message') - assert hasattr(exception, 'status_code') - assert hasattr(exception, 'details') + assert hasattr(exception, "error_code") + assert hasattr(exception, "message") + assert hasattr(exception, "status_code") + assert hasattr(exception, "details") assert isinstance(exception.error_code, str) assert isinstance(exception.message, str) diff --git a/tests/unit/services/test_marketplace_service.py b/tests/unit/services/test_marketplace_service.py index 5a2151b4..a43d7add 100644 --- a/tests/unit/services/test_marketplace_service.py +++ b/tests/unit/services/test_marketplace_service.py @@ -4,19 +4,18 @@ from datetime import datetime import pytest -from app.exceptions.marketplace_import_job import ( - ImportJobNotFoundException, - ImportJobNotOwnedException, - ImportJobCannotBeCancelledException, - ImportJobCannotBeDeletedException, -) -from app.exceptions.vendor import VendorNotFoundException, UnauthorizedVendorAccessException from app.exceptions.base import ValidationException -from app.services.marketplace_import_job_service import MarketplaceImportJobService -from models.schema.marketplace_import_job import MarketplaceImportJobRequest +from app.exceptions.marketplace_import_job import ( + ImportJobCannotBeCancelledException, ImportJobCannotBeDeletedException, + ImportJobNotFoundException, ImportJobNotOwnedException) +from app.exceptions.vendor import (UnauthorizedVendorAccessException, + VendorNotFoundException) +from app.services.marketplace_import_job_service import \ + MarketplaceImportJobService from models.database.marketplace_import_job import MarketplaceImportJob -from models.database.vendor import Vendor from models.database.user import User +from models.database.vendor import Vendor +from models.schema.marketplace_import_job import MarketplaceImportJobRequest @pytest.mark.unit @@ -31,7 +30,9 @@ class TestMarketplaceService: test_vendor.owner_user_id = test_user.id db.commit() - result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_user) + result = self.service.validate_vendor_access( + db, test_vendor.vendor_code, test_user + ) assert result.vendor_code == test_vendor.vendor_code assert result.owner_user_id == test_user.id @@ -39,8 +40,10 @@ class TestMarketplaceService: def test_validate_vendor_access_admin_can_access_any_vendor( self, db, test_vendor, test_admin ): - """Test that admin users can access any vendor """ - result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_admin) + """Test that admin users can access any vendor""" + result = self.service.validate_vendor_access( + db, test_vendor.vendor_code, test_admin + ) assert result.vendor_code == test_vendor.vendor_code @@ -57,7 +60,7 @@ class TestMarketplaceService: def test_validate_vendor_access_permission_denied( self, db, test_vendor, test_user, other_user ): - """Test vendor access validation when user doesn't own the vendor """ + """Test vendor access validation when user doesn't own the vendor""" # Set the vendor owner to a different user test_vendor.owner_user_id = other_user.id db.commit() @@ -93,7 +96,7 @@ class TestMarketplaceService: assert result.vendor_name == test_vendor.name def test_create_import_job_invalid_vendor(self, db, test_user): - """Test import job creation with invalid vendor """ + """Test import job creation with invalid vendor""" request = MarketplaceImportJobRequest( url="https://example.com/products.csv", marketplace="Amazon", @@ -108,7 +111,9 @@ class TestMarketplaceService: assert exception.error_code == "VENDOR_NOT_FOUND" assert "INVALID_VENDOR" in exception.message - def test_create_import_job_unauthorized_access(self, db, test_vendor, test_user, other_user): + def test_create_import_job_unauthorized_access( + self, db, test_vendor, test_user, other_user + ): """Test import job creation with unauthorized vendor access""" # Set the vendor owner to a different user test_vendor.owner_user_id = other_user.id @@ -127,7 +132,9 @@ class TestMarketplaceService: exception = exc_info.value assert exception.error_code == "UNAUTHORIZED_VENDOR_ACCESS" - def test_get_import_job_by_id_success(self, db, test_marketplace_import_job, test_user): + def test_get_import_job_by_id_success( + self, db, test_marketplace_import_job, test_user + ): """Test getting import job by ID for job owner""" result = self.service.get_import_job_by_id( db, test_marketplace_import_job.id, test_user @@ -161,14 +168,18 @@ class TestMarketplaceService: ): """Test access denied when user doesn't own the job""" with pytest.raises(ImportJobNotOwnedException) as exc_info: - self.service.get_import_job_by_id(db, test_marketplace_import_job.id, other_user) + self.service.get_import_job_by_id( + db, test_marketplace_import_job.id, other_user + ) exception = exc_info.value assert exception.error_code == "IMPORT_JOB_NOT_OWNED" assert exception.status_code == 403 assert str(test_marketplace_import_job.id) in exception.message - def test_get_import_jobs_user_filter(self, db, test_marketplace_import_job, test_user): + def test_get_import_jobs_user_filter( + self, db, test_marketplace_import_job, test_user + ): """Test getting import jobs filtered by user""" jobs = self.service.get_import_jobs(db, test_user) @@ -176,7 +187,9 @@ class TestMarketplaceService: assert any(job.id == test_marketplace_import_job.id for job in jobs) assert test_marketplace_import_job.user_id == test_user.id - def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_import_job, test_admin): + def test_get_import_jobs_admin_sees_all( + self, db, test_marketplace_import_job, test_admin + ): """Test that admin sees all import jobs""" jobs = self.service.get_import_jobs(db, test_admin) @@ -192,7 +205,9 @@ class TestMarketplaceService: ) assert len(jobs) >= 1 - assert any(job.marketplace == test_marketplace_import_job.marketplace for job in jobs) + assert any( + job.marketplace == test_marketplace_import_job.marketplace for job in jobs + ) def test_get_import_jobs_with_pagination(self, db, test_user, test_vendor): """Test getting import jobs with pagination""" @@ -330,10 +345,14 @@ class TestMarketplaceService: exception = exc_info.value assert exception.error_code == "IMPORT_JOB_NOT_FOUND" - def test_cancel_import_job_access_denied(self, db, test_marketplace_import_job, other_user): + def test_cancel_import_job_access_denied( + self, db, test_marketplace_import_job, other_user + ): """Test cancelling import job without access""" with pytest.raises(ImportJobNotOwnedException) as exc_info: - self.service.cancel_import_job(db, test_marketplace_import_job.id, other_user) + self.service.cancel_import_job( + db, test_marketplace_import_job.id, other_user + ) exception = exc_info.value assert exception.error_code == "IMPORT_JOB_NOT_OWNED" @@ -347,7 +366,9 @@ class TestMarketplaceService: db.commit() with pytest.raises(ImportJobCannotBeCancelledException) as exc_info: - self.service.cancel_import_job(db, test_marketplace_import_job.id, test_user) + self.service.cancel_import_job( + db, test_marketplace_import_job.id, test_user + ) exception = exc_info.value assert exception.error_code == "IMPORT_JOB_CANNOT_BE_CANCELLED" @@ -396,10 +417,14 @@ class TestMarketplaceService: exception = exc_info.value assert exception.error_code == "IMPORT_JOB_NOT_FOUND" - def test_delete_import_job_access_denied(self, db, test_marketplace_import_job, other_user): + def test_delete_import_job_access_denied( + self, db, test_marketplace_import_job, other_user + ): """Test deleting import job without access""" with pytest.raises(ImportJobNotOwnedException) as exc_info: - self.service.delete_import_job(db, test_marketplace_import_job.id, other_user) + self.service.delete_import_job( + db, test_marketplace_import_job.id, other_user + ) exception = exc_info.value assert exception.error_code == "IMPORT_JOB_NOT_OWNED" @@ -440,11 +465,15 @@ class TestMarketplaceService: db.commit() # Test with lowercase vendor code - result = self.service.validate_vendor_access(db, test_vendor.vendor_code.lower(), test_user) + result = self.service.validate_vendor_access( + db, test_vendor.vendor_code.lower(), test_user + ) assert result.vendor_code == test_vendor.vendor_code # Test with uppercase vendor code - result = self.service.validate_vendor_access(db, test_vendor.vendor_code.upper(), test_user) + result = self.service.validate_vendor_access( + db, test_vendor.vendor_code.upper(), test_user + ) assert result.vendor_code == test_vendor.vendor_code def test_create_import_job_database_error(self, db_with_error, test_user): diff --git a/tests/unit/services/test_product_service.py b/tests/unit/services/test_product_service.py index 1a8f9367..1028b7a1 100644 --- a/tests/unit/services/test_product_service.py +++ b/tests/unit/services/test_product_service.py @@ -1,16 +1,15 @@ # tests/test_product_service.py import pytest +from app.exceptions import (InvalidMarketplaceProductDataException, + MarketplaceProductAlreadyExistsException, + MarketplaceProductNotFoundException, + MarketplaceProductValidationException, + ValidationException) from app.services.marketplace_product_service import MarketplaceProductService -from app.exceptions import ( - MarketplaceProductNotFoundException, - MarketplaceProductAlreadyExistsException, - InvalidMarketplaceProductDataException, - MarketplaceProductValidationException, - ValidationException, -) -from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate from models.database.marketplace_product import MarketplaceProduct +from models.schema.marketplace_product import (MarketplaceProductCreate, + MarketplaceProductUpdate) @pytest.mark.unit @@ -98,7 +97,10 @@ class TestProductService: assert exc_info.value.error_code == "PRODUCT_ALREADY_EXISTS" assert test_marketplace_product.marketplace_product_id in str(exc_info.value) assert exc_info.value.status_code == 409 - assert exc_info.value.details.get("marketplace_product_id") == test_marketplace_product.marketplace_product_id + assert ( + exc_info.value.details.get("marketplace_product_id") + == test_marketplace_product.marketplace_product_id + ) def test_create_product_invalid_price(self, db): """Test product creation with invalid price raises InvalidMarketplaceProductDataException""" @@ -117,9 +119,14 @@ class TestProductService: def test_get_product_by_id_or_raise_success(self, db, test_marketplace_product): """Test successful product retrieval by ID""" - product = self.service.get_product_by_id_or_raise(db, test_marketplace_product.marketplace_product_id) + product = self.service.get_product_by_id_or_raise( + db, test_marketplace_product.marketplace_product_id + ) - assert product.marketplace_product_id == test_marketplace_product.marketplace_product_id + assert ( + product.marketplace_product_id + == test_marketplace_product.marketplace_product_id + ) assert product.title == test_marketplace_product.title def test_get_product_by_id_or_raise_not_found(self, db): @@ -152,21 +159,35 @@ class TestProductService: assert total >= 1 assert len(products) >= 1 # Verify search worked by checking that title contains search term - found_product = next((p for p in products if p.marketplace_product_id == test_marketplace_product.marketplace_product_id), None) + found_product = next( + ( + p + for p in products + if p.marketplace_product_id + == test_marketplace_product.marketplace_product_id + ), + None, + ) assert found_product is not None def test_update_product_success(self, db, test_marketplace_product): """Test successful product update""" update_data = MarketplaceProductUpdate( - title="Updated MarketplaceProduct Title", - price="39.99" + title="Updated MarketplaceProduct Title", price="39.99" ) - updated_product = self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data) + updated_product = self.service.update_product( + db, test_marketplace_product.marketplace_product_id, update_data + ) assert updated_product.title == "Updated MarketplaceProduct Title" - assert updated_product.price == "39.99" # Price is stored as string after processing - assert updated_product.marketplace_product_id == test_marketplace_product.marketplace_product_id # ID unchanged + assert ( + updated_product.price == "39.99" + ) # Price is stored as string after processing + assert ( + updated_product.marketplace_product_id + == test_marketplace_product.marketplace_product_id + ) # ID unchanged def test_update_product_not_found(self, db): """Test updating non-existent product raises MarketplaceProductNotFoundException""" @@ -183,7 +204,9 @@ class TestProductService: update_data = MarketplaceProductUpdate(gtin="invalid_gtin") with pytest.raises(InvalidMarketplaceProductDataException) as exc_info: - self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data) + self.service.update_product( + db, test_marketplace_product.marketplace_product_id, update_data + ) assert exc_info.value.error_code == "INVALID_PRODUCT_DATA" assert "Invalid GTIN format" in str(exc_info.value) @@ -194,7 +217,9 @@ class TestProductService: update_data = MarketplaceProductUpdate(title="") with pytest.raises(MarketplaceProductValidationException) as exc_info: - self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data) + self.service.update_product( + db, test_marketplace_product.marketplace_product_id, update_data + ) assert exc_info.value.error_code == "PRODUCT_VALIDATION_FAILED" assert "MarketplaceProduct title cannot be empty" in str(exc_info.value) @@ -205,7 +230,9 @@ class TestProductService: update_data = MarketplaceProductUpdate(price="invalid_price") with pytest.raises(InvalidMarketplaceProductDataException) as exc_info: - self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data) + self.service.update_product( + db, test_marketplace_product.marketplace_product_id, update_data + ) assert exc_info.value.error_code == "INVALID_PRODUCT_DATA" assert "Invalid price format" in str(exc_info.value) @@ -213,12 +240,16 @@ class TestProductService: def test_delete_product_success(self, db, test_marketplace_product): """Test successful product deletion""" - result = self.service.delete_product(db, test_marketplace_product.marketplace_product_id) + result = self.service.delete_product( + db, test_marketplace_product.marketplace_product_id + ) assert result is True # Verify product is deleted - deleted_product = self.service.get_product_by_id(db, test_marketplace_product.marketplace_product_id) + deleted_product = self.service.get_product_by_id( + db, test_marketplace_product.marketplace_product_id + ) assert deleted_product is None def test_delete_product_not_found(self, db): @@ -229,10 +260,14 @@ class TestProductService: assert exc_info.value.error_code == "PRODUCT_NOT_FOUND" assert "NONEXISTENT" in str(exc_info.value) - def test_get_inventory_info_success(self, db, test_marketplace_product_with_inventory): + def test_get_inventory_info_success( + self, db, test_marketplace_product_with_inventory + ): """Test getting inventory info for product with inventory""" # Extract the product from the dictionary - marketplace_product = test_marketplace_product_with_inventory['marketplace_product'] + marketplace_product = test_marketplace_product_with_inventory[ + "marketplace_product" + ] inventory_info = self.service.get_inventory_info(db, marketplace_product.gtin) @@ -243,13 +278,17 @@ class TestProductService: def test_get_inventory_info_no_inventory(self, db, test_marketplace_product): """Test getting inventory info for product without inventory""" - inventory_info = self.service.get_inventory_info(db, test_marketplace_product.gtin or "1234567890123") + inventory_info = self.service.get_inventory_info( + db, test_marketplace_product.gtin or "1234567890123" + ) assert inventory_info is None def test_product_exists_true(self, db, test_marketplace_product): """Test product_exists returns True for existing product""" - exists = self.service.product_exists(db, test_marketplace_product.marketplace_product_id) + exists = self.service.product_exists( + db, test_marketplace_product.marketplace_product_id + ) assert exists is True def test_product_exists_false(self, db): @@ -265,7 +304,9 @@ class TestProductService: csv_lines = list(csv_generator) assert len(csv_lines) > 1 # Header + at least one data row - assert csv_lines[0].startswith("marketplace_product_id,title,description") # Check header + assert csv_lines[0].startswith( + "marketplace_product_id,title,description" + ) # Check header # Check that test product appears in CSV csv_content = "".join(csv_lines) @@ -274,8 +315,7 @@ class TestProductService: def test_generate_csv_export_with_filters(self, db, test_marketplace_product): """Test CSV export with marketplace filter""" csv_generator = self.service.generate_csv_export( - db, - marketplace=test_marketplace_product.marketplace + db, marketplace=test_marketplace_product.marketplace ) csv_lines = list(csv_generator) diff --git a/tests/unit/services/test_stats_service.py b/tests/unit/services/test_stats_service.py index 55582672..249aedfa 100644 --- a/tests/unit/services/test_stats_service.py +++ b/tests/unit/services/test_stats_service.py @@ -2,8 +2,8 @@ import pytest from app.services.stats_service import StatsService -from models.database.marketplace_product import MarketplaceProduct from models.database.inventory import Inventory +from models.database.marketplace_product import MarketplaceProduct @pytest.mark.unit @@ -15,7 +15,9 @@ class TestStatsService: """Setup method following the same pattern as other service tests""" self.service = StatsService() - def test_get_comprehensive_stats_basic(self, db, test_marketplace_product, test_inventory): + def test_get_comprehensive_stats_basic( + self, db, test_marketplace_product, test_inventory + ): """Test getting comprehensive stats with basic data""" stats = self.service.get_comprehensive_stats(db) @@ -31,7 +33,9 @@ class TestStatsService: assert stats["total_inventory_entries"] >= 1 assert stats["total_inventory_quantity"] >= 10 # test_inventory has quantity 10 - def test_get_comprehensive_stats_multiple_products(self, db, test_marketplace_product): + def test_get_comprehensive_stats_multiple_products( + self, db, test_marketplace_product + ): """Test comprehensive stats with multiple products across different dimensions""" # Create products with different brands, categories, marketplaces additional_products = [ @@ -87,7 +91,7 @@ class TestStatsService: brand=None, # Null brand google_product_category=None, # Null category marketplace=None, # Null marketplace - vendor_name=None, # Null vendor + vendor_name=None, # Null vendor price="10.00", currency="EUR", ), @@ -97,7 +101,7 @@ class TestStatsService: brand="", # Empty brand google_product_category="", # Empty category marketplace="", # Empty marketplace - vendor_name="", # Empty vendor + vendor_name="", # Empty vendor price="15.00", currency="EUR", ), @@ -124,7 +128,11 @@ class TestStatsService: # Find our test marketplace in the results test_marketplace_stat = next( - (stat for stat in stats if stat["marketplace"] == test_marketplace_product.marketplace), + ( + stat + for stat in stats + if stat["marketplace"] == test_marketplace_product.marketplace + ), None, ) assert test_marketplace_stat is not None @@ -309,7 +317,9 @@ class TestStatsService: count = self.service._get_unique_marketplaces_count(db) - assert count >= 2 # At least Amazon and eBay, plus test_marketplace_product marketplace + assert ( + count >= 2 + ) # At least Amazon and eBay, plus test_marketplace_product marketplace assert isinstance(count, int) def test_get_unique_vendors_count(self, db, test_marketplace_product): @@ -338,7 +348,9 @@ class TestStatsService: count = self.service._get_unique_vendors_count(db) - assert count >= 2 # At least VendorA and VendorB, plus test_marketplace_product vendor + assert ( + count >= 2 + ) # At least VendorA and VendorB, plus test_marketplace_product vendor assert isinstance(count, int) def test_get_inventory_statistics(self, db, test_inventory): @@ -438,7 +450,7 @@ class TestStatsService: db.add_all(marketplace_products) db.commit() - vendors =self.service._get_vendors_by_marketplace(db, "TestMarketplace") + vendors = self.service._get_vendors_by_marketplace(db, "TestMarketplace") assert len(vendors) == 2 assert "TestVendor1" in vendors @@ -482,7 +494,9 @@ class TestStatsService: def test_get_products_by_marketplace_not_found(self, db): """Test getting product count for non-existent marketplace""" - count = self.service._get_products_by_marketplace_count(db, "NonExistentMarketplace") + count = self.service._get_products_by_marketplace_count( + db, "NonExistentMarketplace" + ) assert count == 0 diff --git a/tests/unit/services/test_vendor_service.py b/tests/unit/services/test_vendor_service.py index c0a1bc1a..286ea923 100644 --- a/tests/unit/services/test_vendor_service.py +++ b/tests/unit/services/test_vendor_service.py @@ -1,19 +1,16 @@ # tests/test_vendor_service.py (updated to use custom exceptions) import pytest +from app.exceptions import (InvalidVendorDataException, + MarketplaceProductNotFoundException, + MaxVendorsReachedException, + ProductAlreadyExistsException, + UnauthorizedVendorAccessException, + ValidationException, VendorAlreadyExistsException, + VendorNotFoundException) from app.services.vendor_service import VendorService -from app.exceptions import ( - VendorNotFoundException, - VendorAlreadyExistsException, - UnauthorizedVendorAccessException, - InvalidVendorDataException, - MarketplaceProductNotFoundException, - ProductAlreadyExistsException, - MaxVendorsReachedException, - ValidationException, -) -from models.schema.vendor import VendorCreate from models.schema.product import ProductCreate +from models.schema.vendor import VendorCreate @pytest.mark.unit @@ -38,15 +35,17 @@ class TestVendorService: assert vendor is not None assert vendor.vendor_code == "NEWVENDOR" assert vendor.owner_user_id == test_user.id - assert vendor.is_verified is False # Regular user creates unverified vendor + assert vendor.is_verified is False # Regular user creates unverified vendor def test_create_vendor_admin_auto_verify(self, db, test_admin, vendor_factory): """Test admin creates verified vendor automatically""" - vendor_data = VendorCreate(vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor") + vendor_data = VendorCreate( + vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor" + ) vendor = self.service.create_vendor(db, vendor_data, test_admin) - assert vendor.is_verified is True # Admin creates verified vendor + assert vendor.is_verified is True # Admin creates verified vendor def test_create_vendor_duplicate_code(self, db, test_user, test_vendor): """Test vendor creation fails with duplicate vendor code""" @@ -88,7 +87,9 @@ class TestVendorService: def test_create_vendor_invalid_code_format(self, db, test_user): """Test vendor creation fails with invalid vendor code format""" - vendor_data = VendorCreate(vendor_code="INVALID@CODE!", vendor_name="Test Vendor") + vendor_data = VendorCreate( + vendor_code="INVALID@CODE!", vendor_name="Test Vendor" + ) with pytest.raises(InvalidVendorDataException) as exc_info: self.service.create_vendor(db, vendor_data, test_user) @@ -105,7 +106,9 @@ class TestVendorService: def mock_check_vendor_limit(self, db, user): raise MaxVendorsReachedException(max_vendors=5, user_id=user.id) - monkeypatch.setattr(VendorService, "_check_vendor_limit", mock_check_vendor_limit) + monkeypatch.setattr( + VendorService, "_check_vendor_limit", mock_check_vendor_limit + ) vendor_data = VendorCreate(vendor_code="NEWVENDOR", vendor_name="New Vendor") @@ -118,7 +121,9 @@ class TestVendorService: assert exception.details["max_vendors"] == 5 assert exception.details["user_id"] == test_user.id - def test_get_vendors_regular_user(self, db, test_user, test_vendor, inactive_vendor): + def test_get_vendors_regular_user( + self, db, test_user, test_vendor, inactive_vendor + ): """Test regular user can only see active verified vendors and own vendors""" vendors, total = self.service.get_vendors(db, test_user, skip=0, limit=10) @@ -127,7 +132,7 @@ class TestVendorService: assert inactive_vendor.vendor_code not in vendor_codes def test_get_vendors_admin_user( - self, db, test_admin, test_vendor, inactive_vendor, verified_vendor + self, db, test_admin, test_vendor, inactive_vendor, verified_vendor ): """Test admin user can see all vendors with filters""" vendors, total = self.service.get_vendors( @@ -140,14 +145,16 @@ class TestVendorService: assert verified_vendor.vendor_code in vendor_codes def test_get_vendor_by_code_owner_access(self, db, test_user, test_vendor): - """Test vendor owner can access their own vendor """ - vendor = self.service.get_vendor_by_code(db, test_vendor.vendor_code.lower(), test_user) + """Test vendor owner can access their own vendor""" + vendor = self.service.get_vendor_by_code( + db, test_vendor.vendor_code.lower(), test_user + ) assert vendor is not None assert vendor.id == test_vendor.id def test_get_vendor_by_code_admin_access(self, db, test_admin, test_vendor): - """Test admin can access any vendor """ + """Test admin can access any vendor""" vendor = self.service.get_vendor_by_code( db, test_vendor.vendor_code.lower(), test_admin ) @@ -178,16 +185,14 @@ class TestVendorService: assert exception.details["user_id"] == test_user.id def test_add_product_to_vendor_success(self, db, test_vendor, unique_product): - """Test successfully adding product to vendor """ + """Test successfully adding product to vendor""" product_data = ProductCreate( marketplace_product_id=unique_product.marketplace_product_id, price="15.99", is_featured=True, ) - product = self.service.add_product_to_catalog( - db, test_vendor, product_data - ) + product = self.service.add_product_to_catalog(db, test_vendor, product_data) assert product is not None assert product.vendor_id == test_vendor.id @@ -195,7 +200,9 @@ class TestVendorService: def test_add_product_to_vendor_product_not_found(self, db, test_vendor): """Test adding non-existent product to vendor fails""" - product_data = ProductCreate(marketplace_product_id="NONEXISTENT", price="15.99") + product_data = ProductCreate( + marketplace_product_id="NONEXISTENT", price="15.99" + ) with pytest.raises(MarketplaceProductNotFoundException) as exc_info: self.service.add_product_to_catalog(db, test_vendor, product_data) @@ -209,7 +216,8 @@ class TestVendorService: def test_add_product_to_vendor_already_exists(self, db, test_vendor, test_product): """Test adding product that's already in vendor fails""" product_data = ProductCreate( - marketplace_product_id=test_product.marketplace_product.marketplace_product_id, price="15.99" + marketplace_product_id=test_product.marketplace_product.marketplace_product_id, + price="15.99", ) with pytest.raises(ProductAlreadyExistsException) as exc_info: @@ -219,11 +227,12 @@ class TestVendorService: assert exception.status_code == 409 assert exception.error_code == "PRODUCT_ALREADY_EXISTS" assert exception.details["vendor_code"] == test_vendor.vendor_code - assert exception.details["marketplace_product_id"] == test_product.marketplace_product.marketplace_product_id + assert ( + exception.details["marketplace_product_id"] + == test_product.marketplace_product.marketplace_product_id + ) - def test_get_products_owner_access( - self, db, test_user, test_vendor, test_product - ): + def test_get_products_owner_access(self, db, test_user, test_vendor, test_product): """Test vendor owner can get vendor products""" products, total = self.service.get_products(db, test_vendor, test_user) @@ -291,7 +300,9 @@ class TestVendorService: assert exception.error_code == "VALIDATION_ERROR" assert "Failed to retrieve vendors" in exception.message - def test_add_product_database_error(self, db, test_vendor, unique_product, monkeypatch): + def test_add_product_database_error( + self, db, test_vendor, unique_product, monkeypatch + ): """Test add product handles database errors gracefully""" def mock_commit(): diff --git a/tests/unit/utils/test_csv_processor.py b/tests/unit/utils/test_csv_processor.py index 4f35ee16..85615a10 100644 --- a/tests/unit/utils/test_csv_processor.py +++ b/tests/unit/utils/test_csv_processor.py @@ -18,7 +18,9 @@ class TestCSVProcessor: def test_download_csv_encoding_fallback(self, mock_get): """Test CSV download with encoding fallback""" # Create content with special characters that would fail UTF-8 if not properly encoded - special_content = "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99" + special_content = ( + "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99" + ) mock_response = Mock() mock_response.status_code = 200 @@ -40,9 +42,7 @@ class TestCSVProcessor: mock_response = Mock() mock_response.status_code = 200 # Create bytes that will fail most encodings - mock_response.content = ( - b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99" - ) + mock_response.content = b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99" mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response