From 21c13ca39b67c28d70db123382717043812e881d Mon Sep 17 00:00:00 2001
From: Samir Boulahtit
Date: Fri, 28 Nov 2025 19:30:17 +0100
Subject: [PATCH] style: apply black and isort formatting across entire
codebase
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards
Also updated Makefile to exclude both venv and .venv from code quality checks.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
---
Makefile | 12 +-
alembic/env.py | 18 +-
...51b2e50581_initial_migration_all_tables.py | 1388 ++++++++++-------
...07_ensure_content_pages_table_with_all_.py | 87 +-
...dd_architecture_quality_tracking_tables.py | 348 +++--
.../a2064e1dfcd4_add_cart_items_table.py | 63 +-
...dd_template_field_to_content_pages_for_.py | 16 +-
.../fa7d4d10e358_add_rbac_enhancements.py | 95 +-
...20ce8b4_add_content_pages_table_for_cms.py | 31 +-
app/api/deps.py | 193 ++-
app/api/main.py | 22 +-
app/api/v1/__init__.py | 4 +-
app/api/v1/admin/__init__.py | 26 +-
app/api/v1/admin/audit.py | 68 +-
app/api/v1/admin/auth.py | 15 +-
app/api/v1/admin/code_quality.py | 161 +-
app/api/v1/admin/content_pages.py | 65 +-
app/api/v1/admin/dashboard.py | 3 +-
app/api/v1/admin/marketplace.py | 2 +-
app/api/v1/admin/monitoring.py | 2 +-
app/api/v1/admin/notifications.py | 36 +-
app/api/v1/admin/settings.py | 93 +-
app/api/v1/admin/users.py | 2 +-
app/api/v1/admin/vendor_domains.py | 93 +-
app/api/v1/admin/vendor_themes.py | 46 +-
app/api/v1/admin/vendors.py | 116 +-
app/api/v1/shared/health.py | 2 +-
app/api/v1/shared/uploads.py | 2 +-
app/api/v1/shop/__init__.py | 6 +-
app/api/v1/shop/auth.py | 114 +-
app/api/v1/shop/cart.py | 81 +-
app/api/v1/shop/content_pages.py | 26 +-
app/api/v1/shop/orders.py | 72 +-
app/api/v1/shop/products.py | 44 +-
app/api/v1/vendor/__init__.py | 28 +-
app/api/v1/vendor/analytics.py | 11 +-
app/api/v1/vendor/auth.py | 77 +-
app/api/v1/vendor/content_pages.py | 82 +-
app/api/v1/vendor/customers.py | 79 +-
app/api/v1/vendor/dashboard.py | 28 +-
app/api/v1/vendor/info.py | 23 +-
app/api/v1/vendor/inventory.py | 104 +-
app/api/v1/vendor/marketplace.py | 48 +-
app/api/v1/vendor/media.py | 104 +-
app/api/v1/vendor/notifications.py | 119 +-
app/api/v1/vendor/orders.py | 59 +-
app/api/v1/vendor/payments.py | 86 +-
app/api/v1/vendor/products.py | 124 +-
app/api/v1/vendor/profile.py | 19 +-
app/api/v1/vendor/settings.py | 17 +-
app/api/v1/vendor/team.py | 193 +--
app/core/config.py | 39 +-
app/core/environment.py | 25 +-
app/core/lifespan.py | 9 +-
app/core/permissions.py | 13 +-
app/core/theme_presets.py | 112 +-
app/exceptions/__init__.py | 285 ++--
app/exceptions/admin.py | 71 +-
app/exceptions/auth.py | 16 +-
app/exceptions/backup.py | 2 +-
app/exceptions/base.py | 9 +-
app/exceptions/cart.py | 24 +-
app/exceptions/customer.py | 37 +-
app/exceptions/error_renderer.py | 23 +-
app/exceptions/handler.py | 109 +-
app/exceptions/inventory.py | 34 +-
app/exceptions/marketplace.py | 2 +-
app/exceptions/marketplace_import_job.py | 51 +-
app/exceptions/marketplace_product.py | 28 +-
app/exceptions/media.py | 2 +-
app/exceptions/monitoring.py | 2 +-
app/exceptions/notification.py | 2 +-
app/exceptions/order.py | 26 +-
app/exceptions/payment.py | 2 +-
app/exceptions/product.py | 9 +-
app/exceptions/search.py | 2 +-
app/exceptions/team.py | 61 +-
app/exceptions/vendor.py | 33 +-
app/exceptions/vendor_domain.py | 46 +-
app/exceptions/vendor_theme.py | 39 +-
app/models/architecture_scan.py | 92 +-
app/routes/admin_pages.py | 261 ++--
app/routes/shop_pages.py | 313 ++--
app/routes/vendor_pages.py | 169 +-
app/services/admin_audit_service.py | 54 +-
app/services/admin_service.py | 246 +--
app/services/admin_settings_service.py | 114 +-
app/services/audit_service.py | 2 +-
app/services/auth_service.py | 22 +-
app/services/backup_service.py | 2 +-
app/services/cache_service.py | 2 +-
app/services/cart_service.py | 266 ++--
app/services/code_quality_service.py | 287 ++--
app/services/configuration_service.py | 2 +-
app/services/content_page_service.py | 57 +-
app/services/customer_service.py | 190 ++-
app/services/inventory_service.py | 139 +-
.../marketplace_import_job_service.py | 25 +-
app/services/marketplace_product_service.py | 195 ++-
app/services/media_service.py | 2 +-
app/services/monitoring_service.py | 2 +-
app/services/notification_service.py | 2 +-
app/services/order_service.py | 172 +-
app/services/payment_service.py | 2 +-
app/services/product_service.py | 71 +-
app/services/search_service.py | 2 +-
app/services/stats_service.py | 232 +--
app/services/team_service.py | 81 +-
app/services/vendor_domain_service.py | 168 +-
app/services/vendor_service.py | 152 +-
app/services/vendor_team_service.py | 208 +--
app/services/vendor_theme_service.py | 161 +-
app/tasks/background_tasks.py | 6 +-
app/utils/csv_processor.py | 11 +-
app/utils/data_processing.py | 4 +-
app/utils/database.py | 1 +
main.py | 142 +-
middleware/auth.py | 35 +-
middleware/context.py | 35 +-
middleware/decorators.py | 7 +-
middleware/theme_context.py | 39 +-
middleware/vendor_context.py | 159 +-
models/__init__.py | 17 +-
models/database/__init__.py | 17 +-
models/database/admin.py | 40 +-
models/database/audit.py | 2 +-
models/database/backup.py | 2 +-
models/database/base.py | 5 +-
models/database/cart.py | 5 +-
models/database/configuration.py | 2 +-
models/database/content_page.py | 45 +-
models/database/customer.py | 14 +-
models/database/inventory.py | 8 +-
models/database/marketplace.py | 2 +-
models/database/marketplace_import_job.py | 3 +-
models/database/marketplace_product.py | 4 +-
models/database/media.py | 2 +-
models/database/monitoring.py | 2 +-
models/database/notification.py | 2 +-
models/database/order.py | 26 +-
models/database/payment.py | 2 +-
models/database/product.py | 12 +-
models/database/search.py | 2 +-
models/database/task.py | 2 +-
models/database/user.py | 17 +-
models/database/vendor.py | 118 +-
models/database/vendor_domain.py | 23 +-
models/database/vendor_theme.py | 41 +-
models/schema/__init__.py | 9 +-
models/schema/admin.py | 75 +-
models/schema/auth.py | 8 +-
models/schema/cart.py | 24 +-
models/schema/customer.py | 27 +-
models/schema/inventory.py | 10 +-
models/schema/marketplace.py | 2 +-
models/schema/marketplace_import_job.py | 9 +-
models/schema/marketplace_product.py | 12 +-
models/schema/media.py | 2 +-
models/schema/monitoring.py | 2 +-
models/schema/notification.py | 2 +-
models/schema/order.py | 20 +-
models/schema/payment.py | 2 +-
models/schema/product.py | 13 +-
models/schema/search.py | 2 +-
models/schema/stats.py | 11 +-
models/schema/team.py | 90 +-
models/schema/vendor.py | 62 +-
models/schema/vendor_domain.py | 25 +-
models/schema/vendor_theme.py | 40 +-
scripts/backup_database.py | 10 +-
scripts/create_default_content_pages.py | 18 +-
scripts/create_inventory.py | 42 +-
scripts/create_landing_page.py | 45 +-
scripts/create_platform_pages.py | 144 +-
scripts/init_production.py | 25 +-
scripts/route_diagnostics.py | 28 +-
scripts/seed_demo.py | 129 +-
scripts/show_structure.py | 317 ++--
scripts/test_auth_complete.py | 198 +--
scripts/test_vendor_management.py | 41 +-
scripts/validate_architecture.py | 271 ++--
scripts/verify_setup.py | 72 +-
storage/backends.py | 2 +-
storage/utils.py | 2 +-
tasks/analytics_tasks.py | 2 +-
tasks/backup_tasks.py | 2 +-
tasks/cleanup_tasks.py | 2 +-
tasks/email_tasks.py | 2 +-
tasks/marketplace_import.py | 8 +-
tasks/media_processing.py | 2 +-
tasks/search_indexing.py | 2 +-
tasks/task_manager.py | 2 +-
tests/conftest.py | 4 +-
tests/fixtures/auth_fixtures.py | 2 +
tests/fixtures/customer_fixtures.py | 1 +
.../marketplace_import_job_fixtures.py | 1 +
.../fixtures/marketplace_product_fixtures.py | 8 +-
tests/fixtures/testing_fixtures.py | 5 +-
tests/fixtures/vendor_fixtures.py | 6 +-
.../api/v1/test_admin_endpoints.py | 25 +-
.../integration/api/v1/test_auth_endpoints.py | 14 +-
tests/integration/api/v1/test_filtering.py | 97 +-
.../api/v1/test_inventory_endpoints.py | 85 +-
.../test_marketplace_import_job_endpoints.py | 125 +-
.../api/v1/test_marketplace_product_export.py | 101 +-
.../v1/test_marketplace_products_endpoints.py | 88 +-
tests/integration/api/v1/test_pagination.py | 102 +-
.../api/v1/test_stats_endpoints.py | 5 +-
.../api/v1/test_vendor_endpoints.py | 83 +-
.../api/v1/vendor/test_authentication.py | 98 +-
.../api/v1/vendor/test_dashboard.py | 30 +-
tests/integration/middleware/conftest.py | 22 +-
.../middleware/test_context_detection_flow.py | 267 +++-
.../middleware/test_middleware_stack.py | 166 +-
.../middleware/test_theme_loading_flow.py | 218 ++-
.../middleware/test_vendor_context_flow.py | 243 ++-
.../security/test_input_validation.py | 19 +-
.../integration/workflows/test_integration.py | 16 +-
tests/performance/test_api_performance.py | 11 +-
tests/system/test_error_handling.py | 107 +-
tests/unit/middleware/test_auth.py | 100 +-
tests/unit/middleware/test_context.py | 48 +-
tests/unit/middleware/test_decorators.py | 30 +-
tests/unit/middleware/test_logging.py | 39 +-
tests/unit/middleware/test_rate_limiter.py | 25 +-
tests/unit/middleware/test_theme_context.py | 43 +-
tests/unit/middleware/test_vendor_context.py | 207 ++-
tests/unit/models/test_database_models.py | 27 +-
tests/unit/services/test_admin_service.py | 46 +-
tests/unit/services/test_auth_service.py | 22 +-
tests/unit/services/test_inventory_service.py | 133 +-
.../unit/services/test_marketplace_service.py | 85 +-
tests/unit/services/test_product_service.py | 98 +-
tests/unit/services/test_stats_service.py | 34 +-
tests/unit/services/test_vendor_service.py | 75 +-
tests/unit/utils/test_csv_processor.py | 8 +-
236 files changed, 8450 insertions(+), 6545 deletions(-)
diff --git a/Makefile b/Makefile
index 75b36c8f..d201da6a 100644
--- a/Makefile
+++ b/Makefile
@@ -176,19 +176,19 @@ test-inventory:
format:
@echo "Running black..."
- $(PYTHON) -m black . --exclude venv
+ $(PYTHON) -m black . --exclude '/(\.)?venv/'
@echo "Running isort..."
- $(PYTHON) -m isort . --skip venv
+ $(PYTHON) -m isort . --skip venv --skip .venv
lint:
@echo "Running linting..."
- $(PYTHON) -m ruff check . --exclude venv
- $(PYTHON) -m mypy . --ignore-missing-imports --exclude venv
+ $(PYTHON) -m ruff check . --exclude venv --exclude .venv
+ $(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*'
lint-flake8:
@echo "Running linting..."
- $(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,__pycache__,.git
- $(PYTHON) -m mypy . --ignore-missing-imports --exclude venv
+ $(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,.venv,__pycache__,.git
+ $(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*'
check: format lint
diff --git a/alembic/env.py b/alembic/env.py
index ae52fcdf..8605ecc3 100644
--- a/alembic/env.py
+++ b/alembic/env.py
@@ -15,6 +15,7 @@ import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
+
from alembic import context
# Add your project directory to the Python path
@@ -39,13 +40,9 @@ print("=" * 70)
# ADMIN MODELS
# ----------------------------------------------------------------------------
try:
- from models.database.admin import (
- AdminAuditLog,
- AdminNotification,
- AdminSetting,
- PlatformAlert,
- AdminSession
- )
+ from models.database.admin import (AdminAuditLog, AdminNotification,
+ AdminSession, AdminSetting,
+ PlatformAlert)
print(" ✓ Admin models imported (5 models)")
print(" - AdminAuditLog")
@@ -70,7 +67,7 @@ except ImportError as e:
# VENDOR MODELS
# ----------------------------------------------------------------------------
try:
- from models.database.vendor import Vendor, VendorUser, Role
+ from models.database.vendor import Role, Vendor, VendorUser
print(" ✓ Vendor models imported (3 models)")
print(" - Vendor")
@@ -248,10 +245,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
- context.configure(
- connection=connection,
- target_metadata=target_metadata
- )
+ context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
diff --git a/alembic/versions/4951b2e50581_initial_migration_all_tables.py b/alembic/versions/4951b2e50581_initial_migration_all_tables.py
index 36b6fe27..bdf730a1 100644
--- a/alembic/versions/4951b2e50581_initial_migration_all_tables.py
+++ b/alembic/versions/4951b2e50581_initial_migration_all_tables.py
@@ -1,18 +1,19 @@
"""Initial migration - all tables
Revision ID: 4951b2e50581
-Revises:
+Revises:
Create Date: 2025-10-27 22:28:33.137564
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = '4951b2e50581'
+revision: str = "4951b2e50581"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -20,549 +21,888 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('marketplace_products',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('marketplace_product_id', sa.String(), nullable=False),
- sa.Column('title', sa.String(), nullable=False),
- sa.Column('description', sa.String(), nullable=True),
- sa.Column('link', sa.String(), nullable=True),
- sa.Column('image_link', sa.String(), nullable=True),
- sa.Column('availability', sa.String(), nullable=True),
- sa.Column('price', sa.String(), nullable=True),
- sa.Column('brand', sa.String(), nullable=True),
- sa.Column('gtin', sa.String(), nullable=True),
- sa.Column('mpn', sa.String(), nullable=True),
- sa.Column('condition', sa.String(), nullable=True),
- sa.Column('adult', sa.String(), nullable=True),
- sa.Column('multipack', sa.Integer(), nullable=True),
- sa.Column('is_bundle', sa.String(), nullable=True),
- sa.Column('age_group', sa.String(), nullable=True),
- sa.Column('color', sa.String(), nullable=True),
- sa.Column('gender', sa.String(), nullable=True),
- sa.Column('material', sa.String(), nullable=True),
- sa.Column('pattern', sa.String(), nullable=True),
- sa.Column('size', sa.String(), nullable=True),
- sa.Column('size_type', sa.String(), nullable=True),
- sa.Column('size_system', sa.String(), nullable=True),
- sa.Column('item_group_id', sa.String(), nullable=True),
- sa.Column('google_product_category', sa.String(), nullable=True),
- sa.Column('product_type', sa.String(), nullable=True),
- sa.Column('custom_label_0', sa.String(), nullable=True),
- sa.Column('custom_label_1', sa.String(), nullable=True),
- sa.Column('custom_label_2', sa.String(), nullable=True),
- sa.Column('custom_label_3', sa.String(), nullable=True),
- sa.Column('custom_label_4', sa.String(), nullable=True),
- sa.Column('additional_image_link', sa.String(), nullable=True),
- sa.Column('sale_price', sa.String(), nullable=True),
- sa.Column('unit_pricing_measure', sa.String(), nullable=True),
- sa.Column('unit_pricing_base_measure', sa.String(), nullable=True),
- sa.Column('identifier_exists', sa.String(), nullable=True),
- sa.Column('shipping', sa.String(), nullable=True),
- sa.Column('currency', sa.String(), nullable=True),
- sa.Column('marketplace', sa.String(), nullable=True),
- sa.Column('vendor_name', sa.String(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.PrimaryKeyConstraint('id')
+ op.create_table(
+ "marketplace_products",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("marketplace_product_id", sa.String(), nullable=False),
+ sa.Column("title", sa.String(), nullable=False),
+ sa.Column("description", sa.String(), nullable=True),
+ sa.Column("link", sa.String(), nullable=True),
+ sa.Column("image_link", sa.String(), nullable=True),
+ sa.Column("availability", sa.String(), nullable=True),
+ sa.Column("price", sa.String(), nullable=True),
+ sa.Column("brand", sa.String(), nullable=True),
+ sa.Column("gtin", sa.String(), nullable=True),
+ sa.Column("mpn", sa.String(), nullable=True),
+ sa.Column("condition", sa.String(), nullable=True),
+ sa.Column("adult", sa.String(), nullable=True),
+ sa.Column("multipack", sa.Integer(), nullable=True),
+ sa.Column("is_bundle", sa.String(), nullable=True),
+ sa.Column("age_group", sa.String(), nullable=True),
+ sa.Column("color", sa.String(), nullable=True),
+ sa.Column("gender", sa.String(), nullable=True),
+ sa.Column("material", sa.String(), nullable=True),
+ sa.Column("pattern", sa.String(), nullable=True),
+ sa.Column("size", sa.String(), nullable=True),
+ sa.Column("size_type", sa.String(), nullable=True),
+ sa.Column("size_system", sa.String(), nullable=True),
+ sa.Column("item_group_id", sa.String(), nullable=True),
+ sa.Column("google_product_category", sa.String(), nullable=True),
+ sa.Column("product_type", sa.String(), nullable=True),
+ sa.Column("custom_label_0", sa.String(), nullable=True),
+ sa.Column("custom_label_1", sa.String(), nullable=True),
+ sa.Column("custom_label_2", sa.String(), nullable=True),
+ sa.Column("custom_label_3", sa.String(), nullable=True),
+ sa.Column("custom_label_4", sa.String(), nullable=True),
+ sa.Column("additional_image_link", sa.String(), nullable=True),
+ sa.Column("sale_price", sa.String(), nullable=True),
+ sa.Column("unit_pricing_measure", sa.String(), nullable=True),
+ sa.Column("unit_pricing_base_measure", sa.String(), nullable=True),
+ sa.Column("identifier_exists", sa.String(), nullable=True),
+ sa.Column("shipping", sa.String(), nullable=True),
+ sa.Column("currency", sa.String(), nullable=True),
+ sa.Column("marketplace", sa.String(), nullable=True),
+ sa.Column("vendor_name", sa.String(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
)
- op.create_index('idx_marketplace_brand', 'marketplace_products', ['marketplace', 'brand'], unique=False)
- op.create_index('idx_marketplace_vendor', 'marketplace_products', ['marketplace', 'vendor_name'], unique=False)
- op.create_index(op.f('ix_marketplace_products_availability'), 'marketplace_products', ['availability'], unique=False)
- op.create_index(op.f('ix_marketplace_products_brand'), 'marketplace_products', ['brand'], unique=False)
- op.create_index(op.f('ix_marketplace_products_google_product_category'), 'marketplace_products', ['google_product_category'], unique=False)
- op.create_index(op.f('ix_marketplace_products_gtin'), 'marketplace_products', ['gtin'], unique=False)
- op.create_index(op.f('ix_marketplace_products_id'), 'marketplace_products', ['id'], unique=False)
- op.create_index(op.f('ix_marketplace_products_marketplace'), 'marketplace_products', ['marketplace'], unique=False)
- op.create_index(op.f('ix_marketplace_products_marketplace_product_id'), 'marketplace_products', ['marketplace_product_id'], unique=True)
- op.create_index(op.f('ix_marketplace_products_vendor_name'), 'marketplace_products', ['vendor_name'], unique=False)
- op.create_table('users',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('email', sa.String(), nullable=False),
- sa.Column('username', sa.String(), nullable=False),
- sa.Column('first_name', sa.String(), nullable=True),
- sa.Column('last_name', sa.String(), nullable=True),
- sa.Column('hashed_password', sa.String(), nullable=False),
- sa.Column('role', sa.String(), nullable=False),
- sa.Column('is_active', sa.Boolean(), nullable=False),
- sa.Column('last_login', sa.DateTime(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ "idx_marketplace_brand",
+ "marketplace_products",
+ ["marketplace", "brand"],
+ unique=False,
)
- op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
- op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
- op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
- op.create_table('admin_audit_logs',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('admin_user_id', sa.Integer(), nullable=False),
- sa.Column('action', sa.String(length=100), nullable=False),
- sa.Column('target_type', sa.String(length=50), nullable=False),
- sa.Column('target_id', sa.String(length=100), nullable=False),
- sa.Column('details', sa.JSON(), nullable=True),
- sa.Column('ip_address', sa.String(length=45), nullable=True),
- sa.Column('user_agent', sa.Text(), nullable=True),
- sa.Column('request_id', sa.String(length=100), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['admin_user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ "idx_marketplace_vendor",
+ "marketplace_products",
+ ["marketplace", "vendor_name"],
+ unique=False,
)
- op.create_index(op.f('ix_admin_audit_logs_action'), 'admin_audit_logs', ['action'], unique=False)
- op.create_index(op.f('ix_admin_audit_logs_admin_user_id'), 'admin_audit_logs', ['admin_user_id'], unique=False)
- op.create_index(op.f('ix_admin_audit_logs_id'), 'admin_audit_logs', ['id'], unique=False)
- op.create_index(op.f('ix_admin_audit_logs_target_id'), 'admin_audit_logs', ['target_id'], unique=False)
- op.create_index(op.f('ix_admin_audit_logs_target_type'), 'admin_audit_logs', ['target_type'], unique=False)
- op.create_table('admin_notifications',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('type', sa.String(length=50), nullable=False),
- sa.Column('priority', sa.String(length=20), nullable=True),
- sa.Column('title', sa.String(length=200), nullable=False),
- sa.Column('message', sa.Text(), nullable=False),
- sa.Column('is_read', sa.Boolean(), nullable=True),
- sa.Column('read_at', sa.DateTime(), nullable=True),
- sa.Column('read_by_user_id', sa.Integer(), nullable=True),
- sa.Column('action_required', sa.Boolean(), nullable=True),
- sa.Column('action_url', sa.String(length=500), nullable=True),
- sa.Column('notification_metadata', sa.JSON(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['read_by_user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_availability"),
+ "marketplace_products",
+ ["availability"],
+ unique=False,
)
- op.create_index(op.f('ix_admin_notifications_action_required'), 'admin_notifications', ['action_required'], unique=False)
- op.create_index(op.f('ix_admin_notifications_id'), 'admin_notifications', ['id'], unique=False)
- op.create_index(op.f('ix_admin_notifications_is_read'), 'admin_notifications', ['is_read'], unique=False)
- op.create_index(op.f('ix_admin_notifications_priority'), 'admin_notifications', ['priority'], unique=False)
- op.create_index(op.f('ix_admin_notifications_type'), 'admin_notifications', ['type'], unique=False)
- op.create_table('admin_sessions',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('admin_user_id', sa.Integer(), nullable=False),
- sa.Column('session_token', sa.String(length=255), nullable=False),
- sa.Column('ip_address', sa.String(length=45), nullable=False),
- sa.Column('user_agent', sa.Text(), nullable=True),
- sa.Column('login_at', sa.DateTime(), nullable=False),
- sa.Column('last_activity_at', sa.DateTime(), nullable=False),
- sa.Column('logout_at', sa.DateTime(), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=True),
- sa.Column('logout_reason', sa.String(length=50), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['admin_user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_brand"),
+ "marketplace_products",
+ ["brand"],
+ unique=False,
)
- op.create_index(op.f('ix_admin_sessions_admin_user_id'), 'admin_sessions', ['admin_user_id'], unique=False)
- op.create_index(op.f('ix_admin_sessions_id'), 'admin_sessions', ['id'], unique=False)
- op.create_index(op.f('ix_admin_sessions_is_active'), 'admin_sessions', ['is_active'], unique=False)
- op.create_index(op.f('ix_admin_sessions_login_at'), 'admin_sessions', ['login_at'], unique=False)
- op.create_index(op.f('ix_admin_sessions_session_token'), 'admin_sessions', ['session_token'], unique=True)
- op.create_table('admin_settings',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('key', sa.String(length=100), nullable=False),
- sa.Column('value', sa.Text(), nullable=False),
- sa.Column('value_type', sa.String(length=20), nullable=True),
- sa.Column('category', sa.String(length=50), nullable=True),
- sa.Column('description', sa.Text(), nullable=True),
- sa.Column('is_encrypted', sa.Boolean(), nullable=True),
- sa.Column('is_public', sa.Boolean(), nullable=True),
- sa.Column('last_modified_by_user_id', sa.Integer(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['last_modified_by_user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_google_product_category"),
+ "marketplace_products",
+ ["google_product_category"],
+ unique=False,
)
- op.create_index(op.f('ix_admin_settings_category'), 'admin_settings', ['category'], unique=False)
- op.create_index(op.f('ix_admin_settings_id'), 'admin_settings', ['id'], unique=False)
- op.create_index(op.f('ix_admin_settings_key'), 'admin_settings', ['key'], unique=True)
- op.create_table('platform_alerts',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('alert_type', sa.String(length=50), nullable=False),
- sa.Column('severity', sa.String(length=20), nullable=False),
- sa.Column('title', sa.String(length=200), nullable=False),
- sa.Column('description', sa.Text(), nullable=True),
- sa.Column('affected_vendors', sa.JSON(), nullable=True),
- sa.Column('affected_systems', sa.JSON(), nullable=True),
- sa.Column('is_resolved', sa.Boolean(), nullable=True),
- sa.Column('resolved_at', sa.DateTime(), nullable=True),
- sa.Column('resolved_by_user_id', sa.Integer(), nullable=True),
- sa.Column('resolution_notes', sa.Text(), nullable=True),
- sa.Column('auto_generated', sa.Boolean(), nullable=True),
- sa.Column('occurrence_count', sa.Integer(), nullable=True),
- sa.Column('first_occurred_at', sa.DateTime(), nullable=False),
- sa.Column('last_occurred_at', sa.DateTime(), nullable=False),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['resolved_by_user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_gtin"),
+ "marketplace_products",
+ ["gtin"],
+ unique=False,
)
- op.create_index(op.f('ix_platform_alerts_alert_type'), 'platform_alerts', ['alert_type'], unique=False)
- op.create_index(op.f('ix_platform_alerts_id'), 'platform_alerts', ['id'], unique=False)
- op.create_index(op.f('ix_platform_alerts_is_resolved'), 'platform_alerts', ['is_resolved'], unique=False)
- op.create_index(op.f('ix_platform_alerts_severity'), 'platform_alerts', ['severity'], unique=False)
- op.create_table('vendors',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_code', sa.String(), nullable=False),
- sa.Column('subdomain', sa.String(length=100), nullable=False),
- sa.Column('name', sa.String(), nullable=False),
- sa.Column('description', sa.Text(), nullable=True),
- sa.Column('owner_user_id', sa.Integer(), nullable=False),
- sa.Column('contact_email', sa.String(), nullable=True),
- sa.Column('contact_phone', sa.String(), nullable=True),
- sa.Column('website', sa.String(), nullable=True),
- sa.Column('letzshop_csv_url_fr', sa.String(), nullable=True),
- sa.Column('letzshop_csv_url_en', sa.String(), nullable=True),
- sa.Column('letzshop_csv_url_de', sa.String(), nullable=True),
- sa.Column('business_address', sa.Text(), nullable=True),
- sa.Column('tax_number', sa.String(), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=True),
- sa.Column('is_verified', sa.Boolean(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_id"), "marketplace_products", ["id"], unique=False
)
- op.create_index(op.f('ix_vendors_id'), 'vendors', ['id'], unique=False)
- op.create_index(op.f('ix_vendors_subdomain'), 'vendors', ['subdomain'], unique=True)
- op.create_index(op.f('ix_vendors_vendor_code'), 'vendors', ['vendor_code'], unique=True)
- op.create_table('customers',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('email', sa.String(length=255), nullable=False),
- sa.Column('hashed_password', sa.String(length=255), nullable=False),
- sa.Column('first_name', sa.String(length=100), nullable=True),
- sa.Column('last_name', sa.String(length=100), nullable=True),
- sa.Column('phone', sa.String(length=50), nullable=True),
- sa.Column('customer_number', sa.String(length=100), nullable=False),
- sa.Column('preferences', sa.JSON(), nullable=True),
- sa.Column('marketing_consent', sa.Boolean(), nullable=True),
- sa.Column('last_order_date', sa.DateTime(), nullable=True),
- sa.Column('total_orders', sa.Integer(), nullable=True),
- sa.Column('total_spent', sa.Numeric(precision=10, scale=2), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=False),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_marketplace"),
+ "marketplace_products",
+ ["marketplace"],
+ unique=False,
)
- op.create_index(op.f('ix_customers_customer_number'), 'customers', ['customer_number'], unique=False)
- op.create_index(op.f('ix_customers_email'), 'customers', ['email'], unique=False)
- op.create_index(op.f('ix_customers_id'), 'customers', ['id'], unique=False)
- op.create_table('marketplace_import_jobs',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('marketplace', sa.String(), nullable=False),
- sa.Column('source_url', sa.String(), nullable=False),
- sa.Column('status', sa.String(), nullable=False),
- sa.Column('imported_count', sa.Integer(), nullable=True),
- sa.Column('updated_count', sa.Integer(), nullable=True),
- sa.Column('error_count', sa.Integer(), nullable=True),
- sa.Column('total_processed', sa.Integer(), nullable=True),
- sa.Column('error_message', sa.Text(), nullable=True),
- sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
- sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_marketplace_products_marketplace_product_id"),
+ "marketplace_products",
+ ["marketplace_product_id"],
+ unique=True,
)
- op.create_index('idx_import_user_marketplace', 'marketplace_import_jobs', ['user_id', 'marketplace'], unique=False)
- op.create_index('idx_import_vendor_created', 'marketplace_import_jobs', ['vendor_id', 'created_at'], unique=False)
- op.create_index('idx_import_vendor_status', 'marketplace_import_jobs', ['vendor_id', 'status'], unique=False)
- op.create_index(op.f('ix_marketplace_import_jobs_id'), 'marketplace_import_jobs', ['id'], unique=False)
- op.create_index(op.f('ix_marketplace_import_jobs_marketplace'), 'marketplace_import_jobs', ['marketplace'], unique=False)
- op.create_index(op.f('ix_marketplace_import_jobs_vendor_id'), 'marketplace_import_jobs', ['vendor_id'], unique=False)
- op.create_table('products',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('marketplace_product_id', sa.Integer(), nullable=False),
- sa.Column('product_id', sa.String(), nullable=True),
- sa.Column('price', sa.Float(), nullable=True),
- sa.Column('sale_price', sa.Float(), nullable=True),
- sa.Column('currency', sa.String(), nullable=True),
- sa.Column('availability', sa.String(), nullable=True),
- sa.Column('condition', sa.String(), nullable=True),
- sa.Column('is_featured', sa.Boolean(), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=True),
- sa.Column('display_order', sa.Integer(), nullable=True),
- sa.Column('min_quantity', sa.Integer(), nullable=True),
- sa.Column('max_quantity', sa.Integer(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['marketplace_product_id'], ['marketplace_products.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('vendor_id', 'marketplace_product_id', name='uq_product')
+ op.create_index(
+ op.f("ix_marketplace_products_vendor_name"),
+ "marketplace_products",
+ ["vendor_name"],
+ unique=False,
)
- op.create_index('idx_product_active', 'products', ['vendor_id', 'is_active'], unique=False)
- op.create_index('idx_product_featured', 'products', ['vendor_id', 'is_featured'], unique=False)
- op.create_index(op.f('ix_products_id'), 'products', ['id'], unique=False)
- op.create_table('roles',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=100), nullable=False),
- sa.Column('permissions', sa.JSON(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_table(
+ "users",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("email", sa.String(), nullable=False),
+ sa.Column("username", sa.String(), nullable=False),
+ sa.Column("first_name", sa.String(), nullable=True),
+ sa.Column("last_name", sa.String(), nullable=True),
+ sa.Column("hashed_password", sa.String(), nullable=False),
+ sa.Column("role", sa.String(), nullable=False),
+ sa.Column("is_active", sa.Boolean(), nullable=False),
+ sa.Column("last_login", sa.DateTime(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
)
- op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
- op.create_table('vendor_domains',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('domain', sa.String(length=255), nullable=False),
- sa.Column('is_primary', sa.Boolean(), nullable=False),
- sa.Column('is_active', sa.Boolean(), nullable=False),
- sa.Column('ssl_status', sa.String(length=50), nullable=True),
- sa.Column('ssl_verified_at', sa.DateTime(timezone=True), nullable=True),
- sa.Column('verification_token', sa.String(length=100), nullable=True),
- sa.Column('is_verified', sa.Boolean(), nullable=False),
- sa.Column('verified_at', sa.DateTime(timezone=True), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('vendor_id', 'domain', name='uq_vendor_domain'),
- sa.UniqueConstraint('verification_token')
+ op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
+ op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
+ op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
+ op.create_table(
+ "admin_audit_logs",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("admin_user_id", sa.Integer(), nullable=False),
+ sa.Column("action", sa.String(length=100), nullable=False),
+ sa.Column("target_type", sa.String(length=50), nullable=False),
+ sa.Column("target_id", sa.String(length=100), nullable=False),
+ sa.Column("details", sa.JSON(), nullable=True),
+ sa.Column("ip_address", sa.String(length=45), nullable=True),
+ sa.Column("user_agent", sa.Text(), nullable=True),
+ sa.Column("request_id", sa.String(length=100), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["admin_user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
)
- op.create_index('idx_domain_active', 'vendor_domains', ['domain', 'is_active'], unique=False)
- op.create_index('idx_vendor_primary', 'vendor_domains', ['vendor_id', 'is_primary'], unique=False)
- op.create_index(op.f('ix_vendor_domains_domain'), 'vendor_domains', ['domain'], unique=True)
- op.create_index(op.f('ix_vendor_domains_id'), 'vendor_domains', ['id'], unique=False)
- op.create_table('vendor_themes',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('theme_name', sa.String(length=100), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=True),
- sa.Column('colors', sa.JSON(), nullable=True),
- sa.Column('font_family_heading', sa.String(length=100), nullable=True),
- sa.Column('font_family_body', sa.String(length=100), nullable=True),
- sa.Column('logo_url', sa.String(length=500), nullable=True),
- sa.Column('logo_dark_url', sa.String(length=500), nullable=True),
- sa.Column('favicon_url', sa.String(length=500), nullable=True),
- sa.Column('banner_url', sa.String(length=500), nullable=True),
- sa.Column('layout_style', sa.String(length=50), nullable=True),
- sa.Column('header_style', sa.String(length=50), nullable=True),
- sa.Column('product_card_style', sa.String(length=50), nullable=True),
- sa.Column('custom_css', sa.Text(), nullable=True),
- sa.Column('social_links', sa.JSON(), nullable=True),
- sa.Column('meta_title_template', sa.String(length=200), nullable=True),
- sa.Column('meta_description', sa.Text(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('vendor_id')
+ op.create_index(
+ op.f("ix_admin_audit_logs_action"), "admin_audit_logs", ["action"], unique=False
)
- op.create_index(op.f('ix_vendor_themes_id'), 'vendor_themes', ['id'], unique=False)
- op.create_table('customer_addresses',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('customer_id', sa.Integer(), nullable=False),
- sa.Column('address_type', sa.String(length=50), nullable=False),
- sa.Column('first_name', sa.String(length=100), nullable=False),
- sa.Column('last_name', sa.String(length=100), nullable=False),
- sa.Column('company', sa.String(length=200), nullable=True),
- sa.Column('address_line_1', sa.String(length=255), nullable=False),
- sa.Column('address_line_2', sa.String(length=255), nullable=True),
- sa.Column('city', sa.String(length=100), nullable=False),
- sa.Column('postal_code', sa.String(length=20), nullable=False),
- sa.Column('country', sa.String(length=100), nullable=False),
- sa.Column('is_default', sa.Boolean(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['customer_id'], ['customers.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_admin_audit_logs_admin_user_id"),
+ "admin_audit_logs",
+ ["admin_user_id"],
+ unique=False,
)
- op.create_index(op.f('ix_customer_addresses_id'), 'customer_addresses', ['id'], unique=False)
- op.create_table('inventory',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('product_id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('location', sa.String(), nullable=False),
- sa.Column('quantity', sa.Integer(), nullable=False),
- sa.Column('reserved_quantity', sa.Integer(), nullable=True),
- sa.Column('gtin', sa.String(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['product_id'], ['products.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('product_id', 'location', name='uq_inventory_product_location')
+ op.create_index(
+ op.f("ix_admin_audit_logs_id"), "admin_audit_logs", ["id"], unique=False
)
- op.create_index('idx_inventory_product_location', 'inventory', ['product_id', 'location'], unique=False)
- op.create_index('idx_inventory_vendor_product', 'inventory', ['vendor_id', 'product_id'], unique=False)
- op.create_index(op.f('ix_inventory_gtin'), 'inventory', ['gtin'], unique=False)
- op.create_index(op.f('ix_inventory_id'), 'inventory', ['id'], unique=False)
- op.create_index(op.f('ix_inventory_location'), 'inventory', ['location'], unique=False)
- op.create_index(op.f('ix_inventory_product_id'), 'inventory', ['product_id'], unique=False)
- op.create_index(op.f('ix_inventory_vendor_id'), 'inventory', ['vendor_id'], unique=False)
- op.create_table('vendor_users',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('role_id', sa.Integer(), nullable=False),
- sa.Column('invited_by', sa.Integer(), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=False),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['invited_by'], ['users.id'], ),
- sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
- sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_admin_audit_logs_target_id"),
+ "admin_audit_logs",
+ ["target_id"],
+ unique=False,
)
- op.create_index(op.f('ix_vendor_users_id'), 'vendor_users', ['id'], unique=False)
- op.create_table('orders',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('customer_id', sa.Integer(), nullable=False),
- sa.Column('order_number', sa.String(), nullable=False),
- sa.Column('status', sa.String(), nullable=False),
- sa.Column('subtotal', sa.Float(), nullable=False),
- sa.Column('tax_amount', sa.Float(), nullable=True),
- sa.Column('shipping_amount', sa.Float(), nullable=True),
- sa.Column('discount_amount', sa.Float(), nullable=True),
- sa.Column('total_amount', sa.Float(), nullable=False),
- sa.Column('currency', sa.String(), nullable=True),
- sa.Column('shipping_address_id', sa.Integer(), nullable=False),
- sa.Column('billing_address_id', sa.Integer(), nullable=False),
- sa.Column('shipping_method', sa.String(), nullable=True),
- sa.Column('tracking_number', sa.String(), nullable=True),
- sa.Column('customer_notes', sa.Text(), nullable=True),
- sa.Column('internal_notes', sa.Text(), nullable=True),
- sa.Column('paid_at', sa.DateTime(), nullable=True),
- sa.Column('shipped_at', sa.DateTime(), nullable=True),
- sa.Column('delivered_at', sa.DateTime(), nullable=True),
- sa.Column('cancelled_at', sa.DateTime(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['billing_address_id'], ['customer_addresses.id'], ),
- sa.ForeignKeyConstraint(['customer_id'], ['customers.id'], ),
- sa.ForeignKeyConstraint(['shipping_address_id'], ['customer_addresses.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_index(
+ op.f("ix_admin_audit_logs_target_type"),
+ "admin_audit_logs",
+ ["target_type"],
+ unique=False,
)
- op.create_index(op.f('ix_orders_customer_id'), 'orders', ['customer_id'], unique=False)
- op.create_index(op.f('ix_orders_id'), 'orders', ['id'], unique=False)
- op.create_index(op.f('ix_orders_order_number'), 'orders', ['order_number'], unique=True)
- op.create_index(op.f('ix_orders_status'), 'orders', ['status'], unique=False)
- op.create_index(op.f('ix_orders_vendor_id'), 'orders', ['vendor_id'], unique=False)
- op.create_table('order_items',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('order_id', sa.Integer(), nullable=False),
- sa.Column('product_id', sa.Integer(), nullable=False),
- sa.Column('product_name', sa.String(), nullable=False),
- sa.Column('product_sku', sa.String(), nullable=True),
- sa.Column('quantity', sa.Integer(), nullable=False),
- sa.Column('unit_price', sa.Float(), nullable=False),
- sa.Column('total_price', sa.Float(), nullable=False),
- sa.Column('inventory_reserved', sa.Boolean(), nullable=True),
- sa.Column('inventory_fulfilled', sa.Boolean(), nullable=True),
- sa.Column('created_at', sa.DateTime(), nullable=False),
- sa.Column('updated_at', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['order_id'], ['orders.id'], ),
- sa.ForeignKeyConstraint(['product_id'], ['products.id'], ),
- sa.PrimaryKeyConstraint('id')
+ op.create_table(
+ "admin_notifications",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("type", sa.String(length=50), nullable=False),
+ sa.Column("priority", sa.String(length=20), nullable=True),
+ sa.Column("title", sa.String(length=200), nullable=False),
+ sa.Column("message", sa.Text(), nullable=False),
+ sa.Column("is_read", sa.Boolean(), nullable=True),
+ sa.Column("read_at", sa.DateTime(), nullable=True),
+ sa.Column("read_by_user_id", sa.Integer(), nullable=True),
+ sa.Column("action_required", sa.Boolean(), nullable=True),
+ sa.Column("action_url", sa.String(length=500), nullable=True),
+ sa.Column("notification_metadata", sa.JSON(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["read_by_user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_admin_notifications_action_required"),
+ "admin_notifications",
+ ["action_required"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_admin_notifications_id"), "admin_notifications", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_admin_notifications_is_read"),
+ "admin_notifications",
+ ["is_read"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_admin_notifications_priority"),
+ "admin_notifications",
+ ["priority"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_admin_notifications_type"),
+ "admin_notifications",
+ ["type"],
+ unique=False,
+ )
+ op.create_table(
+ "admin_sessions",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("admin_user_id", sa.Integer(), nullable=False),
+ sa.Column("session_token", sa.String(length=255), nullable=False),
+ sa.Column("ip_address", sa.String(length=45), nullable=False),
+ sa.Column("user_agent", sa.Text(), nullable=True),
+ sa.Column("login_at", sa.DateTime(), nullable=False),
+ sa.Column("last_activity_at", sa.DateTime(), nullable=False),
+ sa.Column("logout_at", sa.DateTime(), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=True),
+ sa.Column("logout_reason", sa.String(length=50), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["admin_user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_admin_sessions_admin_user_id"),
+ "admin_sessions",
+ ["admin_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_admin_sessions_id"), "admin_sessions", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_admin_sessions_is_active"),
+ "admin_sessions",
+ ["is_active"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_admin_sessions_login_at"), "admin_sessions", ["login_at"], unique=False
+ )
+ op.create_index(
+ op.f("ix_admin_sessions_session_token"),
+ "admin_sessions",
+ ["session_token"],
+ unique=True,
+ )
+ op.create_table(
+ "admin_settings",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("key", sa.String(length=100), nullable=False),
+ sa.Column("value", sa.Text(), nullable=False),
+ sa.Column("value_type", sa.String(length=20), nullable=True),
+ sa.Column("category", sa.String(length=50), nullable=True),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.Column("is_encrypted", sa.Boolean(), nullable=True),
+ sa.Column("is_public", sa.Boolean(), nullable=True),
+ sa.Column("last_modified_by_user_id", sa.Integer(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["last_modified_by_user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_admin_settings_category"), "admin_settings", ["category"], unique=False
+ )
+ op.create_index(
+ op.f("ix_admin_settings_id"), "admin_settings", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_admin_settings_key"), "admin_settings", ["key"], unique=True
+ )
+ op.create_table(
+ "platform_alerts",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("alert_type", sa.String(length=50), nullable=False),
+ sa.Column("severity", sa.String(length=20), nullable=False),
+ sa.Column("title", sa.String(length=200), nullable=False),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.Column("affected_vendors", sa.JSON(), nullable=True),
+ sa.Column("affected_systems", sa.JSON(), nullable=True),
+ sa.Column("is_resolved", sa.Boolean(), nullable=True),
+ sa.Column("resolved_at", sa.DateTime(), nullable=True),
+ sa.Column("resolved_by_user_id", sa.Integer(), nullable=True),
+ sa.Column("resolution_notes", sa.Text(), nullable=True),
+ sa.Column("auto_generated", sa.Boolean(), nullable=True),
+ sa.Column("occurrence_count", sa.Integer(), nullable=True),
+ sa.Column("first_occurred_at", sa.DateTime(), nullable=False),
+ sa.Column("last_occurred_at", sa.DateTime(), nullable=False),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["resolved_by_user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_platform_alerts_alert_type"),
+ "platform_alerts",
+ ["alert_type"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_platform_alerts_id"), "platform_alerts", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_platform_alerts_is_resolved"),
+ "platform_alerts",
+ ["is_resolved"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_platform_alerts_severity"),
+ "platform_alerts",
+ ["severity"],
+ unique=False,
+ )
+ op.create_table(
+ "vendors",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_code", sa.String(), nullable=False),
+ sa.Column("subdomain", sa.String(length=100), nullable=False),
+ sa.Column("name", sa.String(), nullable=False),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.Column("owner_user_id", sa.Integer(), nullable=False),
+ sa.Column("contact_email", sa.String(), nullable=True),
+ sa.Column("contact_phone", sa.String(), nullable=True),
+ sa.Column("website", sa.String(), nullable=True),
+ sa.Column("letzshop_csv_url_fr", sa.String(), nullable=True),
+ sa.Column("letzshop_csv_url_en", sa.String(), nullable=True),
+ sa.Column("letzshop_csv_url_de", sa.String(), nullable=True),
+ sa.Column("business_address", sa.Text(), nullable=True),
+ sa.Column("tax_number", sa.String(), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=True),
+ sa.Column("is_verified", sa.Boolean(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["owner_user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_vendors_id"), "vendors", ["id"], unique=False)
+ op.create_index(op.f("ix_vendors_subdomain"), "vendors", ["subdomain"], unique=True)
+ op.create_index(
+ op.f("ix_vendors_vendor_code"), "vendors", ["vendor_code"], unique=True
+ )
+ op.create_table(
+ "customers",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("email", sa.String(length=255), nullable=False),
+ sa.Column("hashed_password", sa.String(length=255), nullable=False),
+ sa.Column("first_name", sa.String(length=100), nullable=True),
+ sa.Column("last_name", sa.String(length=100), nullable=True),
+ sa.Column("phone", sa.String(length=50), nullable=True),
+ sa.Column("customer_number", sa.String(length=100), nullable=False),
+ sa.Column("preferences", sa.JSON(), nullable=True),
+ sa.Column("marketing_consent", sa.Boolean(), nullable=True),
+ sa.Column("last_order_date", sa.DateTime(), nullable=True),
+ sa.Column("total_orders", sa.Integer(), nullable=True),
+ sa.Column("total_spent", sa.Numeric(precision=10, scale=2), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=False),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_customers_customer_number"),
+ "customers",
+ ["customer_number"],
+ unique=False,
+ )
+ op.create_index(op.f("ix_customers_email"), "customers", ["email"], unique=False)
+ op.create_index(op.f("ix_customers_id"), "customers", ["id"], unique=False)
+ op.create_table(
+ "marketplace_import_jobs",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("marketplace", sa.String(), nullable=False),
+ sa.Column("source_url", sa.String(), nullable=False),
+ sa.Column("status", sa.String(), nullable=False),
+ sa.Column("imported_count", sa.Integer(), nullable=True),
+ sa.Column("updated_count", sa.Integer(), nullable=True),
+ sa.Column("error_count", sa.Integer(), nullable=True),
+ sa.Column("total_processed", sa.Integer(), nullable=True),
+ sa.Column("error_message", sa.Text(), nullable=True),
+ sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ "idx_import_user_marketplace",
+ "marketplace_import_jobs",
+ ["user_id", "marketplace"],
+ unique=False,
+ )
+ op.create_index(
+ "idx_import_vendor_created",
+ "marketplace_import_jobs",
+ ["vendor_id", "created_at"],
+ unique=False,
+ )
+ op.create_index(
+ "idx_import_vendor_status",
+ "marketplace_import_jobs",
+ ["vendor_id", "status"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_marketplace_import_jobs_id"),
+ "marketplace_import_jobs",
+ ["id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_marketplace_import_jobs_marketplace"),
+ "marketplace_import_jobs",
+ ["marketplace"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_marketplace_import_jobs_vendor_id"),
+ "marketplace_import_jobs",
+ ["vendor_id"],
+ unique=False,
+ )
+ op.create_table(
+ "products",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("marketplace_product_id", sa.Integer(), nullable=False),
+ sa.Column("product_id", sa.String(), nullable=True),
+ sa.Column("price", sa.Float(), nullable=True),
+ sa.Column("sale_price", sa.Float(), nullable=True),
+ sa.Column("currency", sa.String(), nullable=True),
+ sa.Column("availability", sa.String(), nullable=True),
+ sa.Column("condition", sa.String(), nullable=True),
+ sa.Column("is_featured", sa.Boolean(), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=True),
+ sa.Column("display_order", sa.Integer(), nullable=True),
+ sa.Column("min_quantity", sa.Integer(), nullable=True),
+ sa.Column("max_quantity", sa.Integer(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["marketplace_product_id"],
+ ["marketplace_products.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("vendor_id", "marketplace_product_id", name="uq_product"),
+ )
+ op.create_index(
+ "idx_product_active", "products", ["vendor_id", "is_active"], unique=False
+ )
+ op.create_index(
+ "idx_product_featured", "products", ["vendor_id", "is_featured"], unique=False
+ )
+ op.create_index(op.f("ix_products_id"), "products", ["id"], unique=False)
+ op.create_table(
+ "roles",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=100), nullable=False),
+ sa.Column("permissions", sa.JSON(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_roles_id"), "roles", ["id"], unique=False)
+ op.create_table(
+ "vendor_domains",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("domain", sa.String(length=255), nullable=False),
+ sa.Column("is_primary", sa.Boolean(), nullable=False),
+ sa.Column("is_active", sa.Boolean(), nullable=False),
+ sa.Column("ssl_status", sa.String(length=50), nullable=True),
+ sa.Column("ssl_verified_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("verification_token", sa.String(length=100), nullable=True),
+ sa.Column("is_verified", sa.Boolean(), nullable=False),
+ sa.Column("verified_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("vendor_id", "domain", name="uq_vendor_domain"),
+ sa.UniqueConstraint("verification_token"),
+ )
+ op.create_index(
+ "idx_domain_active", "vendor_domains", ["domain", "is_active"], unique=False
+ )
+ op.create_index(
+ "idx_vendor_primary",
+ "vendor_domains",
+ ["vendor_id", "is_primary"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_vendor_domains_domain"), "vendor_domains", ["domain"], unique=True
+ )
+ op.create_index(
+ op.f("ix_vendor_domains_id"), "vendor_domains", ["id"], unique=False
+ )
+ op.create_table(
+ "vendor_themes",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("theme_name", sa.String(length=100), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=True),
+ sa.Column("colors", sa.JSON(), nullable=True),
+ sa.Column("font_family_heading", sa.String(length=100), nullable=True),
+ sa.Column("font_family_body", sa.String(length=100), nullable=True),
+ sa.Column("logo_url", sa.String(length=500), nullable=True),
+ sa.Column("logo_dark_url", sa.String(length=500), nullable=True),
+ sa.Column("favicon_url", sa.String(length=500), nullable=True),
+ sa.Column("banner_url", sa.String(length=500), nullable=True),
+ sa.Column("layout_style", sa.String(length=50), nullable=True),
+ sa.Column("header_style", sa.String(length=50), nullable=True),
+ sa.Column("product_card_style", sa.String(length=50), nullable=True),
+ sa.Column("custom_css", sa.Text(), nullable=True),
+ sa.Column("social_links", sa.JSON(), nullable=True),
+ sa.Column("meta_title_template", sa.String(length=200), nullable=True),
+ sa.Column("meta_description", sa.Text(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("vendor_id"),
+ )
+ op.create_index(op.f("ix_vendor_themes_id"), "vendor_themes", ["id"], unique=False)
+ op.create_table(
+ "customer_addresses",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("customer_id", sa.Integer(), nullable=False),
+ sa.Column("address_type", sa.String(length=50), nullable=False),
+ sa.Column("first_name", sa.String(length=100), nullable=False),
+ sa.Column("last_name", sa.String(length=100), nullable=False),
+ sa.Column("company", sa.String(length=200), nullable=True),
+ sa.Column("address_line_1", sa.String(length=255), nullable=False),
+ sa.Column("address_line_2", sa.String(length=255), nullable=True),
+ sa.Column("city", sa.String(length=100), nullable=False),
+ sa.Column("postal_code", sa.String(length=20), nullable=False),
+ sa.Column("country", sa.String(length=100), nullable=False),
+ sa.Column("is_default", sa.Boolean(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["customer_id"],
+ ["customers.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_customer_addresses_id"), "customer_addresses", ["id"], unique=False
+ )
+ op.create_table(
+ "inventory",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("product_id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("location", sa.String(), nullable=False),
+ sa.Column("quantity", sa.Integer(), nullable=False),
+ sa.Column("reserved_quantity", sa.Integer(), nullable=True),
+ sa.Column("gtin", sa.String(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["product_id"],
+ ["products.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint(
+ "product_id", "location", name="uq_inventory_product_location"
+ ),
+ )
+ op.create_index(
+ "idx_inventory_product_location",
+ "inventory",
+ ["product_id", "location"],
+ unique=False,
+ )
+ op.create_index(
+ "idx_inventory_vendor_product",
+ "inventory",
+ ["vendor_id", "product_id"],
+ unique=False,
+ )
+ op.create_index(op.f("ix_inventory_gtin"), "inventory", ["gtin"], unique=False)
+ op.create_index(op.f("ix_inventory_id"), "inventory", ["id"], unique=False)
+ op.create_index(
+ op.f("ix_inventory_location"), "inventory", ["location"], unique=False
+ )
+ op.create_index(
+ op.f("ix_inventory_product_id"), "inventory", ["product_id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_inventory_vendor_id"), "inventory", ["vendor_id"], unique=False
+ )
+ op.create_table(
+ "vendor_users",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("role_id", sa.Integer(), nullable=False),
+ sa.Column("invited_by", sa.Integer(), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=False),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["invited_by"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["role_id"],
+ ["roles.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_vendor_users_id"), "vendor_users", ["id"], unique=False)
+ op.create_table(
+ "orders",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("customer_id", sa.Integer(), nullable=False),
+ sa.Column("order_number", sa.String(), nullable=False),
+ sa.Column("status", sa.String(), nullable=False),
+ sa.Column("subtotal", sa.Float(), nullable=False),
+ sa.Column("tax_amount", sa.Float(), nullable=True),
+ sa.Column("shipping_amount", sa.Float(), nullable=True),
+ sa.Column("discount_amount", sa.Float(), nullable=True),
+ sa.Column("total_amount", sa.Float(), nullable=False),
+ sa.Column("currency", sa.String(), nullable=True),
+ sa.Column("shipping_address_id", sa.Integer(), nullable=False),
+ sa.Column("billing_address_id", sa.Integer(), nullable=False),
+ sa.Column("shipping_method", sa.String(), nullable=True),
+ sa.Column("tracking_number", sa.String(), nullable=True),
+ sa.Column("customer_notes", sa.Text(), nullable=True),
+ sa.Column("internal_notes", sa.Text(), nullable=True),
+ sa.Column("paid_at", sa.DateTime(), nullable=True),
+ sa.Column("shipped_at", sa.DateTime(), nullable=True),
+ sa.Column("delivered_at", sa.DateTime(), nullable=True),
+ sa.Column("cancelled_at", sa.DateTime(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["billing_address_id"],
+ ["customer_addresses.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["customer_id"],
+ ["customers.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["shipping_address_id"],
+ ["customer_addresses.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_orders_customer_id"), "orders", ["customer_id"], unique=False
+ )
+ op.create_index(op.f("ix_orders_id"), "orders", ["id"], unique=False)
+ op.create_index(
+ op.f("ix_orders_order_number"), "orders", ["order_number"], unique=True
+ )
+ op.create_index(op.f("ix_orders_status"), "orders", ["status"], unique=False)
+ op.create_index(op.f("ix_orders_vendor_id"), "orders", ["vendor_id"], unique=False)
+ op.create_table(
+ "order_items",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("order_id", sa.Integer(), nullable=False),
+ sa.Column("product_id", sa.Integer(), nullable=False),
+ sa.Column("product_name", sa.String(), nullable=False),
+ sa.Column("product_sku", sa.String(), nullable=True),
+ sa.Column("quantity", sa.Integer(), nullable=False),
+ sa.Column("unit_price", sa.Float(), nullable=False),
+ sa.Column("total_price", sa.Float(), nullable=False),
+ sa.Column("inventory_reserved", sa.Boolean(), nullable=True),
+ sa.Column("inventory_fulfilled", sa.Boolean(), nullable=True),
+ sa.Column("created_at", sa.DateTime(), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["order_id"],
+ ["orders.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["product_id"],
+ ["products.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_order_items_id"), "order_items", ["id"], unique=False)
+ op.create_index(
+ op.f("ix_order_items_order_id"), "order_items", ["order_id"], unique=False
)
- op.create_index(op.f('ix_order_items_id'), 'order_items', ['id'], unique=False)
- op.create_index(op.f('ix_order_items_order_id'), 'order_items', ['order_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index(op.f('ix_order_items_order_id'), table_name='order_items')
- op.drop_index(op.f('ix_order_items_id'), table_name='order_items')
- op.drop_table('order_items')
- op.drop_index(op.f('ix_orders_vendor_id'), table_name='orders')
- op.drop_index(op.f('ix_orders_status'), table_name='orders')
- op.drop_index(op.f('ix_orders_order_number'), table_name='orders')
- op.drop_index(op.f('ix_orders_id'), table_name='orders')
- op.drop_index(op.f('ix_orders_customer_id'), table_name='orders')
- op.drop_table('orders')
- op.drop_index(op.f('ix_vendor_users_id'), table_name='vendor_users')
- op.drop_table('vendor_users')
- op.drop_index(op.f('ix_inventory_vendor_id'), table_name='inventory')
- op.drop_index(op.f('ix_inventory_product_id'), table_name='inventory')
- op.drop_index(op.f('ix_inventory_location'), table_name='inventory')
- op.drop_index(op.f('ix_inventory_id'), table_name='inventory')
- op.drop_index(op.f('ix_inventory_gtin'), table_name='inventory')
- op.drop_index('idx_inventory_vendor_product', table_name='inventory')
- op.drop_index('idx_inventory_product_location', table_name='inventory')
- op.drop_table('inventory')
- op.drop_index(op.f('ix_customer_addresses_id'), table_name='customer_addresses')
- op.drop_table('customer_addresses')
- op.drop_index(op.f('ix_vendor_themes_id'), table_name='vendor_themes')
- op.drop_table('vendor_themes')
- op.drop_index(op.f('ix_vendor_domains_id'), table_name='vendor_domains')
- op.drop_index(op.f('ix_vendor_domains_domain'), table_name='vendor_domains')
- op.drop_index('idx_vendor_primary', table_name='vendor_domains')
- op.drop_index('idx_domain_active', table_name='vendor_domains')
- op.drop_table('vendor_domains')
- op.drop_index(op.f('ix_roles_id'), table_name='roles')
- op.drop_table('roles')
- op.drop_index(op.f('ix_products_id'), table_name='products')
- op.drop_index('idx_product_featured', table_name='products')
- op.drop_index('idx_product_active', table_name='products')
- op.drop_table('products')
- op.drop_index(op.f('ix_marketplace_import_jobs_vendor_id'), table_name='marketplace_import_jobs')
- op.drop_index(op.f('ix_marketplace_import_jobs_marketplace'), table_name='marketplace_import_jobs')
- op.drop_index(op.f('ix_marketplace_import_jobs_id'), table_name='marketplace_import_jobs')
- op.drop_index('idx_import_vendor_status', table_name='marketplace_import_jobs')
- op.drop_index('idx_import_vendor_created', table_name='marketplace_import_jobs')
- op.drop_index('idx_import_user_marketplace', table_name='marketplace_import_jobs')
- op.drop_table('marketplace_import_jobs')
- op.drop_index(op.f('ix_customers_id'), table_name='customers')
- op.drop_index(op.f('ix_customers_email'), table_name='customers')
- op.drop_index(op.f('ix_customers_customer_number'), table_name='customers')
- op.drop_table('customers')
- op.drop_index(op.f('ix_vendors_vendor_code'), table_name='vendors')
- op.drop_index(op.f('ix_vendors_subdomain'), table_name='vendors')
- op.drop_index(op.f('ix_vendors_id'), table_name='vendors')
- op.drop_table('vendors')
- op.drop_index(op.f('ix_platform_alerts_severity'), table_name='platform_alerts')
- op.drop_index(op.f('ix_platform_alerts_is_resolved'), table_name='platform_alerts')
- op.drop_index(op.f('ix_platform_alerts_id'), table_name='platform_alerts')
- op.drop_index(op.f('ix_platform_alerts_alert_type'), table_name='platform_alerts')
- op.drop_table('platform_alerts')
- op.drop_index(op.f('ix_admin_settings_key'), table_name='admin_settings')
- op.drop_index(op.f('ix_admin_settings_id'), table_name='admin_settings')
- op.drop_index(op.f('ix_admin_settings_category'), table_name='admin_settings')
- op.drop_table('admin_settings')
- op.drop_index(op.f('ix_admin_sessions_session_token'), table_name='admin_sessions')
- op.drop_index(op.f('ix_admin_sessions_login_at'), table_name='admin_sessions')
- op.drop_index(op.f('ix_admin_sessions_is_active'), table_name='admin_sessions')
- op.drop_index(op.f('ix_admin_sessions_id'), table_name='admin_sessions')
- op.drop_index(op.f('ix_admin_sessions_admin_user_id'), table_name='admin_sessions')
- op.drop_table('admin_sessions')
- op.drop_index(op.f('ix_admin_notifications_type'), table_name='admin_notifications')
- op.drop_index(op.f('ix_admin_notifications_priority'), table_name='admin_notifications')
- op.drop_index(op.f('ix_admin_notifications_is_read'), table_name='admin_notifications')
- op.drop_index(op.f('ix_admin_notifications_id'), table_name='admin_notifications')
- op.drop_index(op.f('ix_admin_notifications_action_required'), table_name='admin_notifications')
- op.drop_table('admin_notifications')
- op.drop_index(op.f('ix_admin_audit_logs_target_type'), table_name='admin_audit_logs')
- op.drop_index(op.f('ix_admin_audit_logs_target_id'), table_name='admin_audit_logs')
- op.drop_index(op.f('ix_admin_audit_logs_id'), table_name='admin_audit_logs')
- op.drop_index(op.f('ix_admin_audit_logs_admin_user_id'), table_name='admin_audit_logs')
- op.drop_index(op.f('ix_admin_audit_logs_action'), table_name='admin_audit_logs')
- op.drop_table('admin_audit_logs')
- op.drop_index(op.f('ix_users_username'), table_name='users')
- op.drop_index(op.f('ix_users_id'), table_name='users')
- op.drop_index(op.f('ix_users_email'), table_name='users')
- op.drop_table('users')
- op.drop_index(op.f('ix_marketplace_products_vendor_name'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_marketplace_product_id'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_marketplace'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_id'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_gtin'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_google_product_category'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_brand'), table_name='marketplace_products')
- op.drop_index(op.f('ix_marketplace_products_availability'), table_name='marketplace_products')
- op.drop_index('idx_marketplace_vendor', table_name='marketplace_products')
- op.drop_index('idx_marketplace_brand', table_name='marketplace_products')
- op.drop_table('marketplace_products')
+ op.drop_index(op.f("ix_order_items_order_id"), table_name="order_items")
+ op.drop_index(op.f("ix_order_items_id"), table_name="order_items")
+ op.drop_table("order_items")
+ op.drop_index(op.f("ix_orders_vendor_id"), table_name="orders")
+ op.drop_index(op.f("ix_orders_status"), table_name="orders")
+ op.drop_index(op.f("ix_orders_order_number"), table_name="orders")
+ op.drop_index(op.f("ix_orders_id"), table_name="orders")
+ op.drop_index(op.f("ix_orders_customer_id"), table_name="orders")
+ op.drop_table("orders")
+ op.drop_index(op.f("ix_vendor_users_id"), table_name="vendor_users")
+ op.drop_table("vendor_users")
+ op.drop_index(op.f("ix_inventory_vendor_id"), table_name="inventory")
+ op.drop_index(op.f("ix_inventory_product_id"), table_name="inventory")
+ op.drop_index(op.f("ix_inventory_location"), table_name="inventory")
+ op.drop_index(op.f("ix_inventory_id"), table_name="inventory")
+ op.drop_index(op.f("ix_inventory_gtin"), table_name="inventory")
+ op.drop_index("idx_inventory_vendor_product", table_name="inventory")
+ op.drop_index("idx_inventory_product_location", table_name="inventory")
+ op.drop_table("inventory")
+ op.drop_index(op.f("ix_customer_addresses_id"), table_name="customer_addresses")
+ op.drop_table("customer_addresses")
+ op.drop_index(op.f("ix_vendor_themes_id"), table_name="vendor_themes")
+ op.drop_table("vendor_themes")
+ op.drop_index(op.f("ix_vendor_domains_id"), table_name="vendor_domains")
+ op.drop_index(op.f("ix_vendor_domains_domain"), table_name="vendor_domains")
+ op.drop_index("idx_vendor_primary", table_name="vendor_domains")
+ op.drop_index("idx_domain_active", table_name="vendor_domains")
+ op.drop_table("vendor_domains")
+ op.drop_index(op.f("ix_roles_id"), table_name="roles")
+ op.drop_table("roles")
+ op.drop_index(op.f("ix_products_id"), table_name="products")
+ op.drop_index("idx_product_featured", table_name="products")
+ op.drop_index("idx_product_active", table_name="products")
+ op.drop_table("products")
+ op.drop_index(
+ op.f("ix_marketplace_import_jobs_vendor_id"),
+ table_name="marketplace_import_jobs",
+ )
+ op.drop_index(
+ op.f("ix_marketplace_import_jobs_marketplace"),
+ table_name="marketplace_import_jobs",
+ )
+ op.drop_index(
+ op.f("ix_marketplace_import_jobs_id"), table_name="marketplace_import_jobs"
+ )
+ op.drop_index("idx_import_vendor_status", table_name="marketplace_import_jobs")
+ op.drop_index("idx_import_vendor_created", table_name="marketplace_import_jobs")
+ op.drop_index("idx_import_user_marketplace", table_name="marketplace_import_jobs")
+ op.drop_table("marketplace_import_jobs")
+ op.drop_index(op.f("ix_customers_id"), table_name="customers")
+ op.drop_index(op.f("ix_customers_email"), table_name="customers")
+ op.drop_index(op.f("ix_customers_customer_number"), table_name="customers")
+ op.drop_table("customers")
+ op.drop_index(op.f("ix_vendors_vendor_code"), table_name="vendors")
+ op.drop_index(op.f("ix_vendors_subdomain"), table_name="vendors")
+ op.drop_index(op.f("ix_vendors_id"), table_name="vendors")
+ op.drop_table("vendors")
+ op.drop_index(op.f("ix_platform_alerts_severity"), table_name="platform_alerts")
+ op.drop_index(op.f("ix_platform_alerts_is_resolved"), table_name="platform_alerts")
+ op.drop_index(op.f("ix_platform_alerts_id"), table_name="platform_alerts")
+ op.drop_index(op.f("ix_platform_alerts_alert_type"), table_name="platform_alerts")
+ op.drop_table("platform_alerts")
+ op.drop_index(op.f("ix_admin_settings_key"), table_name="admin_settings")
+ op.drop_index(op.f("ix_admin_settings_id"), table_name="admin_settings")
+ op.drop_index(op.f("ix_admin_settings_category"), table_name="admin_settings")
+ op.drop_table("admin_settings")
+ op.drop_index(op.f("ix_admin_sessions_session_token"), table_name="admin_sessions")
+ op.drop_index(op.f("ix_admin_sessions_login_at"), table_name="admin_sessions")
+ op.drop_index(op.f("ix_admin_sessions_is_active"), table_name="admin_sessions")
+ op.drop_index(op.f("ix_admin_sessions_id"), table_name="admin_sessions")
+ op.drop_index(op.f("ix_admin_sessions_admin_user_id"), table_name="admin_sessions")
+ op.drop_table("admin_sessions")
+ op.drop_index(op.f("ix_admin_notifications_type"), table_name="admin_notifications")
+ op.drop_index(
+ op.f("ix_admin_notifications_priority"), table_name="admin_notifications"
+ )
+ op.drop_index(
+ op.f("ix_admin_notifications_is_read"), table_name="admin_notifications"
+ )
+ op.drop_index(op.f("ix_admin_notifications_id"), table_name="admin_notifications")
+ op.drop_index(
+ op.f("ix_admin_notifications_action_required"), table_name="admin_notifications"
+ )
+ op.drop_table("admin_notifications")
+ op.drop_index(
+ op.f("ix_admin_audit_logs_target_type"), table_name="admin_audit_logs"
+ )
+ op.drop_index(op.f("ix_admin_audit_logs_target_id"), table_name="admin_audit_logs")
+ op.drop_index(op.f("ix_admin_audit_logs_id"), table_name="admin_audit_logs")
+ op.drop_index(
+ op.f("ix_admin_audit_logs_admin_user_id"), table_name="admin_audit_logs"
+ )
+ op.drop_index(op.f("ix_admin_audit_logs_action"), table_name="admin_audit_logs")
+ op.drop_table("admin_audit_logs")
+ op.drop_index(op.f("ix_users_username"), table_name="users")
+ op.drop_index(op.f("ix_users_id"), table_name="users")
+ op.drop_index(op.f("ix_users_email"), table_name="users")
+ op.drop_table("users")
+ op.drop_index(
+ op.f("ix_marketplace_products_vendor_name"), table_name="marketplace_products"
+ )
+ op.drop_index(
+ op.f("ix_marketplace_products_marketplace_product_id"),
+ table_name="marketplace_products",
+ )
+ op.drop_index(
+ op.f("ix_marketplace_products_marketplace"), table_name="marketplace_products"
+ )
+ op.drop_index(op.f("ix_marketplace_products_id"), table_name="marketplace_products")
+ op.drop_index(
+ op.f("ix_marketplace_products_gtin"), table_name="marketplace_products"
+ )
+ op.drop_index(
+ op.f("ix_marketplace_products_google_product_category"),
+ table_name="marketplace_products",
+ )
+ op.drop_index(
+ op.f("ix_marketplace_products_brand"), table_name="marketplace_products"
+ )
+ op.drop_index(
+ op.f("ix_marketplace_products_availability"), table_name="marketplace_products"
+ )
+ op.drop_index("idx_marketplace_vendor", table_name="marketplace_products")
+ op.drop_index("idx_marketplace_brand", table_name="marketplace_products")
+ op.drop_table("marketplace_products")
# ### end Alembic commands ###
diff --git a/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py b/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py
index 23e274db..24a6d43d 100644
--- a/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py
+++ b/alembic/versions/72aa309d4007_ensure_content_pages_table_with_all_.py
@@ -5,59 +5,72 @@ Revises: fef1d20ce8b4
Create Date: 2025-11-22 15:16:13.213613
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = '72aa309d4007'
-down_revision: Union[str, None] = 'fef1d20ce8b4'
+revision: str = "72aa309d4007"
+down_revision: Union[str, None] = "fef1d20ce8b4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('content_pages',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=True),
- sa.Column('slug', sa.String(length=100), nullable=False),
- sa.Column('title', sa.String(length=200), nullable=False),
- sa.Column('content', sa.Text(), nullable=False),
- sa.Column('content_format', sa.String(length=20), nullable=True),
- sa.Column('meta_description', sa.String(length=300), nullable=True),
- sa.Column('meta_keywords', sa.String(length=300), nullable=True),
- sa.Column('is_published', sa.Boolean(), nullable=False),
- sa.Column('published_at', sa.DateTime(timezone=True), nullable=True),
- sa.Column('display_order', sa.Integer(), nullable=True),
- sa.Column('show_in_footer', sa.Boolean(), nullable=True),
- sa.Column('show_in_header', sa.Boolean(), nullable=True),
- sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
- sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
- sa.Column('created_by', sa.Integer(), nullable=True),
- sa.Column('updated_by', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='SET NULL'),
- sa.ForeignKeyConstraint(['updated_by'], ['users.id'], ondelete='SET NULL'),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('vendor_id', 'slug', name='uq_vendor_slug')
+ op.create_table(
+ "content_pages",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=True),
+ sa.Column("slug", sa.String(length=100), nullable=False),
+ sa.Column("title", sa.String(length=200), nullable=False),
+ sa.Column("content", sa.Text(), nullable=False),
+ sa.Column("content_format", sa.String(length=20), nullable=True),
+ sa.Column("meta_description", sa.String(length=300), nullable=True),
+ sa.Column("meta_keywords", sa.String(length=300), nullable=True),
+ sa.Column("is_published", sa.Boolean(), nullable=False),
+ sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("display_order", sa.Integer(), nullable=True),
+ sa.Column("show_in_footer", sa.Boolean(), nullable=True),
+ sa.Column("show_in_header", sa.Boolean(), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("created_by", sa.Integer(), nullable=True),
+ sa.Column("updated_by", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(["created_by"], ["users.id"], ondelete="SET NULL"),
+ sa.ForeignKeyConstraint(["updated_by"], ["users.id"], ondelete="SET NULL"),
+ sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("vendor_id", "slug", name="uq_vendor_slug"),
+ )
+ op.create_index(
+ "idx_slug_published", "content_pages", ["slug", "is_published"], unique=False
+ )
+ op.create_index(
+ "idx_vendor_published",
+ "content_pages",
+ ["vendor_id", "is_published"],
+ unique=False,
+ )
+ op.create_index(op.f("ix_content_pages_id"), "content_pages", ["id"], unique=False)
+ op.create_index(
+ op.f("ix_content_pages_slug"), "content_pages", ["slug"], unique=False
+ )
+ op.create_index(
+ op.f("ix_content_pages_vendor_id"), "content_pages", ["vendor_id"], unique=False
)
- op.create_index('idx_slug_published', 'content_pages', ['slug', 'is_published'], unique=False)
- op.create_index('idx_vendor_published', 'content_pages', ['vendor_id', 'is_published'], unique=False)
- op.create_index(op.f('ix_content_pages_id'), 'content_pages', ['id'], unique=False)
- op.create_index(op.f('ix_content_pages_slug'), 'content_pages', ['slug'], unique=False)
- op.create_index(op.f('ix_content_pages_vendor_id'), 'content_pages', ['vendor_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index(op.f('ix_content_pages_vendor_id'), table_name='content_pages')
- op.drop_index(op.f('ix_content_pages_slug'), table_name='content_pages')
- op.drop_index(op.f('ix_content_pages_id'), table_name='content_pages')
- op.drop_index('idx_vendor_published', table_name='content_pages')
- op.drop_index('idx_slug_published', table_name='content_pages')
- op.drop_table('content_pages')
+ op.drop_index(op.f("ix_content_pages_vendor_id"), table_name="content_pages")
+ op.drop_index(op.f("ix_content_pages_slug"), table_name="content_pages")
+ op.drop_index(op.f("ix_content_pages_id"), table_name="content_pages")
+ op.drop_index("idx_vendor_published", table_name="content_pages")
+ op.drop_index("idx_slug_published", table_name="content_pages")
+ op.drop_table("content_pages")
# ### end Alembic commands ###
diff --git a/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py b/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py
index 6fc55203..d8c85c8c 100644
--- a/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py
+++ b/alembic/versions/7a7ce92593d5_add_architecture_quality_tracking_tables.py
@@ -5,15 +5,17 @@ Revises: a2064e1dfcd4
Create Date: 2025-11-28 09:21:16.545203
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
+from alembic import op
+
# revision identifiers, used by Alembic.
-revision: str = '7a7ce92593d5'
-down_revision: Union[str, None] = 'a2064e1dfcd4'
+revision: str = "7a7ce92593d5"
+down_revision: Union[str, None] = "a2064e1dfcd4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -21,127 +23,269 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create architecture_scans table
op.create_table(
- 'architecture_scans',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
- sa.Column('total_files', sa.Integer(), nullable=True),
- sa.Column('total_violations', sa.Integer(), nullable=True),
- sa.Column('errors', sa.Integer(), nullable=True),
- sa.Column('warnings', sa.Integer(), nullable=True),
- sa.Column('duration_seconds', sa.Float(), nullable=True),
- sa.Column('triggered_by', sa.String(length=100), nullable=True),
- sa.Column('git_commit_hash', sa.String(length=40), nullable=True),
- sa.PrimaryKeyConstraint('id')
+ "architecture_scans",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column(
+ "timestamp",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("(datetime('now'))"),
+ nullable=False,
+ ),
+ sa.Column("total_files", sa.Integer(), nullable=True),
+ sa.Column("total_violations", sa.Integer(), nullable=True),
+ sa.Column("errors", sa.Integer(), nullable=True),
+ sa.Column("warnings", sa.Integer(), nullable=True),
+ sa.Column("duration_seconds", sa.Float(), nullable=True),
+ sa.Column("triggered_by", sa.String(length=100), nullable=True),
+ sa.Column("git_commit_hash", sa.String(length=40), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_architecture_scans_id"), "architecture_scans", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_architecture_scans_timestamp"),
+ "architecture_scans",
+ ["timestamp"],
+ unique=False,
)
- op.create_index(op.f('ix_architecture_scans_id'), 'architecture_scans', ['id'], unique=False)
- op.create_index(op.f('ix_architecture_scans_timestamp'), 'architecture_scans', ['timestamp'], unique=False)
# Create architecture_rules table
op.create_table(
- 'architecture_rules',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('rule_id', sa.String(length=20), nullable=False),
- sa.Column('category', sa.String(length=50), nullable=False),
- sa.Column('name', sa.String(length=200), nullable=False),
- sa.Column('description', sa.Text(), nullable=True),
- sa.Column('severity', sa.String(length=10), nullable=False),
- sa.Column('enabled', sa.Boolean(), nullable=False, server_default='1'),
- sa.Column('custom_config', sa.JSON(), nullable=True),
- sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
- sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('rule_id')
+ "architecture_rules",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("rule_id", sa.String(length=20), nullable=False),
+ sa.Column("category", sa.String(length=50), nullable=False),
+ sa.Column("name", sa.String(length=200), nullable=False),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.Column("severity", sa.String(length=10), nullable=False),
+ sa.Column("enabled", sa.Boolean(), nullable=False, server_default="1"),
+ sa.Column("custom_config", sa.JSON(), nullable=True),
+ sa.Column(
+ "created_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("(datetime('now'))"),
+ nullable=False,
+ ),
+ sa.Column(
+ "updated_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("(datetime('now'))"),
+ nullable=False,
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("rule_id"),
+ )
+ op.create_index(
+ op.f("ix_architecture_rules_id"), "architecture_rules", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_architecture_rules_rule_id"),
+ "architecture_rules",
+ ["rule_id"],
+ unique=True,
)
- op.create_index(op.f('ix_architecture_rules_id'), 'architecture_rules', ['id'], unique=False)
- op.create_index(op.f('ix_architecture_rules_rule_id'), 'architecture_rules', ['rule_id'], unique=True)
# Create architecture_violations table
op.create_table(
- 'architecture_violations',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('scan_id', sa.Integer(), nullable=False),
- sa.Column('rule_id', sa.String(length=20), nullable=False),
- sa.Column('rule_name', sa.String(length=200), nullable=False),
- sa.Column('severity', sa.String(length=10), nullable=False),
- sa.Column('file_path', sa.String(length=500), nullable=False),
- sa.Column('line_number', sa.Integer(), nullable=False),
- sa.Column('message', sa.Text(), nullable=False),
- sa.Column('context', sa.Text(), nullable=True),
- sa.Column('suggestion', sa.Text(), nullable=True),
- sa.Column('status', sa.String(length=20), server_default='open', nullable=True),
- sa.Column('assigned_to', sa.Integer(), nullable=True),
- sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
- sa.Column('resolved_by', sa.Integer(), nullable=True),
- sa.Column('resolution_note', sa.Text(), nullable=True),
- sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
- sa.ForeignKeyConstraint(['assigned_to'], ['users.id'], ),
- sa.ForeignKeyConstraint(['resolved_by'], ['users.id'], ),
- sa.ForeignKeyConstraint(['scan_id'], ['architecture_scans.id'], ),
- sa.PrimaryKeyConstraint('id')
+ "architecture_violations",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("scan_id", sa.Integer(), nullable=False),
+ sa.Column("rule_id", sa.String(length=20), nullable=False),
+ sa.Column("rule_name", sa.String(length=200), nullable=False),
+ sa.Column("severity", sa.String(length=10), nullable=False),
+ sa.Column("file_path", sa.String(length=500), nullable=False),
+ sa.Column("line_number", sa.Integer(), nullable=False),
+ sa.Column("message", sa.Text(), nullable=False),
+ sa.Column("context", sa.Text(), nullable=True),
+ sa.Column("suggestion", sa.Text(), nullable=True),
+ sa.Column("status", sa.String(length=20), server_default="open", nullable=True),
+ sa.Column("assigned_to", sa.Integer(), nullable=True),
+ sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("resolved_by", sa.Integer(), nullable=True),
+ sa.Column("resolution_note", sa.Text(), nullable=True),
+ sa.Column(
+ "created_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("(datetime('now'))"),
+ nullable=False,
+ ),
+ sa.ForeignKeyConstraint(
+ ["assigned_to"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["resolved_by"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["scan_id"],
+ ["architecture_scans.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_architecture_violations_file_path"),
+ "architecture_violations",
+ ["file_path"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_architecture_violations_id"),
+ "architecture_violations",
+ ["id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_architecture_violations_rule_id"),
+ "architecture_violations",
+ ["rule_id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_architecture_violations_scan_id"),
+ "architecture_violations",
+ ["scan_id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_architecture_violations_severity"),
+ "architecture_violations",
+ ["severity"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_architecture_violations_status"),
+ "architecture_violations",
+ ["status"],
+ unique=False,
)
- op.create_index(op.f('ix_architecture_violations_file_path'), 'architecture_violations', ['file_path'], unique=False)
- op.create_index(op.f('ix_architecture_violations_id'), 'architecture_violations', ['id'], unique=False)
- op.create_index(op.f('ix_architecture_violations_rule_id'), 'architecture_violations', ['rule_id'], unique=False)
- op.create_index(op.f('ix_architecture_violations_scan_id'), 'architecture_violations', ['scan_id'], unique=False)
- op.create_index(op.f('ix_architecture_violations_severity'), 'architecture_violations', ['severity'], unique=False)
- op.create_index(op.f('ix_architecture_violations_status'), 'architecture_violations', ['status'], unique=False)
# Create violation_assignments table
op.create_table(
- 'violation_assignments',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('violation_id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('assigned_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
- sa.Column('assigned_by', sa.Integer(), nullable=True),
- sa.Column('due_date', sa.DateTime(timezone=True), nullable=True),
- sa.Column('priority', sa.String(length=10), server_default='medium', nullable=True),
- sa.ForeignKeyConstraint(['assigned_by'], ['users.id'], ),
- sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
- sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ),
- sa.PrimaryKeyConstraint('id')
+ "violation_assignments",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("violation_id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "assigned_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("(datetime('now'))"),
+ nullable=False,
+ ),
+ sa.Column("assigned_by", sa.Integer(), nullable=True),
+ sa.Column("due_date", sa.DateTime(timezone=True), nullable=True),
+ sa.Column(
+ "priority", sa.String(length=10), server_default="medium", nullable=True
+ ),
+ sa.ForeignKeyConstraint(
+ ["assigned_by"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["violation_id"],
+ ["architecture_violations.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_violation_assignments_id"),
+ "violation_assignments",
+ ["id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_violation_assignments_violation_id"),
+ "violation_assignments",
+ ["violation_id"],
+ unique=False,
)
- op.create_index(op.f('ix_violation_assignments_id'), 'violation_assignments', ['id'], unique=False)
- op.create_index(op.f('ix_violation_assignments_violation_id'), 'violation_assignments', ['violation_id'], unique=False)
# Create violation_comments table
op.create_table(
- 'violation_comments',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('violation_id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('comment', sa.Text(), nullable=False),
- sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
- sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
- sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ),
- sa.PrimaryKeyConstraint('id')
+ "violation_comments",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("violation_id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("comment", sa.Text(), nullable=False),
+ sa.Column(
+ "created_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("(datetime('now'))"),
+ nullable=False,
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["violation_id"],
+ ["architecture_violations.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_violation_comments_id"), "violation_comments", ["id"], unique=False
+ )
+ op.create_index(
+ op.f("ix_violation_comments_violation_id"),
+ "violation_comments",
+ ["violation_id"],
+ unique=False,
)
- op.create_index(op.f('ix_violation_comments_id'), 'violation_comments', ['id'], unique=False)
- op.create_index(op.f('ix_violation_comments_violation_id'), 'violation_comments', ['violation_id'], unique=False)
def downgrade() -> None:
# Drop tables in reverse order (to respect foreign key constraints)
- op.drop_index(op.f('ix_violation_comments_violation_id'), table_name='violation_comments')
- op.drop_index(op.f('ix_violation_comments_id'), table_name='violation_comments')
- op.drop_table('violation_comments')
+ op.drop_index(
+ op.f("ix_violation_comments_violation_id"), table_name="violation_comments"
+ )
+ op.drop_index(op.f("ix_violation_comments_id"), table_name="violation_comments")
+ op.drop_table("violation_comments")
- op.drop_index(op.f('ix_violation_assignments_violation_id'), table_name='violation_assignments')
- op.drop_index(op.f('ix_violation_assignments_id'), table_name='violation_assignments')
- op.drop_table('violation_assignments')
+ op.drop_index(
+ op.f("ix_violation_assignments_violation_id"),
+ table_name="violation_assignments",
+ )
+ op.drop_index(
+ op.f("ix_violation_assignments_id"), table_name="violation_assignments"
+ )
+ op.drop_table("violation_assignments")
- op.drop_index(op.f('ix_architecture_violations_status'), table_name='architecture_violations')
- op.drop_index(op.f('ix_architecture_violations_severity'), table_name='architecture_violations')
- op.drop_index(op.f('ix_architecture_violations_scan_id'), table_name='architecture_violations')
- op.drop_index(op.f('ix_architecture_violations_rule_id'), table_name='architecture_violations')
- op.drop_index(op.f('ix_architecture_violations_id'), table_name='architecture_violations')
- op.drop_index(op.f('ix_architecture_violations_file_path'), table_name='architecture_violations')
- op.drop_table('architecture_violations')
+ op.drop_index(
+ op.f("ix_architecture_violations_status"), table_name="architecture_violations"
+ )
+ op.drop_index(
+ op.f("ix_architecture_violations_severity"),
+ table_name="architecture_violations",
+ )
+ op.drop_index(
+ op.f("ix_architecture_violations_scan_id"), table_name="architecture_violations"
+ )
+ op.drop_index(
+ op.f("ix_architecture_violations_rule_id"), table_name="architecture_violations"
+ )
+ op.drop_index(
+ op.f("ix_architecture_violations_id"), table_name="architecture_violations"
+ )
+ op.drop_index(
+ op.f("ix_architecture_violations_file_path"),
+ table_name="architecture_violations",
+ )
+ op.drop_table("architecture_violations")
- op.drop_index(op.f('ix_architecture_rules_rule_id'), table_name='architecture_rules')
- op.drop_index(op.f('ix_architecture_rules_id'), table_name='architecture_rules')
- op.drop_table('architecture_rules')
+ op.drop_index(
+ op.f("ix_architecture_rules_rule_id"), table_name="architecture_rules"
+ )
+ op.drop_index(op.f("ix_architecture_rules_id"), table_name="architecture_rules")
+ op.drop_table("architecture_rules")
- op.drop_index(op.f('ix_architecture_scans_timestamp'), table_name='architecture_scans')
- op.drop_index(op.f('ix_architecture_scans_id'), table_name='architecture_scans')
- op.drop_table('architecture_scans')
+ op.drop_index(
+ op.f("ix_architecture_scans_timestamp"), table_name="architecture_scans"
+ )
+ op.drop_index(op.f("ix_architecture_scans_id"), table_name="architecture_scans")
+ op.drop_table("architecture_scans")
diff --git a/alembic/versions/a2064e1dfcd4_add_cart_items_table.py b/alembic/versions/a2064e1dfcd4_add_cart_items_table.py
index 7773c638..3eca3ff3 100644
--- a/alembic/versions/a2064e1dfcd4_add_cart_items_table.py
+++ b/alembic/versions/a2064e1dfcd4_add_cart_items_table.py
@@ -5,15 +5,16 @@ Revises: f68d8da5315a
Create Date: 2025-11-23 19:52:40.509538
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = 'a2064e1dfcd4'
-down_revision: Union[str, None] = 'f68d8da5315a'
+revision: str = "a2064e1dfcd4"
+down_revision: Union[str, None] = "f68d8da5315a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -21,34 +22,46 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create cart_items table
op.create_table(
- 'cart_items',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('vendor_id', sa.Integer(), nullable=False),
- sa.Column('product_id', sa.Integer(), nullable=False),
- sa.Column('session_id', sa.String(length=255), nullable=False),
- sa.Column('quantity', sa.Integer(), nullable=False),
- sa.Column('price_at_add', sa.Float(), nullable=False),
- sa.Column('created_at', sa.DateTime(), nullable=True),
- sa.Column('updated_at', sa.DateTime(), nullable=True),
- sa.ForeignKeyConstraint(['product_id'], ['products.id'], ),
- sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('vendor_id', 'session_id', 'product_id', name='uq_cart_item')
+ "cart_items",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("vendor_id", sa.Integer(), nullable=False),
+ sa.Column("product_id", sa.Integer(), nullable=False),
+ sa.Column("session_id", sa.String(length=255), nullable=False),
+ sa.Column("quantity", sa.Integer(), nullable=False),
+ sa.Column("price_at_add", sa.Float(), nullable=False),
+ sa.Column("created_at", sa.DateTime(), nullable=True),
+ sa.Column("updated_at", sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["product_id"],
+ ["products.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["vendor_id"],
+ ["vendors.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint(
+ "vendor_id", "session_id", "product_id", name="uq_cart_item"
+ ),
)
# Create indexes
- op.create_index('idx_cart_session', 'cart_items', ['vendor_id', 'session_id'], unique=False)
- op.create_index('idx_cart_created', 'cart_items', ['created_at'], unique=False)
- op.create_index(op.f('ix_cart_items_id'), 'cart_items', ['id'], unique=False)
- op.create_index(op.f('ix_cart_items_session_id'), 'cart_items', ['session_id'], unique=False)
+ op.create_index(
+ "idx_cart_session", "cart_items", ["vendor_id", "session_id"], unique=False
+ )
+ op.create_index("idx_cart_created", "cart_items", ["created_at"], unique=False)
+ op.create_index(op.f("ix_cart_items_id"), "cart_items", ["id"], unique=False)
+ op.create_index(
+ op.f("ix_cart_items_session_id"), "cart_items", ["session_id"], unique=False
+ )
def downgrade() -> None:
# Drop indexes
- op.drop_index(op.f('ix_cart_items_session_id'), table_name='cart_items')
- op.drop_index(op.f('ix_cart_items_id'), table_name='cart_items')
- op.drop_index('idx_cart_created', table_name='cart_items')
- op.drop_index('idx_cart_session', table_name='cart_items')
+ op.drop_index(op.f("ix_cart_items_session_id"), table_name="cart_items")
+ op.drop_index(op.f("ix_cart_items_id"), table_name="cart_items")
+ op.drop_index("idx_cart_created", table_name="cart_items")
+ op.drop_index("idx_cart_session", table_name="cart_items")
# Drop table
- op.drop_table('cart_items')
+ op.drop_table("cart_items")
diff --git a/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py b/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py
index 2b40dda2..b095835d 100644
--- a/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py
+++ b/alembic/versions/f68d8da5315a_add_template_field_to_content_pages_for_.py
@@ -5,24 +5,30 @@ Revises: 72aa309d4007
Create Date: 2025-11-22 23:51:40.694983
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = 'f68d8da5315a'
-down_revision: Union[str, None] = '72aa309d4007'
+revision: str = "f68d8da5315a"
+down_revision: Union[str, None] = "72aa309d4007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add template column to content_pages table
- op.add_column('content_pages', sa.Column('template', sa.String(length=50), nullable=False, server_default='default'))
+ op.add_column(
+ "content_pages",
+ sa.Column(
+ "template", sa.String(length=50), nullable=False, server_default="default"
+ ),
+ )
def downgrade() -> None:
# Remove template column from content_pages table
- op.drop_column('content_pages', 'template')
+ op.drop_column("content_pages", "template")
diff --git a/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py b/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py
index 2fcca7a8..6c0decce 100644
--- a/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py
+++ b/alembic/versions/fa7d4d10e358_add_rbac_enhancements.py
@@ -6,15 +6,16 @@ Create Date: 2025-11-13 16:51:25.010057
SQLite-compatible version
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = 'fa7d4d10e358'
-down_revision: Union[str, None] = '4951b2e50581'
+revision: str = "fa7d4d10e358"
+down_revision: Union[str, None] = "4951b2e50581"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -28,9 +29,14 @@ def upgrade():
# ========================================================================
# User table changes
# ========================================================================
- with op.batch_alter_table('users', schema=None) as batch_op:
+ with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(
- sa.Column('is_email_verified', sa.Boolean(), nullable=False, server_default='false')
+ sa.Column(
+ "is_email_verified",
+ sa.Boolean(),
+ nullable=False,
+ server_default="false",
+ )
)
# Set existing active users as verified
@@ -39,68 +45,65 @@ def upgrade():
# ========================================================================
# VendorUser table changes (requires table recreation for SQLite)
# ========================================================================
- with op.batch_alter_table('vendor_users', schema=None) as batch_op:
+ with op.batch_alter_table("vendor_users", schema=None) as batch_op:
# Add new columns
batch_op.add_column(
- sa.Column('user_type', sa.String(length=20), nullable=False, server_default='member')
+ sa.Column(
+ "user_type",
+ sa.String(length=20),
+ nullable=False,
+ server_default="member",
+ )
)
batch_op.add_column(
- sa.Column('invitation_token', sa.String(length=100), nullable=True)
+ sa.Column("invitation_token", sa.String(length=100), nullable=True)
)
batch_op.add_column(
- sa.Column('invitation_sent_at', sa.DateTime(), nullable=True)
+ sa.Column("invitation_sent_at", sa.DateTime(), nullable=True)
)
batch_op.add_column(
- sa.Column('invitation_accepted_at', sa.DateTime(), nullable=True)
+ sa.Column("invitation_accepted_at", sa.DateTime(), nullable=True)
)
# Create index on invitation_token
- batch_op.create_index(
- 'idx_vendor_users_invitation_token',
- ['invitation_token']
- )
+ batch_op.create_index("idx_vendor_users_invitation_token", ["invitation_token"])
# Modify role_id to be nullable (this recreates the table in SQLite)
- batch_op.alter_column(
- 'role_id',
- existing_type=sa.Integer(),
- nullable=True
- )
+ batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=True)
# Change is_active default (this recreates the table in SQLite)
batch_op.alter_column(
- 'is_active',
- existing_type=sa.Boolean(),
- server_default='false'
+ "is_active", existing_type=sa.Boolean(), server_default="false"
)
# Set owners correctly (after table modifications)
# SQLite-compatible UPDATE with subquery
- op.execute("""
+ op.execute(
+ """
UPDATE vendor_users
SET user_type = 'owner'
WHERE (vendor_id, user_id) IN (
SELECT id, owner_user_id
FROM vendors
)
- """)
+ """
+ )
# Set existing owners as active
- op.execute("""
+ op.execute(
+ """
UPDATE vendor_users
SET is_active = TRUE
WHERE user_type = 'owner'
- """)
+ """
+ )
# ========================================================================
# Role table changes
# ========================================================================
- with op.batch_alter_table('roles', schema=None) as batch_op:
+ with op.batch_alter_table("roles", schema=None) as batch_op:
# Create index on vendor_id and name
- batch_op.create_index(
- 'idx_roles_vendor_name',
- ['vendor_id', 'name']
- )
+ batch_op.create_index("idx_roles_vendor_name", ["vendor_id", "name"])
# Note: JSONB conversion only for PostgreSQL
# SQLite stores JSON as TEXT by default, no conversion needed
@@ -115,37 +118,31 @@ def downgrade():
# ========================================================================
# Role table changes
# ========================================================================
- with op.batch_alter_table('roles', schema=None) as batch_op:
- batch_op.drop_index('idx_roles_vendor_name')
+ with op.batch_alter_table("roles", schema=None) as batch_op:
+ batch_op.drop_index("idx_roles_vendor_name")
# ========================================================================
# VendorUser table changes
# ========================================================================
- with op.batch_alter_table('vendor_users', schema=None) as batch_op:
+ with op.batch_alter_table("vendor_users", schema=None) as batch_op:
# Revert is_active default
batch_op.alter_column(
- 'is_active',
- existing_type=sa.Boolean(),
- server_default='true'
+ "is_active", existing_type=sa.Boolean(), server_default="true"
)
# Revert role_id to NOT NULL
# Note: This might fail if there are NULL values
- batch_op.alter_column(
- 'role_id',
- existing_type=sa.Integer(),
- nullable=False
- )
+ batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=False)
# Drop indexes and columns
- batch_op.drop_index('idx_vendor_users_invitation_token')
- batch_op.drop_column('invitation_accepted_at')
- batch_op.drop_column('invitation_sent_at')
- batch_op.drop_column('invitation_token')
- batch_op.drop_column('user_type')
+ batch_op.drop_index("idx_vendor_users_invitation_token")
+ batch_op.drop_column("invitation_accepted_at")
+ batch_op.drop_column("invitation_sent_at")
+ batch_op.drop_column("invitation_token")
+ batch_op.drop_column("user_type")
# ========================================================================
# User table changes
# ========================================================================
- with op.batch_alter_table('users', schema=None) as batch_op:
- batch_op.drop_column('is_email_verified')
\ No newline at end of file
+ with op.batch_alter_table("users", schema=None) as batch_op:
+ batch_op.drop_column("is_email_verified")
diff --git a/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py b/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py
index 0c895879..cc24f81a 100644
--- a/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py
+++ b/alembic/versions/fef1d20ce8b4_add_content_pages_table_for_cms.py
@@ -5,30 +5,43 @@ Revises: fa7d4d10e358
Create Date: 2025-11-22 13:41:18.069674
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = 'fef1d20ce8b4'
-down_revision: Union[str, None] = 'fa7d4d10e358'
+revision: str = "fef1d20ce8b4"
+down_revision: Union[str, None] = "fa7d4d10e358"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('idx_roles_vendor_name', table_name='roles')
- op.drop_index('idx_vendor_users_invitation_token', table_name='vendor_users')
- op.create_index(op.f('ix_vendor_users_invitation_token'), 'vendor_users', ['invitation_token'], unique=False)
+ op.drop_index("idx_roles_vendor_name", table_name="roles")
+ op.drop_index("idx_vendor_users_invitation_token", table_name="vendor_users")
+ op.create_index(
+ op.f("ix_vendor_users_invitation_token"),
+ "vendor_users",
+ ["invitation_token"],
+ unique=False,
+ )
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index(op.f('ix_vendor_users_invitation_token'), table_name='vendor_users')
- op.create_index('idx_vendor_users_invitation_token', 'vendor_users', ['invitation_token'], unique=False)
- op.create_index('idx_roles_vendor_name', 'roles', ['vendor_id', 'name'], unique=False)
+ op.drop_index(op.f("ix_vendor_users_invitation_token"), table_name="vendor_users")
+ op.create_index(
+ "idx_vendor_users_invitation_token",
+ "vendor_users",
+ ["invitation_token"],
+ unique=False,
+ )
+ op.create_index(
+ "idx_roles_vendor_name", "roles", ["vendor_id", "name"], unique=False
+ )
# ### end Alembic commands ###
diff --git a/app/api/deps.py b/app/api/deps.py
index d219a60a..f2761a9a 100644
--- a/app/api/deps.py
+++ b/app/api/deps.py
@@ -34,22 +34,20 @@ The cookie path restrictions prevent cross-context cookie leakage:
import logging
from typing import Optional
-from fastapi import Depends, Request, Cookie
+from fastapi import Cookie, Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.core.database import get_db
+from app.exceptions import (AdminRequiredException,
+ InsufficientPermissionsException,
+ InvalidTokenException,
+ UnauthorizedVendorAccessException,
+ VendorNotFoundException)
from middleware.auth import AuthManager
from middleware.rate_limiter import RateLimiter
-from models.database.vendor import Vendor
from models.database.user import User
-from app.exceptions import (
- AdminRequiredException,
- InvalidTokenException,
- InsufficientPermissionsException,
- VendorNotFoundException,
- UnauthorizedVendorAccessException
-)
+from models.database.vendor import Vendor
# Initialize dependencies
security = HTTPBearer(auto_error=False) # auto_error=False prevents automatic 403
@@ -62,11 +60,12 @@ logger = logging.getLogger(__name__)
# HELPER FUNCTIONS
# ============================================================================
+
def _get_token_from_request(
- credentials: Optional[HTTPAuthorizationCredentials],
- cookie_value: Optional[str],
- cookie_name: str,
- request_path: str
+ credentials: Optional[HTTPAuthorizationCredentials],
+ cookie_value: Optional[str],
+ cookie_name: str,
+ request_path: str,
) -> tuple[Optional[str], Optional[str]]:
"""
Extract token from Authorization header or cookie.
@@ -108,10 +107,7 @@ def _validate_user_token(token: str, db: Session) -> User:
Raises:
InvalidTokenException: If token is invalid
"""
- mock_credentials = HTTPAuthorizationCredentials(
- scheme="Bearer",
- credentials=token
- )
+ mock_credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
return auth_manager.get_current_user(db, mock_credentials)
@@ -119,11 +115,12 @@ def _validate_user_token(token: str, db: Session) -> User:
# ADMIN AUTHENTICATION
# ============================================================================
+
def get_current_admin_from_cookie_or_header(
- request: Request,
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
- admin_token: Optional[str] = Cookie(None),
- db: Session = Depends(get_db),
+ request: Request,
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
+ admin_token: Optional[str] = Cookie(None),
+ db: Session = Depends(get_db),
) -> User:
"""
Get current admin user from admin_token cookie or Authorization header.
@@ -148,10 +145,7 @@ def get_current_admin_from_cookie_or_header(
AdminRequiredException: If user is not admin
"""
token, source = _get_token_from_request(
- credentials,
- admin_token,
- "admin_token",
- str(request.url.path)
+ credentials, admin_token, "admin_token", str(request.url.path)
)
if not token:
@@ -172,8 +166,8 @@ def get_current_admin_from_cookie_or_header(
def get_current_admin_api(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: Session = Depends(get_db),
+ credentials: HTTPAuthorizationCredentials = Depends(security),
+ db: Session = Depends(get_db),
) -> User:
"""
Get current admin user from Authorization header ONLY.
@@ -208,11 +202,12 @@ def get_current_admin_api(
# VENDOR AUTHENTICATION
# ============================================================================
+
def get_current_vendor_from_cookie_or_header(
- request: Request,
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
- vendor_token: Optional[str] = Cookie(None),
- db: Session = Depends(get_db),
+ request: Request,
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
+ vendor_token: Optional[str] = Cookie(None),
+ db: Session = Depends(get_db),
) -> User:
"""
Get current vendor user from vendor_token cookie or Authorization header.
@@ -237,10 +232,7 @@ def get_current_vendor_from_cookie_or_header(
InsufficientPermissionsException: If user is not vendor or is admin
"""
token, source = _get_token_from_request(
- credentials,
- vendor_token,
- "vendor_token",
- str(request.url.path)
+ credentials, vendor_token, "vendor_token", str(request.url.path)
)
if not token:
@@ -270,8 +262,8 @@ def get_current_vendor_from_cookie_or_header(
def get_current_vendor_api(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: Session = Depends(get_db),
+ credentials: HTTPAuthorizationCredentials = Depends(security),
+ db: Session = Depends(get_db),
) -> User:
"""
Get current vendor user from Authorization header ONLY.
@@ -310,11 +302,12 @@ def get_current_vendor_api(
# CUSTOMER AUTHENTICATION (SHOP)
# ============================================================================
+
def get_current_customer_from_cookie_or_header(
- request: Request,
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
- customer_token: Optional[str] = Cookie(None),
- db: Session = Depends(get_db),
+ request: Request,
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
+ customer_token: Optional[str] = Cookie(None),
+ db: Session = Depends(get_db),
):
"""
Get current customer from customer_token cookie or Authorization header.
@@ -338,15 +331,14 @@ def get_current_customer_from_cookie_or_header(
Raises:
InvalidTokenException: If no token or invalid token
"""
- from models.database.customer import Customer
- from jose import jwt, JWTError
from datetime import datetime, timezone
+ from jose import JWTError, jwt
+
+ from models.database.customer import Customer
+
token, source = _get_token_from_request(
- credentials,
- customer_token,
- "customer_token",
- str(request.url.path)
+ credentials, customer_token, "customer_token", str(request.url.path)
)
if not token:
@@ -356,9 +348,7 @@ def get_current_customer_from_cookie_or_header(
# Decode and validate customer JWT token
try:
payload = jwt.decode(
- token,
- auth_manager.secret_key,
- algorithms=[auth_manager.algorithm]
+ token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
)
# Verify this is a customer token
@@ -375,7 +365,9 @@ def get_current_customer_from_cookie_or_header(
# Verify token hasn't expired
exp = payload.get("exp")
- if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
+ if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(
+ timezone.utc
+ ):
logger.warning(f"Expired customer token for customer_id={customer_id}")
raise InvalidTokenException("Token has expired")
@@ -400,8 +392,8 @@ def get_current_customer_from_cookie_or_header(
def get_current_customer_api(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: Session = Depends(get_db),
+ credentials: HTTPAuthorizationCredentials = Depends(security),
+ db: Session = Depends(get_db),
) -> User:
"""
Get current customer user from Authorization header ONLY.
@@ -445,9 +437,10 @@ def get_current_customer_api(
# GENERIC AUTHENTICATION (for mixed-use endpoints)
# ============================================================================
+
def get_current_user(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: Session = Depends(get_db)
+ credentials: HTTPAuthorizationCredentials = Depends(security),
+ db: Session = Depends(get_db),
) -> User:
"""
Get current authenticated user from Authorization header only.
@@ -475,10 +468,11 @@ def get_current_user(
# VENDOR OWNERSHIP VERIFICATION
# ============================================================================
+
def get_user_vendor(
- vendor_code: str,
- current_user: User = Depends(get_current_vendor_from_cookie_or_header),
- db: Session = Depends(get_db),
+ vendor_code: str,
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
+ db: Session = Depends(get_db),
) -> Vendor:
"""
Get vendor and verify user ownership/membership.
@@ -500,9 +494,7 @@ def get_user_vendor(
VendorNotFoundException: If vendor doesn't exist
UnauthorizedVendorAccessException: If user doesn't have access
"""
- vendor = db.query(Vendor).filter(
- Vendor.vendor_code == vendor_code.upper()
- ).first()
+ vendor = db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
if not vendor:
raise VendorNotFoundException(vendor_code)
@@ -517,10 +509,12 @@ def get_user_vendor(
# User doesn't have access to this vendor
raise UnauthorizedVendorAccessException(vendor_code, current_user.id)
+
# ============================================================================
# PERMISSIONS CHECKING
# ============================================================================
+
def require_vendor_permission(permission: str):
"""
Dependency factory to require a specific vendor permission.
@@ -535,9 +529,9 @@ def require_vendor_permission(permission: str):
"""
def permission_checker(
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header),
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
) -> User:
# Get vendor from request state (set by middleware)
vendor = getattr(request.state, "vendor", None)
@@ -557,9 +551,9 @@ def require_vendor_permission(permission: str):
def require_vendor_owner(
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header),
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
) -> User:
"""
Dependency to require vendor owner role.
@@ -600,9 +594,9 @@ def require_any_vendor_permission(*permissions: str):
"""
def permission_checker(
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header),
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
) -> User:
vendor = getattr(request.state, "vendor", None)
if not vendor:
@@ -610,8 +604,7 @@ def require_any_vendor_permission(*permissions: str):
# Check if user has ANY of the required permissions
has_permission = any(
- current_user.has_vendor_permission(vendor.id, perm)
- for perm in permissions
+ current_user.has_vendor_permission(vendor.id, perm) for perm in permissions
)
if not has_permission:
@@ -641,9 +634,9 @@ def require_all_vendor_permissions(*permissions: str):
"""
def permission_checker(
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header),
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
) -> User:
vendor = getattr(request.state, "vendor", None)
if not vendor:
@@ -651,7 +644,8 @@ def require_all_vendor_permissions(*permissions: str):
# Check if user has ALL required permissions
missing_permissions = [
- perm for perm in permissions
+ perm
+ for perm in permissions
if not current_user.has_vendor_permission(vendor.id, perm)
]
@@ -667,8 +661,8 @@ def require_all_vendor_permissions(*permissions: str):
def get_user_permissions(
- request: Request,
- current_user: User = Depends(get_current_vendor_from_cookie_or_header),
+ request: Request,
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
) -> list:
"""
Get all permissions for current user in current vendor.
@@ -682,6 +676,7 @@ def get_user_permissions(
# If owner, return all permissions
if current_user.is_owner_of(vendor.id):
from app.core.permissions import VendorPermissions
+
return [p.value for p in VendorPermissions]
# Get permissions from vendor membership
@@ -696,11 +691,12 @@ def get_user_permissions(
# OPTIONAL AUTHENTICATION (For Login Page Redirects)
# ============================================================================
+
def get_current_admin_optional(
- request: Request,
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
- admin_token: Optional[str] = Cookie(None),
- db: Session = Depends(get_db),
+ request: Request,
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
+ admin_token: Optional[str] = Cookie(None),
+ db: Session = Depends(get_db),
) -> Optional[User]:
"""
Get current admin user from admin_token cookie or Authorization header.
@@ -723,10 +719,7 @@ def get_current_admin_optional(
None: If no token, invalid token, or user is not admin
"""
token, source = _get_token_from_request(
- credentials,
- admin_token,
- "admin_token",
- str(request.url.path)
+ credentials, admin_token, "admin_token", str(request.url.path)
)
if not token:
@@ -747,10 +740,10 @@ def get_current_admin_optional(
def get_current_vendor_optional(
- request: Request,
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
- vendor_token: Optional[str] = Cookie(None),
- db: Session = Depends(get_db),
+ request: Request,
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
+ vendor_token: Optional[str] = Cookie(None),
+ db: Session = Depends(get_db),
) -> Optional[User]:
"""
Get current vendor user from vendor_token cookie or Authorization header.
@@ -773,10 +766,7 @@ def get_current_vendor_optional(
None: If no token, invalid token, or user is not vendor
"""
token, source = _get_token_from_request(
- credentials,
- vendor_token,
- "vendor_token",
- str(request.url.path)
+ credentials, vendor_token, "vendor_token", str(request.url.path)
)
if not token:
@@ -797,10 +787,10 @@ def get_current_vendor_optional(
def get_current_customer_optional(
- request: Request,
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
- customer_token: Optional[str] = Cookie(None),
- db: Session = Depends(get_db),
+ request: Request,
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
+ customer_token: Optional[str] = Cookie(None),
+ db: Session = Depends(get_db),
) -> Optional[User]:
"""
Get current customer user from customer_token cookie or Authorization header.
@@ -823,10 +813,7 @@ def get_current_customer_optional(
None: If no token, invalid token, or user is not customer
"""
token, source = _get_token_from_request(
- credentials,
- customer_token,
- "customer_token",
- str(request.url.path)
+ credentials, customer_token, "customer_token", str(request.url.path)
)
if not token:
@@ -844,5 +831,3 @@ def get_current_customer_optional(
pass
return None
-
-
diff --git a/app/api/main.py b/app/api/main.py
index a6997bd1..1f1c43be 100644
--- a/app/api/main.py
+++ b/app/api/main.py
@@ -9,7 +9,8 @@ This module provides:
"""
from fastapi import APIRouter
-from app.api.v1 import admin, vendor, shop
+
+from app.api.v1 import admin, shop, vendor
api_router = APIRouter()
@@ -18,31 +19,18 @@ api_router = APIRouter()
# Prefix: /api/v1/admin
# ============================================================================
-api_router.include_router(
- admin.router,
- prefix="/v1/admin",
- tags=["admin"]
-)
+api_router.include_router(admin.router, prefix="/v1/admin", tags=["admin"])
# ============================================================================
# VENDOR ROUTES (Vendor-scoped operations)
# Prefix: /api/v1/vendor
# ============================================================================
-api_router.include_router(
- vendor.router,
- prefix="/v1/vendor",
- tags=["vendor"]
-)
+api_router.include_router(vendor.router, prefix="/v1/vendor", tags=["vendor"])
# ============================================================================
# SHOP ROUTES (Public shop frontend API)
# Prefix: /api/v1/shop
# ============================================================================
-api_router.include_router(
- shop.router,
- prefix="/v1/shop",
- tags=["shop"]
-)
-
+api_router.include_router(shop.router, prefix="/v1/shop", tags=["shop"])
diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py
index 8890d586..223a83c5 100644
--- a/app/api/v1/__init__.py
+++ b/app/api/v1/__init__.py
@@ -3,6 +3,6 @@
API Version 1 - All endpoints
"""
-from . import admin, vendor, shop
+from . import admin, shop, vendor
-__all__ = ["admin", "vendor", "shop"]
\ No newline at end of file
+__all__ = ["admin", "vendor", "shop"]
diff --git a/app/api/v1/admin/__init__.py b/app/api/v1/admin/__init__.py
index f2e5ca10..a7b9d009 100644
--- a/app/api/v1/admin/__init__.py
+++ b/app/api/v1/admin/__init__.py
@@ -24,21 +24,9 @@ IMPORTANT:
from fastapi import APIRouter
# Import all admin routers
-from . import (
- auth,
- vendors,
- vendor_domains,
- vendor_themes,
- users,
- dashboard,
- marketplace,
- monitoring,
- audit,
- settings,
- notifications,
- content_pages,
- code_quality
-)
+from . import (audit, auth, code_quality, content_pages, dashboard,
+ marketplace, monitoring, notifications, settings, users,
+ vendor_domains, vendor_themes, vendors)
# Create admin router
router = APIRouter()
@@ -66,7 +54,9 @@ router.include_router(vendor_domains.router, tags=["admin-vendor-domains"])
router.include_router(vendor_themes.router, tags=["admin-vendor-themes"])
# Include content pages management endpoints
-router.include_router(content_pages.router, prefix="/content-pages", tags=["admin-content-pages"])
+router.include_router(
+ content_pages.router, prefix="/content-pages", tags=["admin-content-pages"]
+)
# ============================================================================
@@ -115,7 +105,9 @@ router.include_router(notifications.router, tags=["admin-notifications"])
# ============================================================================
# Include code quality and architecture validation endpoints
-router.include_router(code_quality.router, prefix="/code-quality", tags=["admin-code-quality"])
+router.include_router(
+ code_quality.router, prefix="/code-quality", tags=["admin-code-quality"]
+)
# Export the router
__all__ = ["router"]
diff --git a/app/api/v1/admin/audit.py b/app/api/v1/admin/audit.py
index a150cbf2..d3d91511 100644
--- a/app/api/v1/admin/audit.py
+++ b/app/api/v1/admin/audit.py
@@ -9,8 +9,8 @@ Provides endpoints for:
"""
import logging
-from typing import Optional
from datetime import datetime
+from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -18,12 +18,10 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_audit_service import admin_audit_service
-from models.schema.admin import (
- AdminAuditLogResponse,
- AdminAuditLogFilters,
- AdminAuditLogListResponse
-)
from models.database.user import User
+from models.schema.admin import (AdminAuditLogFilters,
+ AdminAuditLogListResponse,
+ AdminAuditLogResponse)
router = APIRouter(prefix="/audit")
logger = logging.getLogger(__name__)
@@ -31,15 +29,15 @@ logger = logging.getLogger(__name__)
@router.get("/logs", response_model=AdminAuditLogListResponse)
def get_audit_logs(
- admin_user_id: Optional[int] = Query(None, description="Filter by admin user"),
- action: Optional[str] = Query(None, description="Filter by action type"),
- target_type: Optional[str] = Query(None, description="Filter by target type"),
- date_from: Optional[datetime] = Query(None, description="Filter from date"),
- date_to: Optional[datetime] = Query(None, description="Filter to date"),
- skip: int = Query(0, ge=0, description="Number of records to skip"),
- limit: int = Query(100, ge=1, le=1000, description="Maximum records to return"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ admin_user_id: Optional[int] = Query(None, description="Filter by admin user"),
+ action: Optional[str] = Query(None, description="Filter by action type"),
+ target_type: Optional[str] = Query(None, description="Filter by target type"),
+ date_from: Optional[datetime] = Query(None, description="Filter from date"),
+ date_to: Optional[datetime] = Query(None, description="Filter to date"),
+ skip: int = Query(0, ge=0, description="Number of records to skip"),
+ limit: int = Query(100, ge=1, le=1000, description="Maximum records to return"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get filtered admin audit logs.
@@ -54,7 +52,7 @@ def get_audit_logs(
date_from=date_from,
date_to=date_to,
skip=skip,
- limit=limit
+ limit=limit,
)
logs = admin_audit_service.get_audit_logs(db, filters)
@@ -62,19 +60,14 @@ def get_audit_logs(
logger.info(f"Admin {current_admin.username} retrieved {len(logs)} audit logs")
- return AdminAuditLogListResponse(
- logs=logs,
- total=total,
- skip=skip,
- limit=limit
- )
+ return AdminAuditLogListResponse(logs=logs, total=total, skip=skip, limit=limit)
@router.get("/logs/recent", response_model=list[AdminAuditLogResponse])
def get_recent_audit_logs(
- limit: int = Query(20, ge=1, le=100),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ limit: int = Query(20, ge=1, le=100),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Get recent audit logs (last 20 by default)."""
filters = AdminAuditLogFilters(limit=limit)
@@ -83,25 +76,23 @@ def get_recent_audit_logs(
@router.get("/logs/my-actions", response_model=list[AdminAuditLogResponse])
def get_my_actions(
- limit: int = Query(50, ge=1, le=100),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ limit: int = Query(50, ge=1, le=100),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Get audit logs for current admin's actions."""
return admin_audit_service.get_recent_actions_by_admin(
- db=db,
- admin_user_id=current_admin.id,
- limit=limit
+ db=db, admin_user_id=current_admin.id, limit=limit
)
@router.get("/logs/target/{target_type}/{target_id}")
def get_actions_by_target(
- target_type: str,
- target_id: str,
- limit: int = Query(50, ge=1, le=100),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ target_type: str,
+ target_id: str,
+ limit: int = Query(50, ge=1, le=100),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get all actions performed on a specific target.
@@ -109,8 +100,5 @@ def get_actions_by_target(
Useful for tracking the history of a specific vendor, user, or entity.
"""
return admin_audit_service.get_actions_by_target(
- db=db,
- target_type=target_type,
- target_id=target_id,
- limit=limit
+ db=db, target_type=target_type, target_id=target_id, limit=limit
)
diff --git a/app/api/v1/admin/auth.py b/app/api/v1/admin/auth.py
index 5de07e00..a739e5cd 100644
--- a/app/api/v1/admin/auth.py
+++ b/app/api/v1/admin/auth.py
@@ -10,16 +10,17 @@ This prevents admin cookies from being sent to vendor routes.
"""
import logging
+
from fastapi import APIRouter, Depends, Response
from sqlalchemy.orm import Session
+from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.core.environment import should_use_secure_cookies
-from app.services.auth_service import auth_service
from app.exceptions import InvalidCredentialsException
-from models.schema.auth import LoginResponse, UserLogin, UserResponse
+from app.services.auth_service import auth_service
from models.database.user import User
-from app.api.deps import get_current_admin_api
+from models.schema.auth import LoginResponse, UserLogin, UserResponse
router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__)
@@ -27,9 +28,7 @@ logger = logging.getLogger(__name__)
@router.post("/login", response_model=LoginResponse)
def admin_login(
- user_credentials: UserLogin,
- response: Response,
- db: Session = Depends(get_db)
+ user_credentials: UserLogin, response: Response, db: Session = Depends(get_db)
):
"""
Admin login endpoint.
@@ -49,7 +48,9 @@ def admin_login(
# Verify user is admin
if login_result["user"].role != "admin":
- logger.warning(f"Non-admin user attempted admin login: {user_credentials.email_or_username}")
+ logger.warning(
+ f"Non-admin user attempted admin login: {user_credentials.email_or_username}"
+ )
raise InvalidCredentialsException("Admin access required")
logger.info(f"Admin login successful: {login_result['user'].username}")
diff --git a/app/api/v1/admin/code_quality.py b/app/api/v1/admin/code_quality.py
index de3e5674..8193638b 100644
--- a/app/api/v1/admin/code_quality.py
+++ b/app/api/v1/admin/code_quality.py
@@ -3,25 +3,27 @@ Code Quality API Endpoints
RESTful API for architecture validation and violation management
"""
-from typing import Optional
from datetime import datetime
-from fastapi import APIRouter, Depends, HTTPException, Query
-from sqlalchemy.orm import Session
-from pydantic import BaseModel, Field
+from typing import Optional
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel, Field
+from sqlalchemy.orm import Session
+
+from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.code_quality_service import code_quality_service
-from app.api.deps import get_current_admin_api
from models.database.user import User
-
router = APIRouter()
# Pydantic Models for API
+
class ScanResponse(BaseModel):
"""Response model for a scan"""
+
id: int
timestamp: str
total_files: int
@@ -38,6 +40,7 @@ class ScanResponse(BaseModel):
class ViolationResponse(BaseModel):
"""Response model for a violation"""
+
id: int
scan_id: int
rule_id: str
@@ -61,6 +64,7 @@ class ViolationResponse(BaseModel):
class ViolationListResponse(BaseModel):
"""Response model for paginated violations list"""
+
violations: list[ViolationResponse]
total: int
page: int
@@ -70,34 +74,42 @@ class ViolationListResponse(BaseModel):
class ViolationDetailResponse(ViolationResponse):
"""Response model for single violation with relationships"""
+
assignments: list = []
comments: list = []
class AssignViolationRequest(BaseModel):
"""Request model for assigning a violation"""
+
user_id: int = Field(..., description="User ID to assign to")
due_date: Optional[datetime] = Field(None, description="Due date for resolution")
- priority: str = Field("medium", description="Priority level (low, medium, high, critical)")
+ priority: str = Field(
+ "medium", description="Priority level (low, medium, high, critical)"
+ )
class ResolveViolationRequest(BaseModel):
"""Request model for resolving a violation"""
+
resolution_note: str = Field(..., description="Note about the resolution")
class IgnoreViolationRequest(BaseModel):
"""Request model for ignoring a violation"""
+
reason: str = Field(..., description="Reason for ignoring")
class AddCommentRequest(BaseModel):
"""Request model for adding a comment"""
+
comment: str = Field(..., min_length=1, description="Comment text")
class DashboardStatsResponse(BaseModel):
"""Response model for dashboard statistics"""
+
total_violations: int
errors: int
warnings: int
@@ -116,10 +128,10 @@ class DashboardStatsResponse(BaseModel):
# API Endpoints
+
@router.post("/scan", response_model=ScanResponse)
async def trigger_scan(
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api)
):
"""
Trigger a new architecture scan
@@ -127,7 +139,9 @@ async def trigger_scan(
Requires authentication. Runs the validator script and stores results.
"""
try:
- scan = code_quality_service.run_scan(db, triggered_by=f"manual:{current_user.username}")
+ scan = code_quality_service.run_scan(
+ db, triggered_by=f"manual:{current_user.username}"
+ )
return ScanResponse(
id=scan.id,
@@ -138,7 +152,7 @@ async def trigger_scan(
warnings=scan.warnings,
duration_seconds=scan.duration_seconds,
triggered_by=scan.triggered_by,
- git_commit_hash=scan.git_commit_hash
+ git_commit_hash=scan.git_commit_hash,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}")
@@ -148,7 +162,7 @@ async def trigger_scan(
async def list_scans(
limit: int = Query(30, ge=1, le=100, description="Number of scans to return"),
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Get scan history
@@ -167,7 +181,7 @@ async def list_scans(
warnings=scan.warnings,
duration_seconds=scan.duration_seconds,
triggered_by=scan.triggered_by,
- git_commit_hash=scan.git_commit_hash
+ git_commit_hash=scan.git_commit_hash,
)
for scan in scans
]
@@ -175,15 +189,23 @@ async def list_scans(
@router.get("/violations", response_model=ViolationListResponse)
async def list_violations(
- scan_id: Optional[int] = Query(None, description="Filter by scan ID (defaults to latest)"),
- severity: Optional[str] = Query(None, description="Filter by severity (error, warning)"),
- status: Optional[str] = Query(None, description="Filter by status (open, assigned, resolved, ignored)"),
+ scan_id: Optional[int] = Query(
+ None, description="Filter by scan ID (defaults to latest)"
+ ),
+ severity: Optional[str] = Query(
+ None, description="Filter by severity (error, warning)"
+ ),
+ status: Optional[str] = Query(
+ None, description="Filter by status (open, assigned, resolved, ignored)"
+ ),
rule_id: Optional[str] = Query(None, description="Filter by rule ID"),
- file_path: Optional[str] = Query(None, description="Filter by file path (partial match)"),
+ file_path: Optional[str] = Query(
+ None, description="Filter by file path (partial match)"
+ ),
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(50, ge=1, le=200, description="Items per page"),
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Get violations with filtering and pagination
@@ -200,7 +222,7 @@ async def list_violations(
rule_id=rule_id,
file_path=file_path,
limit=page_size,
- offset=offset
+ offset=offset,
)
total_pages = (total + page_size - 1) // page_size
@@ -223,14 +245,14 @@ async def list_violations(
resolved_at=v.resolved_at.isoformat() if v.resolved_at else None,
resolved_by=v.resolved_by,
resolution_note=v.resolution_note,
- created_at=v.created_at.isoformat()
+ created_at=v.created_at.isoformat(),
)
for v in violations
],
total=total,
page=page,
page_size=page_size,
- total_pages=total_pages
+ total_pages=total_pages,
)
@@ -238,7 +260,7 @@ async def list_violations(
async def get_violation(
violation_id: int,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Get single violation with details
@@ -253,12 +275,12 @@ async def get_violation(
# Format assignments
assignments = [
{
- 'id': a.id,
- 'user_id': a.user_id,
- 'assigned_at': a.assigned_at.isoformat(),
- 'assigned_by': a.assigned_by,
- 'due_date': a.due_date.isoformat() if a.due_date else None,
- 'priority': a.priority
+ "id": a.id,
+ "user_id": a.user_id,
+ "assigned_at": a.assigned_at.isoformat(),
+ "assigned_by": a.assigned_by,
+ "due_date": a.due_date.isoformat() if a.due_date else None,
+ "priority": a.priority,
}
for a in violation.assignments
]
@@ -266,10 +288,10 @@ async def get_violation(
# Format comments
comments = [
{
- 'id': c.id,
- 'user_id': c.user_id,
- 'comment': c.comment,
- 'created_at': c.created_at.isoformat()
+ "id": c.id,
+ "user_id": c.user_id,
+ "comment": c.comment,
+ "created_at": c.created_at.isoformat(),
}
for c in violation.comments
]
@@ -287,12 +309,14 @@ async def get_violation(
suggestion=violation.suggestion,
status=violation.status,
assigned_to=violation.assigned_to,
- resolved_at=violation.resolved_at.isoformat() if violation.resolved_at else None,
+ resolved_at=(
+ violation.resolved_at.isoformat() if violation.resolved_at else None
+ ),
resolved_by=violation.resolved_by,
resolution_note=violation.resolution_note,
created_at=violation.created_at.isoformat(),
assignments=assignments,
- comments=comments
+ comments=comments,
)
@@ -301,7 +325,7 @@ async def assign_violation(
violation_id: int,
request: AssignViolationRequest,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Assign violation to a developer
@@ -315,17 +339,19 @@ async def assign_violation(
user_id=request.user_id,
assigned_by=current_user.id,
due_date=request.due_date,
- priority=request.priority
+ priority=request.priority,
)
return {
- 'id': assignment.id,
- 'violation_id': assignment.violation_id,
- 'user_id': assignment.user_id,
- 'assigned_at': assignment.assigned_at.isoformat(),
- 'assigned_by': assignment.assigned_by,
- 'due_date': assignment.due_date.isoformat() if assignment.due_date else None,
- 'priority': assignment.priority
+ "id": assignment.id,
+ "violation_id": assignment.violation_id,
+ "user_id": assignment.user_id,
+ "assigned_at": assignment.assigned_at.isoformat(),
+ "assigned_by": assignment.assigned_by,
+ "due_date": (
+ assignment.due_date.isoformat() if assignment.due_date else None
+ ),
+ "priority": assignment.priority,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -336,7 +362,7 @@ async def resolve_violation(
violation_id: int,
request: ResolveViolationRequest,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Mark violation as resolved
@@ -348,15 +374,17 @@ async def resolve_violation(
db,
violation_id=violation_id,
resolved_by=current_user.id,
- resolution_note=request.resolution_note
+ resolution_note=request.resolution_note,
)
return {
- 'id': violation.id,
- 'status': violation.status,
- 'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None,
- 'resolved_by': violation.resolved_by,
- 'resolution_note': violation.resolution_note
+ "id": violation.id,
+ "status": violation.status,
+ "resolved_at": (
+ violation.resolved_at.isoformat() if violation.resolved_at else None
+ ),
+ "resolved_by": violation.resolved_by,
+ "resolution_note": violation.resolution_note,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -369,7 +397,7 @@ async def ignore_violation(
violation_id: int,
request: IgnoreViolationRequest,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Mark violation as ignored (won't fix)
@@ -381,15 +409,17 @@ async def ignore_violation(
db,
violation_id=violation_id,
ignored_by=current_user.id,
- reason=request.reason
+ reason=request.reason,
)
return {
- 'id': violation.id,
- 'status': violation.status,
- 'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None,
- 'resolved_by': violation.resolved_by,
- 'resolution_note': violation.resolution_note
+ "id": violation.id,
+ "status": violation.status,
+ "resolved_at": (
+ violation.resolved_at.isoformat() if violation.resolved_at else None
+ ),
+ "resolved_by": violation.resolved_by,
+ "resolution_note": violation.resolution_note,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -402,7 +432,7 @@ async def add_comment(
violation_id: int,
request: AddCommentRequest,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ current_user: User = Depends(get_current_admin_api),
):
"""
Add comment to violation
@@ -414,15 +444,15 @@ async def add_comment(
db,
violation_id=violation_id,
user_id=current_user.id,
- comment=request.comment
+ comment=request.comment,
)
return {
- 'id': comment.id,
- 'violation_id': comment.violation_id,
- 'user_id': comment.user_id,
- 'comment': comment.comment,
- 'created_at': comment.created_at.isoformat()
+ "id": comment.id,
+ "violation_id": comment.violation_id,
+ "user_id": comment.user_id,
+ "comment": comment.comment,
+ "created_at": comment.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -430,8 +460,7 @@ async def add_comment(
@router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats(
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_admin_api)
+ db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api)
):
"""
Get dashboard statistics
diff --git a/app/api/v1/admin/content_pages.py b/app/api/v1/admin/content_pages.py
index 28d6c1b8..aceb521c 100644
--- a/app/api/v1/admin/content_pages.py
+++ b/app/api/v1/admin/content_pages.py
@@ -10,6 +10,7 @@ Platform administrators can:
import logging
from typing import List, Optional
+
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -26,24 +27,43 @@ logger = logging.getLogger(__name__)
# REQUEST/RESPONSE SCHEMAS
# ============================================================================
+
class ContentPageCreate(BaseModel):
"""Schema for creating a content page."""
- slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)")
+
+ slug: str = Field(
+ ...,
+ max_length=100,
+ description="URL-safe identifier (about, faq, contact, etc.)",
+ )
title: str = Field(..., max_length=200, description="Page title")
content: str = Field(..., description="HTML or Markdown content")
- content_format: str = Field(default="html", description="Content format: html or markdown")
- template: str = Field(default="default", max_length=50, description="Template name (default, minimal, modern)")
- meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description")
- meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords")
+ content_format: str = Field(
+ default="html", description="Content format: html or markdown"
+ )
+ template: str = Field(
+ default="default",
+ max_length=50,
+ description="Template name (default, minimal, modern)",
+ )
+ meta_description: Optional[str] = Field(
+ None, max_length=300, description="SEO meta description"
+ )
+ meta_keywords: Optional[str] = Field(
+ None, max_length=300, description="SEO keywords"
+ )
is_published: bool = Field(default=False, description="Publish immediately")
show_in_footer: bool = Field(default=True, description="Show in footer navigation")
show_in_header: bool = Field(default=False, description="Show in header navigation")
display_order: int = Field(default=0, description="Display order (lower = first)")
- vendor_id: Optional[int] = Field(None, description="Vendor ID (None for platform default)")
+ vendor_id: Optional[int] = Field(
+ None, description="Vendor ID (None for platform default)"
+ )
class ContentPageUpdate(BaseModel):
"""Schema for updating a content page."""
+
title: Optional[str] = Field(None, max_length=200)
content: Optional[str] = None
content_format: Optional[str] = None
@@ -58,6 +78,7 @@ class ContentPageUpdate(BaseModel):
class ContentPageResponse(BaseModel):
"""Schema for content page response."""
+
id: int
vendor_id: Optional[int]
vendor_name: Optional[str]
@@ -84,11 +105,12 @@ class ContentPageResponse(BaseModel):
# PLATFORM DEFAULT PAGES (vendor_id=NULL)
# ============================================================================
+
@router.get("/platform", response_model=List[ContentPageResponse])
def list_platform_pages(
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_admin_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
List all platform default content pages.
@@ -96,8 +118,7 @@ def list_platform_pages(
These are used as fallbacks when vendors haven't created custom pages.
"""
pages = content_page_service.list_all_platform_pages(
- db,
- include_unpublished=include_unpublished
+ db, include_unpublished=include_unpublished
)
return [page.to_dict() for page in pages]
@@ -107,7 +128,7 @@ def list_platform_pages(
def create_platform_page(
page_data: ContentPageCreate,
current_user: User = Depends(get_current_admin_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Create a new platform default content page.
@@ -129,7 +150,7 @@ def create_platform_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
- created_by=current_user.id
+ created_by=current_user.id,
)
return page.to_dict()
@@ -139,12 +160,13 @@ def create_platform_page(
# ALL CONTENT PAGES (Platform + Vendors)
# ============================================================================
+
@router.get("/", response_model=List[ContentPageResponse])
def list_all_pages(
vendor_id: Optional[int] = Query(None, description="Filter by vendor ID"),
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_admin_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
List all content pages (platform defaults and vendor overrides).
@@ -153,15 +175,14 @@ def list_all_pages(
"""
if vendor_id:
pages = content_page_service.list_all_vendor_pages(
- db,
- vendor_id=vendor_id,
- include_unpublished=include_unpublished
+ db, vendor_id=vendor_id, include_unpublished=include_unpublished
)
else:
# Get all pages (both platform and vendor)
- from models.database.content_page import ContentPage
from sqlalchemy import and_
+ from models.database.content_page import ContentPage
+
filters = []
if not include_unpublished:
filters.append(ContentPage.is_published == True)
@@ -169,7 +190,9 @@ def list_all_pages(
pages = (
db.query(ContentPage)
.filter(and_(*filters) if filters else True)
- .order_by(ContentPage.vendor_id, ContentPage.display_order, ContentPage.title)
+ .order_by(
+ ContentPage.vendor_id, ContentPage.display_order, ContentPage.title
+ )
.all()
)
@@ -180,7 +203,7 @@ def list_all_pages(
def get_page(
page_id: int,
current_user: User = Depends(get_current_admin_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""Get a specific content page by ID."""
page = content_page_service.get_page_by_id(db, page_id)
@@ -196,7 +219,7 @@ def update_page(
page_id: int,
page_data: ContentPageUpdate,
current_user: User = Depends(get_current_admin_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""Update a content page (platform or vendor)."""
page = content_page_service.update_page(
@@ -212,7 +235,7 @@ def update_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
- updated_by=current_user.id
+ updated_by=current_user.id,
)
if not page:
@@ -225,7 +248,7 @@ def update_page(
def delete_page(
page_id: int,
current_user: User = Depends(get_current_admin_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""Delete a content page."""
success = content_page_service.delete_page(db, page_id)
diff --git a/app/api/v1/admin/dashboard.py b/app/api/v1/admin/dashboard.py
index 38db202b..1ca77a50 100644
--- a/app/api/v1/admin/dashboard.py
+++ b/app/api/v1/admin/dashboard.py
@@ -5,6 +5,7 @@ Admin dashboard and statistics endpoints.
import logging
from typing import List
+
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -87,4 +88,4 @@ def get_platform_statistics(
"products": stats_service.get_product_statistics(db),
"orders": stats_service.get_order_statistics(db),
"imports": stats_service.get_import_statistics(db),
- }
\ No newline at end of file
+ }
diff --git a/app/api/v1/admin/marketplace.py b/app/api/v1/admin/marketplace.py
index 0fa5843b..b9456a53 100644
--- a/app/api/v1/admin/marketplace.py
+++ b/app/api/v1/admin/marketplace.py
@@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_service import admin_service
from app.services.stats_service import stats_service
-from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.database.user import User
+from models.schema.marketplace_import_job import MarketplaceImportJobResponse
router = APIRouter(prefix="/marketplace-import-jobs")
logger = logging.getLogger(__name__)
diff --git a/app/api/v1/admin/monitoring.py b/app/api/v1/admin/monitoring.py
index 35f4f24b..0b5699d2 100644
--- a/app/api/v1/admin/monitoring.py
+++ b/app/api/v1/admin/monitoring.py
@@ -1 +1 @@
-# Platform monitoring and alerts
+# Platform monitoring and alerts
diff --git a/app/api/v1/admin/notifications.py b/app/api/v1/admin/notifications.py
index 974c9b34..780885b7 100644
--- a/app/api/v1/admin/notifications.py
+++ b/app/api/v1/admin/notifications.py
@@ -16,16 +16,13 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
-from models.schema.admin import (
- AdminNotificationCreate,
- AdminNotificationResponse,
- AdminNotificationListResponse,
- PlatformAlertCreate,
- PlatformAlertResponse,
- PlatformAlertListResponse,
- PlatformAlertResolve
-)
from models.database.user import User
+from models.schema.admin import (AdminNotificationCreate,
+ AdminNotificationListResponse,
+ AdminNotificationResponse,
+ PlatformAlertCreate,
+ PlatformAlertListResponse,
+ PlatformAlertResolve, PlatformAlertResponse)
router = APIRouter(prefix="/notifications")
logger = logging.getLogger(__name__)
@@ -35,6 +32,7 @@ logger = logging.getLogger(__name__)
# ADMIN NOTIFICATIONS
# ============================================================================
+
@router.get("", response_model=AdminNotificationListResponse)
def get_notifications(
priority: Optional[str] = Query(None, description="Filter by priority"),
@@ -47,11 +45,7 @@ def get_notifications(
"""Get admin notifications with filtering."""
# TODO: Implement notification service
return AdminNotificationListResponse(
- notifications=[],
- total=0,
- unread_count=0,
- skip=skip,
- limit=limit
+ notifications=[], total=0, unread_count=0, skip=skip, limit=limit
)
@@ -90,10 +84,13 @@ def mark_all_as_read(
# PLATFORM ALERTS
# ============================================================================
+
@router.get("/alerts", response_model=PlatformAlertListResponse)
def get_platform_alerts(
severity: Optional[str] = Query(None, description="Filter by severity"),
- is_resolved: Optional[bool] = Query(None, description="Filter by resolution status"),
+ is_resolved: Optional[bool] = Query(
+ None, description="Filter by resolution status"
+ ),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100),
db: Session = Depends(get_db),
@@ -102,12 +99,7 @@ def get_platform_alerts(
"""Get platform alerts with filtering."""
# TODO: Implement alert service
return PlatformAlertListResponse(
- alerts=[],
- total=0,
- active_count=0,
- critical_count=0,
- skip=skip,
- limit=limit
+ alerts=[], total=0, active_count=0, critical_count=0, skip=skip, limit=limit
)
@@ -147,5 +139,5 @@ def get_alert_statistics(
"total_alerts": 0,
"active_alerts": 0,
"critical_alerts": 0,
- "resolved_today": 0
+ "resolved_today": 0,
}
diff --git a/app/api/v1/admin/settings.py b/app/api/v1/admin/settings.py
index 0fd2297f..025b2a3c 100644
--- a/app/api/v1/admin/settings.py
+++ b/app/api/v1/admin/settings.py
@@ -16,15 +16,11 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
-from app.services.admin_settings_service import admin_settings_service
from app.services.admin_audit_service import admin_audit_service
-from models.schema.admin import (
- AdminSettingCreate,
- AdminSettingResponse,
- AdminSettingUpdate,
- AdminSettingListResponse
-)
+from app.services.admin_settings_service import admin_settings_service
from models.database.user import User
+from models.schema.admin import (AdminSettingCreate, AdminSettingListResponse,
+ AdminSettingResponse, AdminSettingUpdate)
router = APIRouter(prefix="/settings")
logger = logging.getLogger(__name__)
@@ -32,10 +28,10 @@ logger = logging.getLogger(__name__)
@router.get("", response_model=AdminSettingListResponse)
def get_all_settings(
- category: Optional[str] = Query(None, description="Filter by category"),
- is_public: Optional[bool] = Query(None, description="Filter by public flag"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ category: Optional[str] = Query(None, description="Filter by category"),
+ is_public: Optional[bool] = Query(None, description="Filter by public flag"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get all platform settings.
@@ -46,16 +42,14 @@ def get_all_settings(
settings = admin_settings_service.get_all_settings(db, category, is_public)
return AdminSettingListResponse(
- settings=settings,
- total=len(settings),
- category=category
+ settings=settings, total=len(settings), category=category
)
@router.get("/categories")
def get_setting_categories(
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Get list of all setting categories."""
# This could be enhanced to return counts per category
@@ -66,22 +60,23 @@ def get_setting_categories(
"marketplace",
"notifications",
"integrations",
- "payments"
+ "payments",
]
}
@router.get("/{key}", response_model=AdminSettingResponse)
def get_setting(
- key: str,
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ key: str,
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Get specific setting by key."""
setting = admin_settings_service.get_setting_by_key(db, key)
if not setting:
from fastapi import HTTPException
+
raise HTTPException(status_code=404, detail=f"Setting '{key}' not found")
return AdminSettingResponse.model_validate(setting)
@@ -89,9 +84,9 @@ def get_setting(
@router.post("", response_model=AdminSettingResponse)
def create_setting(
- setting_data: AdminSettingCreate,
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ setting_data: AdminSettingCreate,
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Create new platform setting.
@@ -99,9 +94,7 @@ def create_setting(
Setting keys should be lowercase with underscores (e.g., max_vendors_allowed).
"""
result = admin_settings_service.create_setting(
- db=db,
- setting_data=setting_data,
- admin_user_id=current_admin.id
+ db=db, setting_data=setting_data, admin_user_id=current_admin.id
)
# Log action
@@ -111,7 +104,10 @@ def create_setting(
action="create_setting",
target_type="setting",
target_id=setting_data.key,
- details={"category": setting_data.category, "value_type": setting_data.value_type}
+ details={
+ "category": setting_data.category,
+ "value_type": setting_data.value_type,
+ },
)
return result
@@ -119,19 +115,16 @@ def create_setting(
@router.put("/{key}", response_model=AdminSettingResponse)
def update_setting(
- key: str,
- update_data: AdminSettingUpdate,
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ key: str,
+ update_data: AdminSettingUpdate,
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Update existing setting value."""
old_value = admin_settings_service.get_setting_value(db, key)
result = admin_settings_service.update_setting(
- db=db,
- key=key,
- update_data=update_data,
- admin_user_id=current_admin.id
+ db=db, key=key, update_data=update_data, admin_user_id=current_admin.id
)
# Log action
@@ -141,7 +134,7 @@ def update_setting(
action="update_setting",
target_type="setting",
target_id=key,
- details={"old_value": str(old_value), "new_value": update_data.value}
+ details={"old_value": str(old_value), "new_value": update_data.value},
)
return result
@@ -149,9 +142,9 @@ def update_setting(
@router.post("/upsert", response_model=AdminSettingResponse)
def upsert_setting(
- setting_data: AdminSettingCreate,
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ setting_data: AdminSettingCreate,
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Create or update setting (upsert).
@@ -159,9 +152,7 @@ def upsert_setting(
If setting exists, updates its value. If not, creates new setting.
"""
result = admin_settings_service.upsert_setting(
- db=db,
- setting_data=setting_data,
- admin_user_id=current_admin.id
+ db=db, setting_data=setting_data, admin_user_id=current_admin.id
)
# Log action
@@ -171,7 +162,7 @@ def upsert_setting(
action="upsert_setting",
target_type="setting",
target_id=setting_data.key,
- details={"category": setting_data.category}
+ details={"category": setting_data.category},
)
return result
@@ -179,10 +170,10 @@ def upsert_setting(
@router.delete("/{key}")
def delete_setting(
- key: str,
- confirm: bool = Query(False, description="Must be true to confirm deletion"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ key: str,
+ confirm: bool = Query(False, description="Must be true to confirm deletion"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Delete platform setting.
@@ -195,13 +186,11 @@ def delete_setting(
if not confirm:
raise HTTPException(
status_code=400,
- detail="Deletion requires confirmation parameter: confirm=true"
+ detail="Deletion requires confirmation parameter: confirm=true",
)
message = admin_settings_service.delete_setting(
- db=db,
- key=key,
- admin_user_id=current_admin.id
+ db=db, key=key, admin_user_id=current_admin.id
)
# Log action
@@ -211,7 +200,7 @@ def delete_setting(
action="delete_setting",
target_type="setting",
target_id=key,
- details={}
+ details={},
)
return {"message": message}
diff --git a/app/api/v1/admin/users.py b/app/api/v1/admin/users.py
index 5c4d968c..2cadd850 100644
--- a/app/api/v1/admin/users.py
+++ b/app/api/v1/admin/users.py
@@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_service import admin_service
from app.services.stats_service import stats_service
-from models.schema.auth import UserResponse
from models.database.user import User
+from models.schema.auth import UserResponse
router = APIRouter(prefix="/users")
logger = logging.getLogger(__name__)
diff --git a/app/api/v1/admin/vendor_domains.py b/app/api/v1/admin/vendor_domains.py
index a01604f3..c33e1786 100644
--- a/app/api/v1/admin/vendor_domains.py
+++ b/app/api/v1/admin/vendor_domains.py
@@ -12,24 +12,22 @@ Follows the architecture pattern:
import logging
from typing import List
-from fastapi import APIRouter, Depends, Path, Body, Query
+from fastapi import APIRouter, Body, Depends, Path, Query
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
-from app.services.vendor_domain_service import vendor_domain_service
from app.exceptions import VendorNotFoundException
-from models.schema.vendor_domain import (
- VendorDomainCreate,
- VendorDomainUpdate,
- VendorDomainResponse,
- VendorDomainListResponse,
- DomainVerificationInstructions,
- DomainVerificationResponse,
- DomainDeletionResponse,
-)
+from app.services.vendor_domain_service import vendor_domain_service
from models.database.user import User
from models.database.vendor import Vendor
+from models.schema.vendor_domain import (DomainDeletionResponse,
+ DomainVerificationInstructions,
+ DomainVerificationResponse,
+ VendorDomainCreate,
+ VendorDomainListResponse,
+ VendorDomainResponse,
+ VendorDomainUpdate)
router = APIRouter(prefix="/vendors")
logger = logging.getLogger(__name__)
@@ -57,10 +55,10 @@ def _get_vendor_by_id(db: Session, vendor_id: int) -> Vendor:
@router.post("/{vendor_id}/domains", response_model=VendorDomainResponse)
def add_vendor_domain(
- vendor_id: int = Path(..., description="Vendor ID", gt=0),
- domain_data: VendorDomainCreate = Body(...),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_id: int = Path(..., description="Vendor ID", gt=0),
+ domain_data: VendorDomainCreate = Body(...),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Add a custom domain to vendor (Admin only).
@@ -88,9 +86,7 @@ def add_vendor_domain(
- 422: Invalid domain format or reserved subdomain
"""
domain = vendor_domain_service.add_domain(
- db=db,
- vendor_id=vendor_id,
- domain_data=domain_data
+ db=db, vendor_id=vendor_id, domain_data=domain_data
)
return VendorDomainResponse(
@@ -111,9 +107,9 @@ def add_vendor_domain(
@router.get("/{vendor_id}/domains", response_model=VendorDomainListResponse)
def list_vendor_domains(
- vendor_id: int = Path(..., description="Vendor ID", gt=0),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_id: int = Path(..., description="Vendor ID", gt=0),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
List all domains for a vendor (Admin only).
@@ -148,15 +144,15 @@ def list_vendor_domains(
)
for d in domains
],
- total=len(domains)
+ total=len(domains),
)
@router.get("/domains/{domain_id}", response_model=VendorDomainResponse)
def get_domain_details(
- domain_id: int = Path(..., description="Domain ID", gt=0),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ domain_id: int = Path(..., description="Domain ID", gt=0),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get detailed information about a specific domain (Admin only).
@@ -174,7 +170,9 @@ def get_domain_details(
is_active=domain.is_active,
is_verified=domain.is_verified,
ssl_status=domain.ssl_status,
- verification_token=domain.verification_token if not domain.is_verified else None,
+ verification_token=(
+ domain.verification_token if not domain.is_verified else None
+ ),
verified_at=domain.verified_at,
ssl_verified_at=domain.ssl_verified_at,
created_at=domain.created_at,
@@ -184,10 +182,10 @@ def get_domain_details(
@router.put("/domains/{domain_id}", response_model=VendorDomainResponse)
def update_vendor_domain(
- domain_id: int = Path(..., description="Domain ID", gt=0),
- domain_update: VendorDomainUpdate = Body(...),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ domain_id: int = Path(..., description="Domain ID", gt=0),
+ domain_update: VendorDomainUpdate = Body(...),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Update domain settings (Admin only).
@@ -206,9 +204,7 @@ def update_vendor_domain(
- 400: Cannot activate unverified domain
"""
domain = vendor_domain_service.update_domain(
- db=db,
- domain_id=domain_id,
- domain_update=domain_update
+ db=db, domain_id=domain_id, domain_update=domain_update
)
return VendorDomainResponse(
@@ -229,9 +225,9 @@ def update_vendor_domain(
@router.delete("/domains/{domain_id}", response_model=DomainDeletionResponse)
def delete_vendor_domain(
- domain_id: int = Path(..., description="Domain ID", gt=0),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ domain_id: int = Path(..., description="Domain ID", gt=0),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Delete a custom domain (Admin only).
@@ -250,17 +246,15 @@ def delete_vendor_domain(
message = vendor_domain_service.delete_domain(db, domain_id)
return DomainDeletionResponse(
- message=message,
- domain=domain_name,
- vendor_id=vendor_id
+ message=message, domain=domain_name, vendor_id=vendor_id
)
@router.post("/domains/{domain_id}/verify", response_model=DomainVerificationResponse)
def verify_domain_ownership(
- domain_id: int = Path(..., description="Domain ID", gt=0),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ domain_id: int = Path(..., description="Domain ID", gt=0),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Verify domain ownership via DNS TXT record (Admin only).
@@ -290,15 +284,18 @@ def verify_domain_ownership(
message=message,
domain=domain.domain,
verified_at=domain.verified_at,
- is_verified=domain.is_verified
+ is_verified=domain.is_verified,
)
-@router.get("/domains/{domain_id}/verification-instructions", response_model=DomainVerificationInstructions)
+@router.get(
+ "/domains/{domain_id}/verification-instructions",
+ response_model=DomainVerificationInstructions,
+)
def get_domain_verification_instructions(
- domain_id: int = Path(..., description="Domain ID", gt=0),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ domain_id: int = Path(..., description="Domain ID", gt=0),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get DNS verification instructions for domain (Admin only).
@@ -324,5 +321,5 @@ def get_domain_verification_instructions(
verification_token=instructions["verification_token"],
instructions=instructions["instructions"],
txt_record=instructions["txt_record"],
- common_registrars=instructions["common_registrars"]
+ common_registrars=instructions["common_registrars"],
)
diff --git a/app/api/v1/admin/vendor_themes.py b/app/api/v1/admin/vendor_themes.py
index 7a4b51fc..3a0fa06e 100644
--- a/app/api/v1/admin/vendor_themes.py
+++ b/app/api/v1/admin/vendor_themes.py
@@ -20,11 +20,8 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api, get_db
from app.services.vendor_theme_service import vendor_theme_service
from models.database.user import User
-from models.schema.vendor_theme import (
- VendorThemeResponse,
- VendorThemeUpdate,
- ThemePresetListResponse
-)
+from models.schema.vendor_theme import (ThemePresetListResponse,
+ VendorThemeResponse, VendorThemeUpdate)
router = APIRouter(prefix="/vendor-themes")
logger = logging.getLogger(__name__)
@@ -34,10 +31,9 @@ logger = logging.getLogger(__name__)
# PRESET ENDPOINTS
# ============================================================================
+
@router.get("/presets", response_model=ThemePresetListResponse)
-async def get_theme_presets(
- current_admin: User = Depends(get_current_admin_api)
-):
+async def get_theme_presets(current_admin: User = Depends(get_current_admin_api)):
"""
Get all available theme presets with preview information.
@@ -59,11 +55,12 @@ async def get_theme_presets(
# THEME RETRIEVAL
# ============================================================================
+
@router.get("/{vendor_code}", response_model=VendorThemeResponse)
async def get_vendor_theme(
- vendor_code: str = Path(..., description="Vendor code"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api)
+ vendor_code: str = Path(..., description="Vendor code"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get theme configuration for a vendor.
@@ -93,12 +90,13 @@ async def get_vendor_theme(
# THEME UPDATE
# ============================================================================
+
@router.put("/{vendor_code}", response_model=VendorThemeResponse)
async def update_vendor_theme(
- vendor_code: str = Path(..., description="Vendor code"),
- theme_data: VendorThemeUpdate = None,
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api)
+ vendor_code: str = Path(..., description="Vendor code"),
+ theme_data: VendorThemeUpdate = None,
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Update or create theme for a vendor.
@@ -140,12 +138,13 @@ async def update_vendor_theme(
# PRESET APPLICATION
# ============================================================================
+
@router.post("/{vendor_code}/preset/{preset_name}")
async def apply_theme_preset(
- vendor_code: str = Path(..., description="Vendor code"),
- preset_name: str = Path(..., description="Preset name"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api)
+ vendor_code: str = Path(..., description="Vendor code"),
+ preset_name: str = Path(..., description="Preset name"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Apply a theme preset to a vendor.
@@ -184,7 +183,7 @@ async def apply_theme_preset(
return {
"message": f"Applied {preset_name} preset successfully",
- "theme": theme.to_dict()
+ "theme": theme.to_dict(),
}
@@ -192,11 +191,12 @@ async def apply_theme_preset(
# THEME DELETION
# ============================================================================
+
@router.delete("/{vendor_code}")
async def delete_vendor_theme(
- vendor_code: str = Path(..., description="Vendor code"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api)
+ vendor_code: str = Path(..., description="Vendor code"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Delete custom theme for a vendor.
diff --git a/app/api/v1/admin/vendors.py b/app/api/v1/admin/vendors.py
index e2db6b88..58b1f615 100644
--- a/app/api/v1/admin/vendors.py
+++ b/app/api/v1/admin/vendors.py
@@ -6,28 +6,24 @@ Vendor management endpoints for admin.
import logging
from typing import Optional
-from fastapi import APIRouter, Depends, Query, Path, Body
+from fastapi import APIRouter, Body, Depends, Path, Query
+from sqlalchemy import func
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
+from app.exceptions import (ConfirmationRequiredException,
+ VendorNotFoundException)
from app.services.admin_service import admin_service
from app.services.stats_service import stats_service
-from app.exceptions import VendorNotFoundException, ConfirmationRequiredException
-from models.schema.stats import VendorStatsResponse
-from models.schema.vendor import (
- VendorListResponse,
- VendorResponse,
- VendorDetailResponse,
- VendorCreate,
- VendorCreateResponse,
- VendorUpdate,
- VendorTransferOwnership,
- VendorTransferOwnershipResponse,
-)
from models.database.user import User
from models.database.vendor import Vendor
-from sqlalchemy import func
+from models.schema.stats import VendorStatsResponse
+from models.schema.vendor import (VendorCreate, VendorCreateResponse,
+ VendorDetailResponse, VendorListResponse,
+ VendorResponse, VendorTransferOwnership,
+ VendorTransferOwnershipResponse,
+ VendorUpdate)
router = APIRouter(prefix="/vendors")
logger = logging.getLogger(__name__)
@@ -60,9 +56,11 @@ def _get_vendor_by_identifier(db: Session, identifier: str) -> Vendor:
pass
# Try as vendor_code (case-insensitive)
- vendor = db.query(Vendor).filter(
- func.upper(Vendor.vendor_code) == identifier.upper()
- ).first()
+ vendor = (
+ db.query(Vendor)
+ .filter(func.upper(Vendor.vendor_code) == identifier.upper())
+ .first()
+ )
if not vendor:
raise VendorNotFoundException(identifier, identifier_type="code")
@@ -72,9 +70,9 @@ def _get_vendor_by_identifier(db: Session, identifier: str) -> Vendor:
@router.post("", response_model=VendorCreateResponse)
def create_vendor_with_owner(
- vendor_data: VendorCreate,
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_data: VendorCreate,
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Create a new vendor with owner user account (Admin only).
@@ -93,8 +91,7 @@ def create_vendor_with_owner(
Returns vendor details with owner credentials.
"""
vendor, owner_user, temp_password = admin_service.create_vendor_with_owner(
- db=db,
- vendor_data=vendor_data
+ db=db, vendor_data=vendor_data
)
return VendorCreateResponse(
@@ -121,19 +118,19 @@ def create_vendor_with_owner(
owner_email=owner_user.email,
owner_username=owner_user.username,
temporary_password=temp_password,
- login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login"
+ login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login",
)
@router.get("", response_model=VendorListResponse)
def get_all_vendors_admin(
- skip: int = Query(0, ge=0),
- limit: int = Query(100, ge=1, le=1000),
- search: Optional[str] = Query(None, description="Search by name or vendor code"),
- is_active: Optional[bool] = Query(None),
- is_verified: Optional[bool] = Query(None),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(100, ge=1, le=1000),
+ search: Optional[str] = Query(None, description="Search by name or vendor code"),
+ is_active: Optional[bool] = Query(None),
+ is_verified: Optional[bool] = Query(None),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Get all vendors with filtering (Admin only)."""
vendors, total = admin_service.get_all_vendors(
@@ -142,15 +139,15 @@ def get_all_vendors_admin(
limit=limit,
search=search,
is_active=is_active,
- is_verified=is_verified
+ is_verified=is_verified,
)
return VendorListResponse(vendors=vendors, total=total, skip=skip, limit=limit)
@router.get("/stats", response_model=VendorStatsResponse)
def get_vendor_statistics_endpoint(
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""Get vendor statistics for admin dashboard (Admin only)."""
stats = stats_service.get_vendor_statistics(db)
@@ -165,9 +162,9 @@ def get_vendor_statistics_endpoint(
@router.get("/{vendor_identifier}", response_model=VendorDetailResponse)
def get_vendor_details(
- vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Get detailed vendor information including owner details (Admin only).
@@ -208,10 +205,10 @@ def get_vendor_details(
@router.put("/{vendor_identifier}", response_model=VendorDetailResponse)
def update_vendor(
- vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
- vendor_update: VendorUpdate = Body(...),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
+ vendor_update: VendorUpdate = Body(...),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Update vendor information (Admin only).
@@ -257,12 +254,15 @@ def update_vendor(
)
-@router.post("/{vendor_identifier}/transfer-ownership", response_model=VendorTransferOwnershipResponse)
+@router.post(
+ "/{vendor_identifier}/transfer-ownership",
+ response_model=VendorTransferOwnershipResponse,
+)
def transfer_vendor_ownership(
- vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
- transfer_data: VendorTransferOwnership = Body(...),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
+ transfer_data: VendorTransferOwnership = Body(...),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Transfer vendor ownership to another user (Admin only).
@@ -311,10 +311,10 @@ def transfer_vendor_ownership(
@router.put("/{vendor_identifier}/verification", response_model=VendorDetailResponse)
def toggle_vendor_verification(
- vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
- verification_data: dict = Body(..., example={"is_verified": True}),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
+ verification_data: dict = Body(..., example={"is_verified": True}),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Toggle vendor verification status (Admin only).
@@ -362,10 +362,10 @@ def toggle_vendor_verification(
@router.put("/{vendor_identifier}/status", response_model=VendorDetailResponse)
def toggle_vendor_status(
- vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
- status_data: dict = Body(..., example={"is_active": True}),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
+ status_data: dict = Body(..., example={"is_active": True}),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Toggle vendor active status (Admin only).
@@ -413,10 +413,10 @@ def toggle_vendor_status(
@router.delete("/{vendor_identifier}")
def delete_vendor(
- vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
- confirm: bool = Query(False, description="Must be true to confirm deletion"),
- db: Session = Depends(get_db),
- current_admin: User = Depends(get_current_admin_api),
+ vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
+ confirm: bool = Query(False, description="Must be true to confirm deletion"),
+ db: Session = Depends(get_db),
+ current_admin: User = Depends(get_current_admin_api),
):
"""
Delete vendor and all associated data (Admin only).
@@ -436,7 +436,7 @@ def delete_vendor(
if not confirm:
raise ConfirmationRequiredException(
operation="delete_vendor",
- message="Deletion requires confirmation parameter: confirm=true"
+ message="Deletion requires confirmation parameter: confirm=true",
)
vendor = _get_vendor_by_identifier(db, vendor_identifier)
diff --git a/app/api/v1/shared/health.py b/app/api/v1/shared/health.py
index 069c2bac..0b438ce2 100644
--- a/app/api/v1/shared/health.py
+++ b/app/api/v1/shared/health.py
@@ -1 +1 @@
-# Health checks
+# Health checks
diff --git a/app/api/v1/shared/uploads.py b/app/api/v1/shared/uploads.py
index 9a4e1aa0..e573f17b 100644
--- a/app/api/v1/shared/uploads.py
+++ b/app/api/v1/shared/uploads.py
@@ -1 +1 @@
-# File upload handling
+# File upload handling
diff --git a/app/api/v1/shop/__init__.py b/app/api/v1/shop/__init__.py
index f0e93ed4..23d483bd 100644
--- a/app/api/v1/shop/__init__.py
+++ b/app/api/v1/shop/__init__.py
@@ -21,7 +21,7 @@ Authentication:
from fastapi import APIRouter
# Import shop routers
-from . import products, cart, orders, auth, content_pages
+from . import auth, cart, content_pages, orders, products
# Create shop router
router = APIRouter()
@@ -43,6 +43,8 @@ router.include_router(cart.router, tags=["shop-cart"])
router.include_router(orders.router, tags=["shop-orders"])
# Content pages (public)
-router.include_router(content_pages.router, prefix="/content-pages", tags=["shop-content-pages"])
+router.include_router(
+ content_pages.router, prefix="/content-pages", tags=["shop-content-pages"]
+)
__all__ = ["router"]
diff --git a/app/api/v1/shop/auth.py b/app/api/v1/shop/auth.py
index f9d50b43..0c3365d8 100644
--- a/app/api/v1/shop/auth.py
+++ b/app/api/v1/shop/auth.py
@@ -15,15 +15,16 @@ This prevents:
"""
import logging
-from fastapi import APIRouter, Depends, Response, Request, HTTPException
+
+from fastapi import APIRouter, Depends, HTTPException, Request, Response
+from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import get_db
+from app.core.environment import should_use_secure_cookies
from app.services.customer_service import customer_service
from models.schema.auth import UserLogin
from models.schema.customer import CustomerRegister, CustomerResponse
-from app.core.environment import should_use_secure_cookies
-from pydantic import BaseModel
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
# Response model for customer login
class CustomerLoginResponse(BaseModel):
"""Customer login response with token and customer data."""
+
access_token: str
token_type: str
expires_in: int
@@ -40,9 +42,7 @@ class CustomerLoginResponse(BaseModel):
@router.post("/auth/register", response_model=CustomerResponse)
def register_customer(
- request: Request,
- customer_data: CustomerRegister,
- db: Session = Depends(get_db)
+ request: Request, customer_data: CustomerRegister, db: Session = Depends(get_db)
):
"""
Register a new customer for current vendor.
@@ -59,12 +59,12 @@ def register_customer(
- phone: Customer phone number (optional)
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -73,14 +73,12 @@ def register_customer(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"email": customer_data.email,
- }
+ },
)
# Create customer account
customer = customer_service.register_customer(
- db=db,
- vendor_id=vendor.id,
- customer_data=customer_data
+ db=db, vendor_id=vendor.id, customer_data=customer_data
)
logger.info(
@@ -89,7 +87,7 @@ def register_customer(
"customer_id": customer.id,
"vendor_id": vendor.id,
"email": customer.email,
- }
+ },
)
return CustomerResponse.model_validate(customer)
@@ -100,7 +98,7 @@ def customer_login(
request: Request,
user_credentials: UserLogin,
response: Response,
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Customer login for current vendor.
@@ -121,12 +119,12 @@ def customer_login(
- password: Customer password
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -135,33 +133,39 @@ def customer_login(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"email_or_username": user_credentials.email_or_username,
- }
+ },
)
# Authenticate customer
login_result = customer_service.login_customer(
- db=db,
- vendor_id=vendor.id,
- credentials=user_credentials
+ db=db, vendor_id=vendor.id, credentials=user_credentials
)
logger.info(
f"Customer login successful: {login_result['customer'].email} for vendor {vendor.subdomain}",
extra={
- "customer_id": login_result['customer'].id,
+ "customer_id": login_result["customer"].id,
"vendor_id": vendor.id,
- "email": login_result['customer'].email,
- }
+ "email": login_result["customer"].email,
+ },
)
# Calculate cookie path based on vendor access method
- vendor_context = getattr(request.state, 'vendor_context', None)
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ vendor_context = getattr(request.state, "vendor_context", None)
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
cookie_path = "/shop" # Default for domain/subdomain access
if access_method == "path":
# For path-based access like /vendors/wizamart/shop
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
cookie_path = f"{full_prefix}{vendor.subdomain}/shop"
# Set HTTP-only cookie for browser navigation
@@ -180,10 +184,10 @@ def customer_login(
f"Set customer_token cookie with {login_result['token_data']['expires_in']}s expiry "
f"(path={cookie_path}, httponly=True, secure={should_use_secure_cookies()})",
extra={
- "expires_in": login_result['token_data']['expires_in'],
+ "expires_in": login_result["token_data"]["expires_in"],
"secure": should_use_secure_cookies(),
"cookie_path": cookie_path,
- }
+ },
)
# Return full login response
@@ -196,10 +200,7 @@ def customer_login(
@router.post("/auth/logout")
-def customer_logout(
- request: Request,
- response: Response
-):
+def customer_logout(request: Request, response: Response):
"""
Customer logout for current vendor.
@@ -208,24 +209,32 @@ def customer_logout(
Client should also remove token from localStorage.
"""
# Get vendor from middleware (for logging)
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
logger.info(
f"Customer logout for vendor {vendor.subdomain if vendor else 'unknown'}",
extra={
"vendor_id": vendor.id if vendor else None,
"vendor_code": vendor.subdomain if vendor else None,
- }
+ },
)
# Calculate cookie path based on vendor access method (must match login)
- vendor_context = getattr(request.state, 'vendor_context', None)
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ vendor_context = getattr(request.state, "vendor_context", None)
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
cookie_path = "/shop" # Default for domain/subdomain access
if access_method == "path" and vendor:
# For path-based access like /vendors/wizamart/shop
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
cookie_path = f"{full_prefix}{vendor.subdomain}/shop"
# Clear the cookie (must match path used when setting)
@@ -240,11 +249,7 @@ def customer_logout(
@router.post("/auth/forgot-password")
-def forgot_password(
- request: Request,
- email: str,
- db: Session = Depends(get_db)
-):
+def forgot_password(request: Request, email: str, db: Session = Depends(get_db)):
"""
Request password reset for customer.
@@ -255,12 +260,12 @@ def forgot_password(
- email: Customer email address
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -269,7 +274,7 @@ def forgot_password(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"email": email,
- }
+ },
)
# TODO: Implement password reset functionality
@@ -278,9 +283,7 @@ def forgot_password(
# - Send reset email to customer
# - Return success message (don't reveal if email exists)
- logger.info(
- f"Password reset requested for {email} (vendor: {vendor.subdomain})"
- )
+ logger.info(f"Password reset requested for {email} (vendor: {vendor.subdomain})")
return {
"message": "If an account exists with this email, a password reset link has been sent."
@@ -289,10 +292,7 @@ def forgot_password(
@router.post("/auth/reset-password")
def reset_password(
- request: Request,
- reset_token: str,
- new_password: str,
- db: Session = Depends(get_db)
+ request: Request, reset_token: str, new_password: str, db: Session = Depends(get_db)
):
"""
Reset customer password using reset token.
@@ -304,12 +304,12 @@ def reset_password(
- new_password: New password
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -317,7 +317,7 @@ def reset_password(
extra={
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
- }
+ },
)
# TODO: Implement password reset
@@ -327,9 +327,7 @@ def reset_password(
# - Invalidate reset token
# - Return success
- logger.info(
- f"Password reset completed (vendor: {vendor.subdomain})"
- )
+ logger.info(f"Password reset completed (vendor: {vendor.subdomain})")
return {
"message": "Password reset successfully. You can now log in with your new password."
diff --git a/app/api/v1/shop/cart.py b/app/api/v1/shop/cart.py
index 35520670..ac51667c 100644
--- a/app/api/v1/shop/cart.py
+++ b/app/api/v1/shop/cart.py
@@ -8,18 +8,15 @@ No authentication required - uses session ID for cart tracking.
"""
import logging
-from fastapi import APIRouter, Depends, Path, Body, Request, HTTPException
+
+from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.services.cart_service import cart_service
-from models.schema.cart import (
- AddToCartRequest,
- UpdateCartItemRequest,
- CartResponse,
- CartOperationResponse,
- ClearCartResponse,
-)
+from models.schema.cart import (AddToCartRequest, CartOperationResponse,
+ CartResponse, ClearCartResponse,
+ UpdateCartItemRequest)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -29,6 +26,7 @@ logger = logging.getLogger(__name__)
# CART ENDPOINTS
# ============================================================================
+
@router.get("/cart/{session_id}", response_model=CartResponse)
def get_cart(
request: Request,
@@ -45,12 +43,12 @@ def get_cart(
- session_id: Unique session identifier for the cart
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.info(
@@ -59,23 +57,19 @@ def get_cart(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"session_id": session_id,
- }
+ },
)
- cart = cart_service.get_cart(
- db=db,
- vendor_id=vendor.id,
- session_id=session_id
- )
+ cart = cart_service.get_cart(db=db, vendor_id=vendor.id, session_id=session_id)
logger.info(
f"[SHOP_API] get_cart result: {len(cart.get('items', []))} items in cart",
extra={
"session_id": session_id,
"vendor_id": vendor.id,
- "item_count": len(cart.get('items', [])),
- "total": cart.get('total', 0),
- }
+ "item_count": len(cart.get("items", [])),
+ "total": cart.get("total", 0),
+ },
)
return CartResponse.from_service_dict(cart)
@@ -102,12 +96,12 @@ def add_to_cart(
- quantity: Quantity to add (default: 1)
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.info(
@@ -118,7 +112,7 @@ def add_to_cart(
"session_id": session_id,
"product_id": cart_data.product_id,
"quantity": cart_data.quantity,
- }
+ },
)
result = cart_service.add_to_cart(
@@ -126,7 +120,7 @@ def add_to_cart(
vendor_id=vendor.id,
session_id=session_id,
product_id=cart_data.product_id,
- quantity=cart_data.quantity
+ quantity=cart_data.quantity,
)
logger.info(
@@ -134,13 +128,15 @@ def add_to_cart(
extra={
"session_id": session_id,
"result": result,
- }
+ },
)
return CartOperationResponse(**result)
-@router.put("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse)
+@router.put(
+ "/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse
+)
def update_cart_item(
request: Request,
session_id: str = Path(..., description="Shopping session ID"),
@@ -162,12 +158,12 @@ def update_cart_item(
- quantity: New quantity (must be >= 1)
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -178,7 +174,7 @@ def update_cart_item(
"session_id": session_id,
"product_id": product_id,
"quantity": cart_data.quantity,
- }
+ },
)
result = cart_service.update_cart_item(
@@ -186,13 +182,15 @@ def update_cart_item(
vendor_id=vendor.id,
session_id=session_id,
product_id=product_id,
- quantity=cart_data.quantity
+ quantity=cart_data.quantity,
)
return CartOperationResponse(**result)
-@router.delete("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse)
+@router.delete(
+ "/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse
+)
def remove_from_cart(
request: Request,
session_id: str = Path(..., description="Shopping session ID"),
@@ -210,12 +208,12 @@ def remove_from_cart(
- product_id: ID of product to remove
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -225,14 +223,11 @@ def remove_from_cart(
"vendor_code": vendor.subdomain,
"session_id": session_id,
"product_id": product_id,
- }
+ },
)
result = cart_service.remove_from_cart(
- db=db,
- vendor_id=vendor.id,
- session_id=session_id,
- product_id=product_id
+ db=db, vendor_id=vendor.id, session_id=session_id, product_id=product_id
)
return CartOperationResponse(**result)
@@ -254,12 +249,12 @@ def clear_cart(
- session_id: Unique session identifier for the cart
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -268,13 +263,9 @@ def clear_cart(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"session_id": session_id,
- }
+ },
)
- result = cart_service.clear_cart(
- db=db,
- vendor_id=vendor.id,
- session_id=session_id
- )
+ result = cart_service.clear_cart(db=db, vendor_id=vendor.id, session_id=session_id)
return ClearCartResponse(**result)
diff --git a/app/api/v1/shop/content_pages.py b/app/api/v1/shop/content_pages.py
index 1c228911..48e19d64 100644
--- a/app/api/v1/shop/content_pages.py
+++ b/app/api/v1/shop/content_pages.py
@@ -8,6 +8,7 @@ No authentication required.
import logging
from typing import List
+
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -23,8 +24,10 @@ logger = logging.getLogger(__name__)
# RESPONSE SCHEMAS
# ============================================================================
+
class PublicContentPageResponse(BaseModel):
"""Public content page response (no internal IDs)."""
+
slug: str
title: str
content: str
@@ -36,6 +39,7 @@ class PublicContentPageResponse(BaseModel):
class ContentPageListItem(BaseModel):
"""Content page list item for navigation."""
+
slug: str
title: str
show_in_footer: bool
@@ -47,25 +51,21 @@ class ContentPageListItem(BaseModel):
# PUBLIC ENDPOINTS
# ============================================================================
+
@router.get("/navigation", response_model=List[ContentPageListItem])
-def get_navigation_pages(
- request: Request,
- db: Session = Depends(get_db)
-):
+def get_navigation_pages(request: Request, db: Session = Depends(get_db)):
"""
Get list of content pages for navigation (footer/header).
Uses vendor from request.state (set by middleware).
Returns vendor overrides + platform defaults.
"""
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
# Get all published pages for this vendor
pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=vendor_id,
- include_unpublished=False
+ db, vendor_id=vendor_id, include_unpublished=False
)
return [
@@ -81,25 +81,21 @@ def get_navigation_pages(
@router.get("/{slug}", response_model=PublicContentPageResponse)
-def get_content_page(
- slug: str,
- request: Request,
- db: Session = Depends(get_db)
-):
+def get_content_page(slug: str, request: Request, db: Session = Depends(get_db)):
"""
Get a specific content page by slug.
Uses vendor from request.state (set by middleware).
Returns vendor override if exists, otherwise platform default.
"""
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=vendor_id,
- include_unpublished=False # Only show published pages
+ include_unpublished=False, # Only show published pages
)
if not page:
diff --git a/app/api/v1/shop/orders.py b/app/api/v1/shop/orders.py
index 3523e16e..74ac25bf 100644
--- a/app/api/v1/shop/orders.py
+++ b/app/api/v1/shop/orders.py
@@ -10,31 +10,23 @@ Requires customer authentication for most operations.
import logging
from typing import Optional
-from fastapi import APIRouter, Depends, Path, Query, Request, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
from sqlalchemy.orm import Session
-from app.core.database import get_db
from app.api.deps import get_current_customer_api
-from app.services.order_service import order_service
+from app.core.database import get_db
from app.services.customer_service import customer_service
-from models.schema.order import (
- OrderCreate,
- OrderResponse,
- OrderDetailResponse,
- OrderListResponse
-)
-from models.database.user import User
+from app.services.order_service import order_service
from models.database.customer import Customer
+from models.database.user import User
+from models.schema.order import (OrderCreate, OrderDetailResponse,
+ OrderListResponse, OrderResponse)
router = APIRouter()
logger = logging.getLogger(__name__)
-def get_customer_from_user(
- request: Request,
- user: User,
- db: Session
-) -> Customer:
+def get_customer_from_user(request: Request, user: User, db: Session) -> Customer:
"""
Helper to get Customer record from authenticated User.
@@ -49,25 +41,22 @@ def get_customer_from_user(
Raises:
HTTPException: If customer not found or vendor mismatch
"""
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Find customer record for this user and vendor
customer = customer_service.get_customer_by_user_id(
- db=db,
- vendor_id=vendor.id,
- user_id=user.id
+ db=db, vendor_id=vendor.id, user_id=user.id
)
if not customer:
raise HTTPException(
- status_code=404,
- detail="Customer account not found for current vendor"
+ status_code=404, detail="Customer account not found for current vendor"
)
return customer
@@ -91,12 +80,12 @@ def place_order(
- Order data including shipping address, payment method, etc.
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Get customer record
@@ -109,14 +98,12 @@ def place_order(
"vendor_code": vendor.subdomain,
"customer_id": customer.id,
"user_id": current_user.id,
- }
+ },
)
# Create order
order = order_service.create_order(
- db=db,
- vendor_id=vendor.id,
- order_data=order_data
+ db=db, vendor_id=vendor.id, order_data=order_data
)
logger.info(
@@ -127,7 +114,7 @@ def place_order(
"order_number": order.order_number,
"customer_id": customer.id,
"total_amount": float(order.total_amount),
- }
+ },
)
# TODO: Update customer stats
@@ -156,12 +143,12 @@ def get_my_orders(
- limit: Maximum number of orders to return
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Get customer record
@@ -175,23 +162,19 @@ def get_my_orders(
"customer_id": customer.id,
"skip": skip,
"limit": limit,
- }
+ },
)
# Get orders
orders, total = order_service.get_customer_orders(
- db=db,
- vendor_id=vendor.id,
- customer_id=customer.id,
- skip=skip,
- limit=limit
+ db=db, vendor_id=vendor.id, customer_id=customer.id, skip=skip, limit=limit
)
return OrderListResponse(
orders=[OrderResponse.model_validate(o) for o in orders],
total=total,
skip=skip,
- limit=limit
+ limit=limit,
)
@@ -212,12 +195,12 @@ def get_order_details(
- order_id: ID of the order to retrieve
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Get customer record
@@ -230,19 +213,16 @@ def get_order_details(
"vendor_code": vendor.subdomain,
"customer_id": customer.id,
"order_id": order_id,
- }
+ },
)
# Get order
- order = order_service.get_order(
- db=db,
- vendor_id=vendor.id,
- order_id=order_id
- )
+ order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id)
# Verify order belongs to customer
if order.customer_id != customer.id:
from app.exceptions import OrderNotFoundException
+
raise OrderNotFoundException(str(order_id))
return OrderDetailResponse.model_validate(order)
diff --git a/app/api/v1/shop/products.py b/app/api/v1/shop/products.py
index 60bdcb83..37836bec 100644
--- a/app/api/v1/shop/products.py
+++ b/app/api/v1/shop/products.py
@@ -10,12 +10,13 @@ No authentication required.
import logging
from typing import Optional
-from fastapi import APIRouter, Depends, Query, Path, Request, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.services.product_service import product_service
-from models.schema.product import ProductResponse, ProductDetailResponse, ProductListResponse
+from models.schema.product import (ProductDetailResponse, ProductListResponse,
+ ProductResponse)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -27,7 +28,9 @@ def get_product_catalog(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None, description="Search products by name"),
- is_featured: Optional[bool] = Query(None, description="Filter by featured products"),
+ is_featured: Optional[bool] = Query(
+ None, description="Filter by featured products"
+ ),
db: Session = Depends(get_db),
):
"""
@@ -44,12 +47,12 @@ def get_product_catalog(
- is_featured: Filter by featured products only
"""
# Get vendor from middleware (injected by VendorContextMiddleware)
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -61,7 +64,7 @@ def get_product_catalog(
"limit": limit,
"search": search,
"is_featured": is_featured,
- }
+ },
)
# Get only active products for public view
@@ -71,14 +74,14 @@ def get_product_catalog(
skip=skip,
limit=limit,
is_active=True, # Only show active products to customers
- is_featured=is_featured
+ is_featured=is_featured,
)
return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products],
total=total,
skip=skip,
- limit=limit
+ limit=limit,
)
@@ -98,12 +101,12 @@ def get_product_details(
- product_id: ID of the product to retrieve
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -112,18 +115,17 @@ def get_product_details(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"product_id": product_id,
- }
+ },
)
product = product_service.get_product(
- db=db,
- vendor_id=vendor.id,
- product_id=product_id
+ db=db, vendor_id=vendor.id, product_id=product_id
)
# Check if product is active
if not product.is_active:
from app.exceptions import ProductNotActiveException
+
raise ProductNotActiveException(str(product_id))
return ProductDetailResponse.model_validate(product)
@@ -150,12 +152,12 @@ def search_products(
- limit: Maximum number of results to return
"""
# Get vendor from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
- detail="Vendor not found. Please access via vendor domain/subdomain/path."
+ detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -166,22 +168,18 @@ def search_products(
"query": q,
"skip": skip,
"limit": limit,
- }
+ },
)
# TODO: Implement full-text search functionality
# For now, return filtered products
products, total = product_service.get_vendor_products(
- db=db,
- vendor_id=vendor.id,
- skip=skip,
- limit=limit,
- is_active=True
+ db=db, vendor_id=vendor.id, skip=skip, limit=limit, is_active=True
)
return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products],
total=total,
skip=skip,
- limit=limit
+ limit=limit,
)
diff --git a/app/api/v1/vendor/__init__.py b/app/api/v1/vendor/__init__.py
index 9eb0cba8..b984af4f 100644
--- a/app/api/v1/vendor/__init__.py
+++ b/app/api/v1/vendor/__init__.py
@@ -13,25 +13,9 @@ IMPORTANT:
from fastapi import APIRouter
# Import all sub-routers (JSON API only)
-from . import (
- info,
- auth,
- dashboard,
- profile,
- settings,
- products,
- orders,
- customers,
- team,
- inventory,
- marketplace,
- payments,
- media,
- notifications,
- analytics,
- content_pages,
-)
-
+from . import (analytics, auth, content_pages, customers, dashboard, info,
+ inventory, marketplace, media, notifications, orders, payments,
+ products, profile, settings, team)
# Create vendor router
router = APIRouter()
@@ -68,7 +52,11 @@ router.include_router(notifications.router, tags=["vendor-notifications"])
router.include_router(analytics.router, tags=["vendor-analytics"])
# Content pages management
-router.include_router(content_pages.router, prefix="/{vendor_code}/content-pages", tags=["vendor-content-pages"])
+router.include_router(
+ content_pages.router,
+ prefix="/{vendor_code}/content-pages",
+ tags=["vendor-content-pages"],
+)
# Vendor info endpoint - MUST BE LAST! Has catch-all GET /{vendor_code}
router.include_router(info.router, tags=["vendor-info"])
diff --git a/app/api/v1/vendor/analytics.py b/app/api/v1/vendor/analytics.py
index 4e99270c..7d92ac5b 100644
--- a/app/api/v1/vendor/analytics.py
+++ b/app/api/v1/vendor/analytics.py
@@ -4,13 +4,14 @@ Vendor analytics and reporting endpoints.
"""
import logging
+
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.stats_service import stats_service
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
@@ -20,10 +21,10 @@ logger = logging.getLogger(__name__)
@router.get("")
def get_vendor_analytics(
- period: str = Query("30d", description="Time period: 7d, 30d, 90d, 1y"),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ period: str = Query("30d", description="Time period: 7d, 30d, 90d, 1y"),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get vendor analytics data for specified time period."""
return stats_service.get_vendor_analytics(db, vendor.id, period)
diff --git a/app/api/v1/vendor/auth.py b/app/api/v1/vendor/auth.py
index 127355bc..8f8bac33 100644
--- a/app/api/v1/vendor/auth.py
+++ b/app/api/v1/vendor/auth.py
@@ -13,19 +13,20 @@ This prevents:
"""
import logging
+
from fastapi import APIRouter, Depends, Request, Response
+from pydantic import BaseModel
from sqlalchemy.orm import Session
-from app.core.database import get_db
-from app.services.auth_service import auth_service
-from app.exceptions import InvalidCredentialsException
-from middleware.vendor_context import get_current_vendor
-from models.schema.auth import UserLogin
-from models.database.vendor import Vendor, VendorUser, Role
-from models.database.user import User
-from pydantic import BaseModel
from app.api.deps import get_current_vendor_api
+from app.core.database import get_db
from app.core.environment import should_use_secure_cookies
+from app.exceptions import InvalidCredentialsException
+from app.services.auth_service import auth_service
+from middleware.vendor_context import get_current_vendor
+from models.database.user import User
+from models.database.vendor import Role, Vendor, VendorUser
+from models.schema.auth import UserLogin
router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__)
@@ -43,10 +44,10 @@ class VendorLoginResponse(BaseModel):
@router.post("/login", response_model=VendorLoginResponse)
def vendor_login(
- user_credentials: UserLogin,
- request: Request,
- response: Response,
- db: Session = Depends(get_db)
+ user_credentials: UserLogin,
+ request: Request,
+ response: Response,
+ db: Session = Depends(get_db),
):
"""
Vendor team member login.
@@ -64,13 +65,16 @@ def vendor_login(
vendor = get_current_vendor(request)
# If no vendor from middleware, try to get from request body
- if not vendor and hasattr(user_credentials, 'vendor_code'):
- vendor_code = getattr(user_credentials, 'vendor_code', None)
+ if not vendor and hasattr(user_credentials, "vendor_code"):
+ vendor_code = getattr(user_credentials, "vendor_code", None)
if vendor_code:
- vendor = db.query(Vendor).filter(
- Vendor.vendor_code == vendor_code.upper(),
- Vendor.is_active == True
- ).first()
+ vendor = (
+ db.query(Vendor)
+ .filter(
+ Vendor.vendor_code == vendor_code.upper(), Vendor.is_active == True
+ )
+ .first()
+ )
# Authenticate user
login_result = auth_service.login_user(db=db, user_credentials=user_credentials)
@@ -79,7 +83,9 @@ def vendor_login(
# CRITICAL: Prevent admin users from using vendor login
if user.role == "admin":
logger.warning(f"Admin user attempted vendor login: {user.username}")
- raise InvalidCredentialsException("Admins cannot access vendor portal. Please use admin portal.")
+ raise InvalidCredentialsException(
+ "Admins cannot access vendor portal. Please use admin portal."
+ )
# Determine vendor and role
vendor_role = "Member"
@@ -92,11 +98,16 @@ def vendor_login(
vendor_role = "Owner"
else:
# Check if user is team member
- vendor_user = db.query(VendorUser).join(Role).filter(
- VendorUser.user_id == user.id,
- VendorUser.vendor_id == vendor.id,
- VendorUser.is_active == True
- ).first()
+ vendor_user = (
+ db.query(VendorUser)
+ .join(Role)
+ .filter(
+ VendorUser.user_id == user.id,
+ VendorUser.vendor_id == vendor.id,
+ VendorUser.is_active == True,
+ )
+ .first()
+ )
if vendor_user:
vendor_role = vendor_user.role.name
@@ -117,17 +128,14 @@ def vendor_login(
# Check vendor memberships
elif user.vendor_memberships:
active_membership = next(
- (vm for vm in user.vendor_memberships if vm.is_active),
- None
+ (vm for vm in user.vendor_memberships if vm.is_active), None
)
if active_membership:
vendor = active_membership.vendor
vendor_role = active_membership.role.name
if not vendor:
- raise InvalidCredentialsException(
- "User is not associated with any vendor"
- )
+ raise InvalidCredentialsException("User is not associated with any vendor")
logger.info(
f"Vendor team login successful: {user.username} "
@@ -161,7 +169,7 @@ def vendor_login(
"username": user.username,
"email": user.email,
"role": user.role,
- "is_active": user.is_active
+ "is_active": user.is_active,
},
vendor={
"id": vendor.id,
@@ -169,9 +177,9 @@ def vendor_login(
"subdomain": vendor.subdomain,
"name": vendor.name,
"is_active": vendor.is_active,
- "is_verified": vendor.is_verified
+ "is_verified": vendor.is_verified,
},
- vendor_role=vendor_role
+ vendor_role=vendor_role,
)
@@ -198,8 +206,7 @@ def vendor_logout(response: Response):
@router.get("/me")
def get_current_vendor_user(
- user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ user: User = Depends(get_current_vendor_api), db: Session = Depends(get_db)
):
"""
Get current authenticated vendor user.
@@ -212,5 +219,5 @@ def get_current_vendor_user(
"username": user.username,
"email": user.email,
"role": user.role,
- "is_active": user.is_active
+ "is_active": user.is_active,
}
diff --git a/app/api/v1/vendor/content_pages.py b/app/api/v1/vendor/content_pages.py
index b9af5302..4a552898 100644
--- a/app/api/v1/vendor/content_pages.py
+++ b/app/api/v1/vendor/content_pages.py
@@ -10,6 +10,7 @@ Vendors can:
import logging
from typing import List, Optional
+
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -26,14 +27,26 @@ logger = logging.getLogger(__name__)
# REQUEST/RESPONSE SCHEMAS
# ============================================================================
+
class VendorContentPageCreate(BaseModel):
"""Schema for creating a vendor content page."""
- slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)")
+
+ slug: str = Field(
+ ...,
+ max_length=100,
+ description="URL-safe identifier (about, faq, contact, etc.)",
+ )
title: str = Field(..., max_length=200, description="Page title")
content: str = Field(..., description="HTML or Markdown content")
- content_format: str = Field(default="html", description="Content format: html or markdown")
- meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description")
- meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords")
+ content_format: str = Field(
+ default="html", description="Content format: html or markdown"
+ )
+ meta_description: Optional[str] = Field(
+ None, max_length=300, description="SEO meta description"
+ )
+ meta_keywords: Optional[str] = Field(
+ None, max_length=300, description="SEO keywords"
+ )
is_published: bool = Field(default=False, description="Publish immediately")
show_in_footer: bool = Field(default=True, description="Show in footer navigation")
show_in_header: bool = Field(default=False, description="Show in header navigation")
@@ -42,6 +55,7 @@ class VendorContentPageCreate(BaseModel):
class VendorContentPageUpdate(BaseModel):
"""Schema for updating a vendor content page."""
+
title: Optional[str] = Field(None, max_length=200)
content: Optional[str] = None
content_format: Optional[str] = None
@@ -55,6 +69,7 @@ class VendorContentPageUpdate(BaseModel):
class ContentPageResponse(BaseModel):
"""Schema for content page response."""
+
id: int
vendor_id: Optional[int]
vendor_name: Optional[str]
@@ -81,11 +96,12 @@ class ContentPageResponse(BaseModel):
# VENDOR CONTENT PAGES
# ============================================================================
+
@router.get("/", response_model=List[ContentPageResponse])
def list_vendor_pages(
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
List all content pages available for this vendor.
@@ -93,12 +109,12 @@ def list_vendor_pages(
Returns vendor-specific overrides + platform defaults (vendor overrides take precedence).
"""
if not current_user.vendor_id:
- raise HTTPException(status_code=403, detail="User is not associated with a vendor")
+ raise HTTPException(
+ status_code=403, detail="User is not associated with a vendor"
+ )
pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=current_user.vendor_id,
- include_unpublished=include_unpublished
+ db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished
)
return [page.to_dict() for page in pages]
@@ -108,7 +124,7 @@ def list_vendor_pages(
def list_vendor_overrides(
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
List only vendor-specific content pages (no platform defaults).
@@ -116,12 +132,12 @@ def list_vendor_overrides(
Shows what the vendor has customized.
"""
if not current_user.vendor_id:
- raise HTTPException(status_code=403, detail="User is not associated with a vendor")
+ raise HTTPException(
+ status_code=403, detail="User is not associated with a vendor"
+ )
pages = content_page_service.list_all_vendor_pages(
- db,
- vendor_id=current_user.vendor_id,
- include_unpublished=include_unpublished
+ db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished
)
return [page.to_dict() for page in pages]
@@ -132,7 +148,7 @@ def get_page(
slug: str,
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Get a specific content page by slug.
@@ -140,13 +156,15 @@ def get_page(
Returns vendor override if exists, otherwise platform default.
"""
if not current_user.vendor_id:
- raise HTTPException(status_code=403, detail="User is not associated with a vendor")
+ raise HTTPException(
+ status_code=403, detail="User is not associated with a vendor"
+ )
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=current_user.vendor_id,
- include_unpublished=include_unpublished
+ include_unpublished=include_unpublished,
)
if not page:
@@ -159,7 +177,7 @@ def get_page(
def create_vendor_page(
page_data: VendorContentPageCreate,
current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Create a vendor-specific content page override.
@@ -167,7 +185,9 @@ def create_vendor_page(
This will be shown instead of the platform default for this vendor.
"""
if not current_user.vendor_id:
- raise HTTPException(status_code=403, detail="User is not associated with a vendor")
+ raise HTTPException(
+ status_code=403, detail="User is not associated with a vendor"
+ )
page = content_page_service.create_page(
db,
@@ -182,7 +202,7 @@ def create_vendor_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
- created_by=current_user.id
+ created_by=current_user.id,
)
return page.to_dict()
@@ -193,7 +213,7 @@ def update_vendor_page(
page_id: int,
page_data: VendorContentPageUpdate,
current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Update a vendor-specific content page.
@@ -201,7 +221,9 @@ def update_vendor_page(
Can only update pages owned by this vendor.
"""
if not current_user.vendor_id:
- raise HTTPException(status_code=403, detail="User is not associated with a vendor")
+ raise HTTPException(
+ status_code=403, detail="User is not associated with a vendor"
+ )
# Verify ownership
existing_page = content_page_service.get_page_by_id(db, page_id)
@@ -209,7 +231,9 @@ def update_vendor_page(
raise HTTPException(status_code=404, detail="Content page not found")
if existing_page.vendor_id != current_user.vendor_id:
- raise HTTPException(status_code=403, detail="Cannot edit pages from other vendors")
+ raise HTTPException(
+ status_code=403, detail="Cannot edit pages from other vendors"
+ )
# Update
page = content_page_service.update_page(
@@ -224,7 +248,7 @@ def update_vendor_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
- updated_by=current_user.id
+ updated_by=current_user.id,
)
return page.to_dict()
@@ -234,7 +258,7 @@ def update_vendor_page(
def delete_vendor_page(
page_id: int,
current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Delete a vendor-specific content page.
@@ -243,7 +267,9 @@ def delete_vendor_page(
After deletion, platform default will be shown (if exists).
"""
if not current_user.vendor_id:
- raise HTTPException(status_code=403, detail="User is not associated with a vendor")
+ raise HTTPException(
+ status_code=403, detail="User is not associated with a vendor"
+ )
# Verify ownership
existing_page = content_page_service.get_page_by_id(db, page_id)
@@ -251,7 +277,9 @@ def delete_vendor_page(
raise HTTPException(status_code=404, detail="Content page not found")
if existing_page.vendor_id != current_user.vendor_id:
- raise HTTPException(status_code=403, detail="Cannot delete pages from other vendors")
+ raise HTTPException(
+ status_code=403, detail="Cannot delete pages from other vendors"
+ )
# Delete
content_page_service.delete_page(db, page_id)
diff --git a/app/api/v1/vendor/customers.py b/app/api/v1/vendor/customers.py
index 2727a017..3a267c17 100644
--- a/app/api/v1/vendor/customers.py
+++ b/app/api/v1/vendor/customers.py
@@ -6,6 +6,7 @@ Vendor customer management endpoints.
import logging
from typing import Optional
+
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -21,13 +22,13 @@ logger = logging.getLogger(__name__)
@router.get("")
def get_vendor_customers(
- skip: int = Query(0, ge=0),
- limit: int = Query(100, ge=1, le=1000),
- search: Optional[str] = Query(None),
- is_active: Optional[bool] = Query(None),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(100, ge=1, le=1000),
+ search: Optional[str] = Query(None),
+ is_active: Optional[bool] = Query(None),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get all customers for this vendor.
@@ -43,16 +44,16 @@ def get_vendor_customers(
"total": 0,
"skip": skip,
"limit": limit,
- "message": "Customer management coming in Slice 4"
+ "message": "Customer management coming in Slice 4",
}
@router.get("/{customer_id}")
def get_customer_details(
- customer_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ customer_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get detailed customer information.
@@ -63,17 +64,15 @@ def get_customer_details(
- Include order history
- Include total spent, etc.
"""
- return {
- "message": "Customer details coming in Slice 4"
- }
+ return {"message": "Customer details coming in Slice 4"}
@router.get("/{customer_id}/orders")
def get_customer_orders(
- customer_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ customer_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get order history for a specific customer.
@@ -83,19 +82,16 @@ def get_customer_orders(
- Filter by vendor_id
- Return order details
"""
- return {
- "orders": [],
- "message": "Customer orders coming in Slice 5"
- }
+ return {"orders": [], "message": "Customer orders coming in Slice 5"}
@router.put("/{customer_id}")
def update_customer(
- customer_id: int,
- customer_data: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ customer_id: int,
+ customer_data: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Update customer information.
@@ -105,17 +101,15 @@ def update_customer(
- Verify customer belongs to vendor
- Update customer preferences
"""
- return {
- "message": "Customer update coming in Slice 4"
- }
+ return {"message": "Customer update coming in Slice 4"}
@router.put("/{customer_id}/status")
def toggle_customer_status(
- customer_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ customer_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Activate/deactivate customer account.
@@ -125,17 +119,15 @@ def toggle_customer_status(
- Verify customer belongs to vendor
- Log the change
"""
- return {
- "message": "Customer status toggle coming in Slice 4"
- }
+ return {"message": "Customer status toggle coming in Slice 4"}
@router.get("/{customer_id}/stats")
def get_customer_statistics(
- customer_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ customer_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get customer statistics and metrics.
@@ -151,6 +143,5 @@ def get_customer_statistics(
"total_spent": 0.0,
"average_order_value": 0.0,
"last_order_date": None,
- "message": "Customer statistics coming in Slice 4"
+ "message": "Customer statistics coming in Slice 4",
}
-
diff --git a/app/api/v1/vendor/dashboard.py b/app/api/v1/vendor/dashboard.py
index 9e8cc066..b49351a6 100644
--- a/app/api/v1/vendor/dashboard.py
+++ b/app/api/v1/vendor/dashboard.py
@@ -4,13 +4,14 @@ Vendor dashboard and statistics endpoints.
"""
import logging
+
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.stats_service import stats_service
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
@@ -20,9 +21,9 @@ logger = logging.getLogger(__name__)
@router.get("/stats")
def get_vendor_dashboard_stats(
- request: Request,
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ request: Request,
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get vendor-specific dashboard statistics.
@@ -38,24 +39,23 @@ def get_vendor_dashboard_stats(
"""
# Get vendor from authenticated user's vendor_user record
from models.database.vendor import VendorUser
- vendor_user = db.query(VendorUser).filter(
- VendorUser.user_id == current_user.id
- ).first()
+
+ vendor_user = (
+ db.query(VendorUser).filter(VendorUser.user_id == current_user.id).first()
+ )
if not vendor_user:
from fastapi import HTTPException
+
raise HTTPException(
- status_code=403,
- detail="User is not associated with any vendor"
+ status_code=403, detail="User is not associated with any vendor"
)
vendor = vendor_user.vendor
if not vendor or not vendor.is_active:
from fastapi import HTTPException
- raise HTTPException(
- status_code=404,
- detail="Vendor not found or inactive"
- )
+
+ raise HTTPException(status_code=404, detail="Vendor not found or inactive")
# Get vendor-scoped statistics
stats_data = stats_service.get_vendor_stats(db=db, vendor_id=vendor.id)
@@ -82,5 +82,5 @@ def get_vendor_dashboard_stats(
"revenue": {
"total": stats_data.get("total_revenue", 0),
"this_month": stats_data.get("revenue_this_month", 0),
- }
+ },
}
diff --git a/app/api/v1/vendor/info.py b/app/api/v1/vendor/info.py
index fffcebb7..e84d0a8d 100644
--- a/app/api/v1/vendor/info.py
+++ b/app/api/v1/vendor/info.py
@@ -8,14 +8,15 @@ This module provides:
"""
import logging
-from fastapi import APIRouter, Path, Depends
-from sqlalchemy.orm import Session
+
+from fastapi import APIRouter, Depends, Path
from sqlalchemy import func
+from sqlalchemy.orm import Session
from app.core.database import get_db
from app.exceptions import VendorNotFoundException
-from models.schema.vendor import VendorResponse, VendorDetailResponse
from models.database.vendor import Vendor
+from models.schema.vendor import VendorDetailResponse, VendorResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -35,10 +36,14 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor:
Raises:
VendorNotFoundException: If vendor not found or inactive
"""
- vendor = db.query(Vendor).filter(
- func.upper(Vendor.vendor_code) == vendor_code.upper(),
- Vendor.is_active == True
- ).first()
+ vendor = (
+ db.query(Vendor)
+ .filter(
+ func.upper(Vendor.vendor_code) == vendor_code.upper(),
+ Vendor.is_active == True,
+ )
+ .first()
+ )
if not vendor:
logger.warning(f"Vendor not found or inactive: {vendor_code}")
@@ -49,8 +54,8 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor:
@router.get("/{vendor_code}", response_model=VendorDetailResponse)
def get_vendor_info(
- vendor_code: str = Path(..., description="Vendor code"),
- db: Session = Depends(get_db)
+ vendor_code: str = Path(..., description="Vendor code"),
+ db: Session = Depends(get_db),
):
"""
Get public vendor information by vendor code.
diff --git a/app/api/v1/vendor/inventory.py b/app/api/v1/vendor/inventory.py
index 65b2d331..0cb49c2c 100644
--- a/app/api/v1/vendor/inventory.py
+++ b/app/api/v1/vendor/inventory.py
@@ -7,19 +7,14 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.inventory_service import inventory_service
-from models.schema.inventory import (
- InventoryCreate,
- InventoryAdjust,
- InventoryUpdate,
- InventoryReserve,
- InventoryResponse,
- ProductInventorySummary,
- InventoryListResponse
-)
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
+from models.schema.inventory import (InventoryAdjust, InventoryCreate,
+ InventoryListResponse, InventoryReserve,
+ InventoryResponse, InventoryUpdate,
+ ProductInventorySummary)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -27,10 +22,10 @@ logger = logging.getLogger(__name__)
@router.post("/inventory/set", response_model=InventoryResponse)
def set_inventory(
- inventory: InventoryCreate,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ inventory: InventoryCreate,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Set exact inventory quantity (replaces existing)."""
return inventory_service.set_inventory(db, vendor.id, inventory)
@@ -38,10 +33,10 @@ def set_inventory(
@router.post("/inventory/adjust", response_model=InventoryResponse)
def adjust_inventory(
- adjustment: InventoryAdjust,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ adjustment: InventoryAdjust,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Adjust inventory (positive to add, negative to remove)."""
return inventory_service.adjust_inventory(db, vendor.id, adjustment)
@@ -49,10 +44,10 @@ def adjust_inventory(
@router.post("/inventory/reserve", response_model=InventoryResponse)
def reserve_inventory(
- reservation: InventoryReserve,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ reservation: InventoryReserve,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Reserve inventory for an order."""
return inventory_service.reserve_inventory(db, vendor.id, reservation)
@@ -60,10 +55,10 @@ def reserve_inventory(
@router.post("/inventory/release", response_model=InventoryResponse)
def release_reservation(
- reservation: InventoryReserve,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ reservation: InventoryReserve,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Release reserved inventory (cancel order)."""
return inventory_service.release_reservation(db, vendor.id, reservation)
@@ -71,10 +66,10 @@ def release_reservation(
@router.post("/inventory/fulfill", response_model=InventoryResponse)
def fulfill_reservation(
- reservation: InventoryReserve,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ reservation: InventoryReserve,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Fulfill reservation (complete order, remove from stock)."""
return inventory_service.fulfill_reservation(db, vendor.id, reservation)
@@ -82,10 +77,10 @@ def fulfill_reservation(
@router.get("/inventory/product/{product_id}", response_model=ProductInventorySummary)
def get_product_inventory(
- product_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get inventory summary for a product."""
return inventory_service.get_product_inventory(db, vendor.id, product_id)
@@ -93,13 +88,13 @@ def get_product_inventory(
@router.get("/inventory", response_model=InventoryListResponse)
def get_vendor_inventory(
- skip: int = Query(0, ge=0),
- limit: int = Query(100, ge=1, le=1000),
- location: Optional[str] = Query(None),
- low_stock: Optional[int] = Query(None, ge=0),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(100, ge=1, le=1000),
+ location: Optional[str] = Query(None),
+ low_stock: Optional[int] = Query(None, ge=0),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get all inventory for vendor."""
inventories = inventory_service.get_vendor_inventory(
@@ -110,31 +105,30 @@ def get_vendor_inventory(
total = len(inventories) # You might want a separate count query for large datasets
return InventoryListResponse(
- inventories=inventories,
- total=total,
- skip=skip,
- limit=limit
+ inventories=inventories, total=total, skip=skip, limit=limit
)
@router.put("/inventory/{inventory_id}", response_model=InventoryResponse)
def update_inventory(
- inventory_id: int,
- inventory_update: InventoryUpdate,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ inventory_id: int,
+ inventory_update: InventoryUpdate,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Update inventory entry."""
- return inventory_service.update_inventory(db, vendor.id, inventory_id, inventory_update)
+ return inventory_service.update_inventory(
+ db, vendor.id, inventory_id, inventory_update
+ )
@router.delete("/inventory/{inventory_id}")
def delete_inventory(
- inventory_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ inventory_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Delete inventory entry."""
inventory_service.delete_inventory(db, vendor.id, inventory_id)
diff --git a/app/api/v1/vendor/marketplace.py b/app/api/v1/vendor/marketplace.py
index cb98d980..9084bfd2 100644
--- a/app/api/v1/vendor/marketplace.py
+++ b/app/api/v1/vendor/marketplace.py
@@ -12,16 +12,15 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context # IMPORTANT
-from app.services.marketplace_import_job_service import marketplace_import_job_service
+from app.services.marketplace_import_job_service import \
+ marketplace_import_job_service
from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit
-from models.schema.marketplace_import_job import (
- MarketplaceImportJobResponse,
- MarketplaceImportJobRequest
-)
+from middleware.vendor_context import require_vendor_context # IMPORTANT
from models.database.user import User
from models.database.vendor import Vendor
+from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
+ MarketplaceImportJobResponse)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -30,11 +29,11 @@ logger = logging.getLogger(__name__)
@router.post("/import", response_model=MarketplaceImportJobResponse)
@rate_limit(max_requests=10, window_seconds=3600)
async def import_products_from_marketplace(
- request: MarketplaceImportJobRequest,
- background_tasks: BackgroundTasks,
- vendor: Vendor = Depends(require_vendor_context()), # ADDED: Vendor from middleware
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ request: MarketplaceImportJobRequest,
+ background_tasks: BackgroundTasks,
+ vendor: Vendor = Depends(require_vendor_context()), # ADDED: Vendor from middleware
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Import products from marketplace CSV with background processing (Protected)."""
logger.info(
@@ -66,7 +65,7 @@ async def import_products_from_marketplace(
vendor_name=vendor.name, # FIXED: from vendor object
source_url=request.source_url,
message=f"Marketplace import started from {request.marketplace}. "
- f"Check status with /import-status/{import_job.id}",
+ f"Check status with /import-status/{import_job.id}",
imported=0,
updated=0,
total_processed=0,
@@ -77,10 +76,10 @@ async def import_products_from_marketplace(
@router.get("/imports/{job_id}", response_model=MarketplaceImportJobResponse)
def get_marketplace_import_status(
- job_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ job_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get status of marketplace import job (Protected)."""
job = marketplace_import_job_service.get_import_job_by_id(db, job_id, current_user)
@@ -88,6 +87,7 @@ def get_marketplace_import_status(
# Verify job belongs to current vendor
if job.vendor_id != vendor.id:
from app.exceptions import UnauthorizedVendorAccessException
+
raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id)
return marketplace_import_job_service.convert_to_response_model(job)
@@ -95,12 +95,12 @@ def get_marketplace_import_status(
@router.get("/imports", response_model=List[MarketplaceImportJobResponse])
def get_marketplace_import_jobs(
- marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
- skip: int = Query(0, ge=0),
- limit: int = Query(50, ge=1, le=100),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(50, ge=1, le=100),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get marketplace import jobs for current vendor (Protected)."""
jobs = marketplace_import_job_service.get_import_jobs(
@@ -112,4 +112,6 @@ def get_marketplace_import_jobs(
limit=limit,
)
- return [marketplace_import_job_service.convert_to_response_model(job) for job in jobs]
+ return [
+ marketplace_import_job_service.convert_to_response_model(job) for job in jobs
+ ]
diff --git a/app/api/v1/vendor/media.py b/app/api/v1/vendor/media.py
index d90ee1fe..900bbb6e 100644
--- a/app/api/v1/vendor/media.py
+++ b/app/api/v1/vendor/media.py
@@ -6,7 +6,8 @@ Vendor media and file management endpoints.
import logging
from typing import Optional
-from fastapi import APIRouter, Depends, Query, UploadFile, File
+
+from fastapi import APIRouter, Depends, File, Query, UploadFile
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
@@ -21,13 +22,13 @@ logger = logging.getLogger(__name__)
@router.get("")
def get_media_library(
- skip: int = Query(0, ge=0),
- limit: int = Query(100, ge=1, le=1000),
- media_type: Optional[str] = Query(None, description="image, video, document"),
- search: Optional[str] = Query(None),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(100, ge=1, le=1000),
+ media_type: Optional[str] = Query(None, description="image, video, document"),
+ search: Optional[str] = Query(None),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get vendor media library.
@@ -44,17 +45,17 @@ def get_media_library(
"total": 0,
"skip": skip,
"limit": limit,
- "message": "Media library coming in Slice 3"
+ "message": "Media library coming in Slice 3",
}
@router.post("/upload")
async def upload_media(
- file: UploadFile = File(...),
- folder: Optional[str] = Query(None, description="products, general, etc."),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ file: UploadFile = File(...),
+ folder: Optional[str] = Query(None, description="products, general, etc."),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Upload media file.
@@ -70,17 +71,17 @@ async def upload_media(
return {
"file_url": None,
"thumbnail_url": None,
- "message": "Media upload coming in Slice 3"
+ "message": "Media upload coming in Slice 3",
}
@router.post("/upload/multiple")
async def upload_multiple_media(
- files: list[UploadFile] = File(...),
- folder: Optional[str] = Query(None),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ files: list[UploadFile] = File(...),
+ folder: Optional[str] = Query(None),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Upload multiple media files at once.
@@ -94,16 +95,16 @@ async def upload_multiple_media(
return {
"uploaded_files": [],
"failed_files": [],
- "message": "Multiple upload coming in Slice 3"
+ "message": "Multiple upload coming in Slice 3",
}
@router.get("/{media_id}")
def get_media_details(
- media_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ media_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get media file details.
@@ -113,18 +114,16 @@ def get_media_details(
- Return file URL
- Return usage information (which products use this file)
"""
- return {
- "message": "Media details coming in Slice 3"
- }
+ return {"message": "Media details coming in Slice 3"}
@router.put("/{media_id}")
def update_media_metadata(
- media_id: int,
- metadata: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ media_id: int,
+ metadata: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Update media file metadata.
@@ -135,17 +134,15 @@ def update_media_metadata(
- Update tags/categories
- Update description
"""
- return {
- "message": "Media update coming in Slice 3"
- }
+ return {"message": "Media update coming in Slice 3"}
@router.delete("/{media_id}")
def delete_media(
- media_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ media_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Delete media file.
@@ -157,17 +154,15 @@ def delete_media(
- Delete database record
- Return success/error
"""
- return {
- "message": "Media deletion coming in Slice 3"
- }
+ return {"message": "Media deletion coming in Slice 3"}
@router.get("/{media_id}/usage")
def get_media_usage(
- media_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ media_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get where this media file is being used.
@@ -180,16 +175,16 @@ def get_media_usage(
return {
"products": [],
"other_usage": [],
- "message": "Media usage tracking coming in Slice 3"
+ "message": "Media usage tracking coming in Slice 3",
}
@router.post("/optimize/{media_id}")
def optimize_media(
- media_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ media_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Optimize media file (compress, resize, etc.).
@@ -200,7 +195,4 @@ def optimize_media(
- Keep original
- Update database with new versions
"""
- return {
- "message": "Media optimization coming in Slice 3"
- }
-
+ return {"message": "Media optimization coming in Slice 3"}
diff --git a/app/api/v1/vendor/notifications.py b/app/api/v1/vendor/notifications.py
index 2889a9e7..2b6a8205 100644
--- a/app/api/v1/vendor/notifications.py
+++ b/app/api/v1/vendor/notifications.py
@@ -1,4 +1,4 @@
-# Notification management
+# Notification management
# app/api/v1/vendor/notifications.py
"""
Vendor notification management endpoints.
@@ -6,6 +6,7 @@ Vendor notification management endpoints.
import logging
from typing import Optional
+
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -21,12 +22,12 @@ logger = logging.getLogger(__name__)
@router.get("")
def get_notifications(
- skip: int = Query(0, ge=0),
- limit: int = Query(50, ge=1, le=100),
- unread_only: Optional[bool] = Query(False),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(50, ge=1, le=100),
+ unread_only: Optional[bool] = Query(False),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get vendor notifications.
@@ -41,15 +42,15 @@ def get_notifications(
"notifications": [],
"total": 0,
"unread_count": 0,
- "message": "Notifications coming in Slice 5"
+ "message": "Notifications coming in Slice 5",
}
@router.get("/unread-count")
def get_unread_count(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get count of unread notifications.
@@ -58,18 +59,15 @@ def get_unread_count(
- Count unread notifications for vendor
- Used for notification badge
"""
- return {
- "unread_count": 0,
- "message": "Unread count coming in Slice 5"
- }
+ return {"unread_count": 0, "message": "Unread count coming in Slice 5"}
@router.put("/{notification_id}/read")
def mark_as_read(
- notification_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ notification_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Mark notification as read.
@@ -78,16 +76,14 @@ def mark_as_read(
- Mark single notification as read
- Update read timestamp
"""
- return {
- "message": "Mark as read coming in Slice 5"
- }
+ return {"message": "Mark as read coming in Slice 5"}
@router.put("/mark-all-read")
def mark_all_as_read(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Mark all notifications as read.
@@ -96,17 +92,15 @@ def mark_all_as_read(
- Mark all vendor notifications as read
- Update timestamps
"""
- return {
- "message": "Mark all as read coming in Slice 5"
- }
+ return {"message": "Mark all as read coming in Slice 5"}
@router.delete("/{notification_id}")
def delete_notification(
- notification_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ notification_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Delete notification.
@@ -115,16 +109,14 @@ def delete_notification(
- Delete single notification
- Verify notification belongs to vendor
"""
- return {
- "message": "Notification deletion coming in Slice 5"
- }
+ return {"message": "Notification deletion coming in Slice 5"}
@router.get("/settings")
def get_notification_settings(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get notification preferences.
@@ -138,16 +130,16 @@ def get_notification_settings(
"email_notifications": True,
"in_app_notifications": True,
"notification_types": {},
- "message": "Notification settings coming in Slice 5"
+ "message": "Notification settings coming in Slice 5",
}
@router.put("/settings")
def update_notification_settings(
- settings: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ settings: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Update notification preferences.
@@ -157,16 +149,14 @@ def update_notification_settings(
- Update in-app notification settings
- Enable/disable specific notification types
"""
- return {
- "message": "Notification settings update coming in Slice 5"
- }
+ return {"message": "Notification settings update coming in Slice 5"}
@router.get("/templates")
def get_notification_templates(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get notification email templates.
@@ -176,19 +166,16 @@ def get_notification_templates(
- Include: order confirmation, shipping notification, etc.
- Return template details
"""
- return {
- "templates": [],
- "message": "Notification templates coming in Slice 5"
- }
+ return {"templates": [], "message": "Notification templates coming in Slice 5"}
@router.put("/templates/{template_id}")
def update_notification_template(
- template_id: int,
- template_data: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ template_id: int,
+ template_data: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Update notification email template.
@@ -199,17 +186,15 @@ def update_notification_template(
- Validate template variables
- Preview template
"""
- return {
- "message": "Template update coming in Slice 5"
- }
+ return {"message": "Template update coming in Slice 5"}
@router.post("/test")
def send_test_notification(
- notification_data: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ notification_data: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Send test notification.
@@ -219,6 +204,4 @@ def send_test_notification(
- Use specified template
- Send to current user's email
"""
- return {
- "message": "Test notification coming in Slice 5"
- }
+ return {"message": "Test notification coming in Slice 5"}
diff --git a/app/api/v1/vendor/orders.py b/app/api/v1/vendor/orders.py
index 49b5a360..a32b4253 100644
--- a/app/api/v1/vendor/orders.py
+++ b/app/api/v1/vendor/orders.py
@@ -6,21 +6,17 @@ Vendor order management endpoints.
import logging
from typing import Optional
-from fastapi import APIRouter, Depends, Query, Request, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.order_service import order_service
-from models.schema.order import (
- OrderResponse,
- OrderDetailResponse,
- OrderListResponse,
- OrderUpdate
-)
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor, VendorUser
+from models.schema.order import (OrderDetailResponse, OrderListResponse,
+ OrderResponse, OrderUpdate)
router = APIRouter(prefix="/orders")
logger = logging.getLogger(__name__)
@@ -28,13 +24,13 @@ logger = logging.getLogger(__name__)
@router.get("", response_model=OrderListResponse)
def get_vendor_orders(
- skip: int = Query(0, ge=0),
- limit: int = Query(100, ge=1, le=1000),
- status: Optional[str] = Query(None, description="Filter by order status"),
- customer_id: Optional[int] = Query(None, description="Filter by customer"),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(100, ge=1, le=1000),
+ status: Optional[str] = Query(None, description="Filter by order status"),
+ customer_id: Optional[int] = Query(None, description="Filter by customer"),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get all orders for vendor.
@@ -51,45 +47,41 @@ def get_vendor_orders(
skip=skip,
limit=limit,
status=status,
- customer_id=customer_id
+ customer_id=customer_id,
)
return OrderListResponse(
orders=[OrderResponse.model_validate(o) for o in orders],
total=total,
skip=skip,
- limit=limit
+ limit=limit,
)
@router.get("/{order_id}", response_model=OrderDetailResponse)
def get_order_details(
- order_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ order_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get detailed order information including items and addresses.
Requires Authorization header (API endpoint).
"""
- order = order_service.get_order(
- db=db,
- vendor_id=vendor.id,
- order_id=order_id
- )
+ order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id)
return OrderDetailResponse.model_validate(order)
@router.put("/{order_id}/status", response_model=OrderResponse)
def update_order_status(
- order_id: int,
- order_update: OrderUpdate,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ order_id: int,
+ order_update: OrderUpdate,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Update order status and tracking information.
@@ -105,10 +97,7 @@ def update_order_status(
Requires Authorization header (API endpoint).
"""
order = order_service.update_order_status(
- db=db,
- vendor_id=vendor.id,
- order_id=order_id,
- order_update=order_update
+ db=db, vendor_id=vendor.id, order_id=order_id, order_update=order_update
)
logger.info(
diff --git a/app/api/v1/vendor/payments.py b/app/api/v1/vendor/payments.py
index a25c7b69..a29672f6 100644
--- a/app/api/v1/vendor/payments.py
+++ b/app/api/v1/vendor/payments.py
@@ -1,10 +1,11 @@
-# Payment configuration and processing
+# Payment configuration and processing
# app/api/v1/vendor/payments.py
"""
Vendor payment configuration and processing endpoints.
"""
import logging
+
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -20,9 +21,9 @@ logger = logging.getLogger(__name__)
@router.get("/config")
def get_payment_configuration(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get vendor payment configuration.
@@ -38,16 +39,16 @@ def get_payment_configuration(
"accepted_methods": [],
"currency": "EUR",
"stripe_connected": False,
- "message": "Payment configuration coming in Slice 5"
+ "message": "Payment configuration coming in Slice 5",
}
@router.put("/config")
def update_payment_configuration(
- payment_config: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ payment_config: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Update vendor payment configuration.
@@ -58,17 +59,15 @@ def update_payment_configuration(
- Update accepted payment methods
- Validate configuration before saving
"""
- return {
- "message": "Payment configuration update coming in Slice 5"
- }
+ return {"message": "Payment configuration update coming in Slice 5"}
@router.post("/stripe/connect")
def connect_stripe_account(
- stripe_data: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ stripe_data: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Connect Stripe account for payment processing.
@@ -79,16 +78,14 @@ def connect_stripe_account(
- Verify Stripe account is active
- Enable payment processing
"""
- return {
- "message": "Stripe connection coming in Slice 5"
- }
+ return {"message": "Stripe connection coming in Slice 5"}
@router.delete("/stripe/disconnect")
def disconnect_stripe_account(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Disconnect Stripe account.
@@ -98,16 +95,14 @@ def disconnect_stripe_account(
- Disable payment processing
- Warn about pending payments
"""
- return {
- "message": "Stripe disconnection coming in Slice 5"
- }
+ return {"message": "Stripe disconnection coming in Slice 5"}
@router.get("/methods")
def get_payment_methods(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get accepted payment methods for vendor.
@@ -116,17 +111,14 @@ def get_payment_methods(
- Return list of enabled payment methods
- Include: credit card, PayPal, bank transfer, etc.
"""
- return {
- "methods": [],
- "message": "Payment methods coming in Slice 5"
- }
+ return {"methods": [], "message": "Payment methods coming in Slice 5"}
@router.get("/transactions")
def get_payment_transactions(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get payment transaction history.
@@ -140,15 +132,15 @@ def get_payment_transactions(
return {
"transactions": [],
"total": 0,
- "message": "Payment transactions coming in Slice 5"
+ "message": "Payment transactions coming in Slice 5",
}
@router.get("/balance")
def get_payment_balance(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get vendor payment balance and payout information.
@@ -164,17 +156,17 @@ def get_payment_balance(
"pending_balance": 0.0,
"currency": "EUR",
"next_payout_date": None,
- "message": "Payment balance coming in Slice 5"
+ "message": "Payment balance coming in Slice 5",
}
@router.post("/refund/{payment_id}")
def refund_payment(
- payment_id: int,
- refund_data: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ payment_id: int,
+ refund_data: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Process payment refund.
@@ -185,6 +177,4 @@ def refund_payment(
- Update order status
- Send refund notification to customer
"""
- return {
- "message": "Payment refund coming in Slice 5"
- }
+ return {"message": "Payment refund coming in Slice 5"}
diff --git a/app/api/v1/vendor/products.py b/app/api/v1/vendor/products.py
index d229ca57..f2263ef3 100644
--- a/app/api/v1/vendor/products.py
+++ b/app/api/v1/vendor/products.py
@@ -11,17 +11,13 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.product_service import product_service
-from models.schema.product import (
- ProductCreate,
- ProductUpdate,
- ProductResponse,
- ProductDetailResponse,
- ProductListResponse
-)
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
+from models.schema.product import (ProductCreate, ProductDetailResponse,
+ ProductListResponse, ProductResponse,
+ ProductUpdate)
router = APIRouter(prefix="/products")
logger = logging.getLogger(__name__)
@@ -29,13 +25,13 @@ logger = logging.getLogger(__name__)
@router.get("", response_model=ProductListResponse)
def get_vendor_products(
- skip: int = Query(0, ge=0),
- limit: int = Query(100, ge=1, le=1000),
- is_active: Optional[bool] = Query(None),
- is_featured: Optional[bool] = Query(None),
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ skip: int = Query(0, ge=0),
+ limit: int = Query(100, ge=1, le=1000),
+ is_active: Optional[bool] = Query(None),
+ is_featured: Optional[bool] = Query(None),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Get all products in vendor catalog.
@@ -50,29 +46,27 @@ def get_vendor_products(
skip=skip,
limit=limit,
is_active=is_active,
- is_featured=is_featured
+ is_featured=is_featured,
)
return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products],
total=total,
skip=skip,
- limit=limit
+ limit=limit,
)
@router.get("/{product_id}", response_model=ProductDetailResponse)
def get_product_details(
- product_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get detailed product information including inventory."""
product = product_service.get_product(
- db=db,
- vendor_id=vendor.id,
- product_id=product_id
+ db=db, vendor_id=vendor.id, product_id=product_id
)
return ProductDetailResponse.model_validate(product)
@@ -80,10 +74,10 @@ def get_product_details(
@router.post("", response_model=ProductResponse)
def add_product_to_catalog(
- product_data: ProductCreate,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_data: ProductCreate,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Add a product from marketplace to vendor catalog.
@@ -91,9 +85,7 @@ def add_product_to_catalog(
This publishes a MarketplaceProduct to the vendor's public catalog.
"""
product = product_service.create_product(
- db=db,
- vendor_id=vendor.id,
- product_data=product_data
+ db=db, vendor_id=vendor.id, product_data=product_data
)
logger.info(
@@ -106,18 +98,15 @@ def add_product_to_catalog(
@router.put("/{product_id}", response_model=ProductResponse)
def update_product(
- product_id: int,
- product_data: ProductUpdate,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_id: int,
+ product_data: ProductUpdate,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Update product in vendor catalog."""
product = product_service.update_product(
- db=db,
- vendor_id=vendor.id,
- product_id=product_id,
- product_update=product_data
+ db=db, vendor_id=vendor.id, product_id=product_id, product_update=product_data
)
logger.info(
@@ -130,17 +119,13 @@ def update_product(
@router.delete("/{product_id}")
def remove_product_from_catalog(
- product_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Remove product from vendor catalog."""
- product_service.delete_product(
- db=db,
- vendor_id=vendor.id,
- product_id=product_id
- )
+ product_service.delete_product(db=db, vendor_id=vendor.id, product_id=product_id)
logger.info(
f"Product {product_id} removed from catalog by user {current_user.username} "
@@ -152,10 +137,10 @@ def remove_product_from_catalog(
@router.post("/from-import/{marketplace_product_id}", response_model=ProductResponse)
def publish_from_marketplace(
- marketplace_product_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ marketplace_product_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""
Publish a marketplace product to vendor catalog.
@@ -163,14 +148,11 @@ def publish_from_marketplace(
Shortcut endpoint for publishing directly from marketplace import.
"""
product_data = ProductCreate(
- marketplace_product_id=marketplace_product_id,
- is_active=True
+ marketplace_product_id=marketplace_product_id, is_active=True
)
product = product_service.create_product(
- db=db,
- vendor_id=vendor.id,
- product_data=product_data
+ db=db, vendor_id=vendor.id, product_data=product_data
)
logger.info(
@@ -183,10 +165,10 @@ def publish_from_marketplace(
@router.put("/{product_id}/toggle-active")
def toggle_product_active(
- product_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Toggle product active status."""
product = product_service.get_product(db, vendor.id, product_id)
@@ -198,18 +180,15 @@ def toggle_product_active(
status = "activated" if product.is_active else "deactivated"
logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}")
- return {
- "message": f"Product {status}",
- "is_active": product.is_active
- }
+ return {"message": f"Product {status}", "is_active": product.is_active}
@router.put("/{product_id}/toggle-featured")
def toggle_product_featured(
- product_id: int,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ product_id: int,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Toggle product featured status."""
product = product_service.get_product(db, vendor.id, product_id)
@@ -221,7 +200,4 @@ def toggle_product_featured(
status = "featured" if product.is_featured else "unfeatured"
logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}")
- return {
- "message": f"Product {status}",
- "is_featured": product.is_featured
- }
+ return {"message": f"Product {status}", "is_featured": product.is_featured}
diff --git a/app/api/v1/vendor/profile.py b/app/api/v1/vendor/profile.py
index b04f88bf..2161b493 100644
--- a/app/api/v1/vendor/profile.py
+++ b/app/api/v1/vendor/profile.py
@@ -4,16 +4,17 @@ Vendor profile management endpoints.
"""
import logging
+
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.vendor_service import vendor_service
-from models.schema.vendor import VendorUpdate, VendorResponse
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
+from models.schema.vendor import VendorResponse, VendorUpdate
router = APIRouter(prefix="/profile")
logger = logging.getLogger(__name__)
@@ -21,9 +22,9 @@ logger = logging.getLogger(__name__)
@router.get("", response_model=VendorResponse)
def get_vendor_profile(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get current vendor profile information."""
return vendor
@@ -31,10 +32,10 @@ def get_vendor_profile(
@router.put("", response_model=VendorResponse)
def update_vendor_profile(
- vendor_update: VendorUpdate,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor_update: VendorUpdate,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Update vendor profile information."""
# Verify user has permission to update vendor
diff --git a/app/api/v1/vendor/settings.py b/app/api/v1/vendor/settings.py
index 7391bd91..0ab97b5c 100644
--- a/app/api/v1/vendor/settings.py
+++ b/app/api/v1/vendor/settings.py
@@ -4,13 +4,14 @@ Vendor settings and configuration endpoints.
"""
import logging
+
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
-from middleware.vendor_context import require_vendor_context
from app.services.vendor_service import vendor_service
+from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
@@ -20,9 +21,9 @@ logger = logging.getLogger(__name__)
@router.get("")
def get_vendor_settings(
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Get vendor settings and configuration."""
return {
@@ -44,10 +45,10 @@ def get_vendor_settings(
@router.put("/marketplace")
def update_marketplace_settings(
- marketplace_config: dict,
- vendor: Vendor = Depends(require_vendor_context()),
- current_user: User = Depends(get_current_vendor_api),
- db: Session = Depends(get_db),
+ marketplace_config: dict,
+ vendor: Vendor = Depends(require_vendor_context()),
+ current_user: User = Depends(get_current_vendor_api),
+ db: Session = Depends(get_db),
):
"""Update marketplace integration settings."""
# Verify permissions
diff --git a/app/api/v1/vendor/team.py b/app/api/v1/vendor/team.py
index ad4108e2..fe84c858 100644
--- a/app/api/v1/vendor/team.py
+++ b/app/api/v1/vendor/team.py
@@ -12,35 +12,24 @@ Implements complete team management with:
import logging
from typing import List
+
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
+from app.api.deps import (get_current_vendor_api, get_user_permissions,
+ require_vendor_owner, require_vendor_permission)
from app.core.database import get_db
from app.core.permissions import VendorPermissions
-from app.api.deps import (
- get_current_vendor_api,
- require_vendor_owner,
- require_vendor_permission,
- get_user_permissions
-)
from app.services.vendor_team_service import vendor_team_service
from models.database.user import User
from models.database.vendor import Vendor
-from models.schema.team import (
- TeamMemberInvite,
- TeamMemberUpdate,
- TeamMemberResponse,
- TeamMemberListResponse,
- InvitationAccept,
- InvitationResponse,
- InvitationAcceptResponse,
- RoleResponse,
- RoleListResponse,
- UserPermissionsResponse,
- TeamStatistics,
- BulkRemoveRequest,
- BulkRemoveResponse,
-)
+from models.schema.team import (BulkRemoveRequest, BulkRemoveResponse,
+ InvitationAccept, InvitationAcceptResponse,
+ InvitationResponse, RoleListResponse,
+ RoleResponse, TeamMemberInvite,
+ TeamMemberListResponse, TeamMemberResponse,
+ TeamMemberUpdate, TeamStatistics,
+ UserPermissionsResponse)
router = APIRouter(prefix="/team")
logger = logging.getLogger(__name__)
@@ -50,14 +39,15 @@ logger = logging.getLogger(__name__)
# Team Member Routes
# ============================================================================
+
@router.get("/members", response_model=TeamMemberListResponse)
def list_team_members(
- request: Request,
- include_inactive: bool = False,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_permission(
- VendorPermissions.TEAM_VIEW.value
- ))
+ request: Request,
+ include_inactive: bool = False,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(
+ require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
+ ),
):
"""
Get all team members for current vendor.
@@ -74,9 +64,7 @@ def list_team_members(
vendor = request.state.vendor
members = vendor_team_service.get_team_members(
- db=db,
- vendor=vendor,
- include_inactive=include_inactive
+ db=db, vendor=vendor, include_inactive=include_inactive
)
# Calculate statistics
@@ -90,19 +78,16 @@ def list_team_members(
)
return TeamMemberListResponse(
- members=members,
- total=total,
- active_count=active,
- pending_invitations=pending
+ members=members, total=total, active_count=active, pending_invitations=pending
)
@router.post("/invite", response_model=InvitationResponse)
def invite_team_member(
- invitation: TeamMemberInvite,
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_owner) # Owner only
+ invitation: TeamMemberInvite,
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(require_vendor_owner), # Owner only
):
"""
Invite a new team member to the vendor.
@@ -135,7 +120,7 @@ def invite_team_member(
vendor=vendor,
inviter=current_user,
email=invitation.email,
- role_id=invitation.role_id
+ role_id=invitation.role_id,
)
elif invitation.role_name:
# Use role name with optional custom permissions
@@ -145,7 +130,7 @@ def invite_team_member(
inviter=current_user,
email=invitation.email,
role_name=invitation.role_name,
- custom_permissions=invitation.custom_permissions
+ custom_permissions=invitation.custom_permissions,
)
else:
# Default to Staff role
@@ -154,7 +139,7 @@ def invite_team_member(
vendor=vendor,
inviter=current_user,
email=invitation.email,
- role_name="staff"
+ role_name="staff",
)
logger.info(
@@ -166,15 +151,12 @@ def invite_team_member(
message="Invitation sent successfully",
email=result["email"],
role=result["role"],
- invitation_sent=True
+ invitation_sent=True,
)
@router.post("/accept-invitation", response_model=InvitationAcceptResponse)
-def accept_invitation(
- acceptance: InvitationAccept,
- db: Session = Depends(get_db)
-):
+def accept_invitation(acceptance: InvitationAccept, db: Session = Depends(get_db)):
"""
Accept a team invitation and activate account.
@@ -196,7 +178,7 @@ def accept_invitation(
invitation_token=acceptance.invitation_token,
password=acceptance.password,
first_name=acceptance.first_name,
- last_name=acceptance.last_name
+ last_name=acceptance.last_name,
)
logger.info(
@@ -210,26 +192,26 @@ def accept_invitation(
"id": result["vendor"].id,
"vendor_code": result["vendor"].vendor_code,
"name": result["vendor"].name,
- "subdomain": result["vendor"].subdomain
+ "subdomain": result["vendor"].subdomain,
},
user={
"id": result["user"].id,
"email": result["user"].email,
"username": result["user"].username,
- "full_name": result["user"].full_name
+ "full_name": result["user"].full_name,
},
- role=result["role"]
+ role=result["role"],
)
@router.get("/members/{user_id}", response_model=TeamMemberResponse)
def get_team_member(
- user_id: int,
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_permission(
- VendorPermissions.TEAM_VIEW.value
- ))
+ user_id: int,
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(
+ require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
+ ),
):
"""
Get details of a specific team member.
@@ -239,14 +221,13 @@ def get_team_member(
vendor = request.state.vendor
members = vendor_team_service.get_team_members(
- db=db,
- vendor=vendor,
- include_inactive=True
+ db=db, vendor=vendor, include_inactive=True
)
member = next((m for m in members if m["id"] == user_id), None)
if not member:
from app.exceptions import UserNotFoundException
+
raise UserNotFoundException(str(user_id))
return TeamMemberResponse(**member)
@@ -254,11 +235,11 @@ def get_team_member(
@router.put("/members/{user_id}", response_model=TeamMemberResponse)
def update_team_member(
- user_id: int,
- update_data: TeamMemberUpdate,
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_owner) # Owner only
+ user_id: int,
+ update_data: TeamMemberUpdate,
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(require_vendor_owner), # Owner only
):
"""
Update a team member's role or status.
@@ -280,7 +261,7 @@ def update_team_member(
vendor=vendor,
user_id=user_id,
new_role_id=update_data.role_id,
- is_active=update_data.is_active
+ is_active=update_data.is_active,
)
logger.info(
@@ -297,10 +278,10 @@ def update_team_member(
@router.delete("/members/{user_id}")
def remove_team_member(
- user_id: int,
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_owner) # Owner only
+ user_id: int,
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(require_vendor_owner), # Owner only
):
"""
Remove a team member from the vendor.
@@ -316,29 +297,22 @@ def remove_team_member(
"""
vendor = request.state.vendor
- vendor_team_service.remove_team_member(
- db=db,
- vendor=vendor,
- user_id=user_id
- )
+ vendor_team_service.remove_team_member(db=db, vendor=vendor, user_id=user_id)
logger.info(
f"Team member removed: {user_id} from {vendor.vendor_code} "
f"by {current_user.username}"
)
- return {
- "message": "Team member removed successfully",
- "user_id": user_id
- }
+ return {"message": "Team member removed successfully", "user_id": user_id}
@router.post("/members/bulk-remove", response_model=BulkRemoveResponse)
def bulk_remove_team_members(
- bulk_remove: BulkRemoveRequest,
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_owner)
+ bulk_remove: BulkRemoveRequest,
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(require_vendor_owner),
):
"""
Remove multiple team members at once.
@@ -354,17 +328,12 @@ def bulk_remove_team_members(
for user_id in bulk_remove.user_ids:
try:
vendor_team_service.remove_team_member(
- db=db,
- vendor=vendor,
- user_id=user_id
+ db=db, vendor=vendor, user_id=user_id
)
success_count += 1
except Exception as e:
failed_count += 1
- errors.append({
- "user_id": user_id,
- "error": str(e)
- })
+ errors.append({"user_id": user_id, "error": str(e)})
logger.info(
f"Bulk remove completed: {success_count} removed, {failed_count} failed "
@@ -372,9 +341,7 @@ def bulk_remove_team_members(
)
return BulkRemoveResponse(
- success_count=success_count,
- failed_count=failed_count,
- errors=errors
+ success_count=success_count, failed_count=failed_count, errors=errors
)
@@ -382,13 +349,14 @@ def bulk_remove_team_members(
# Role Management Routes
# ============================================================================
+
@router.get("/roles", response_model=RoleListResponse)
def list_roles(
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_permission(
- VendorPermissions.TEAM_VIEW.value
- ))
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(
+ require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
+ ),
):
"""
Get all available roles for the vendor.
@@ -403,21 +371,19 @@ def list_roles(
roles = vendor_team_service.get_vendor_roles(db=db, vendor_id=vendor.id)
- return RoleListResponse(
- roles=roles,
- total=len(roles)
- )
+ return RoleListResponse(roles=roles, total=len(roles))
# ============================================================================
# Permission Routes
# ============================================================================
+
@router.get("/me/permissions", response_model=UserPermissionsResponse)
def get_my_permissions(
- request: Request,
- permissions: List[str] = Depends(get_user_permissions),
- current_user: User = Depends(get_current_vendor_api)
+ request: Request,
+ permissions: List[str] = Depends(get_user_permissions),
+ current_user: User = Depends(get_current_vendor_api),
):
"""
Get current user's permissions in this vendor.
@@ -443,7 +409,7 @@ def get_my_permissions(
permissions=permissions,
permission_count=len(permissions),
is_owner=is_owner,
- role_name=role_name
+ role_name=role_name,
)
@@ -451,13 +417,14 @@ def get_my_permissions(
# Statistics Routes
# ============================================================================
+
@router.get("/statistics", response_model=TeamStatistics)
def get_team_statistics(
- request: Request,
- db: Session = Depends(get_db),
- current_user: User = Depends(require_vendor_permission(
- VendorPermissions.TEAM_VIEW.value
- ))
+ request: Request,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(
+ require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
+ ),
):
"""
Get team statistics for the vendor.
@@ -474,9 +441,7 @@ def get_team_statistics(
vendor = request.state.vendor
members = vendor_team_service.get_team_members(
- db=db,
- vendor=vendor,
- include_inactive=True
+ db=db, vendor=vendor, include_inactive=True
)
# Calculate statistics
@@ -500,5 +465,5 @@ def get_team_statistics(
pending_invitations=pending,
owners=owners,
team_members=team_members,
- roles_breakdown=roles_breakdown
+ roles_breakdown=roles_breakdown,
)
diff --git a/app/core/config.py b/app/core/config.py
index 59a5793b..ed0b2641 100644
--- a/app/core/config.py
+++ b/app/core/config.py
@@ -14,6 +14,7 @@ This module focuses purely on configuration storage and validation.
"""
from typing import List, Optional
+
from pydantic_settings import BaseSettings
@@ -137,13 +138,9 @@ settings = Settings()
# ENVIRONMENT UTILITIES - Module-level functions
# =============================================================================
# Import environment detection utilities
-from app.core.environment import (
- get_environment,
- is_development,
- is_production,
- is_staging,
- should_use_secure_cookies
-)
+from app.core.environment import (get_environment, is_development,
+ is_production, is_staging,
+ should_use_secure_cookies)
def get_current_environment() -> str:
@@ -190,6 +187,7 @@ def is_staging_environment() -> bool:
# VALIDATION FUNCTIONS
# =============================================================================
+
def validate_production_settings() -> List[str]:
"""
Validate settings for production environment.
@@ -243,22 +241,19 @@ def print_environment_info():
# =============================================================================
__all__ = [
# Settings singleton
- 'settings',
-
+ "settings",
# Environment detection (re-exported from app.core.environment)
- 'get_environment',
- 'is_development',
- 'is_production',
- 'is_staging',
- 'should_use_secure_cookies',
-
+ "get_environment",
+ "is_development",
+ "is_production",
+ "is_staging",
+ "should_use_secure_cookies",
# Convenience functions
- 'get_current_environment',
- 'is_production_environment',
- 'is_development_environment',
- 'is_staging_environment',
-
+ "get_current_environment",
+ "is_production_environment",
+ "is_development_environment",
+ "is_staging_environment",
# Validation
- 'validate_production_settings',
- 'print_environment_info',
+ "validate_production_settings",
+ "print_environment_info",
]
diff --git a/app/core/environment.py b/app/core/environment.py
index e34e6fc4..0a430311 100644
--- a/app/core/environment.py
+++ b/app/core/environment.py
@@ -15,7 +15,7 @@ EnvironmentType = Literal["development", "staging", "production"]
def get_environment() -> EnvironmentType:
"""
Detect current environment automatically.
-
+
Detection logic:
1. Check ENV environment variable if set
2. Check ENVIRONMENT environment variable if set
@@ -23,7 +23,7 @@ def get_environment() -> EnvironmentType:
- localhost, 127.0.0.1 → development
- Contains 'staging' → staging
- Otherwise → production (safe default)
-
+
Returns:
str: 'development', 'staging', or 'production'
"""
@@ -35,7 +35,7 @@ def get_environment() -> EnvironmentType:
return "staging"
elif env in ["production", "prod"]:
return "production"
-
+
# Priority 2: ENVIRONMENT variable
env = os.getenv("ENVIRONMENT", "").lower()
if env in ["development", "dev", "local"]:
@@ -44,22 +44,25 @@ def get_environment() -> EnvironmentType:
return "staging"
elif env in ["production", "prod"]:
return "production"
-
+
# Priority 3: Auto-detect from common indicators
-
+
# Check if running in debug mode (common in development)
if os.getenv("DEBUG", "").lower() in ["true", "1", "yes"]:
return "development"
-
+
# Check common development indicators
hostname = os.getenv("HOSTNAME", "").lower()
- if any(dev_indicator in hostname for dev_indicator in ["local", "dev", "laptop", "desktop"]):
+ if any(
+ dev_indicator in hostname
+ for dev_indicator in ["local", "dev", "laptop", "desktop"]
+ ):
return "development"
-
+
# Check for staging indicators
if "staging" in hostname or "stage" in hostname:
return "staging"
-
+
# Default to development for safety (HTTPS not required in dev)
# Change this to "production" if you prefer secure-by-default
return "development"
@@ -83,7 +86,7 @@ def is_production() -> bool:
def should_use_secure_cookies() -> bool:
"""
Determine if cookies should have secure flag (HTTPS only).
-
+
Returns:
bool: True if production or staging, False if development
"""
@@ -97,7 +100,7 @@ _cached_environment: EnvironmentType | None = None
def get_cached_environment() -> EnvironmentType:
"""
Get environment with caching.
-
+
Environment is detected once and cached for performance.
Useful if you call this frequently.
"""
diff --git a/app/core/lifespan.py b/app/core/lifespan.py
index 6c0731c6..f26a8363 100644
--- a/app/core/lifespan.py
+++ b/app/core/lifespan.py
@@ -15,11 +15,13 @@ from fastapi import FastAPI
from sqlalchemy import text
from middleware.auth import AuthManager
-# Remove this import if not needed: from models.database.base import Base
from .database import SessionLocal, engine
from .logging import setup_logging
+# Remove this import if not needed: from models.database.base import Base
+
+
logger = logging.getLogger(__name__)
auth_manager = AuthManager()
@@ -46,7 +48,9 @@ def check_database_ready():
try:
with engine.connect() as conn:
# Try to query a table that should exist
- result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1"))
+ result = conn.execute(
+ text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1")
+ )
tables = result.fetchall()
return len(tables) > 0
except Exception:
@@ -93,6 +97,7 @@ def verify_startup_requirements():
logger.info("[OK] Startup verification passed")
return True
+
# You can call this in your main.py if desired:
# if not verify_startup_requirements():
# raise RuntimeError("Application startup requirements not met")
diff --git a/app/core/permissions.py b/app/core/permissions.py
index 5651abf9..2a04968c 100644
--- a/app/core/permissions.py
+++ b/app/core/permissions.py
@@ -17,6 +17,7 @@ class VendorPermissions(str, Enum):
Naming convention: RESOURCE_ACTION
"""
+
# Dashboard
DASHBOARD_VIEW = "dashboard.view"
@@ -166,17 +167,23 @@ class PermissionChecker:
return required_permission in permissions
@staticmethod
- def has_any_permission(permissions: List[str], required_permissions: List[str]) -> bool:
+ def has_any_permission(
+ permissions: List[str], required_permissions: List[str]
+ ) -> bool:
"""Check if a permission list contains ANY of the required permissions."""
return any(perm in permissions for perm in required_permissions)
@staticmethod
- def has_all_permissions(permissions: List[str], required_permissions: List[str]) -> bool:
+ def has_all_permissions(
+ permissions: List[str], required_permissions: List[str]
+ ) -> bool:
"""Check if a permission list contains ALL of the required permissions."""
return all(perm in permissions for perm in required_permissions)
@staticmethod
- def get_missing_permissions(permissions: List[str], required_permissions: List[str]) -> List[str]:
+ def get_missing_permissions(
+ permissions: List[str], required_permissions: List[str]
+ ) -> List[str]:
"""Get list of missing permissions."""
return [perm for perm in required_permissions if perm not in permissions]
diff --git a/app/core/theme_presets.py b/app/core/theme_presets.py
index 7b091ad3..20eea4a4 100644
--- a/app/core/theme_presets.py
+++ b/app/core/theme_presets.py
@@ -16,19 +16,11 @@ THEME_PRESETS = {
"accent": "#ec4899", # Pink
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
- "border": "#e5e7eb" # Gray-200
+ "border": "#e5e7eb", # Gray-200
},
- "fonts": {
- "heading": "Inter, sans-serif",
- "body": "Inter, sans-serif"
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "modern"
- }
+ "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
},
-
"modern": {
"colors": {
"primary": "#6366f1", # Indigo - Modern tech look
@@ -36,19 +28,11 @@ THEME_PRESETS = {
"accent": "#ec4899", # Pink
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
- "border": "#e5e7eb" # Gray-200
+ "border": "#e5e7eb", # Gray-200
},
- "fonts": {
- "heading": "Inter, sans-serif",
- "body": "Inter, sans-serif"
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "modern"
- }
+ "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
},
-
"classic": {
"colors": {
"primary": "#1e40af", # Dark blue - Traditional
@@ -56,19 +40,11 @@ THEME_PRESETS = {
"accent": "#dc2626", # Red
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
- "border": "#d1d5db" # Gray-300
+ "border": "#d1d5db", # Gray-300
},
- "fonts": {
- "heading": "Georgia, serif",
- "body": "Arial, sans-serif"
- },
- "layout": {
- "style": "list",
- "header": "static",
- "product_card": "classic"
- }
+ "fonts": {"heading": "Georgia, serif", "body": "Arial, sans-serif"},
+ "layout": {"style": "list", "header": "static", "product_card": "classic"},
},
-
"minimal": {
"colors": {
"primary": "#000000", # Black - Ultra minimal
@@ -76,19 +52,11 @@ THEME_PRESETS = {
"accent": "#666666", # Medium gray
"background": "#ffffff", # White
"text": "#000000", # Black
- "border": "#e5e7eb" # Light gray
+ "border": "#e5e7eb", # Light gray
},
- "fonts": {
- "heading": "Helvetica, sans-serif",
- "body": "Helvetica, sans-serif"
- },
- "layout": {
- "style": "grid",
- "header": "transparent",
- "product_card": "minimal"
- }
+ "fonts": {"heading": "Helvetica, sans-serif", "body": "Helvetica, sans-serif"},
+ "layout": {"style": "grid", "header": "transparent", "product_card": "minimal"},
},
-
"vibrant": {
"colors": {
"primary": "#f59e0b", # Orange - Bold & energetic
@@ -96,19 +64,11 @@ THEME_PRESETS = {
"accent": "#8b5cf6", # Purple
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
- "border": "#fbbf24" # Yellow
+ "border": "#fbbf24", # Yellow
},
- "fonts": {
- "heading": "Poppins, sans-serif",
- "body": "Open Sans, sans-serif"
- },
- "layout": {
- "style": "masonry",
- "header": "fixed",
- "product_card": "modern"
- }
+ "fonts": {"heading": "Poppins, sans-serif", "body": "Open Sans, sans-serif"},
+ "layout": {"style": "masonry", "header": "fixed", "product_card": "modern"},
},
-
"elegant": {
"colors": {
"primary": "#6b7280", # Gray - Sophisticated
@@ -116,19 +76,11 @@ THEME_PRESETS = {
"accent": "#d97706", # Amber
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
- "border": "#e5e7eb" # Gray-200
+ "border": "#e5e7eb", # Gray-200
},
- "fonts": {
- "heading": "Playfair Display, serif",
- "body": "Lato, sans-serif"
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "classic"
- }
+ "fonts": {"heading": "Playfair Display, serif", "body": "Lato, sans-serif"},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "classic"},
},
-
"nature": {
"colors": {
"primary": "#059669", # Green - Natural & eco
@@ -136,18 +88,11 @@ THEME_PRESETS = {
"accent": "#f59e0b", # Amber
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
- "border": "#d1fae5" # Light green
+ "border": "#d1fae5", # Light green
},
- "fonts": {
- "heading": "Montserrat, sans-serif",
- "body": "Open Sans, sans-serif"
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "modern"
- }
- }
+ "fonts": {"heading": "Montserrat, sans-serif", "body": "Open Sans, sans-serif"},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
+ },
}
@@ -243,7 +188,7 @@ def get_preset_preview(preset_name: str) -> dict:
"minimal": "Ultra-clean black and white aesthetic",
"vibrant": "Bold and energetic with bright accent colors",
"elegant": "Sophisticated gray tones with refined typography",
- "nature": "Fresh and eco-friendly green color palette"
+ "nature": "Fresh and eco-friendly green color palette",
}
return {
@@ -259,10 +204,7 @@ def get_preset_preview(preset_name: str) -> dict:
def create_custom_preset(
- colors: dict,
- fonts: dict,
- layout: dict,
- name: str = "custom"
+ colors: dict, fonts: dict, layout: dict, name: str = "custom"
) -> dict:
"""
Create a custom preset from provided settings.
@@ -304,8 +246,4 @@ def create_custom_preset(
if "product_card" not in layout:
layout["product_card"] = "modern"
- return {
- "colors": colors,
- "fonts": fonts,
- "layout": layout
- }
+ return {"colors": colors, "fonts": fonts, "layout": layout}
diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py
index 19f5fd2c..fa7f17f1 100644
--- a/app/exceptions/__init__.py
+++ b/app/exceptions/__init__.py
@@ -6,179 +6,109 @@ This module provides frontend-friendly exceptions with consistent error codes,
messages, and HTTP status mappings.
"""
-# Base exceptions
-from .base import (
- WizamartException,
- ValidationException,
- AuthenticationException,
- AuthorizationException,
- ResourceNotFoundException,
- ConflictException,
- BusinessLogicException,
- ExternalServiceException,
- RateLimitException,
- ServiceUnavailableException,
-)
-
-# Authentication exceptions
-from .auth import (
- InvalidCredentialsException,
- TokenExpiredException,
- InvalidTokenException,
- InsufficientPermissionsException,
- UserNotActiveException,
- AdminRequiredException,
- UserAlreadyExistsException
-)
-
# Admin exceptions
-from .admin import (
- UserNotFoundException,
- UserStatusChangeException,
- VendorVerificationException,
- AdminOperationException,
- CannotModifyAdminException,
- CannotModifySelfException,
- InvalidAdminActionException,
- BulkOperationException,
- ConfirmationRequiredException,
-)
-
-# Marketplace import job exceptions
-from .marketplace_import_job import (
- MarketplaceImportException,
- ImportJobNotFoundException,
- ImportJobNotOwnedException,
- InvalidImportDataException,
- ImportJobCannotBeCancelledException,
- ImportJobCannotBeDeletedException,
- MarketplaceConnectionException,
- MarketplaceDataParsingException,
- ImportRateLimitException,
- InvalidMarketplaceException,
- ImportJobAlreadyProcessingException,
-)
-
-# Marketplace product exceptions
-from .marketplace_product import (
- MarketplaceProductNotFoundException,
- MarketplaceProductAlreadyExistsException,
- InvalidMarketplaceProductDataException,
- MarketplaceProductValidationException,
- InvalidGTINException,
- MarketplaceProductCSVImportException,
-)
-
-# Inventory exceptions
-from .inventory import (
- InventoryNotFoundException,
- InsufficientInventoryException,
- InvalidInventoryOperationException,
- InventoryValidationException,
- NegativeInventoryException,
- InvalidQuantityException,
- LocationNotFoundException
-)
-
-# Vendor exceptions
-from .vendor import (
- VendorNotFoundException,
- VendorAlreadyExistsException,
- VendorNotActiveException,
- VendorNotVerifiedException,
- UnauthorizedVendorAccessException,
- InvalidVendorDataException,
- MaxVendorsReachedException,
- VendorValidationException,
-)
-
-# Vendor domain exceptions
-from .vendor_domain import (
- VendorDomainNotFoundException,
- VendorDomainAlreadyExistsException,
- InvalidDomainFormatException,
- ReservedDomainException,
- DomainNotVerifiedException,
- DomainVerificationFailedException,
- DomainAlreadyVerifiedException,
- MultiplePrimaryDomainsException,
- DNSVerificationException,
- MaxDomainsReachedException,
- UnauthorizedDomainAccessException,
-)
-
-# Vendor theme exceptions
-from .vendor_theme import (
- VendorThemeNotFoundException,
- InvalidThemeDataException,
- ThemePresetNotFoundException,
- ThemeValidationException,
- ThemePresetAlreadyAppliedException,
- InvalidColorFormatException,
- InvalidFontFamilyException,
- ThemeOperationException,
-)
-
-# Customer exceptions
-from .customer import (
- CustomerNotFoundException,
- CustomerAlreadyExistsException,
- DuplicateCustomerEmailException,
- CustomerNotActiveException,
- InvalidCustomerCredentialsException,
- CustomerValidationException,
- CustomerAuthorizationException,
-)
-
-# Team exceptions
-from .team import (
- TeamMemberNotFoundException,
- TeamMemberAlreadyExistsException,
- TeamInvitationNotFoundException,
- TeamInvitationExpiredException,
- TeamInvitationAlreadyAcceptedException,
- UnauthorizedTeamActionException,
- CannotRemoveOwnerException,
- CannotModifyOwnRoleException,
- RoleNotFoundException,
- InvalidRoleException,
- InsufficientTeamPermissionsException,
- MaxTeamMembersReachedException,
- TeamValidationException,
- InvalidInvitationDataException,
- InvalidInvitationTokenException,
-)
-
-# Product exceptions
-from .product import (
- ProductNotFoundException,
- ProductAlreadyExistsException,
- ProductNotInCatalogException,
- ProductNotActiveException,
- InvalidProductDataException,
- ProductValidationException,
- CannotDeleteProductWithInventoryException,
- CannotDeleteProductWithOrdersException,
-)
-
-# Order exceptions
-from .order import (
- OrderNotFoundException,
- OrderAlreadyExistsException,
- OrderValidationException,
- InvalidOrderStatusException,
- OrderCannotBeCancelledException,
-)
-
+from .admin import (AdminOperationException, BulkOperationException,
+ CannotModifyAdminException, CannotModifySelfException,
+ ConfirmationRequiredException, InvalidAdminActionException,
+ UserNotFoundException, UserStatusChangeException,
+ VendorVerificationException)
+# Authentication exceptions
+from .auth import (AdminRequiredException, InsufficientPermissionsException,
+ InvalidCredentialsException, InvalidTokenException,
+ TokenExpiredException, UserAlreadyExistsException,
+ UserNotActiveException)
+# Base exceptions
+from .base import (AuthenticationException, AuthorizationException,
+ BusinessLogicException, ConflictException,
+ ExternalServiceException, RateLimitException,
+ ResourceNotFoundException, ServiceUnavailableException,
+ ValidationException, WizamartException)
# Cart exceptions
-from .cart import (
- CartItemNotFoundException,
- EmptyCartException,
- CartValidationException,
- InsufficientInventoryForCartException,
- InvalidCartQuantityException,
- ProductNotAvailableForCartException,
-)
+from .cart import (CartItemNotFoundException, CartValidationException,
+ EmptyCartException, InsufficientInventoryForCartException,
+ InvalidCartQuantityException,
+ ProductNotAvailableForCartException)
+# Customer exceptions
+from .customer import (CustomerAlreadyExistsException,
+ CustomerAuthorizationException,
+ CustomerNotActiveException, CustomerNotFoundException,
+ CustomerValidationException,
+ DuplicateCustomerEmailException,
+ InvalidCustomerCredentialsException)
+# Inventory exceptions
+from .inventory import (InsufficientInventoryException,
+ InvalidInventoryOperationException,
+ InvalidQuantityException, InventoryNotFoundException,
+ InventoryValidationException,
+ LocationNotFoundException, NegativeInventoryException)
+# Marketplace import job exceptions
+from .marketplace_import_job import (ImportJobAlreadyProcessingException,
+ ImportJobCannotBeCancelledException,
+ ImportJobCannotBeDeletedException,
+ ImportJobNotFoundException,
+ ImportJobNotOwnedException,
+ ImportRateLimitException,
+ InvalidImportDataException,
+ InvalidMarketplaceException,
+ MarketplaceConnectionException,
+ MarketplaceDataParsingException,
+ MarketplaceImportException)
+# Marketplace product exceptions
+from .marketplace_product import (InvalidGTINException,
+ InvalidMarketplaceProductDataException,
+ MarketplaceProductAlreadyExistsException,
+ MarketplaceProductCSVImportException,
+ MarketplaceProductNotFoundException,
+ MarketplaceProductValidationException)
+# Order exceptions
+from .order import (InvalidOrderStatusException, OrderAlreadyExistsException,
+ OrderCannotBeCancelledException, OrderNotFoundException,
+ OrderValidationException)
+# Product exceptions
+from .product import (CannotDeleteProductWithInventoryException,
+ CannotDeleteProductWithOrdersException,
+ InvalidProductDataException,
+ ProductAlreadyExistsException, ProductNotActiveException,
+ ProductNotFoundException, ProductNotInCatalogException,
+ ProductValidationException)
+# Team exceptions
+from .team import (CannotModifyOwnRoleException, CannotRemoveOwnerException,
+ InsufficientTeamPermissionsException,
+ InvalidInvitationDataException,
+ InvalidInvitationTokenException, InvalidRoleException,
+ MaxTeamMembersReachedException, RoleNotFoundException,
+ TeamInvitationAlreadyAcceptedException,
+ TeamInvitationExpiredException,
+ TeamInvitationNotFoundException,
+ TeamMemberAlreadyExistsException,
+ TeamMemberNotFoundException, TeamValidationException,
+ UnauthorizedTeamActionException)
+# Vendor exceptions
+from .vendor import (InvalidVendorDataException, MaxVendorsReachedException,
+ UnauthorizedVendorAccessException,
+ VendorAlreadyExistsException, VendorNotActiveException,
+ VendorNotFoundException, VendorNotVerifiedException,
+ VendorValidationException)
+# Vendor domain exceptions
+from .vendor_domain import (DNSVerificationException,
+ DomainAlreadyVerifiedException,
+ DomainNotVerifiedException,
+ DomainVerificationFailedException,
+ InvalidDomainFormatException,
+ MaxDomainsReachedException,
+ MultiplePrimaryDomainsException,
+ ReservedDomainException,
+ UnauthorizedDomainAccessException,
+ VendorDomainAlreadyExistsException,
+ VendorDomainNotFoundException)
+# Vendor theme exceptions
+from .vendor_theme import (InvalidColorFormatException,
+ InvalidFontFamilyException,
+ InvalidThemeDataException, ThemeOperationException,
+ ThemePresetAlreadyAppliedException,
+ ThemePresetNotFoundException,
+ ThemeValidationException,
+ VendorThemeNotFoundException)
__all__ = [
# Base exceptions
@@ -192,7 +122,6 @@ __all__ = [
"ExternalServiceException",
"RateLimitException",
"ServiceUnavailableException",
-
# Auth exceptions
"InvalidCredentialsException",
"TokenExpiredException",
@@ -201,7 +130,6 @@ __all__ = [
"UserNotActiveException",
"AdminRequiredException",
"UserAlreadyExistsException",
-
# Customer exceptions
"CustomerNotFoundException",
"CustomerAlreadyExistsException",
@@ -210,7 +138,6 @@ __all__ = [
"InvalidCustomerCredentialsException",
"CustomerValidationException",
"CustomerAuthorizationException",
-
# Team exceptions
"TeamMemberNotFoundException",
"TeamMemberAlreadyExistsException",
@@ -227,7 +154,6 @@ __all__ = [
"TeamValidationException",
"InvalidInvitationDataException",
"InvalidInvitationTokenException",
-
# Inventory exceptions
"InventoryNotFoundException",
"InsufficientInventoryException",
@@ -236,7 +162,6 @@ __all__ = [
"NegativeInventoryException",
"InvalidQuantityException",
"LocationNotFoundException",
-
# Vendor exceptions
"VendorNotFoundException",
"VendorAlreadyExistsException",
@@ -246,7 +171,6 @@ __all__ = [
"InvalidVendorDataException",
"MaxVendorsReachedException",
"VendorValidationException",
-
# Vendor Domain
"VendorDomainNotFoundException",
"VendorDomainAlreadyExistsException",
@@ -259,7 +183,6 @@ __all__ = [
"DNSVerificationException",
"MaxDomainsReachedException",
"UnauthorizedDomainAccessException",
-
# Vendor Theme
"VendorThemeNotFoundException",
"InvalidThemeDataException",
@@ -269,7 +192,6 @@ __all__ = [
"InvalidColorFormatException",
"InvalidFontFamilyException",
"ThemeOperationException",
-
# Product exceptions
"ProductNotFoundException",
"ProductAlreadyExistsException",
@@ -279,14 +201,12 @@ __all__ = [
"ProductValidationException",
"CannotDeleteProductWithInventoryException",
"CannotDeleteProductWithOrdersException",
-
# Order exceptions
"OrderNotFoundException",
"OrderAlreadyExistsException",
"OrderValidationException",
"InvalidOrderStatusException",
"OrderCannotBeCancelledException",
-
# Cart exceptions
"CartItemNotFoundException",
"EmptyCartException",
@@ -294,7 +214,6 @@ __all__ = [
"InsufficientInventoryForCartException",
"InvalidCartQuantityException",
"ProductNotAvailableForCartException",
-
# MarketplaceProduct exceptions
"MarketplaceProductNotFoundException",
"MarketplaceProductAlreadyExistsException",
@@ -302,7 +221,6 @@ __all__ = [
"MarketplaceProductValidationException",
"InvalidGTINException",
"MarketplaceProductCSVImportException",
-
# Marketplace import exceptions
"MarketplaceImportException",
"ImportJobNotFoundException",
@@ -315,7 +233,6 @@ __all__ = [
"ImportRateLimitException",
"InvalidMarketplaceException",
"ImportJobAlreadyProcessingException",
-
# Admin exceptions
"UserNotFoundException",
"UserStatusChangeException",
diff --git a/app/exceptions/admin.py b/app/exceptions/admin.py
index 8df4d62e..c98abdb0 100644
--- a/app/exceptions/admin.py
+++ b/app/exceptions/admin.py
@@ -4,12 +4,9 @@ Admin operations specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- BusinessLogicException,
- AuthorizationException,
- ValidationException
-)
+
+from .base import (AuthorizationException, BusinessLogicException,
+ ResourceNotFoundException, ValidationException)
class UserNotFoundException(ResourceNotFoundException):
@@ -35,11 +32,11 @@ class UserStatusChangeException(BusinessLogicException):
"""Raised when user status cannot be changed."""
def __init__(
- self,
- user_id: int,
- current_status: str,
- attempted_action: str,
- reason: Optional[str] = None,
+ self,
+ user_id: int,
+ current_status: str,
+ attempted_action: str,
+ reason: Optional[str] = None,
):
message = f"Cannot {attempted_action} user {user_id} (current status: {current_status})"
if reason:
@@ -61,10 +58,10 @@ class ShopVerificationException(BusinessLogicException):
"""Raised when shop verification fails."""
def __init__(
- self,
- shop_id: int,
- reason: str,
- current_verification_status: Optional[bool] = None,
+ self,
+ shop_id: int,
+ reason: str,
+ current_verification_status: Optional[bool] = None,
):
details = {
"shop_id": shop_id,
@@ -85,11 +82,11 @@ class AdminOperationException(BusinessLogicException):
"""Raised when admin operation fails."""
def __init__(
- self,
- operation: str,
- reason: str,
- target_type: Optional[str] = None,
- target_id: Optional[str] = None,
+ self,
+ operation: str,
+ reason: str,
+ target_type: Optional[str] = None,
+ target_id: Optional[str] = None,
):
message = f"Admin operation '{operation}' failed: {reason}"
@@ -142,10 +139,10 @@ class InvalidAdminActionException(ValidationException):
"""Raised when admin action is invalid."""
def __init__(
- self,
- action: str,
- reason: str,
- valid_actions: Optional[list] = None,
+ self,
+ action: str,
+ reason: str,
+ valid_actions: Optional[list] = None,
):
details = {
"action": action,
@@ -166,11 +163,11 @@ class BulkOperationException(BusinessLogicException):
"""Raised when bulk admin operation fails."""
def __init__(
- self,
- operation: str,
- total_items: int,
- failed_items: int,
- errors: Optional[Dict[str, Any]] = None,
+ self,
+ operation: str,
+ total_items: int,
+ failed_items: int,
+ errors: Optional[Dict[str, Any]] = None,
):
message = f"Bulk {operation} completed with errors: {failed_items}/{total_items} failed"
@@ -195,10 +192,10 @@ class ConfirmationRequiredException(BusinessLogicException):
"""Raised when a destructive operation requires explicit confirmation."""
def __init__(
- self,
- operation: str,
- message: Optional[str] = None,
- confirmation_param: str = "confirm"
+ self,
+ operation: str,
+ message: Optional[str] = None,
+ confirmation_param: str = "confirm",
):
if not message:
message = f"Operation '{operation}' requires confirmation parameter: {confirmation_param}=true"
@@ -217,10 +214,10 @@ class VendorVerificationException(BusinessLogicException):
"""Raised when vendor verification fails."""
def __init__(
- self,
- vendor_id: int,
- reason: str,
- current_verification_status: Optional[bool] = None,
+ self,
+ vendor_id: int,
+ reason: str,
+ current_verification_status: Optional[bool] = None,
):
details = {
"vendor_id": vendor_id,
diff --git a/app/exceptions/auth.py b/app/exceptions/auth.py
index d78bf32b..48c96c15 100644
--- a/app/exceptions/auth.py
+++ b/app/exceptions/auth.py
@@ -4,7 +4,9 @@ Authentication and authorization specific exceptions.
"""
from typing import Optional
-from .base import AuthenticationException, AuthorizationException, ConflictException
+
+from .base import (AuthenticationException, AuthorizationException,
+ ConflictException)
class InvalidCredentialsException(AuthenticationException):
@@ -41,9 +43,9 @@ class InsufficientPermissionsException(AuthorizationException):
"""Raised when user lacks required permissions for an action."""
def __init__(
- self,
- message: str = "Insufficient permissions for this action",
- required_permission: Optional[str] = None,
+ self,
+ message: str = "Insufficient permissions for this action",
+ required_permission: Optional[str] = None,
):
details = {}
if required_permission:
@@ -80,9 +82,9 @@ class UserAlreadyExistsException(ConflictException):
"""Raised when trying to register with existing username/email."""
def __init__(
- self,
- message: str = "User already exists",
- field: Optional[str] = None,
+ self,
+ message: str = "User already exists",
+ field: Optional[str] = None,
):
details = {}
if field:
diff --git a/app/exceptions/backup.py b/app/exceptions/backup.py
index f8e0394d..5d0229cb 100644
--- a/app/exceptions/backup.py
+++ b/app/exceptions/backup.py
@@ -1 +1 @@
-# Backup/recovery exceptions
+# Backup/recovery exceptions
diff --git a/app/exceptions/base.py b/app/exceptions/base.py
index 79fc589d..5cfe05a3 100644
--- a/app/exceptions/base.py
+++ b/app/exceptions/base.py
@@ -39,8 +39,6 @@ class WizamartException(Exception):
return result
-
-
class ValidationException(WizamartException):
"""Raised when request validation fails."""
@@ -62,8 +60,6 @@ class ValidationException(WizamartException):
)
-
-
class AuthenticationException(WizamartException):
"""Raised when authentication fails."""
@@ -97,6 +93,7 @@ class AuthorizationException(WizamartException):
details=details,
)
+
class ResourceNotFoundException(WizamartException):
"""Raised when a requested resource is not found."""
@@ -122,6 +119,7 @@ class ResourceNotFoundException(WizamartException):
},
)
+
class ConflictException(WizamartException):
"""Raised when a resource conflict occurs."""
@@ -138,6 +136,7 @@ class ConflictException(WizamartException):
details=details,
)
+
class BusinessLogicException(WizamartException):
"""Raised when business logic rules are violated."""
@@ -196,6 +195,7 @@ class RateLimitException(WizamartException):
details=rate_limit_details,
)
+
class ServiceUnavailableException(WizamartException):
"""Raised when service is unavailable."""
@@ -206,6 +206,7 @@ class ServiceUnavailableException(WizamartException):
status_code=503,
)
+
# Note: Domain-specific exceptions like VendorNotFoundException, UserNotFoundException, etc.
# are defined in their respective domain modules (vendor.py, admin.py, etc.)
# to keep domain-specific logic separate from base exceptions.
diff --git a/app/exceptions/cart.py b/app/exceptions/cart.py
index 1a56e527..6a980730 100644
--- a/app/exceptions/cart.py
+++ b/app/exceptions/cart.py
@@ -4,11 +4,9 @@ Shopping cart specific exceptions.
"""
from typing import Optional
-from .base import (
- ResourceNotFoundException,
- ValidationException,
- BusinessLogicException
-)
+
+from .base import (BusinessLogicException, ResourceNotFoundException,
+ ValidationException)
class CartItemNotFoundException(ResourceNotFoundException):
@@ -19,22 +17,16 @@ class CartItemNotFoundException(ResourceNotFoundException):
resource_type="CartItem",
identifier=str(product_id),
message=f"Product {product_id} not found in cart",
- error_code="CART_ITEM_NOT_FOUND"
+ error_code="CART_ITEM_NOT_FOUND",
)
- self.details.update({
- "product_id": product_id,
- "session_id": session_id
- })
+ self.details.update({"product_id": product_id, "session_id": session_id})
class EmptyCartException(ValidationException):
"""Raised when trying to perform operations on an empty cart."""
def __init__(self, session_id: str):
- super().__init__(
- message="Cart is empty",
- details={"session_id": session_id}
- )
+ super().__init__(message="Cart is empty", details={"session_id": session_id})
self.error_code = "CART_EMPTY"
@@ -82,7 +74,9 @@ class InsufficientInventoryForCartException(BusinessLogicException):
class InvalidCartQuantityException(ValidationException):
"""Raised when cart quantity is invalid."""
- def __init__(self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None):
+ def __init__(
+ self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None
+ ):
if quantity < min_quantity:
message = f"Quantity must be at least {min_quantity}"
elif max_quantity and quantity > max_quantity:
diff --git a/app/exceptions/customer.py b/app/exceptions/customer.py
index ca3d68b0..a9f50ace 100644
--- a/app/exceptions/customer.py
+++ b/app/exceptions/customer.py
@@ -4,13 +4,10 @@ Customer management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- ConflictException,
- ValidationException,
- AuthenticationException,
- BusinessLogicException
-)
+
+from .base import (AuthenticationException, BusinessLogicException,
+ ConflictException, ResourceNotFoundException,
+ ValidationException)
class CustomerNotFoundException(ResourceNotFoundException):
@@ -21,7 +18,7 @@ class CustomerNotFoundException(ResourceNotFoundException):
resource_type="Customer",
identifier=customer_identifier,
message=f"Customer '{customer_identifier}' not found",
- error_code="CUSTOMER_NOT_FOUND"
+ error_code="CUSTOMER_NOT_FOUND",
)
@@ -32,7 +29,7 @@ class CustomerAlreadyExistsException(ConflictException):
super().__init__(
message=f"Customer with email '{email}' already exists",
error_code="CUSTOMER_ALREADY_EXISTS",
- details={"email": email}
+ details={"email": email},
)
@@ -43,10 +40,7 @@ class DuplicateCustomerEmailException(ConflictException):
super().__init__(
message=f"Email '{email}' is already registered for this vendor",
error_code="DUPLICATE_CUSTOMER_EMAIL",
- details={
- "email": email,
- "vendor_code": vendor_code
- }
+ details={"email": email, "vendor_code": vendor_code},
)
@@ -57,7 +51,7 @@ class CustomerNotActiveException(BusinessLogicException):
super().__init__(
message=f"Customer account '{email}' is not active",
error_code="CUSTOMER_NOT_ACTIVE",
- details={"email": email}
+ details={"email": email},
)
@@ -67,7 +61,7 @@ class InvalidCustomerCredentialsException(AuthenticationException):
def __init__(self):
super().__init__(
message="Invalid email or password",
- error_code="INVALID_CUSTOMER_CREDENTIALS"
+ error_code="INVALID_CUSTOMER_CREDENTIALS",
)
@@ -78,13 +72,9 @@ class CustomerValidationException(ValidationException):
self,
message: str = "Customer validation failed",
field: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None
+ details: Optional[Dict[str, Any]] = None,
):
- super().__init__(
- message=message,
- field=field,
- details=details
- )
+ super().__init__(message=message, field=field, details=details)
self.error_code = "CUSTOMER_VALIDATION_FAILED"
@@ -95,8 +85,5 @@ class CustomerAuthorizationException(BusinessLogicException):
super().__init__(
message=f"Customer '{customer_email}' not authorized for: {operation}",
error_code="CUSTOMER_NOT_AUTHORIZED",
- details={
- "customer_email": customer_email,
- "operation": operation
- }
+ details={"customer_email": customer_email, "operation": operation},
)
diff --git a/app/exceptions/error_renderer.py b/app/exceptions/error_renderer.py
index ac574cd4..cc804722 100644
--- a/app/exceptions/error_renderer.py
+++ b/app/exceptions/error_renderer.py
@@ -7,7 +7,7 @@ Handles fallback logic and context-specific customization.
"""
import logging
from pathlib import Path
-from typing import Optional, Dict, Any
+from typing import Any, Dict, Optional
from fastapi import Request
from fastapi.responses import HTMLResponse
@@ -114,7 +114,7 @@ class ErrorPageRenderer:
"error_code": error_code,
"context": context_type.value,
"template": template_path,
- }
+ },
)
try:
@@ -129,8 +129,7 @@ class ErrorPageRenderer:
)
except Exception as e:
logger.error(
- f"Failed to render error template {template_path}: {e}",
- exc_info=True
+ f"Failed to render error template {template_path}: {e}", exc_info=True
)
# Return basic HTML as absolute fallback
return ErrorPageRenderer._render_basic_html_fallback(
@@ -228,7 +227,9 @@ class ErrorPageRenderer:
}
@staticmethod
- def _get_context_data(request: Request, context_type: RequestContext) -> Dict[str, Any]:
+ def _get_context_data(
+ request: Request, context_type: RequestContext
+ ) -> Dict[str, Any]:
"""Get context-specific data for error templates."""
data = {}
@@ -261,11 +262,19 @@ class ErrorPageRenderer:
# Calculate base_url for shop links
vendor_context = getattr(request.state, "vendor_context", None)
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
base_url = "/"
if access_method == "path" and vendor:
# Use the full_prefix from vendor_context to determine which pattern was used
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
base_url = f"{full_prefix}{vendor.subdomain}/"
data["base_url"] = base_url
diff --git a/app/exceptions/handler.py b/app/exceptions/handler.py
index 2ea2690e..c1b38edf 100644
--- a/app/exceptions/handler.py
+++ b/app/exceptions/handler.py
@@ -13,13 +13,14 @@ This module provides classes and functions for:
import logging
from typing import Union
-from fastapi import Request, HTTPException
+from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, RedirectResponse
+from middleware.context import RequestContext, get_request_context
+
from .base import WizamartException
from .error_renderer import ErrorPageRenderer
-from middleware.context import RequestContext, get_request_context
logger = logging.getLogger(__name__)
@@ -38,8 +39,8 @@ def setup_exception_handlers(app):
extra={
"path": request.url.path,
"accept": request.headers.get("accept", ""),
- "method": request.method
- }
+ "method": request.method,
+ },
)
# Redirect to appropriate login page based on context
@@ -56,15 +57,12 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": type(exc).__name__,
- }
+ },
)
# Check if this is an API request
if _is_api_request(request):
- return JSONResponse(
- status_code=exc.status_code,
- content=exc.to_dict()
- )
+ return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
# Check if this is an HTML page request
if _is_html_page_request(request):
@@ -78,10 +76,7 @@ def setup_exception_handlers(app):
)
# Default to JSON for unknown request types
- return JSONResponse(
- status_code=exc.status_code,
- content=exc.to_dict()
- )
+ return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
@@ -96,7 +91,7 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": "HTTPException",
- }
+ },
)
# Check if this is an API request
@@ -107,7 +102,7 @@ def setup_exception_handlers(app):
"error_code": f"HTTP_{exc.status_code}",
"message": exc.detail,
"status_code": exc.status_code,
- }
+ },
)
# Check if this is an HTML page request
@@ -128,11 +123,13 @@ def setup_exception_handlers(app):
"error_code": f"HTTP_{exc.status_code}",
"message": exc.detail,
"status_code": exc.status_code,
- }
+ },
)
@app.exception_handler(RequestValidationError)
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ async def validation_exception_handler(
+ request: Request, exc: RequestValidationError
+ ):
"""Handle Pydantic validation errors with consistent format."""
# Sanitize errors to remove sensitive data from logs
@@ -140,8 +137,8 @@ def setup_exception_handlers(app):
for error in exc.errors():
sanitized_error = error.copy()
# Remove 'input' field which may contain passwords
- if 'input' in sanitized_error:
- sanitized_error['input'] = ''
+ if "input" in sanitized_error:
+ sanitized_error["input"] = ""
sanitized_errors.append(sanitized_error)
logger.error(
@@ -151,7 +148,7 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": "RequestValidationError",
- }
+ },
)
# Clean up validation errors to ensure JSON serializability
@@ -159,15 +156,17 @@ def setup_exception_handlers(app):
for error in exc.errors():
clean_error = {}
for key, value in error.items():
- if key == 'input' and isinstance(value, bytes):
+ if key == "input" and isinstance(value, bytes):
# Convert bytes to string representation for JSON serialization
clean_error[key] = f""
- elif key == 'ctx' and isinstance(value, dict):
+ elif key == "ctx" and isinstance(value, dict):
# Handle the 'ctx' field that contains ValueError objects
clean_ctx = {}
for ctx_key, ctx_value in value.items():
if isinstance(ctx_value, Exception):
- clean_ctx[ctx_key] = str(ctx_value) # Convert exception to string
+ clean_ctx[ctx_key] = str(
+ ctx_value
+ ) # Convert exception to string
else:
clean_ctx[ctx_key] = ctx_value
clean_error[key] = clean_ctx
@@ -186,10 +185,8 @@ def setup_exception_handlers(app):
"error_code": "VALIDATION_ERROR",
"message": "Request validation failed",
"status_code": 422,
- "details": {
- "validation_errors": clean_errors
- }
- }
+ "details": {"validation_errors": clean_errors},
+ },
)
# Check if this is an HTML page request
@@ -210,10 +207,8 @@ def setup_exception_handlers(app):
"error_code": "VALIDATION_ERROR",
"message": "Request validation failed",
"status_code": 422,
- "details": {
- "validation_errors": clean_errors
- }
- }
+ "details": {"validation_errors": clean_errors},
+ },
)
@app.exception_handler(Exception)
@@ -227,7 +222,7 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": type(exc).__name__,
- }
+ },
)
# Check if this is an API request
@@ -238,7 +233,7 @@ def setup_exception_handlers(app):
"error_code": "INTERNAL_SERVER_ERROR",
"message": "Internal server error",
"status_code": 500,
- }
+ },
)
# Check if this is an HTML page request
@@ -259,7 +254,7 @@ def setup_exception_handlers(app):
"error_code": "INTERNAL_SERVER_ERROR",
"message": "Internal server error",
"status_code": 500,
- }
+ },
)
@app.exception_handler(404)
@@ -275,11 +270,8 @@ def setup_exception_handlers(app):
"error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}",
"status_code": 404,
- "details": {
- "path": request.url.path,
- "method": request.method
- }
- }
+ "details": {"path": request.url.path, "method": request.method},
+ },
)
# Check if this is an HTML page request
@@ -300,11 +292,8 @@ def setup_exception_handlers(app):
"error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}",
"status_code": 404,
- "details": {
- "path": request.url.path,
- "method": request.method
- }
- }
+ "details": {"path": request.url.path, "method": request.method},
+ },
)
@@ -332,8 +321,8 @@ def _is_html_page_request(request: Request) -> bool:
extra={
"path": request.url.path,
"method": request.method,
- "accept": request.headers.get("accept", "")
- }
+ "accept": request.headers.get("accept", ""),
+ },
)
# Don't redirect API calls
@@ -354,7 +343,9 @@ def _is_html_page_request(request: Request) -> bool:
# MUST explicitly accept HTML (strict check)
accept_header = request.headers.get("accept", "")
if "text/html" not in accept_header:
- logger.debug(f"Not HTML page: Accept header doesn't include text/html: {accept_header}")
+ logger.debug(
+ f"Not HTML page: Accept header doesn't include text/html: {accept_header}"
+ )
return False
logger.debug("IS HTML page request")
@@ -379,13 +370,21 @@ def _redirect_to_login(request: Request) -> RedirectResponse:
elif context_type == RequestContext.SHOP:
# For shop context, redirect to shop login (customer login)
# Calculate base_url for proper routing (supports domain, subdomain, and path-based access)
- vendor = getattr(request.state, 'vendor', None)
- vendor_context = getattr(request.state, 'vendor_context', None)
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ vendor = getattr(request.state, "vendor", None)
+ vendor_context = getattr(request.state, "vendor_context", None)
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
base_url = "/"
if access_method == "path" and vendor:
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
base_url = f"{full_prefix}{vendor.subdomain}/"
login_url = f"{base_url}shop/account/login"
@@ -401,22 +400,28 @@ def _redirect_to_login(request: Request) -> RedirectResponse:
def raise_not_found(resource_type: str, identifier: str) -> None:
"""Convenience function to raise ResourceNotFoundException."""
from .base import ResourceNotFoundException
+
raise ResourceNotFoundException(resource_type, identifier)
-def raise_validation_error(message: str, field: str = None, details: dict = None) -> None:
+def raise_validation_error(
+ message: str, field: str = None, details: dict = None
+) -> None:
"""Convenience function to raise ValidationException."""
from .base import ValidationException
+
raise ValidationException(message, field, details)
def raise_auth_error(message: str = "Authentication failed") -> None:
"""Convenience function to raise AuthenticationException."""
from .base import AuthenticationException
+
raise AuthenticationException(message)
def raise_permission_error(message: str = "Access denied") -> None:
"""Convenience function to raise AuthorizationException."""
from .base import AuthorizationException
+
raise AuthorizationException(message)
diff --git a/app/exceptions/inventory.py b/app/exceptions/inventory.py
index 86a26427..c2c73fd7 100644
--- a/app/exceptions/inventory.py
+++ b/app/exceptions/inventory.py
@@ -4,7 +4,9 @@ Inventory management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import ResourceNotFoundException, ValidationException, BusinessLogicException
+
+from .base import (BusinessLogicException, ResourceNotFoundException,
+ ValidationException)
class InventoryNotFoundException(ResourceNotFoundException):
@@ -14,7 +16,9 @@ class InventoryNotFoundException(ResourceNotFoundException):
if identifier_type.lower() == "gtin":
message = f"No inventory found for GTIN '{identifier}'"
else:
- message = f"Inventory record with {identifier_type} '{identifier}' not found"
+ message = (
+ f"Inventory record with {identifier_type} '{identifier}' not found"
+ )
super().__init__(
resource_type="Inventory",
@@ -28,11 +32,11 @@ class InsufficientInventoryException(BusinessLogicException):
"""Raised when trying to remove more inventory than available."""
def __init__(
- self,
- gtin: str,
- location: str,
- requested: int,
- available: int,
+ self,
+ gtin: str,
+ location: str,
+ requested: int,
+ available: int,
):
message = f"Insufficient inventory for GTIN '{gtin}' at '{location}'. Requested: {requested}, Available: {available}"
@@ -52,10 +56,10 @@ class InvalidInventoryOperationException(ValidationException):
"""Raised when inventory operation is invalid."""
def __init__(
- self,
- message: str,
- operation: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str,
+ operation: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
if not details:
details = {}
@@ -74,10 +78,10 @@ class InventoryValidationException(ValidationException):
"""Raised when inventory data validation fails."""
def __init__(
- self,
- message: str = "Inventory validation failed",
- field: Optional[str] = None,
- validation_errors: Optional[Dict[str, str]] = None,
+ self,
+ message: str = "Inventory validation failed",
+ field: Optional[str] = None,
+ validation_errors: Optional[Dict[str, str]] = None,
):
details = {}
if validation_errors:
diff --git a/app/exceptions/marketplace.py b/app/exceptions/marketplace.py
index 1a92b6ab..28f96110 100644
--- a/app/exceptions/marketplace.py
+++ b/app/exceptions/marketplace.py
@@ -1 +1 @@
-# Import/marketplace exceptions
+# Import/marketplace exceptions
diff --git a/app/exceptions/marketplace_import_job.py b/app/exceptions/marketplace_import_job.py
index c383579e..c442e1f9 100644
--- a/app/exceptions/marketplace_import_job.py
+++ b/app/exceptions/marketplace_import_job.py
@@ -4,24 +4,21 @@ Marketplace import specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- ValidationException,
- BusinessLogicException,
- AuthorizationException,
- ExternalServiceException
-)
+
+from .base import (AuthorizationException, BusinessLogicException,
+ ExternalServiceException, ResourceNotFoundException,
+ ValidationException)
class MarketplaceImportException(BusinessLogicException):
"""Base exception for marketplace import operations."""
def __init__(
- self,
- message: str,
- error_code: str = "MARKETPLACE_IMPORT_ERROR",
- marketplace: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str,
+ error_code: str = "MARKETPLACE_IMPORT_ERROR",
+ marketplace: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
if not details:
details = {}
@@ -67,11 +64,11 @@ class InvalidImportDataException(ValidationException):
"""Raised when import data is invalid."""
def __init__(
- self,
- message: str = "Invalid import data",
- field: Optional[str] = None,
- row_number: Optional[int] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "Invalid import data",
+ field: Optional[str] = None,
+ row_number: Optional[int] = None,
+ details: Optional[Dict[str, Any]] = None,
):
if not details:
details = {}
@@ -118,7 +115,9 @@ class ImportJobCannotBeDeletedException(BusinessLogicException):
class MarketplaceConnectionException(ExternalServiceException):
"""Raised when marketplace connection fails."""
- def __init__(self, marketplace: str, message: str = "Failed to connect to marketplace"):
+ def __init__(
+ self, marketplace: str, message: str = "Failed to connect to marketplace"
+ ):
super().__init__(
service=marketplace,
message=f"{message}: {marketplace}",
@@ -130,10 +129,10 @@ class MarketplaceDataParsingException(ValidationException):
"""Raised when marketplace data cannot be parsed."""
def __init__(
- self,
- marketplace: str,
- message: str = "Failed to parse marketplace data",
- details: Optional[Dict[str, Any]] = None,
+ self,
+ marketplace: str,
+ message: str = "Failed to parse marketplace data",
+ details: Optional[Dict[str, Any]] = None,
):
if not details:
details = {}
@@ -150,10 +149,10 @@ class ImportRateLimitException(BusinessLogicException):
"""Raised when import rate limit is exceeded."""
def __init__(
- self,
- max_imports: int,
- time_window: str,
- retry_after: Optional[int] = None,
+ self,
+ max_imports: int,
+ time_window: str,
+ retry_after: Optional[int] = None,
):
details = {
"max_imports": max_imports,
diff --git a/app/exceptions/marketplace_product.py b/app/exceptions/marketplace_product.py
index a153a717..8fbc3e71 100644
--- a/app/exceptions/marketplace_product.py
+++ b/app/exceptions/marketplace_product.py
@@ -4,7 +4,9 @@ MarketplaceProduct management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import ResourceNotFoundException, ConflictException, ValidationException, BusinessLogicException
+
+from .base import (BusinessLogicException, ConflictException,
+ ResourceNotFoundException, ValidationException)
class MarketplaceProductNotFoundException(ResourceNotFoundException):
@@ -34,10 +36,10 @@ class InvalidMarketplaceProductDataException(ValidationException):
"""Raised when product data is invalid."""
def __init__(
- self,
- message: str = "Invalid product data",
- field: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "Invalid product data",
+ field: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
super().__init__(
message=message,
@@ -51,10 +53,10 @@ class MarketplaceProductValidationException(ValidationException):
"""Raised when product validation fails."""
def __init__(
- self,
- message: str,
- field: Optional[str] = None,
- validation_errors: Optional[Dict[str, str]] = None,
+ self,
+ message: str,
+ field: Optional[str] = None,
+ validation_errors: Optional[Dict[str, str]] = None,
):
details = {}
if validation_errors:
@@ -84,10 +86,10 @@ class MarketplaceProductCSVImportException(BusinessLogicException):
"""Raised when product CSV import fails."""
def __init__(
- self,
- message: str = "MarketplaceProduct CSV import failed",
- row_number: Optional[int] = None,
- errors: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "MarketplaceProduct CSV import failed",
+ row_number: Optional[int] = None,
+ errors: Optional[Dict[str, Any]] = None,
):
details = {}
if row_number:
diff --git a/app/exceptions/media.py b/app/exceptions/media.py
index a56df71b..d7663627 100644
--- a/app/exceptions/media.py
+++ b/app/exceptions/media.py
@@ -1 +1 @@
-# Media/file management exceptions
+# Media/file management exceptions
diff --git a/app/exceptions/monitoring.py b/app/exceptions/monitoring.py
index 70f33c86..d16c3949 100644
--- a/app/exceptions/monitoring.py
+++ b/app/exceptions/monitoring.py
@@ -1 +1 @@
-# Monitoring exceptions
+# Monitoring exceptions
diff --git a/app/exceptions/notification.py b/app/exceptions/notification.py
index c1772546..8dfb9341 100644
--- a/app/exceptions/notification.py
+++ b/app/exceptions/notification.py
@@ -1 +1 @@
-# Notification exceptions
+# Notification exceptions
diff --git a/app/exceptions/order.py b/app/exceptions/order.py
index 3052f206..d5b74705 100644
--- a/app/exceptions/order.py
+++ b/app/exceptions/order.py
@@ -4,11 +4,9 @@ Order management specific exceptions.
"""
from typing import Optional
-from .base import (
- ResourceNotFoundException,
- ValidationException,
- BusinessLogicException
-)
+
+from .base import (BusinessLogicException, ResourceNotFoundException,
+ ValidationException)
class OrderNotFoundException(ResourceNotFoundException):
@@ -19,7 +17,7 @@ class OrderNotFoundException(ResourceNotFoundException):
resource_type="Order",
identifier=order_identifier,
message=f"Order '{order_identifier}' not found",
- error_code="ORDER_NOT_FOUND"
+ error_code="ORDER_NOT_FOUND",
)
@@ -30,7 +28,7 @@ class OrderAlreadyExistsException(ValidationException):
super().__init__(
message=f"Order with number '{order_number}' already exists",
error_code="ORDER_ALREADY_EXISTS",
- details={"order_number": order_number}
+ details={"order_number": order_number},
)
@@ -39,9 +37,7 @@ class OrderValidationException(ValidationException):
def __init__(self, message: str, details: Optional[dict] = None):
super().__init__(
- message=message,
- error_code="ORDER_VALIDATION_FAILED",
- details=details
+ message=message, error_code="ORDER_VALIDATION_FAILED", details=details
)
@@ -52,10 +48,7 @@ class InvalidOrderStatusException(BusinessLogicException):
super().__init__(
message=f"Cannot change order status from '{current_status}' to '{new_status}'",
error_code="INVALID_ORDER_STATUS_CHANGE",
- details={
- "current_status": current_status,
- "new_status": new_status
- }
+ details={"current_status": current_status, "new_status": new_status},
)
@@ -66,8 +59,5 @@ class OrderCannotBeCancelledException(BusinessLogicException):
super().__init__(
message=f"Order '{order_number}' cannot be cancelled: {reason}",
error_code="ORDER_CANNOT_BE_CANCELLED",
- details={
- "order_number": order_number,
- "reason": reason
- }
+ details={"order_number": order_number, "reason": reason},
)
diff --git a/app/exceptions/payment.py b/app/exceptions/payment.py
index c3fbe15b..2c7f0f89 100644
--- a/app/exceptions/payment.py
+++ b/app/exceptions/payment.py
@@ -1 +1 @@
-# Payment processing exceptions
+# Payment processing exceptions
diff --git a/app/exceptions/product.py b/app/exceptions/product.py
index fcd984ba..72a3d8ba 100644
--- a/app/exceptions/product.py
+++ b/app/exceptions/product.py
@@ -4,12 +4,9 @@ Product (vendor catalog) specific exceptions.
"""
from typing import Optional
-from .base import (
- ResourceNotFoundException,
- ConflictException,
- ValidationException,
- BusinessLogicException
-)
+
+from .base import (BusinessLogicException, ConflictException,
+ ResourceNotFoundException, ValidationException)
class ProductNotFoundException(ResourceNotFoundException):
diff --git a/app/exceptions/search.py b/app/exceptions/search.py
index 5db85736..ed380099 100644
--- a/app/exceptions/search.py
+++ b/app/exceptions/search.py
@@ -1 +1 @@
-# Search exceptions
+# Search exceptions
diff --git a/app/exceptions/team.py b/app/exceptions/team.py
index 4578a690..21283db1 100644
--- a/app/exceptions/team.py
+++ b/app/exceptions/team.py
@@ -4,13 +4,10 @@ Team management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- ConflictException,
- ValidationException,
- AuthorizationException,
- BusinessLogicException
-)
+
+from .base import (AuthorizationException, BusinessLogicException,
+ ConflictException, ResourceNotFoundException,
+ ValidationException)
class TeamMemberNotFoundException(ResourceNotFoundException):
@@ -20,7 +17,9 @@ class TeamMemberNotFoundException(ResourceNotFoundException):
details = {"user_id": user_id}
if vendor_id:
details["vendor_id"] = vendor_id
- message = f"Team member with user ID '{user_id}' not found in vendor {vendor_id}"
+ message = (
+ f"Team member with user ID '{user_id}' not found in vendor {vendor_id}"
+ )
else:
message = f"Team member with user ID '{user_id}' not found"
@@ -84,7 +83,12 @@ class TeamInvitationAlreadyAcceptedException(ConflictException):
class UnauthorizedTeamActionException(AuthorizationException):
"""Raised when user tries to perform team action without permission."""
- def __init__(self, action: str, user_id: Optional[int] = None, required_permission: Optional[str] = None):
+ def __init__(
+ self,
+ action: str,
+ user_id: Optional[int] = None,
+ required_permission: Optional[str] = None,
+ ):
details = {"action": action}
if user_id:
details["user_id"] = user_id
@@ -147,10 +151,10 @@ class InvalidRoleException(ValidationException):
"""Raised when role data is invalid."""
def __init__(
- self,
- message: str = "Invalid role data",
- field: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "Invalid role data",
+ field: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
super().__init__(
message=message,
@@ -164,10 +168,10 @@ class InsufficientTeamPermissionsException(AuthorizationException):
"""Raised when user lacks required team permissions for an action."""
def __init__(
- self,
- required_permission: str,
- user_id: Optional[int] = None,
- action: Optional[str] = None,
+ self,
+ required_permission: str,
+ user_id: Optional[int] = None,
+ action: Optional[str] = None,
):
details = {"required_permission": required_permission}
if user_id:
@@ -202,10 +206,10 @@ class TeamValidationException(ValidationException):
"""Raised when team operation validation fails."""
def __init__(
- self,
- message: str = "Team operation validation failed",
- field: Optional[str] = None,
- validation_errors: Optional[Dict[str, str]] = None,
+ self,
+ message: str = "Team operation validation failed",
+ field: Optional[str] = None,
+ validation_errors: Optional[Dict[str, str]] = None,
):
details = {}
if validation_errors:
@@ -223,10 +227,10 @@ class InvalidInvitationDataException(ValidationException):
"""Raised when team invitation data is invalid."""
def __init__(
- self,
- message: str = "Invalid invitation data",
- field: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "Invalid invitation data",
+ field: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
super().__init__(
message=message,
@@ -240,6 +244,7 @@ class InvalidInvitationDataException(ValidationException):
# NEW: Add InvalidInvitationTokenException
# ============================================================================
+
class InvalidInvitationTokenException(ValidationException):
"""Raised when invitation token is invalid, expired, or already used.
@@ -248,9 +253,9 @@ class InvalidInvitationTokenException(ValidationException):
"""
def __init__(
- self,
- message: str = "Invalid or expired invitation token",
- invitation_token: Optional[str] = None
+ self,
+ message: str = "Invalid or expired invitation token",
+ invitation_token: Optional[str] = None,
):
details = {}
if invitation_token:
diff --git a/app/exceptions/vendor.py b/app/exceptions/vendor.py
index 5563488f..44380e67 100644
--- a/app/exceptions/vendor.py
+++ b/app/exceptions/vendor.py
@@ -4,13 +4,10 @@ Vendor management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- ConflictException,
- ValidationException,
- AuthorizationException,
- BusinessLogicException
-)
+
+from .base import (AuthorizationException, BusinessLogicException,
+ ConflictException, ResourceNotFoundException,
+ ValidationException)
class VendorNotFoundException(ResourceNotFoundException):
@@ -82,10 +79,10 @@ class InvalidVendorDataException(ValidationException):
"""Raised when vendor data is invalid or incomplete."""
def __init__(
- self,
- message: str = "Invalid vendor data",
- field: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "Invalid vendor data",
+ field: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
super().__init__(
message=message,
@@ -99,10 +96,10 @@ class VendorValidationException(ValidationException):
"""Raised when vendor validation fails."""
def __init__(
- self,
- message: str = "Vendor validation failed",
- field: Optional[str] = None,
- validation_errors: Optional[Dict[str, str]] = None,
+ self,
+ message: str = "Vendor validation failed",
+ field: Optional[str] = None,
+ validation_errors: Optional[Dict[str, str]] = None,
):
details = {}
if validation_errors:
@@ -120,9 +117,9 @@ class IncompleteVendorDataException(ValidationException):
"""Raised when vendor data is missing required fields."""
def __init__(
- self,
- vendor_code: str,
- missing_fields: list,
+ self,
+ vendor_code: str,
+ missing_fields: list,
):
super().__init__(
message=f"Vendor '{vendor_code}' is missing required fields: {', '.join(missing_fields)}",
diff --git a/app/exceptions/vendor_domain.py b/app/exceptions/vendor_domain.py
index cb04d883..8a8e704b 100644
--- a/app/exceptions/vendor_domain.py
+++ b/app/exceptions/vendor_domain.py
@@ -4,13 +4,10 @@ Vendor domain management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- ConflictException,
- ValidationException,
- BusinessLogicException,
- ExternalServiceException
-)
+
+from .base import (BusinessLogicException, ConflictException,
+ ExternalServiceException, ResourceNotFoundException,
+ ValidationException)
class VendorDomainNotFoundException(ResourceNotFoundException):
@@ -64,10 +61,7 @@ class ReservedDomainException(ValidationException):
super().__init__(
message=f"Domain cannot use reserved subdomain: {reserved_part}",
field="domain",
- details={
- "domain": domain,
- "reserved_part": reserved_part
- },
+ details={"domain": domain, "reserved_part": reserved_part},
)
self.error_code = "RESERVED_DOMAIN"
@@ -79,10 +73,7 @@ class DomainNotVerifiedException(BusinessLogicException):
super().__init__(
message=f"Domain '{domain}' must be verified before activation",
error_code="DOMAIN_NOT_VERIFIED",
- details={
- "domain_id": domain_id,
- "domain": domain
- },
+ details={"domain_id": domain_id, "domain": domain},
)
@@ -93,10 +84,7 @@ class DomainVerificationFailedException(BusinessLogicException):
super().__init__(
message=f"Domain verification failed for '{domain}': {reason}",
error_code="DOMAIN_VERIFICATION_FAILED",
- details={
- "domain": domain,
- "reason": reason
- },
+ details={"domain": domain, "reason": reason},
)
@@ -107,10 +95,7 @@ class DomainAlreadyVerifiedException(BusinessLogicException):
super().__init__(
message=f"Domain '{domain}' is already verified",
error_code="DOMAIN_ALREADY_VERIFIED",
- details={
- "domain_id": domain_id,
- "domain": domain
- },
+ details={"domain_id": domain_id, "domain": domain},
)
@@ -133,10 +118,7 @@ class DNSVerificationException(ExternalServiceException):
service_name="DNS",
message=f"DNS verification failed for '{domain}': {reason}",
error_code="DNS_VERIFICATION_ERROR",
- details={
- "domain": domain,
- "reason": reason
- },
+ details={"domain": domain, "reason": reason},
)
@@ -147,10 +129,7 @@ class MaxDomainsReachedException(BusinessLogicException):
super().__init__(
message=f"Maximum number of domains reached ({max_domains})",
error_code="MAX_DOMAINS_REACHED",
- details={
- "vendor_id": vendor_id,
- "max_domains": max_domains
- },
+ details={"vendor_id": vendor_id, "max_domains": max_domains},
)
@@ -161,8 +140,5 @@ class UnauthorizedDomainAccessException(BusinessLogicException):
super().__init__(
message=f"Unauthorized access to domain {domain_id}",
error_code="UNAUTHORIZED_DOMAIN_ACCESS",
- details={
- "domain_id": domain_id,
- "vendor_id": vendor_id
- },
+ details={"domain_id": domain_id, "vendor_id": vendor_id},
)
diff --git a/app/exceptions/vendor_theme.py b/app/exceptions/vendor_theme.py
index b38622c3..08000432 100644
--- a/app/exceptions/vendor_theme.py
+++ b/app/exceptions/vendor_theme.py
@@ -4,12 +4,9 @@ Vendor theme management specific exceptions.
"""
from typing import Any, Dict, Optional
-from .base import (
- ResourceNotFoundException,
- ConflictException,
- ValidationException,
- BusinessLogicException
-)
+
+from .base import (BusinessLogicException, ConflictException,
+ ResourceNotFoundException, ValidationException)
class VendorThemeNotFoundException(ResourceNotFoundException):
@@ -28,10 +25,10 @@ class InvalidThemeDataException(ValidationException):
"""Raised when theme data is invalid."""
def __init__(
- self,
- message: str = "Invalid theme data",
- field: Optional[str] = None,
- details: Optional[Dict[str, Any]] = None,
+ self,
+ message: str = "Invalid theme data",
+ field: Optional[str] = None,
+ details: Optional[Dict[str, Any]] = None,
):
super().__init__(
message=message,
@@ -62,10 +59,10 @@ class ThemeValidationException(ValidationException):
"""Raised when theme validation fails."""
def __init__(
- self,
- message: str = "Theme validation failed",
- field: Optional[str] = None,
- validation_errors: Optional[Dict[str, str]] = None,
+ self,
+ message: str = "Theme validation failed",
+ field: Optional[str] = None,
+ validation_errors: Optional[Dict[str, str]] = None,
):
details = {}
if validation_errors:
@@ -86,10 +83,7 @@ class ThemePresetAlreadyAppliedException(BusinessLogicException):
super().__init__(
message=f"Preset '{preset_name}' is already applied to vendor '{vendor_code}'",
error_code="THEME_PRESET_ALREADY_APPLIED",
- details={
- "preset_name": preset_name,
- "vendor_code": vendor_code
- },
+ details={"preset_name": preset_name, "vendor_code": vendor_code},
)
@@ -120,18 +114,13 @@ class InvalidFontFamilyException(ValidationException):
class ThemeOperationException(BusinessLogicException):
"""Raised when theme operation fails."""
- def __init__(
- self,
- operation: str,
- vendor_code: str,
- reason: str
- ):
+ def __init__(self, operation: str, vendor_code: str, reason: str):
super().__init__(
message=f"Theme operation '{operation}' failed for vendor '{vendor_code}': {reason}",
error_code="THEME_OPERATION_FAILED",
details={
"operation": operation,
"vendor_code": vendor_code,
- "reason": reason
+ "reason": reason,
},
)
diff --git a/app/models/architecture_scan.py b/app/models/architecture_scan.py
index 39b861ec..403fbcbf 100644
--- a/app/models/architecture_scan.py
+++ b/app/models/architecture_scan.py
@@ -3,18 +3,23 @@ Architecture Scan Models
Database models for tracking code quality scans and violations
"""
-from sqlalchemy import Column, Integer, String, Float, DateTime, Text, Boolean, ForeignKey, JSON
+from sqlalchemy import (JSON, Boolean, Column, DateTime, Float, ForeignKey,
+ Integer, String, Text)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
+
from app.core.database import Base
class ArchitectureScan(Base):
"""Represents a single run of the architecture validator"""
+
__tablename__ = "architecture_scans"
id = Column(Integer, primary_key=True, index=True)
- timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
+ timestamp = Column(
+ DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
+ )
total_files = Column(Integer, default=0)
total_violations = Column(Integer, default=0)
errors = Column(Integer, default=0)
@@ -24,7 +29,9 @@ class ArchitectureScan(Base):
git_commit_hash = Column(String(40))
# Relationship to violations
- violations = relationship("ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan")
+ violations = relationship(
+ "ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan"
+ )
def __repr__(self):
return f""
@@ -32,31 +39,48 @@ class ArchitectureScan(Base):
class ArchitectureViolation(Base):
"""Represents a single architectural violation found during a scan"""
+
__tablename__ = "architecture_violations"
id = Column(Integer, primary_key=True, index=True)
- scan_id = Column(Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True)
+ scan_id = Column(
+ Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True
+ )
rule_id = Column(String(20), nullable=False, index=True) # e.g., 'API-001'
rule_name = Column(String(200), nullable=False)
- severity = Column(String(10), nullable=False, index=True) # 'error', 'warning', 'info'
+ severity = Column(
+ String(10), nullable=False, index=True
+ ) # 'error', 'warning', 'info'
file_path = Column(String(500), nullable=False, index=True)
line_number = Column(Integer, nullable=False)
message = Column(Text, nullable=False)
context = Column(Text) # Code snippet
suggestion = Column(Text)
- status = Column(String(20), default='open', index=True) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt'
+ status = Column(
+ String(20), default="open", index=True
+ ) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt'
assigned_to = Column(Integer, ForeignKey("users.id"))
resolved_at = Column(DateTime(timezone=True))
resolved_by = Column(Integer, ForeignKey("users.id"))
resolution_note = Column(Text)
- created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
+ created_at = Column(
+ DateTime(timezone=True), server_default=func.now(), nullable=False
+ )
# Relationships
scan = relationship("ArchitectureScan", back_populates="violations")
- assigned_user = relationship("User", foreign_keys=[assigned_to], backref="assigned_violations")
- resolver = relationship("User", foreign_keys=[resolved_by], backref="resolved_violations")
- assignments = relationship("ViolationAssignment", back_populates="violation", cascade="all, delete-orphan")
- comments = relationship("ViolationComment", back_populates="violation", cascade="all, delete-orphan")
+ assigned_user = relationship(
+ "User", foreign_keys=[assigned_to], backref="assigned_violations"
+ )
+ resolver = relationship(
+ "User", foreign_keys=[resolved_by], backref="resolved_violations"
+ )
+ assignments = relationship(
+ "ViolationAssignment", back_populates="violation", cascade="all, delete-orphan"
+ )
+ comments = relationship(
+ "ViolationComment", back_populates="violation", cascade="all, delete-orphan"
+ )
def __repr__(self):
return f""
@@ -64,18 +88,30 @@ class ArchitectureViolation(Base):
class ArchitectureRule(Base):
"""Architecture rules configuration (from YAML with database overrides)"""
+
__tablename__ = "architecture_rules"
id = Column(Integer, primary_key=True, index=True)
- rule_id = Column(String(20), unique=True, nullable=False, index=True) # e.g., 'API-001'
- category = Column(String(50), nullable=False) # 'api_endpoint', 'service_layer', etc.
+ rule_id = Column(
+ String(20), unique=True, nullable=False, index=True
+ ) # e.g., 'API-001'
+ category = Column(
+ String(50), nullable=False
+ ) # 'api_endpoint', 'service_layer', etc.
name = Column(String(200), nullable=False)
description = Column(Text)
severity = Column(String(10), nullable=False) # Can override default from YAML
enabled = Column(Boolean, default=True, nullable=False)
custom_config = Column(JSON) # For rule-specific settings
- created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
- updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
+ created_at = Column(
+ DateTime(timezone=True), server_default=func.now(), nullable=False
+ )
+ updated_at = Column(
+ DateTime(timezone=True),
+ server_default=func.now(),
+ onupdate=func.now(),
+ nullable=False,
+ )
def __repr__(self):
return f""
@@ -83,20 +119,29 @@ class ArchitectureRule(Base):
class ViolationAssignment(Base):
"""Tracks assignment of violations to developers"""
+
__tablename__ = "violation_assignments"
id = Column(Integer, primary_key=True, index=True)
- violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True)
+ violation_id = Column(
+ Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True
+ )
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
- assigned_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
+ assigned_at = Column(
+ DateTime(timezone=True), server_default=func.now(), nullable=False
+ )
assigned_by = Column(Integer, ForeignKey("users.id"))
due_date = Column(DateTime(timezone=True))
- priority = Column(String(10), default='medium') # 'low', 'medium', 'high', 'critical'
+ priority = Column(
+ String(10), default="medium"
+ ) # 'low', 'medium', 'high', 'critical'
# Relationships
violation = relationship("ArchitectureViolation", back_populates="assignments")
user = relationship("User", foreign_keys=[user_id], backref="violation_assignments")
- assigner = relationship("User", foreign_keys=[assigned_by], backref="assigned_by_me")
+ assigner = relationship(
+ "User", foreign_keys=[assigned_by], backref="assigned_by_me"
+ )
def __repr__(self):
return f""
@@ -104,13 +149,18 @@ class ViolationAssignment(Base):
class ViolationComment(Base):
"""Comments on violations for collaboration"""
+
__tablename__ = "violation_comments"
id = Column(Integer, primary_key=True, index=True)
- violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True)
+ violation_id = Column(
+ Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True
+ )
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
comment = Column(Text, nullable=False)
- created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
+ created_at = Column(
+ DateTime(timezone=True), server_default=func.now(), nullable=False
+ )
# Relationships
violation = relationship("ArchitectureViolation", back_populates="comments")
diff --git a/app/routes/admin_pages.py b/app/routes/admin_pages.py
index 49b10873..736b27db 100644
--- a/app/routes/admin_pages.py
+++ b/app/routes/admin_pages.py
@@ -30,17 +30,15 @@ Routes:
- GET /code-quality/violations/{violation_id} → Violation details (auth required)
"""
-from fastapi import APIRouter, Request, Depends, Path
+from typing import Optional
+
+from fastapi import APIRouter, Depends, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
-from typing import Optional
-from app.api.deps import (
- get_current_admin_from_cookie_or_header,
- get_current_admin_optional,
- get_db
-)
+from app.api.deps import (get_current_admin_from_cookie_or_header,
+ get_current_admin_optional, get_db)
from models.database.user import User
router = APIRouter()
@@ -51,9 +49,10 @@ templates = Jinja2Templates(directory="app/templates")
# PUBLIC ROUTES (No Authentication Required)
# ============================================================================
+
@router.get("/", response_class=RedirectResponse, include_in_schema=False)
async def admin_root(
- current_user: Optional[User] = Depends(get_current_admin_optional)
+ current_user: Optional[User] = Depends(get_current_admin_optional),
):
"""
Redirect /admin/ based on authentication status.
@@ -70,8 +69,7 @@ async def admin_root(
@router.get("/login", response_class=HTMLResponse, include_in_schema=False)
async def admin_login_page(
- request: Request,
- current_user: Optional[User] = Depends(get_current_admin_optional)
+ request: Request, current_user: Optional[User] = Depends(get_current_admin_optional)
):
"""
Render admin login page.
@@ -83,21 +81,19 @@ async def admin_login_page(
# User is already logged in as admin, redirect to dashboard
return RedirectResponse(url="/admin/dashboard", status_code=302)
- return templates.TemplateResponse(
- "admin/login.html",
- {"request": request}
- )
+ return templates.TemplateResponse("admin/login.html", {"request": request})
# ============================================================================
# AUTHENTICATED ROUTES (Admin Only)
# ============================================================================
+
@router.get("/dashboard", response_class=HTMLResponse, include_in_schema=False)
async def admin_dashboard_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render admin dashboard page.
@@ -108,7 +104,7 @@ async def admin_dashboard_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@@ -116,11 +112,12 @@ async def admin_dashboard_page(
# VENDOR MANAGEMENT ROUTES
# ============================================================================
+
@router.get("/vendors", response_class=HTMLResponse, include_in_schema=False)
async def admin_vendors_list_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendors management page.
@@ -131,15 +128,15 @@ async def admin_vendors_list_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@router.get("/vendors/create", response_class=HTMLResponse, include_in_schema=False)
async def admin_vendor_create_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendor creation form.
@@ -149,16 +146,18 @@ async def admin_vendor_create_page(
{
"request": request,
"user": current_user,
- }
+ },
)
-@router.get("/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False
+)
async def admin_vendor_detail_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendor detail page.
@@ -170,16 +169,18 @@ async def admin_vendor_detail_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
-@router.get("/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False
+)
async def admin_vendor_edit_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendor edit form.
@@ -190,7 +191,7 @@ async def admin_vendor_edit_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -198,12 +199,17 @@ async def admin_vendor_edit_page(
# VENDOR DOMAINS ROUTES
# ============================================================================
-@router.get("/vendors/{vendor_code}/domains", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/vendors/{vendor_code}/domains",
+ response_class=HTMLResponse,
+ include_in_schema=False,
+)
async def admin_vendor_domains_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendor domains management page.
@@ -215,7 +221,7 @@ async def admin_vendor_domains_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -223,12 +229,15 @@ async def admin_vendor_domains_page(
# VENDOR THEMES ROUTES
# ============================================================================
-@router.get("/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False
+)
async def admin_vendor_theme_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendor theme customization page.
@@ -240,7 +249,7 @@ async def admin_vendor_theme_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -248,11 +257,12 @@ async def admin_vendor_theme_page(
# USER MANAGEMENT ROUTES
# ============================================================================
+
@router.get("/users", response_class=HTMLResponse, include_in_schema=False)
async def admin_users_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render users management page.
@@ -263,7 +273,7 @@ async def admin_users_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@@ -271,11 +281,12 @@ async def admin_users_page(
# IMPORT MANAGEMENT ROUTES
# ============================================================================
+
@router.get("/imports", response_class=HTMLResponse, include_in_schema=False)
async def admin_imports_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render imports management page.
@@ -286,7 +297,7 @@ async def admin_imports_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@@ -294,11 +305,12 @@ async def admin_imports_page(
# SETTINGS ROUTES
# ============================================================================
+
@router.get("/settings", response_class=HTMLResponse, include_in_schema=False)
async def admin_settings_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render admin settings page.
@@ -309,7 +321,7 @@ async def admin_settings_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@@ -317,11 +329,12 @@ async def admin_settings_page(
# CONTENT MANAGEMENT SYSTEM (CMS) ROUTES
# ============================================================================
+
@router.get("/platform-homepage", response_class=HTMLResponse, include_in_schema=False)
async def admin_platform_homepage_manager(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render platform homepage manager.
@@ -332,15 +345,15 @@ async def admin_platform_homepage_manager(
{
"request": request,
"user": current_user,
- }
+ },
)
@router.get("/content-pages", response_class=HTMLResponse, include_in_schema=False)
async def admin_content_pages_list(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render content pages list.
@@ -351,15 +364,17 @@ async def admin_content_pages_list(
{
"request": request,
"user": current_user,
- }
+ },
)
-@router.get("/content-pages/create", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/content-pages/create", response_class=HTMLResponse, include_in_schema=False
+)
async def admin_content_page_create(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render create content page form.
@@ -371,16 +386,20 @@ async def admin_content_page_create(
"request": request,
"user": current_user,
"page_id": None, # Indicates this is a create operation
- }
+ },
)
-@router.get("/content-pages/{page_id}/edit", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/content-pages/{page_id}/edit",
+ response_class=HTMLResponse,
+ include_in_schema=False,
+)
async def admin_content_page_edit(
- request: Request,
- page_id: int = Path(..., description="Content page ID"),
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ page_id: int = Path(..., description="Content page ID"),
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render edit content page form.
@@ -392,7 +411,7 @@ async def admin_content_page_edit(
"request": request,
"user": current_user,
"page_id": page_id,
- }
+ },
)
@@ -400,11 +419,12 @@ async def admin_content_page_edit(
# DEVELOPER TOOLS - COMPONENTS & TESTING
# ============================================================================
+
@router.get("/components", response_class=HTMLResponse, include_in_schema=False)
async def admin_components_page(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render UI components library page.
@@ -415,7 +435,7 @@ async def admin_components_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@@ -423,7 +443,7 @@ async def admin_components_page(
async def admin_icons_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Render icons browser page.
@@ -434,15 +454,15 @@ async def admin_icons_page(
{
"request": request,
"user": current_user,
- }
+ },
)
@router.get("/testing", response_class=HTMLResponse, include_in_schema=False)
async def admin_testing_hub(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render testing hub page.
@@ -453,15 +473,15 @@ async def admin_testing_hub(
{
"request": request,
"user": current_user,
- }
+ },
)
@router.get("/test/auth-flow", response_class=HTMLResponse, include_in_schema=False)
async def admin_test_auth_flow(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render authentication flow testing page.
@@ -472,15 +492,19 @@ async def admin_test_auth_flow(
{
"request": request,
"user": current_user,
- }
+ },
)
-@router.get("/test/vendors-users-migration", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/test/vendors-users-migration",
+ response_class=HTMLResponse,
+ include_in_schema=False,
+)
async def admin_test_vendors_users_migration(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render vendors and users migration testing page.
@@ -491,7 +515,7 @@ async def admin_test_vendors_users_migration(
{
"request": request,
"user": current_user,
- }
+ },
)
@@ -499,11 +523,12 @@ async def admin_test_vendors_users_migration(
# CODE QUALITY & ARCHITECTURE ROUTES
# ============================================================================
+
@router.get("/code-quality", response_class=HTMLResponse, include_in_schema=False)
async def admin_code_quality_dashboard(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render code quality dashboard.
@@ -514,15 +539,17 @@ async def admin_code_quality_dashboard(
{
"request": request,
"user": current_user,
- }
+ },
)
-@router.get("/code-quality/violations", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/code-quality/violations", response_class=HTMLResponse, include_in_schema=False
+)
async def admin_code_quality_violations(
- request: Request,
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render violations list page.
@@ -533,16 +560,20 @@ async def admin_code_quality_violations(
{
"request": request,
"user": current_user,
- }
+ },
)
-@router.get("/code-quality/violations/{violation_id}", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/code-quality/violations/{violation_id}",
+ response_class=HTMLResponse,
+ include_in_schema=False,
+)
async def admin_code_quality_violation_detail(
- request: Request,
- violation_id: int = Path(..., description="Violation ID"),
- current_user: User = Depends(get_current_admin_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ violation_id: int = Path(..., description="Violation ID"),
+ current_user: User = Depends(get_current_admin_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render violation detail page.
@@ -554,5 +585,5 @@ async def admin_code_quality_violation_detail(
"request": request,
"user": current_user,
"violation_id": violation_id,
- }
+ },
)
diff --git a/app/routes/shop_pages.py b/app/routes/shop_pages.py
index 569113d9..8103d809 100644
--- a/app/routes/shop_pages.py
+++ b/app/routes/shop_pages.py
@@ -31,7 +31,8 @@ Routes (all mounted at /shop/* or /vendors/{code}/shop/* prefix):
"""
import logging
-from fastapi import APIRouter, Request, Depends, Path
+
+from fastapi import APIRouter, Depends, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
@@ -50,6 +51,7 @@ logger = logging.getLogger(__name__)
# HELPER: Build Shop Template Context
# ============================================================================
+
def get_shop_context(request: Request, db: Session = None, **extra_context) -> dict:
"""
Build template context for shop pages.
@@ -76,13 +78,17 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
get_shop_context(request, db=db, user=current_user, product_id=123)
"""
# Extract from middleware state
- vendor = getattr(request.state, 'vendor', None)
- theme = getattr(request.state, 'theme', None)
- clean_path = getattr(request.state, 'clean_path', request.url.path)
- vendor_context = getattr(request.state, 'vendor_context', None)
+ vendor = getattr(request.state, "vendor", None)
+ theme = getattr(request.state, "theme", None)
+ clean_path = getattr(request.state, "clean_path", request.url.path)
+ vendor_context = getattr(request.state, "vendor_context", None)
# Get detection method from vendor_context
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
if vendor is None:
logger.warning(
@@ -91,7 +97,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
"path": request.url.path,
"host": request.headers.get("host", ""),
"has_vendor": False,
- }
+ },
)
# Calculate base URL for links
@@ -100,7 +106,11 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
base_url = "/"
if access_method == "path" and vendor:
# Use the full_prefix from vendor_context to determine which pattern was used
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
base_url = f"{full_prefix}{vendor.subdomain}/"
# Load footer navigation pages from CMS if db session provided
@@ -111,22 +121,16 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
vendor_id = vendor.id
# Get pages configured to show in footer
footer_pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=vendor_id,
- footer_only=True,
- include_unpublished=False
+ db, vendor_id=vendor_id, footer_only=True, include_unpublished=False
)
# Get pages configured to show in header
header_pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=vendor_id,
- header_only=True,
- include_unpublished=False
+ db, vendor_id=vendor_id, header_only=True, include_unpublished=False
)
except Exception as e:
logger.error(
f"[SHOP_CONTEXT] Failed to load navigation pages",
- extra={"error": str(e), "vendor_id": vendor.id if vendor else None}
+ extra={"error": str(e), "vendor_id": vendor.id if vendor else None},
)
context = {
@@ -156,7 +160,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
"footer_pages_count": len(footer_pages),
"header_pages_count": len(header_pages),
"extra_keys": list(extra_context.keys()) if extra_context else [],
- }
+ },
)
return context
@@ -166,6 +170,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
# PUBLIC SHOP ROUTES (No Authentication Required)
# ============================================================================
+
@router.get("/", response_class=HTMLResponse, include_in_schema=False)
@router.get("/products", response_class=HTMLResponse, include_in_schema=False)
async def shop_products_page(request: Request, db: Session = Depends(get_db)):
@@ -177,21 +182,21 @@ async def shop_products_page(request: Request, db: Session = Depends(get_db)):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/products.html",
- get_shop_context(request, db=db)
+ "shop/products.html", get_shop_context(request, db=db)
)
-@router.get("/products/{product_id}", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/products/{product_id}", response_class=HTMLResponse, include_in_schema=False
+)
async def shop_product_detail_page(
- request: Request,
- product_id: int = Path(..., description="Product ID")
+ request: Request, product_id: int = Path(..., description="Product ID")
):
"""
Render product detail page.
@@ -201,21 +206,21 @@ async def shop_product_detail_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/product.html",
- get_shop_context(request, product_id=product_id)
+ "shop/product.html", get_shop_context(request, product_id=product_id)
)
-@router.get("/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False
+)
async def shop_category_page(
- request: Request,
- category_slug: str = Path(..., description="Category slug")
+ request: Request, category_slug: str = Path(..., description="Category slug")
):
"""
Render category products page.
@@ -225,14 +230,13 @@ async def shop_category_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/category.html",
- get_shop_context(request, category_slug=category_slug)
+ "shop/category.html", get_shop_context(request, category_slug=category_slug)
)
@@ -246,15 +250,12 @@ async def shop_cart_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
- return templates.TemplateResponse(
- "shop/cart.html",
- get_shop_context(request)
- )
+ return templates.TemplateResponse("shop/cart.html", get_shop_context(request))
@router.get("/checkout", response_class=HTMLResponse, include_in_schema=False)
@@ -267,15 +268,12 @@ async def shop_checkout_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
- return templates.TemplateResponse(
- "shop/checkout.html",
- get_shop_context(request)
- )
+ return templates.TemplateResponse("shop/checkout.html", get_shop_context(request))
@router.get("/search", response_class=HTMLResponse, include_in_schema=False)
@@ -288,21 +286,19 @@ async def shop_search_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
- return templates.TemplateResponse(
- "shop/search.html",
- get_shop_context(request)
- )
+ return templates.TemplateResponse("shop/search.html", get_shop_context(request))
# ============================================================================
# CUSTOMER ACCOUNT - PUBLIC ROUTES (No Authentication)
# ============================================================================
+
@router.get("/account/register", response_class=HTMLResponse, include_in_schema=False)
async def shop_register_page(request: Request):
"""
@@ -313,14 +309,13 @@ async def shop_register_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/register.html",
- get_shop_context(request)
+ "shop/account/register.html", get_shop_context(request)
)
@@ -334,18 +329,19 @@ async def shop_login_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/login.html",
- get_shop_context(request)
+ "shop/account/login.html", get_shop_context(request)
)
-@router.get("/account/forgot-password", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/account/forgot-password", response_class=HTMLResponse, include_in_schema=False
+)
async def shop_forgot_password_page(request: Request):
"""
Render forgot password page.
@@ -355,14 +351,13 @@ async def shop_forgot_password_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/forgot-password.html",
- get_shop_context(request)
+ "shop/account/forgot-password.html", get_shop_context(request)
)
@@ -370,6 +365,7 @@ async def shop_forgot_password_page(request: Request):
# CUSTOMER ACCOUNT - AUTHENTICATED ROUTES
# ============================================================================
+
@router.get("/account", response_class=RedirectResponse, include_in_schema=False)
@router.get("/account/", response_class=RedirectResponse, include_in_schema=False)
async def shop_account_root(request: Request):
@@ -380,19 +376,27 @@ async def shop_account_root(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
# Get base_url from context for proper redirect
- vendor = getattr(request.state, 'vendor', None)
- vendor_context = getattr(request.state, 'vendor_context', None)
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ vendor = getattr(request.state, "vendor", None)
+ vendor_context = getattr(request.state, "vendor_context", None)
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
base_url = "/"
if access_method == "path" and vendor:
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
base_url = f"{full_prefix}{vendor.subdomain}/"
return RedirectResponse(url=f"{base_url}shop/account/dashboard", status_code=302)
@@ -400,9 +404,9 @@ async def shop_account_root(request: Request):
@router.get("/account/dashboard", response_class=HTMLResponse, include_in_schema=False)
async def shop_account_dashboard_page(
- request: Request,
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer account dashboard.
@@ -413,22 +417,21 @@ async def shop_account_dashboard_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/dashboard.html",
- get_shop_context(request, user=current_customer)
+ "shop/account/dashboard.html", get_shop_context(request, user=current_customer)
)
@router.get("/account/orders", response_class=HTMLResponse, include_in_schema=False)
async def shop_orders_page(
- request: Request,
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer orders history page.
@@ -439,23 +442,24 @@ async def shop_orders_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/orders.html",
- get_shop_context(request, user=current_customer)
+ "shop/account/orders.html", get_shop_context(request, user=current_customer)
)
-@router.get("/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False
+)
async def shop_order_detail_page(
- request: Request,
- order_id: int = Path(..., description="Order ID"),
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ order_id: int = Path(..., description="Order ID"),
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer order detail page.
@@ -466,22 +470,22 @@ async def shop_order_detail_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
"shop/account/order-detail.html",
- get_shop_context(request, user=current_customer, order_id=order_id)
+ get_shop_context(request, user=current_customer, order_id=order_id),
)
@router.get("/account/profile", response_class=HTMLResponse, include_in_schema=False)
async def shop_profile_page(
- request: Request,
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer profile page.
@@ -492,22 +496,21 @@ async def shop_profile_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/profile.html",
- get_shop_context(request, user=current_customer)
+ "shop/account/profile.html", get_shop_context(request, user=current_customer)
)
@router.get("/account/addresses", response_class=HTMLResponse, include_in_schema=False)
async def shop_addresses_page(
- request: Request,
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer addresses management page.
@@ -518,22 +521,21 @@ async def shop_addresses_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/addresses.html",
- get_shop_context(request, user=current_customer)
+ "shop/account/addresses.html", get_shop_context(request, user=current_customer)
)
@router.get("/account/wishlist", response_class=HTMLResponse, include_in_schema=False)
async def shop_wishlist_page(
- request: Request,
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer wishlist page.
@@ -544,22 +546,21 @@ async def shop_wishlist_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/wishlist.html",
- get_shop_context(request, user=current_customer)
+ "shop/account/wishlist.html", get_shop_context(request, user=current_customer)
)
@router.get("/account/settings", response_class=HTMLResponse, include_in_schema=False)
async def shop_settings_page(
- request: Request,
- current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
- db: Session = Depends(get_db)
+ request: Request,
+ current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
+ db: Session = Depends(get_db),
):
"""
Render customer account settings page.
@@ -570,14 +571,13 @@ async def shop_settings_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
return templates.TemplateResponse(
- "shop/account/settings.html",
- get_shop_context(request, user=current_customer)
+ "shop/account/settings.html", get_shop_context(request, user=current_customer)
)
@@ -585,11 +585,12 @@ async def shop_settings_page(
# DYNAMIC CONTENT PAGES (CMS)
# ============================================================================
+
@router.get("/{slug}", response_class=HTMLResponse, include_in_schema=False)
async def generic_content_page(
request: Request,
slug: str = Path(..., description="Content page slug"),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Generic content page handler (CMS).
@@ -612,20 +613,17 @@ async def generic_content_page(
extra={
"path": request.url.path,
"slug": slug,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
# Load content page from database (vendor override → platform default)
page = content_page_service.get_page_for_vendor(
- db,
- slug=slug,
- vendor_id=vendor_id,
- include_unpublished=False
+ db, slug=slug, vendor_id=vendor_id, include_unpublished=False
)
if not page:
@@ -635,7 +633,7 @@ async def generic_content_page(
"slug": slug,
"vendor_id": vendor_id,
"vendor_name": vendor.name if vendor else None,
- }
+ },
)
raise HTTPException(status_code=404, detail=f"Page not found: {slug}")
@@ -647,12 +645,11 @@ async def generic_content_page(
"page_title": page.title,
"is_vendor_override": page.vendor_id is not None,
"vendor_id": vendor_id,
- }
+ },
)
return templates.TemplateResponse(
- "shop/content-page.html",
- get_shop_context(request, page=page)
+ "shop/content-page.html", get_shop_context(request, page=page)
)
@@ -660,6 +657,7 @@ async def generic_content_page(
# DEBUG ENDPOINTS - For troubleshooting context issues
# ============================================================================
+
@router.get("/debug/context", response_class=HTMLResponse, include_in_schema=False)
async def debug_context(request: Request):
"""
@@ -670,8 +668,8 @@ async def debug_context(request: Request):
URL: /shop/debug/context
"""
- vendor = getattr(request.state, 'vendor', None)
- theme = getattr(request.state, 'theme', None)
+ vendor = getattr(request.state, "vendor", None)
+ theme = getattr(request.state, "theme", None)
debug_info = {
"path": request.url.path,
@@ -687,12 +685,13 @@ async def debug_context(request: Request):
"found": theme is not None,
"name": theme.get("theme_name") if theme else None,
},
- "clean_path": getattr(request.state, 'clean_path', 'NOT SET'),
- "context_type": str(getattr(request.state, 'context_type', 'NOT SET')),
+ "clean_path": getattr(request.state, "clean_path", "NOT SET"),
+ "context_type": str(getattr(request.state, "context_type", "NOT SET")),
}
# Return as JSON-like HTML for easy reading
import json
+
html_content = f"""
diff --git a/app/routes/vendor_pages.py b/app/routes/vendor_pages.py
index a6b24ad2..41dd702d 100644
--- a/app/routes/vendor_pages.py
+++ b/app/routes/vendor_pages.py
@@ -21,18 +21,16 @@ Routes:
- GET /vendor/{vendor_code}/settings → Vendor settings
"""
-from fastapi import APIRouter, Request, Depends, Path, HTTPException
+import logging
+from typing import Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
-from typing import Optional
-import logging
-from app.api.deps import (
- get_current_vendor_from_cookie_or_header,
- get_current_vendor_optional,
- get_db
-)
+from app.api.deps import (get_current_vendor_from_cookie_or_header,
+ get_current_vendor_optional, get_db)
from app.services.content_page_service import content_page_service
from models.database.user import User
@@ -46,6 +44,7 @@ templates = Jinja2Templates(directory="app/templates")
# PUBLIC ROUTES (No Authentication Required)
# ============================================================================
+
@router.get("/{vendor_code}", response_class=RedirectResponse, include_in_schema=False)
async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor code")):
"""
@@ -57,8 +56,8 @@ async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor
@router.get("/{vendor_code}/", response_class=RedirectResponse, include_in_schema=False)
async def vendor_root(
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: Optional[User] = Depends(get_current_vendor_optional)
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: Optional[User] = Depends(get_current_vendor_optional),
):
"""
Redirect /vendor/{code}/ based on authentication status.
@@ -73,11 +72,13 @@ async def vendor_root(
return RedirectResponse(url=f"/vendor/{vendor_code}/login", status_code=302)
-@router.get("/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_login_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: Optional[User] = Depends(get_current_vendor_optional)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: Optional[User] = Depends(get_current_vendor_optional),
):
"""
Render vendor login page.
@@ -99,7 +100,7 @@ async def vendor_login_page(
{
"request": request,
"vendor_code": vendor_code,
- }
+ },
)
@@ -107,11 +108,14 @@ async def vendor_login_page(
# AUTHENTICATED ROUTES (Vendor Users Only)
# ============================================================================
-@router.get("/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_dashboard_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render vendor dashboard.
@@ -128,7 +132,7 @@ async def vendor_dashboard_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -136,11 +140,14 @@ async def vendor_dashboard_page(
# PRODUCT MANAGEMENT
# ============================================================================
-@router.get("/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_products_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render products management page.
@@ -152,7 +159,7 @@ async def vendor_products_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -160,11 +167,14 @@ async def vendor_products_page(
# ORDER MANAGEMENT
# ============================================================================
-@router.get("/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_orders_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render orders management page.
@@ -176,7 +186,7 @@ async def vendor_orders_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -184,11 +194,14 @@ async def vendor_orders_page(
# CUSTOMER MANAGEMENT
# ============================================================================
-@router.get("/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_customers_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render customers management page.
@@ -200,7 +213,7 @@ async def vendor_customers_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -208,11 +221,14 @@ async def vendor_customers_page(
# INVENTORY MANAGEMENT
# ============================================================================
-@router.get("/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_inventory_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render inventory management page.
@@ -224,7 +240,7 @@ async def vendor_inventory_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -232,11 +248,14 @@ async def vendor_inventory_page(
# MARKETPLACE IMPORTS
# ============================================================================
-@router.get("/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_marketplace_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render marketplace import page.
@@ -248,7 +267,7 @@ async def vendor_marketplace_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -256,11 +275,12 @@ async def vendor_marketplace_page(
# TEAM MANAGEMENT
# ============================================================================
+
@router.get("/{vendor_code}/team", response_class=HTMLResponse, include_in_schema=False)
async def vendor_team_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render team management page.
@@ -272,7 +292,7 @@ async def vendor_team_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -280,11 +300,14 @@ async def vendor_team_page(
# PROFILE & SETTINGS
# ============================================================================
-@router.get("/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_profile_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render vendor profile page.
@@ -296,15 +319,17 @@ async def vendor_profile_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
-@router.get("/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False)
+@router.get(
+ "/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_settings_page(
- request: Request,
- vendor_code: str = Path(..., description="Vendor code"),
- current_user: User = Depends(get_current_vendor_from_cookie_or_header)
+ request: Request,
+ vendor_code: str = Path(..., description="Vendor code"),
+ current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render vendor settings page.
@@ -316,7 +341,7 @@ async def vendor_settings_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
- }
+ },
)
@@ -324,12 +349,15 @@ async def vendor_settings_page(
# DYNAMIC CONTENT PAGES (CMS)
# ============================================================================
-@router.get("/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False)
+
+@router.get(
+ "/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False
+)
async def vendor_content_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
slug: str = Path(..., description="Content page slug"),
- db: Session = Depends(get_db)
+ db: Session = Depends(get_db),
):
"""
Generic content page handler for vendor shop (CMS).
@@ -351,20 +379,17 @@ async def vendor_content_page(
"path": request.url.path,
"vendor_code": vendor_code,
"slug": slug,
- "vendor": getattr(request.state, 'vendor', 'NOT SET'),
- "context": getattr(request.state, 'context_type', 'NOT SET'),
- }
+ "vendor": getattr(request.state, "vendor", "NOT SET"),
+ "context": getattr(request.state, "context_type", "NOT SET"),
+ },
)
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
# Load content page from database (vendor override → platform default)
page = content_page_service.get_page_for_vendor(
- db,
- slug=slug,
- vendor_id=vendor_id,
- include_unpublished=False
+ db, slug=slug, vendor_id=vendor_id, include_unpublished=False
)
if not page:
@@ -374,7 +399,7 @@ async def vendor_content_page(
"slug": slug,
"vendor_code": vendor_code,
"vendor_id": vendor_id,
- }
+ },
)
raise HTTPException(status_code=404, detail="Page not found")
@@ -385,7 +410,7 @@ async def vendor_content_page(
"page_id": page.id,
"is_vendor_override": page.vendor_id is not None,
"vendor_id": vendor_id,
- }
+ },
)
return templates.TemplateResponse(
@@ -394,5 +419,5 @@ async def vendor_content_page(
"request": request,
"page": page,
"vendor_code": vendor_code,
- }
+ },
)
diff --git a/app/services/admin_audit_service.py b/app/services/admin_audit_service.py
index ee9da385..503a2c32 100644
--- a/app/services/admin_audit_service.py
+++ b/app/services/admin_audit_service.py
@@ -10,15 +10,15 @@ This module provides functions for:
import logging
from datetime import datetime, timezone
-from typing import List, Optional, Dict, Any
+from typing import Any, Dict, List, Optional
-from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
+from sqlalchemy.orm import Session
+from app.exceptions import AdminOperationException
from models.database.admin import AdminAuditLog
from models.database.user import User
from models.schema.admin import AdminAuditLogFilters, AdminAuditLogResponse
-from app.exceptions import AdminOperationException
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class AdminAuditService:
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
- request_id: Optional[str] = None
+ request_id: Optional[str] = None,
) -> AdminAuditLog:
"""
Log an admin action to the audit trail.
@@ -63,7 +63,7 @@ class AdminAuditService:
details=details or {},
ip_address=ip_address,
user_agent=user_agent,
- request_id=request_id
+ request_id=request_id,
)
db.add(audit_log)
@@ -84,9 +84,7 @@ class AdminAuditService:
return None
def get_audit_logs(
- self,
- db: Session,
- filters: AdminAuditLogFilters
+ self, db: Session, filters: AdminAuditLogFilters
) -> List[AdminAuditLogResponse]:
"""
Get filtered admin audit logs with pagination.
@@ -98,7 +96,9 @@ class AdminAuditService:
List of audit log responses
"""
try:
- query = db.query(AdminAuditLog).join(User, AdminAuditLog.admin_user_id == User.id)
+ query = db.query(AdminAuditLog).join(
+ User, AdminAuditLog.admin_user_id == User.id
+ )
# Apply filters
conditions = []
@@ -123,8 +123,7 @@ class AdminAuditService:
# Execute query with pagination
logs = (
- query
- .order_by(AdminAuditLog.created_at.desc())
+ query.order_by(AdminAuditLog.created_at.desc())
.offset(filters.skip)
.limit(filters.limit)
.all()
@@ -143,7 +142,7 @@ class AdminAuditService:
ip_address=log.ip_address,
user_agent=log.user_agent,
request_id=log.request_id,
- created_at=log.created_at
+ created_at=log.created_at,
)
for log in logs
]
@@ -151,15 +150,10 @@ class AdminAuditService:
except Exception as e:
logger.error(f"Failed to retrieve audit logs: {str(e)}")
raise AdminOperationException(
- operation="get_audit_logs",
- reason="Database query failed"
+ operation="get_audit_logs", reason="Database query failed"
)
- def get_audit_logs_count(
- self,
- db: Session,
- filters: AdminAuditLogFilters
- ) -> int:
+ def get_audit_logs_count(self, db: Session, filters: AdminAuditLogFilters) -> int:
"""Get total count of audit logs matching filters."""
try:
query = db.query(AdminAuditLog)
@@ -192,24 +186,14 @@ class AdminAuditService:
return 0
def get_recent_actions_by_admin(
- self,
- db: Session,
- admin_user_id: int,
- limit: int = 10
+ self, db: Session, admin_user_id: int, limit: int = 10
) -> List[AdminAuditLogResponse]:
"""Get recent actions by a specific admin."""
- filters = AdminAuditLogFilters(
- admin_user_id=admin_user_id,
- limit=limit
- )
+ filters = AdminAuditLogFilters(admin_user_id=admin_user_id, limit=limit)
return self.get_audit_logs(db, filters)
def get_actions_by_target(
- self,
- db: Session,
- target_type: str,
- target_id: str,
- limit: int = 50
+ self, db: Session, target_type: str, target_id: str, limit: int = 50
) -> List[AdminAuditLogResponse]:
"""Get all actions performed on a specific target."""
try:
@@ -218,7 +202,7 @@ class AdminAuditService:
.filter(
and_(
AdminAuditLog.target_type == target_type,
- AdminAuditLog.target_id == str(target_id)
+ AdminAuditLog.target_id == str(target_id),
)
)
.order_by(AdminAuditLog.created_at.desc())
@@ -236,7 +220,7 @@ class AdminAuditService:
target_id=log.target_id,
details=log.details,
ip_address=log.ip_address,
- created_at=log.created_at
+ created_at=log.created_at,
)
for log in logs
]
@@ -247,4 +231,4 @@ class AdminAuditService:
# Create service instance
-admin_audit_service = AdminAuditService()
\ No newline at end of file
+admin_audit_service = AdminAuditService()
diff --git a/app/services/admin_service.py b/app/services/admin_service.py
index f6529725..65637df8 100644
--- a/app/services/admin_service.py
+++ b/app/services/admin_service.py
@@ -16,24 +16,19 @@ import string
from datetime import datetime, timezone
from typing import List, Optional, Tuple
-from sqlalchemy.orm import Session
from sqlalchemy import func, or_
+from sqlalchemy.orm import Session
-from app.exceptions import (
- UserNotFoundException,
- UserStatusChangeException,
- CannotModifySelfException,
- VendorNotFoundException,
- VendorAlreadyExistsException,
- VendorVerificationException,
- AdminOperationException,
- ValidationException,
-)
+from app.exceptions import (AdminOperationException, CannotModifySelfException,
+ UserNotFoundException, UserStatusChangeException,
+ ValidationException, VendorAlreadyExistsException,
+ VendorNotFoundException,
+ VendorVerificationException)
+from models.database.marketplace_import_job import MarketplaceImportJob
+from models.database.user import User
+from models.database.vendor import Role, Vendor, VendorUser
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.schema.vendor import VendorCreate
-from models.database.marketplace_import_job import MarketplaceImportJob
-from models.database.vendor import Vendor, Role, VendorUser
-from models.database.user import User
logger = logging.getLogger(__name__)
@@ -52,12 +47,11 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve users: {str(e)}")
raise AdminOperationException(
- operation="get_all_users",
- reason="Database query failed"
+ operation="get_all_users", reason="Database query failed"
)
def toggle_user_status(
- self, db: Session, user_id: int, current_admin_id: int
+ self, db: Session, user_id: int, current_admin_id: int
) -> Tuple[User, str]:
"""Toggle user active status."""
user = self._get_user_by_id_or_raise(db, user_id)
@@ -72,7 +66,7 @@ class AdminService:
user_id=user_id,
current_status="admin",
attempted_action="toggle status",
- reason="Cannot modify another admin user"
+ reason="Cannot modify another admin user",
)
try:
@@ -95,7 +89,7 @@ class AdminService:
user_id=user_id,
current_status="active" if original_status else "inactive",
attempted_action="toggle status",
- reason="Database update failed"
+ reason="Database update failed",
)
# ============================================================================
@@ -103,7 +97,7 @@ class AdminService:
# ============================================================================
def create_vendor_with_owner(
- self, db: Session, vendor_data: VendorCreate
+ self, db: Session, vendor_data: VendorCreate
) -> Tuple[Vendor, User, str]:
"""
Create vendor with owner user account.
@@ -118,17 +112,23 @@ class AdminService:
"""
try:
# Check if vendor code already exists
- existing_vendor = db.query(Vendor).filter(
- func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
- ).first()
+ existing_vendor = (
+ db.query(Vendor)
+ .filter(
+ func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
+ )
+ .first()
+ )
if existing_vendor:
raise VendorAlreadyExistsException(vendor_data.vendor_code)
# Check if subdomain already exists
- existing_subdomain = db.query(Vendor).filter(
- func.lower(Vendor.subdomain) == vendor_data.subdomain.lower()
- ).first()
+ existing_subdomain = (
+ db.query(Vendor)
+ .filter(func.lower(Vendor.subdomain) == vendor_data.subdomain.lower())
+ .first()
+ )
if existing_subdomain:
raise ValidationException(
@@ -140,15 +140,14 @@ class AdminService:
# Create owner user with owner_email
from middleware.auth import AuthManager
+
auth_manager = AuthManager()
owner_username = f"{vendor_data.subdomain}_owner"
owner_email = vendor_data.owner_email # ✅ For User authentication
# Check if user with this email already exists
- existing_user = db.query(User).filter(
- User.email == owner_email
- ).first()
+ existing_user = db.query(User).filter(User.email == owner_email).first()
if existing_user:
# Use existing user as owner
@@ -215,17 +214,17 @@ class AdminService:
logger.error(f"Failed to create vendor: {str(e)}")
raise AdminOperationException(
operation="create_vendor_with_owner",
- reason=f"Failed to create vendor: {str(e)}"
+ reason=f"Failed to create vendor: {str(e)}",
)
def get_all_vendors(
- self,
- db: Session,
- skip: int = 0,
- limit: int = 100,
- search: Optional[str] = None,
- is_active: Optional[bool] = None,
- is_verified: Optional[bool] = None
+ self,
+ db: Session,
+ skip: int = 0,
+ limit: int = 100,
+ search: Optional[str] = None,
+ is_active: Optional[bool] = None,
+ is_verified: Optional[bool] = None,
) -> Tuple[List[Vendor], int]:
"""Get paginated list of all vendors with filtering."""
try:
@@ -238,7 +237,7 @@ class AdminService:
or_(
Vendor.name.ilike(search_term),
Vendor.vendor_code.ilike(search_term),
- Vendor.subdomain.ilike(search_term)
+ Vendor.subdomain.ilike(search_term),
)
)
@@ -255,8 +254,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve vendors: {str(e)}")
raise AdminOperationException(
- operation="get_all_vendors",
- reason="Database query failed"
+ operation="get_all_vendors", reason="Database query failed"
)
def get_vendor_by_id(self, db: Session, vendor_id: int) -> Vendor:
@@ -290,7 +288,7 @@ class AdminService:
raise VendorVerificationException(
vendor_id=vendor_id,
reason="Database update failed",
- current_verification_status=original_status
+ current_verification_status=original_status,
)
def toggle_vendor_status(self, db: Session, vendor_id: int) -> Tuple[Vendor, str]:
@@ -317,7 +315,7 @@ class AdminService:
operation="toggle_vendor_status",
reason="Database update failed",
target_type="vendor",
- target_id=str(vendor_id)
+ target_id=str(vendor_id),
)
def delete_vendor(self, db: Session, vendor_id: int) -> str:
@@ -345,15 +343,11 @@ class AdminService:
db.rollback()
logger.error(f"Failed to delete vendor {vendor_id}: {str(e)}")
raise AdminOperationException(
- operation="delete_vendor",
- reason="Database deletion failed"
+ operation="delete_vendor", reason="Database deletion failed"
)
def update_vendor(
- self,
- db: Session,
- vendor_id: int,
- vendor_update # VendorUpdate schema
+ self, db: Session, vendor_id: int, vendor_update # VendorUpdate schema
) -> Vendor:
"""
Update vendor information (Admin only).
@@ -387,11 +381,18 @@ class AdminService:
update_data = vendor_update.model_dump(exclude_unset=True)
# Check subdomain uniqueness if changing
- if 'subdomain' in update_data and update_data['subdomain'] != vendor.subdomain:
- existing = db.query(Vendor).filter(
- Vendor.subdomain == update_data['subdomain'],
- Vendor.id != vendor_id
- ).first()
+ if (
+ "subdomain" in update_data
+ and update_data["subdomain"] != vendor.subdomain
+ ):
+ existing = (
+ db.query(Vendor)
+ .filter(
+ Vendor.subdomain == update_data["subdomain"],
+ Vendor.id != vendor_id,
+ )
+ .first()
+ )
if existing:
raise ValidationException(
f"Subdomain '{update_data['subdomain']}' is already taken"
@@ -419,17 +420,16 @@ class AdminService:
db.rollback()
logger.error(f"Failed to update vendor {vendor_id}: {str(e)}")
raise AdminOperationException(
- operation="update_vendor",
- reason=f"Database update failed: {str(e)}"
+ operation="update_vendor", reason=f"Database update failed: {str(e)}"
)
# Add this NEW method for transferring ownership:
def transfer_vendor_ownership(
- self,
- db: Session,
- vendor_id: int,
- transfer_data # VendorTransferOwnership schema
+ self,
+ db: Session,
+ vendor_id: int,
+ transfer_data, # VendorTransferOwnership schema
) -> Tuple[Vendor, User, User]:
"""
Transfer vendor ownership to another user.
@@ -466,9 +466,9 @@ class AdminService:
old_owner = vendor.owner
# Get new owner
- new_owner = db.query(User).filter(
- User.id == transfer_data.new_owner_user_id
- ).first()
+ new_owner = (
+ db.query(User).filter(User.id == transfer_data.new_owner_user_id).first()
+ )
if not new_owner:
raise UserNotFoundException(str(transfer_data.new_owner_user_id))
@@ -487,26 +487,32 @@ class AdminService:
try:
# Get Owner role for this vendor
- owner_role = db.query(Role).filter(
- Role.vendor_id == vendor_id,
- Role.name == "Owner"
- ).first()
+ owner_role = (
+ db.query(Role)
+ .filter(Role.vendor_id == vendor_id, Role.name == "Owner")
+ .first()
+ )
if not owner_role:
raise ValidationException("Owner role not found for vendor")
# Get Manager role (to demote old owner)
- manager_role = db.query(Role).filter(
- Role.vendor_id == vendor_id,
- Role.name == "Manager"
- ).first()
+ manager_role = (
+ db.query(Role)
+ .filter(Role.vendor_id == vendor_id, Role.name == "Manager")
+ .first()
+ )
# Remove old owner from Owner role
- old_owner_link = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor_id,
- VendorUser.user_id == old_owner.id,
- VendorUser.role_id == owner_role.id
- ).first()
+ old_owner_link = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor_id,
+ VendorUser.user_id == old_owner.id,
+ VendorUser.role_id == owner_role.id,
+ )
+ .first()
+ )
if old_owner_link:
if manager_role:
@@ -525,10 +531,14 @@ class AdminService:
)
# Check if new owner already has a vendor_user link
- new_owner_link = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor_id,
- VendorUser.user_id == new_owner.id
- ).first()
+ new_owner_link = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor_id,
+ VendorUser.user_id == new_owner.id,
+ )
+ .first()
+ )
if new_owner_link:
# Update existing link to Owner role
@@ -540,7 +550,7 @@ class AdminService:
vendor_id=vendor_id,
user_id=new_owner.id,
role_id=owner_role.id,
- is_active=True
+ is_active=True,
)
db.add(new_owner_link)
@@ -568,10 +578,12 @@ class AdminService:
raise
except Exception as e:
db.rollback()
- logger.error(f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}")
+ logger.error(
+ f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}"
+ )
raise AdminOperationException(
operation="transfer_vendor_ownership",
- reason=f"Ownership transfer failed: {str(e)}"
+ reason=f"Ownership transfer failed: {str(e)}",
)
# ============================================================================
@@ -579,13 +591,13 @@ class AdminService:
# ============================================================================
def get_marketplace_import_jobs(
- self,
- db: Session,
- marketplace: Optional[str] = None,
- vendor_name: Optional[str] = None,
- status: Optional[str] = None,
- skip: int = 0,
- limit: int = 100,
+ self,
+ db: Session,
+ marketplace: Optional[str] = None,
+ vendor_name: Optional[str] = None,
+ status: Optional[str] = None,
+ skip: int = 0,
+ limit: int = 100,
) -> List[MarketplaceImportJobResponse]:
"""Get filtered and paginated marketplace import jobs."""
try:
@@ -596,7 +608,9 @@ class AdminService:
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
- query = query.filter(MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%"))
+ query = query.filter(
+ MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%")
+ )
if status:
query = query.filter(MarketplaceImportJob.status == status)
@@ -612,8 +626,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve marketplace import jobs: {str(e)}")
raise AdminOperationException(
- operation="get_marketplace_import_jobs",
- reason="Database query failed"
+ operation="get_marketplace_import_jobs", reason="Database query failed"
)
# ============================================================================
@@ -624,10 +637,7 @@ class AdminService:
"""Get recently created vendors."""
try:
vendors = (
- db.query(Vendor)
- .order_by(Vendor.created_at.desc())
- .limit(limit)
- .all()
+ db.query(Vendor).order_by(Vendor.created_at.desc()).limit(limit).all()
)
return [
@@ -638,7 +648,7 @@ class AdminService:
"subdomain": v.subdomain,
"is_active": v.is_active,
"is_verified": v.is_verified,
- "created_at": v.created_at
+ "created_at": v.created_at,
}
for v in vendors
]
@@ -663,7 +673,7 @@ class AdminService:
"vendor_name": j.vendor_name,
"status": j.status,
"total_processed": j.total_processed or 0,
- "created_at": j.created_at
+ "created_at": j.created_at,
}
for j in jobs
]
@@ -692,47 +702,53 @@ class AdminService:
def _generate_temp_password(self, length: int = 12) -> str:
"""Generate secure temporary password."""
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
- return ''.join(secrets.choice(alphabet) for _ in range(length))
+ return "".join(secrets.choice(alphabet) for _ in range(length))
def _create_default_roles(self, db: Session, vendor_id: int):
"""Create default roles for a new vendor."""
default_roles = [
- {
- "name": "Owner",
- "permissions": ["*"] # Full access
- },
+ {"name": "Owner", "permissions": ["*"]}, # Full access
{
"name": "Manager",
"permissions": [
- "products.*", "orders.*", "customers.view",
- "inventory.*", "team.view"
- ]
+ "products.*",
+ "orders.*",
+ "customers.view",
+ "inventory.*",
+ "team.view",
+ ],
},
{
"name": "Editor",
"permissions": [
- "products.view", "products.edit",
- "orders.view", "inventory.view"
- ]
+ "products.view",
+ "products.edit",
+ "orders.view",
+ "inventory.view",
+ ],
},
{
"name": "Viewer",
"permissions": [
- "products.view", "orders.view",
- "customers.view", "inventory.view"
- ]
- }
+ "products.view",
+ "orders.view",
+ "customers.view",
+ "inventory.view",
+ ],
+ },
]
for role_data in default_roles:
role = Role(
vendor_id=vendor_id,
name=role_data["name"],
- permissions=role_data["permissions"]
+ permissions=role_data["permissions"],
)
db.add(role)
- def _convert_job_to_response(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse:
+ def _convert_job_to_response(
+ self, job: MarketplaceImportJob
+ ) -> MarketplaceImportJobResponse:
"""Convert database model to response schema."""
return MarketplaceImportJobResponse(
job_id=job.id,
diff --git a/app/services/admin_settings_service.py b/app/services/admin_settings_service.py
index dec031af..fa4565f3 100644
--- a/app/services/admin_settings_service.py
+++ b/app/services/admin_settings_service.py
@@ -8,25 +8,19 @@ This module provides functions for:
- Encrypting sensitive settings
"""
-import logging
import json
-from typing import Optional, List, Any, Dict
+import logging
from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional
-from sqlalchemy.orm import Session
from sqlalchemy import func
+from sqlalchemy.orm import Session
+from app.exceptions import (AdminOperationException, ResourceNotFoundException,
+ ValidationException)
from models.database.admin import AdminSetting
-from models.schema.admin import (
- AdminSettingCreate,
- AdminSettingResponse,
- AdminSettingUpdate
-)
-from app.exceptions import (
- AdminOperationException,
- ValidationException,
- ResourceNotFoundException
-)
+from models.schema.admin import (AdminSettingCreate, AdminSettingResponse,
+ AdminSettingUpdate)
logger = logging.getLogger(__name__)
@@ -34,26 +28,19 @@ logger = logging.getLogger(__name__)
class AdminSettingsService:
"""Service for managing platform-wide settings."""
- def get_setting_by_key(
- self,
- db: Session,
- key: str
- ) -> Optional[AdminSetting]:
+ def get_setting_by_key(self, db: Session, key: str) -> Optional[AdminSetting]:
"""Get setting by key."""
try:
- return db.query(AdminSetting).filter(
- func.lower(AdminSetting.key) == key.lower()
- ).first()
+ return (
+ db.query(AdminSetting)
+ .filter(func.lower(AdminSetting.key) == key.lower())
+ .first()
+ )
except Exception as e:
logger.error(f"Failed to get setting {key}: {str(e)}")
return None
- def get_setting_value(
- self,
- db: Session,
- key: str,
- default: Any = None
- ) -> Any:
+ def get_setting_value(self, db: Session, key: str, default: Any = None) -> Any:
"""
Get setting value with type conversion.
@@ -76,7 +63,7 @@ class AdminSettingsService:
elif setting.value_type == "float":
return float(setting.value)
elif setting.value_type == "boolean":
- return setting.value.lower() in ('true', '1', 'yes')
+ return setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json":
return json.loads(setting.value)
else:
@@ -86,10 +73,10 @@ class AdminSettingsService:
return default
def get_all_settings(
- self,
- db: Session,
- category: Optional[str] = None,
- is_public: Optional[bool] = None
+ self,
+ db: Session,
+ category: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> List[AdminSettingResponse]:
"""Get all settings with optional filtering."""
try:
@@ -104,22 +91,16 @@ class AdminSettingsService:
settings = query.order_by(AdminSetting.category, AdminSetting.key).all()
return [
- AdminSettingResponse.model_validate(setting)
- for setting in settings
+ AdminSettingResponse.model_validate(setting) for setting in settings
]
except Exception as e:
logger.error(f"Failed to get settings: {str(e)}")
raise AdminOperationException(
- operation="get_all_settings",
- reason="Database query failed"
+ operation="get_all_settings", reason="Database query failed"
)
- def get_settings_by_category(
- self,
- db: Session,
- category: str
- ) -> Dict[str, Any]:
+ def get_settings_by_category(self, db: Session, category: str) -> Dict[str, Any]:
"""
Get all settings in a category as a dictionary.
@@ -136,7 +117,7 @@ class AdminSettingsService:
elif setting.value_type == "float":
result[setting.key] = float(setting.value)
elif setting.value_type == "boolean":
- result[setting.key] = setting.value.lower() in ('true', '1', 'yes')
+ result[setting.key] = setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json":
result[setting.key] = json.loads(setting.value)
else:
@@ -145,10 +126,7 @@ class AdminSettingsService:
return result
def create_setting(
- self,
- db: Session,
- setting_data: AdminSettingCreate,
- admin_user_id: int
+ self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
) -> AdminSettingResponse:
"""Create new setting."""
try:
@@ -176,7 +154,7 @@ class AdminSettingsService:
description=setting_data.description,
is_encrypted=setting_data.is_encrypted,
is_public=setting_data.is_public,
- last_modified_by_user_id=admin_user_id
+ last_modified_by_user_id=admin_user_id,
)
db.add(setting)
@@ -194,25 +172,17 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to create setting: {str(e)}")
raise AdminOperationException(
- operation="create_setting",
- reason="Database operation failed"
+ operation="create_setting", reason="Database operation failed"
)
def update_setting(
- self,
- db: Session,
- key: str,
- update_data: AdminSettingUpdate,
- admin_user_id: int
+ self, db: Session, key: str, update_data: AdminSettingUpdate, admin_user_id: int
) -> AdminSettingResponse:
"""Update existing setting."""
setting = self.get_setting_by_key(db, key)
if not setting:
- raise ResourceNotFoundException(
- resource_type="setting",
- identifier=key
- )
+ raise ResourceNotFoundException(resource_type="setting", identifier=key)
try:
# Validate new value
@@ -244,42 +214,29 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to update setting {key}: {str(e)}")
raise AdminOperationException(
- operation="update_setting",
- reason="Database operation failed"
+ operation="update_setting", reason="Database operation failed"
)
def upsert_setting(
- self,
- db: Session,
- setting_data: AdminSettingCreate,
- admin_user_id: int
+ self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
) -> AdminSettingResponse:
"""Create or update setting (upsert)."""
existing = self.get_setting_by_key(db, setting_data.key)
if existing:
update_data = AdminSettingUpdate(
- value=setting_data.value,
- description=setting_data.description
+ value=setting_data.value, description=setting_data.description
)
return self.update_setting(db, setting_data.key, update_data, admin_user_id)
else:
return self.create_setting(db, setting_data, admin_user_id)
- def delete_setting(
- self,
- db: Session,
- key: str,
- admin_user_id: int
- ) -> str:
+ def delete_setting(self, db: Session, key: str, admin_user_id: int) -> str:
"""Delete setting."""
setting = self.get_setting_by_key(db, key)
if not setting:
- raise ResourceNotFoundException(
- resource_type="setting",
- identifier=key
- )
+ raise ResourceNotFoundException(resource_type="setting", identifier=key)
try:
db.delete(setting)
@@ -293,8 +250,7 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to delete setting {key}: {str(e)}")
raise AdminOperationException(
- operation="delete_setting",
- reason="Database operation failed"
+ operation="delete_setting", reason="Database operation failed"
)
# ============================================================================
@@ -309,7 +265,7 @@ class AdminSettingsService:
elif value_type == "float":
float(value)
elif value_type == "boolean":
- if value.lower() not in ('true', 'false', '1', '0', 'yes', 'no'):
+ if value.lower() not in ("true", "false", "1", "0", "yes", "no"):
raise ValueError("Invalid boolean value")
elif value_type == "json":
json.loads(value)
diff --git a/app/services/audit_service.py b/app/services/audit_service.py
index ea03eb12..51cb387c 100644
--- a/app/services/audit_service.py
+++ b/app/services/audit_service.py
@@ -1 +1 @@
-# Audit logging services
+# Audit logging services
diff --git a/app/services/auth_service.py b/app/services/auth_service.py
index c2c6c400..f05da054 100644
--- a/app/services/auth_service.py
+++ b/app/services/auth_service.py
@@ -13,15 +13,12 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
-from app.exceptions import (
- UserAlreadyExistsException,
- InvalidCredentialsException,
- UserNotActiveException,
- ValidationException,
-)
+from app.exceptions import (InvalidCredentialsException,
+ UserAlreadyExistsException, UserNotActiveException,
+ ValidationException)
from middleware.auth import AuthManager
-from models.schema.auth import UserLogin, UserRegister
from models.database.user import User
+from models.schema.auth import UserLogin, UserRegister
logger = logging.getLogger(__name__)
@@ -51,11 +48,15 @@ class AuthService:
try:
# Check if email already exists
if self._email_exists(db, user_data.email):
- raise UserAlreadyExistsException("Email already registered", field="email")
+ raise UserAlreadyExistsException(
+ "Email already registered", field="email"
+ )
# Check if username already exists
if self._username_exists(db, user_data.username):
- raise UserAlreadyExistsException("Username already taken", field="username")
+ raise UserAlreadyExistsException(
+ "Username already taken", field="username"
+ )
# Hash password and create user
hashed_password = self.auth_manager.hash_password(user_data.password)
@@ -182,7 +183,9 @@ class AuthService:
Dictionary with access_token, token_type, and expires_in
"""
from datetime import datetime, timedelta, timezone
+
from jose import jwt
+
from app.core.config import settings
try:
@@ -217,6 +220,5 @@ class AuthService:
return db.query(User).filter(User.username == username).first() is not None
-
# Create service instance following the same pattern as other services
auth_service = AuthService()
diff --git a/app/services/backup_service.py b/app/services/backup_service.py
index b4d8b0d4..0759bea8 100644
--- a/app/services/backup_service.py
+++ b/app/services/backup_service.py
@@ -1 +1 @@
-# Backup and recovery services
+# Backup and recovery services
diff --git a/app/services/cache_service.py b/app/services/cache_service.py
index b0c83d5f..63e64747 100644
--- a/app/services/cache_service.py
+++ b/app/services/cache_service.py
@@ -1 +1 @@
-# Caching services
+# Caching services
diff --git a/app/services/cart_service.py b/app/services/cart_service.py
index d566b244..b3c8ddd1 100644
--- a/app/services/cart_service.py
+++ b/app/services/cart_service.py
@@ -9,23 +9,20 @@ This module provides:
"""
import logging
-from typing import Dict, List, Optional
from datetime import datetime, timezone
+from typing import Dict, List, Optional
-from sqlalchemy.orm import Session
from sqlalchemy import and_
+from sqlalchemy.orm import Session
+from app.exceptions import (CartItemNotFoundException, CartValidationException,
+ InsufficientInventoryForCartException,
+ InvalidCartQuantityException,
+ ProductNotAvailableForCartException,
+ ProductNotFoundException)
+from models.database.cart import CartItem
from models.database.product import Product
from models.database.vendor import Vendor
-from models.database.cart import CartItem
-from app.exceptions import (
- ProductNotFoundException,
- CartItemNotFoundException,
- CartValidationException,
- InsufficientInventoryForCartException,
- InvalidCartQuantityException,
- ProductNotAvailableForCartException,
-)
logger = logging.getLogger(__name__)
@@ -33,12 +30,7 @@ logger = logging.getLogger(__name__)
class CartService:
"""Service for managing shopping carts."""
- def get_cart(
- self,
- db: Session,
- vendor_id: int,
- session_id: str
- ) -> Dict:
+ def get_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
"""
Get cart contents for a session.
@@ -55,20 +47,21 @@ class CartService:
extra={
"vendor_id": vendor_id,
"session_id": session_id,
- }
+ },
)
# Fetch cart items from database
- cart_items = db.query(CartItem).filter(
- and_(
- CartItem.vendor_id == vendor_id,
- CartItem.session_id == session_id
+ cart_items = (
+ db.query(CartItem)
+ .filter(
+ and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
- ).all()
+ .all()
+ )
logger.info(
f"[CART_SERVICE] Found {len(cart_items)} items in database",
- extra={"item_count": len(cart_items)}
+ extra={"item_count": len(cart_items)},
)
# Build response
@@ -79,14 +72,20 @@ class CartService:
product = cart_item.product
line_total = cart_item.line_total
- items.append({
- "product_id": product.id,
- "product_name": product.marketplace_product.title,
- "quantity": cart_item.quantity,
- "price": cart_item.price_at_add,
- "line_total": line_total,
- "image_url": product.marketplace_product.image_link if product.marketplace_product else None,
- })
+ items.append(
+ {
+ "product_id": product.id,
+ "product_name": product.marketplace_product.title,
+ "quantity": cart_item.quantity,
+ "price": cart_item.price_at_add,
+ "line_total": line_total,
+ "image_url": (
+ product.marketplace_product.image_link
+ if product.marketplace_product
+ else None
+ ),
+ }
+ )
subtotal += line_total
@@ -95,23 +94,23 @@ class CartService:
"session_id": session_id,
"items": items,
"subtotal": subtotal,
- "total": subtotal # Could add tax/shipping later
+ "total": subtotal, # Could add tax/shipping later
}
logger.info(
f"[CART_SERVICE] get_cart returning: {len(cart_data['items'])} items, total: {cart_data['total']}",
- extra={"cart": cart_data}
+ extra={"cart": cart_data},
)
return cart_data
def add_to_cart(
- self,
- db: Session,
- vendor_id: int,
- session_id: str,
- product_id: int,
- quantity: int = 1
+ self,
+ db: Session,
+ vendor_id: int,
+ session_id: str,
+ product_id: int,
+ quantity: int = 1,
) -> Dict:
"""
Add product to cart.
@@ -136,23 +135,27 @@ class CartService:
"vendor_id": vendor_id,
"session_id": session_id,
"product_id": product_id,
- "quantity": quantity
- }
+ "quantity": quantity,
+ },
)
# Verify product exists and belongs to vendor
- product = db.query(Product).filter(
- and_(
- Product.id == product_id,
- Product.vendor_id == vendor_id,
- Product.is_active == True
+ product = (
+ db.query(Product)
+ .filter(
+ and_(
+ Product.id == product_id,
+ Product.vendor_id == vendor_id,
+ Product.is_active == True,
+ )
)
- ).first()
+ .first()
+ )
if not product:
logger.error(
f"[CART_SERVICE] Product not found",
- extra={"product_id": product_id, "vendor_id": vendor_id}
+ extra={"product_id": product_id, "vendor_id": vendor_id},
)
raise ProductNotFoundException(product_id=product_id, vendor_id=vendor_id)
@@ -161,21 +164,25 @@ class CartService:
extra={
"product_id": product_id,
"product_name": product.marketplace_product.title,
- "available_inventory": product.available_inventory
- }
+ "available_inventory": product.available_inventory,
+ },
)
# Get current price (use sale_price if available, otherwise regular price)
current_price = product.sale_price if product.sale_price else product.price
# Check if item already exists in cart
- existing_item = db.query(CartItem).filter(
- and_(
- CartItem.vendor_id == vendor_id,
- CartItem.session_id == session_id,
- CartItem.product_id == product_id
+ existing_item = (
+ db.query(CartItem)
+ .filter(
+ and_(
+ CartItem.vendor_id == vendor_id,
+ CartItem.session_id == session_id,
+ CartItem.product_id == product_id,
+ )
)
- ).first()
+ .first()
+ )
if existing_item:
# Update quantity
@@ -190,14 +197,14 @@ class CartService:
"current_in_cart": existing_item.quantity,
"adding": quantity,
"requested_total": new_quantity,
- "available": product.available_inventory
- }
+ "available": product.available_inventory,
+ },
)
raise InsufficientInventoryForCartException(
product_id=product_id,
product_name=product.marketplace_product.title,
requested=new_quantity,
- available=product.available_inventory
+ available=product.available_inventory,
)
existing_item.quantity = new_quantity
@@ -206,16 +213,13 @@ class CartService:
logger.info(
f"[CART_SERVICE] Updated existing cart item",
- extra={
- "cart_item_id": existing_item.id,
- "new_quantity": new_quantity
- }
+ extra={"cart_item_id": existing_item.id, "new_quantity": new_quantity},
)
return {
"message": "Product quantity updated in cart",
"product_id": product_id,
- "quantity": new_quantity
+ "quantity": new_quantity,
}
else:
# Check inventory for new item
@@ -225,14 +229,14 @@ class CartService:
extra={
"product_id": product_id,
"requested": quantity,
- "available": product.available_inventory
- }
+ "available": product.available_inventory,
+ },
)
raise InsufficientInventoryForCartException(
product_id=product_id,
product_name=product.marketplace_product.title,
requested=quantity,
- available=product.available_inventory
+ available=product.available_inventory,
)
# Create new cart item
@@ -241,7 +245,7 @@ class CartService:
session_id=session_id,
product_id=product_id,
quantity=quantity,
- price_at_add=current_price
+ price_at_add=current_price,
)
db.add(cart_item)
db.commit()
@@ -252,23 +256,23 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"quantity": quantity,
- "price": current_price
- }
+ "price": current_price,
+ },
)
return {
"message": "Product added to cart",
"product_id": product_id,
- "quantity": quantity
+ "quantity": quantity,
}
def update_cart_item(
- self,
- db: Session,
- vendor_id: int,
- session_id: str,
- product_id: int,
- quantity: int
+ self,
+ db: Session,
+ vendor_id: int,
+ session_id: str,
+ product_id: int,
+ quantity: int,
) -> Dict:
"""
Update quantity of item in cart.
@@ -292,25 +296,35 @@ class CartService:
raise InvalidCartQuantityException(quantity=quantity, min_quantity=1)
# Find cart item
- cart_item = db.query(CartItem).filter(
- and_(
- CartItem.vendor_id == vendor_id,
- CartItem.session_id == session_id,
- CartItem.product_id == product_id
+ cart_item = (
+ db.query(CartItem)
+ .filter(
+ and_(
+ CartItem.vendor_id == vendor_id,
+ CartItem.session_id == session_id,
+ CartItem.product_id == product_id,
+ )
)
- ).first()
+ .first()
+ )
if not cart_item:
- raise CartItemNotFoundException(product_id=product_id, session_id=session_id)
+ raise CartItemNotFoundException(
+ product_id=product_id, session_id=session_id
+ )
# Verify product still exists and is active
- product = db.query(Product).filter(
- and_(
- Product.id == product_id,
- Product.vendor_id == vendor_id,
- Product.is_active == True
+ product = (
+ db.query(Product)
+ .filter(
+ and_(
+ Product.id == product_id,
+ Product.vendor_id == vendor_id,
+ Product.is_active == True,
+ )
)
- ).first()
+ .first()
+ )
if not product:
raise ProductNotFoundException(str(product_id))
@@ -321,7 +335,7 @@ class CartService:
product_id=product_id,
product_name=product.marketplace_product.title,
requested=quantity,
- available=product.available_inventory
+ available=product.available_inventory,
)
# Update quantity
@@ -334,22 +348,18 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"product_id": product_id,
- "new_quantity": quantity
- }
+ "new_quantity": quantity,
+ },
)
return {
"message": "Cart updated",
"product_id": product_id,
- "quantity": quantity
+ "quantity": quantity,
}
def remove_from_cart(
- self,
- db: Session,
- vendor_id: int,
- session_id: str,
- product_id: int
+ self, db: Session, vendor_id: int, session_id: str, product_id: int
) -> Dict:
"""
Remove item from cart.
@@ -367,16 +377,22 @@ class CartService:
ProductNotFoundException: If product not in cart
"""
# Find and delete cart item
- cart_item = db.query(CartItem).filter(
- and_(
- CartItem.vendor_id == vendor_id,
- CartItem.session_id == session_id,
- CartItem.product_id == product_id
+ cart_item = (
+ db.query(CartItem)
+ .filter(
+ and_(
+ CartItem.vendor_id == vendor_id,
+ CartItem.session_id == session_id,
+ CartItem.product_id == product_id,
+ )
)
- ).first()
+ .first()
+ )
if not cart_item:
- raise CartItemNotFoundException(product_id=product_id, session_id=session_id)
+ raise CartItemNotFoundException(
+ product_id=product_id, session_id=session_id
+ )
db.delete(cart_item)
db.commit()
@@ -386,21 +402,13 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"product_id": product_id,
- "session_id": session_id
- }
+ "session_id": session_id,
+ },
)
- return {
- "message": "Item removed from cart",
- "product_id": product_id
- }
+ return {"message": "Item removed from cart", "product_id": product_id}
- def clear_cart(
- self,
- db: Session,
- vendor_id: int,
- session_id: str
- ) -> Dict:
+ def clear_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
"""
Clear all items from cart.
@@ -413,12 +421,13 @@ class CartService:
Success message with count of items removed
"""
# Delete all cart items for this session
- deleted_count = db.query(CartItem).filter(
- and_(
- CartItem.vendor_id == vendor_id,
- CartItem.session_id == session_id
+ deleted_count = (
+ db.query(CartItem)
+ .filter(
+ and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
- ).delete()
+ .delete()
+ )
db.commit()
@@ -427,14 +436,11 @@ class CartService:
extra={
"session_id": session_id,
"vendor_id": vendor_id,
- "items_removed": deleted_count
- }
+ "items_removed": deleted_count,
+ },
)
- return {
- "message": "Cart cleared",
- "items_removed": deleted_count
- }
+ return {"message": "Cart cleared", "items_removed": deleted_count}
# Create service instance
diff --git a/app/services/code_quality_service.py b/app/services/code_quality_service.py
index d129ff2b..ff058e78 100644
--- a/app/services/code_quality_service.py
+++ b/app/services/code_quality_service.py
@@ -3,22 +3,20 @@ Code Quality Service
Business logic for managing architecture scans and violations
"""
-import subprocess
import json
import logging
+import subprocess
from datetime import datetime
-from typing import List, Tuple, Optional, Dict
from pathlib import Path
-from sqlalchemy.orm import Session
-from sqlalchemy import func, desc
+from typing import Dict, List, Optional, Tuple
-from app.models.architecture_scan import (
- ArchitectureScan,
- ArchitectureViolation,
- ArchitectureRule,
- ViolationAssignment,
- ViolationComment
-)
+from sqlalchemy import desc, func
+from sqlalchemy.orm import Session
+
+from app.models.architecture_scan import (ArchitectureRule, ArchitectureScan,
+ ArchitectureViolation,
+ ViolationAssignment,
+ ViolationComment)
logger = logging.getLogger(__name__)
@@ -26,7 +24,7 @@ logger = logging.getLogger(__name__)
class CodeQualityService:
"""Service for managing code quality scans and violations"""
- def run_scan(self, db: Session, triggered_by: str = 'manual') -> ArchitectureScan:
+ def run_scan(self, db: Session, triggered_by: str = "manual") -> ArchitectureScan:
"""
Run architecture validator and store results in database
@@ -49,10 +47,10 @@ class CodeQualityService:
start_time = datetime.now()
try:
result = subprocess.run(
- ['python', 'scripts/validate_architecture.py', '--json'],
+ ["python", "scripts/validate_architecture.py", "--json"],
capture_output=True,
text=True,
- timeout=300 # 5 minute timeout
+ timeout=300, # 5 minute timeout
)
except subprocess.TimeoutExpired:
logger.error("Architecture scan timed out after 5 minutes")
@@ -63,17 +61,17 @@ class CodeQualityService:
# Parse JSON output (get only the JSON part, skip progress messages)
try:
# Find the JSON part in stdout
- lines = result.stdout.strip().split('\n')
+ lines = result.stdout.strip().split("\n")
json_start = -1
for i, line in enumerate(lines):
- if line.strip().startswith('{'):
+ if line.strip().startswith("{"):
json_start = i
break
if json_start == -1:
raise ValueError("No JSON output found")
- json_output = '\n'.join(lines[json_start:])
+ json_output = "\n".join(lines[json_start:])
data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse validator output: {e}")
@@ -84,33 +82,33 @@ class CodeQualityService:
# Create scan record
scan = ArchitectureScan(
timestamp=datetime.now(),
- total_files=data.get('files_checked', 0),
- total_violations=data.get('total_violations', 0),
- errors=data.get('errors', 0),
- warnings=data.get('warnings', 0),
+ total_files=data.get("files_checked", 0),
+ total_violations=data.get("total_violations", 0),
+ errors=data.get("errors", 0),
+ warnings=data.get("warnings", 0),
duration_seconds=duration,
triggered_by=triggered_by,
- git_commit_hash=git_commit
+ git_commit_hash=git_commit,
)
db.add(scan)
db.flush() # Get scan.id
# Create violation records
- violations_data = data.get('violations', [])
+ violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} violation records")
for v in violations_data:
violation = ArchitectureViolation(
scan_id=scan.id,
- rule_id=v['rule_id'],
- rule_name=v['rule_name'],
- severity=v['severity'],
- file_path=v['file_path'],
- line_number=v['line_number'],
- message=v['message'],
- context=v.get('context', ''),
- suggestion=v.get('suggestion', ''),
- status='open'
+ rule_id=v["rule_id"],
+ rule_name=v["rule_name"],
+ severity=v["severity"],
+ file_path=v["file_path"],
+ line_number=v["line_number"],
+ message=v["message"],
+ context=v.get("context", ""),
+ suggestion=v.get("suggestion", ""),
+ status="open",
)
db.add(violation)
@@ -122,7 +120,11 @@ class CodeQualityService:
def get_latest_scan(self, db: Session) -> Optional[ArchitectureScan]:
"""Get the most recent scan"""
- return db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)).first()
+ return (
+ db.query(ArchitectureScan)
+ .order_by(desc(ArchitectureScan.timestamp))
+ .first()
+ )
def get_scan_by_id(self, db: Session, scan_id: int) -> Optional[ArchitectureScan]:
"""Get scan by ID"""
@@ -139,10 +141,12 @@ class CodeQualityService:
Returns:
List of ArchitectureScan objects, newest first
"""
- return db.query(ArchitectureScan)\
- .order_by(desc(ArchitectureScan.timestamp))\
- .limit(limit)\
+ return (
+ db.query(ArchitectureScan)
+ .order_by(desc(ArchitectureScan.timestamp))
+ .limit(limit)
.all()
+ )
def get_violations(
self,
@@ -153,7 +157,7 @@ class CodeQualityService:
rule_id: str = None,
file_path: str = None,
limit: int = 100,
- offset: int = 0
+ offset: int = 0,
) -> Tuple[List[ArchitectureViolation], int]:
"""
Get violations with filtering and pagination
@@ -194,24 +198,32 @@ class CodeQualityService:
query = query.filter(ArchitectureViolation.rule_id == rule_id)
if file_path:
- query = query.filter(ArchitectureViolation.file_path.like(f'%{file_path}%'))
+ query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
# Get total count
total = query.count()
# Get page of results
- violations = query.order_by(
- ArchitectureViolation.severity.desc(),
- ArchitectureViolation.file_path
- ).limit(limit).offset(offset).all()
+ violations = (
+ query.order_by(
+ ArchitectureViolation.severity.desc(), ArchitectureViolation.file_path
+ )
+ .limit(limit)
+ .offset(offset)
+ .all()
+ )
return violations, total
- def get_violation_by_id(self, db: Session, violation_id: int) -> Optional[ArchitectureViolation]:
+ def get_violation_by_id(
+ self, db: Session, violation_id: int
+ ) -> Optional[ArchitectureViolation]:
"""Get single violation with details"""
- return db.query(ArchitectureViolation).filter(
- ArchitectureViolation.id == violation_id
- ).first()
+ return (
+ db.query(ArchitectureViolation)
+ .filter(ArchitectureViolation.id == violation_id)
+ .first()
+ )
def assign_violation(
self,
@@ -220,7 +232,7 @@ class CodeQualityService:
user_id: int,
assigned_by: int,
due_date: datetime = None,
- priority: str = 'medium'
+ priority: str = "medium",
) -> ViolationAssignment:
"""
Assign violation to a developer
@@ -239,7 +251,7 @@ class CodeQualityService:
# Update violation status
violation = self.get_violation_by_id(db, violation_id)
if violation:
- violation.status = 'assigned'
+ violation.status = "assigned"
violation.assigned_to = user_id
# Create assignment record
@@ -248,7 +260,7 @@ class CodeQualityService:
user_id=user_id,
assigned_by=assigned_by,
due_date=due_date,
- priority=priority
+ priority=priority,
)
db.add(assignment)
db.commit()
@@ -257,11 +269,7 @@ class CodeQualityService:
return assignment
def resolve_violation(
- self,
- db: Session,
- violation_id: int,
- resolved_by: int,
- resolution_note: str
+ self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
) -> ArchitectureViolation:
"""
Mark violation as resolved
@@ -279,7 +287,7 @@ class CodeQualityService:
if not violation:
raise ValueError(f"Violation {violation_id} not found")
- violation.status = 'resolved'
+ violation.status = "resolved"
violation.resolved_at = datetime.now()
violation.resolved_by = resolved_by
violation.resolution_note = resolution_note
@@ -289,11 +297,7 @@ class CodeQualityService:
return violation
def ignore_violation(
- self,
- db: Session,
- violation_id: int,
- ignored_by: int,
- reason: str
+ self, db: Session, violation_id: int, ignored_by: int, reason: str
) -> ArchitectureViolation:
"""
Mark violation as ignored/won't fix
@@ -311,7 +315,7 @@ class CodeQualityService:
if not violation:
raise ValueError(f"Violation {violation_id} not found")
- violation.status = 'ignored'
+ violation.status = "ignored"
violation.resolved_at = datetime.now()
violation.resolved_by = ignored_by
violation.resolution_note = f"Ignored: {reason}"
@@ -321,11 +325,7 @@ class CodeQualityService:
return violation
def add_comment(
- self,
- db: Session,
- violation_id: int,
- user_id: int,
- comment: str
+ self, db: Session, violation_id: int, user_id: int, comment: str
) -> ViolationComment:
"""
Add comment to violation
@@ -340,9 +340,7 @@ class CodeQualityService:
ViolationComment object
"""
comment_obj = ViolationComment(
- violation_id=violation_id,
- user_id=user_id,
- comment=comment
+ violation_id=violation_id, user_id=user_id, comment=comment
)
db.add(comment_obj)
db.commit()
@@ -360,79 +358,95 @@ class CodeQualityService:
latest_scan = self.get_latest_scan(db)
if not latest_scan:
return {
- 'total_violations': 0,
- 'errors': 0,
- 'warnings': 0,
- 'open': 0,
- 'assigned': 0,
- 'resolved': 0,
- 'ignored': 0,
- 'technical_debt_score': 100,
- 'trend': [],
- 'by_severity': {},
- 'by_rule': {},
- 'by_module': {},
- 'top_files': []
+ "total_violations": 0,
+ "errors": 0,
+ "warnings": 0,
+ "open": 0,
+ "assigned": 0,
+ "resolved": 0,
+ "ignored": 0,
+ "technical_debt_score": 100,
+ "trend": [],
+ "by_severity": {},
+ "by_rule": {},
+ "by_module": {},
+ "top_files": [],
}
# Get violation counts by status
- status_counts = db.query(
- ArchitectureViolation.status,
- func.count(ArchitectureViolation.id)
- ).filter(
- ArchitectureViolation.scan_id == latest_scan.id
- ).group_by(ArchitectureViolation.status).all()
+ status_counts = (
+ db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
+ .filter(ArchitectureViolation.scan_id == latest_scan.id)
+ .group_by(ArchitectureViolation.status)
+ .all()
+ )
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
- severity_counts = db.query(
- ArchitectureViolation.severity,
- func.count(ArchitectureViolation.id)
- ).filter(
- ArchitectureViolation.scan_id == latest_scan.id
- ).group_by(ArchitectureViolation.severity).all()
+ severity_counts = (
+ db.query(
+ ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
+ )
+ .filter(ArchitectureViolation.scan_id == latest_scan.id)
+ .group_by(ArchitectureViolation.severity)
+ .all()
+ )
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule
- rule_counts = db.query(
- ArchitectureViolation.rule_id,
- func.count(ArchitectureViolation.id)
- ).filter(
- ArchitectureViolation.scan_id == latest_scan.id
- ).group_by(ArchitectureViolation.rule_id).all()
+ rule_counts = (
+ db.query(
+ ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
+ )
+ .filter(ArchitectureViolation.scan_id == latest_scan.id)
+ .group_by(ArchitectureViolation.rule_id)
+ .all()
+ )
- by_rule = {rule: count for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]}
+ by_rule = {
+ rule: count
+ for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[
+ :10
+ ]
+ }
# Get top violating files
- file_counts = db.query(
- ArchitectureViolation.file_path,
- func.count(ArchitectureViolation.id).label('count')
- ).filter(
- ArchitectureViolation.scan_id == latest_scan.id
- ).group_by(ArchitectureViolation.file_path)\
- .order_by(desc('count'))\
- .limit(10).all()
+ file_counts = (
+ db.query(
+ ArchitectureViolation.file_path,
+ func.count(ArchitectureViolation.id).label("count"),
+ )
+ .filter(ArchitectureViolation.scan_id == latest_scan.id)
+ .group_by(ArchitectureViolation.file_path)
+ .order_by(desc("count"))
+ .limit(10)
+ .all()
+ )
- top_files = [{'file': file, 'count': count} for file, count in file_counts]
+ top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module (extract module from file path)
by_module = {}
- violations = db.query(ArchitectureViolation.file_path).filter(
- ArchitectureViolation.scan_id == latest_scan.id
- ).all()
+ violations = (
+ db.query(ArchitectureViolation.file_path)
+ .filter(ArchitectureViolation.scan_id == latest_scan.id)
+ .all()
+ )
for v in violations:
- path_parts = v.file_path.split('/')
+ path_parts = v.file_path.split("/")
if len(path_parts) >= 2:
- module = '/'.join(path_parts[:2]) # e.g., 'app/api'
+ module = "/".join(path_parts[:2]) # e.g., 'app/api'
else:
module = path_parts[0]
by_module[module] = by_module.get(module, 0) + 1
# Sort by count and take top 10
- by_module = dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
+ by_module = dict(
+ sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
+ )
# Calculate technical debt score
tech_debt_score = self.calculate_technical_debt_score(db, latest_scan.id)
@@ -441,29 +455,29 @@ class CodeQualityService:
trend_scans = self.get_scan_history(db, limit=7)
trend = [
{
- 'timestamp': scan.timestamp.isoformat(),
- 'violations': scan.total_violations,
- 'errors': scan.errors,
- 'warnings': scan.warnings
+ "timestamp": scan.timestamp.isoformat(),
+ "violations": scan.total_violations,
+ "errors": scan.errors,
+ "warnings": scan.warnings,
}
for scan in reversed(trend_scans) # Oldest first for chart
]
return {
- 'total_violations': latest_scan.total_violations,
- 'errors': latest_scan.errors,
- 'warnings': latest_scan.warnings,
- 'open': status_dict.get('open', 0),
- 'assigned': status_dict.get('assigned', 0),
- 'resolved': status_dict.get('resolved', 0),
- 'ignored': status_dict.get('ignored', 0),
- 'technical_debt_score': tech_debt_score,
- 'trend': trend,
- 'by_severity': by_severity,
- 'by_rule': by_rule,
- 'by_module': by_module,
- 'top_files': top_files,
- 'last_scan': latest_scan.timestamp.isoformat() if latest_scan else None
+ "total_violations": latest_scan.total_violations,
+ "errors": latest_scan.errors,
+ "warnings": latest_scan.warnings,
+ "open": status_dict.get("open", 0),
+ "assigned": status_dict.get("assigned", 0),
+ "resolved": status_dict.get("resolved", 0),
+ "ignored": status_dict.get("ignored", 0),
+ "technical_debt_score": tech_debt_score,
+ "trend": trend,
+ "by_severity": by_severity,
+ "by_rule": by_rule,
+ "by_module": by_module,
+ "top_files": top_files,
+ "last_scan": latest_scan.timestamp.isoformat() if latest_scan else None,
}
def calculate_technical_debt_score(self, db: Session, scan_id: int = None) -> int:
@@ -497,10 +511,7 @@ class CodeQualityService:
"""Get current git commit hash"""
try:
result = subprocess.run(
- ['git', 'rev-parse', 'HEAD'],
- capture_output=True,
- text=True,
- timeout=5
+ ["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return result.stdout.strip()[:40]
diff --git a/app/services/configuration_service.py b/app/services/configuration_service.py
index ff83d3b4..54f872ab 100644
--- a/app/services/configuration_service.py
+++ b/app/services/configuration_service.py
@@ -1 +1 @@
-# Configuration management services
+# Configuration management services
diff --git a/app/services/content_page_service.py b/app/services/content_page_service.py
index beb15fe4..af17bc7d 100644
--- a/app/services/content_page_service.py
+++ b/app/services/content_page_service.py
@@ -19,8 +19,9 @@ This allows:
import logging
from datetime import datetime, timezone
from typing import List, Optional
-from sqlalchemy.orm import Session
+
from sqlalchemy import and_, or_
+from sqlalchemy.orm import Session
from models.database.content_page import ContentPage
@@ -35,7 +36,7 @@ class ContentPageService:
db: Session,
slug: str,
vendor_id: Optional[int] = None,
- include_unpublished: bool = False
+ include_unpublished: bool = False,
) -> Optional[ContentPage]:
"""
Get content page for a vendor with fallback to platform default.
@@ -62,28 +63,20 @@ class ContentPageService:
if vendor_id:
vendor_page = (
db.query(ContentPage)
- .filter(
- and_(
- ContentPage.vendor_id == vendor_id,
- *filters
- )
- )
+ .filter(and_(ContentPage.vendor_id == vendor_id, *filters))
.first()
)
if vendor_page:
- logger.debug(f"Found vendor-specific page: {slug} for vendor_id={vendor_id}")
+ logger.debug(
+ f"Found vendor-specific page: {slug} for vendor_id={vendor_id}"
+ )
return vendor_page
# Fallback to platform default
platform_page = (
db.query(ContentPage)
- .filter(
- and_(
- ContentPage.vendor_id == None,
- *filters
- )
- )
+ .filter(and_(ContentPage.vendor_id == None, *filters))
.first()
)
@@ -100,7 +93,7 @@ class ContentPageService:
vendor_id: Optional[int] = None,
include_unpublished: bool = False,
footer_only: bool = False,
- header_only: bool = False
+ header_only: bool = False,
) -> List[ContentPage]:
"""
List all available pages for a vendor (includes vendor overrides + platform defaults).
@@ -133,12 +126,7 @@ class ContentPageService:
if vendor_id:
vendor_pages = (
db.query(ContentPage)
- .filter(
- and_(
- ContentPage.vendor_id == vendor_id,
- *filters
- )
- )
+ .filter(and_(ContentPage.vendor_id == vendor_id, *filters))
.order_by(ContentPage.display_order, ContentPage.title)
.all()
)
@@ -146,12 +134,7 @@ class ContentPageService:
# Get platform defaults
platform_pages = (
db.query(ContentPage)
- .filter(
- and_(
- ContentPage.vendor_id == None,
- *filters
- )
- )
+ .filter(and_(ContentPage.vendor_id == None, *filters))
.order_by(ContentPage.display_order, ContentPage.title)
.all()
)
@@ -159,8 +142,7 @@ class ContentPageService:
# Merge: vendor overrides take precedence
vendor_slugs = {page.slug for page in vendor_pages}
all_pages = vendor_pages + [
- page for page in platform_pages
- if page.slug not in vendor_slugs
+ page for page in platform_pages if page.slug not in vendor_slugs
]
# Sort by display_order
@@ -183,7 +165,7 @@ class ContentPageService:
show_in_footer: bool = True,
show_in_header: bool = False,
display_order: int = 0,
- created_by: Optional[int] = None
+ created_by: Optional[int] = None,
) -> ContentPage:
"""
Create a new content page.
@@ -229,7 +211,9 @@ class ContentPageService:
db.commit()
db.refresh(page)
- logger.info(f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})")
+ logger.info(
+ f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})"
+ )
return page
@staticmethod
@@ -246,7 +230,7 @@ class ContentPageService:
show_in_footer: Optional[bool] = None,
show_in_header: Optional[bool] = None,
display_order: Optional[int] = None,
- updated_by: Optional[int] = None
+ updated_by: Optional[int] = None,
) -> Optional[ContentPage]:
"""
Update an existing content page.
@@ -338,9 +322,7 @@ class ContentPageService:
@staticmethod
def list_all_vendor_pages(
- db: Session,
- vendor_id: int,
- include_unpublished: bool = False
+ db: Session, vendor_id: int, include_unpublished: bool = False
) -> List[ContentPage]:
"""
List only vendor-specific pages (no platform defaults).
@@ -367,8 +349,7 @@ class ContentPageService:
@staticmethod
def list_all_platform_pages(
- db: Session,
- include_unpublished: bool = False
+ db: Session, include_unpublished: bool = False
) -> List[ContentPage]:
"""
List only platform default pages.
diff --git a/app/services/customer_service.py b/app/services/customer_service.py
index 53e70ee6..d76078b6 100644
--- a/app/services/customer_service.py
+++ b/app/services/customer_service.py
@@ -8,24 +8,24 @@ with complete vendor isolation.
import logging
from datetime import datetime, timedelta
-from typing import Optional, Dict, Any
-from sqlalchemy.orm import Session
-from sqlalchemy import and_
+from typing import Any, Dict, Optional
+from sqlalchemy import and_
+from sqlalchemy.orm import Session
+
+from app.exceptions.customer import (CustomerAlreadyExistsException,
+ CustomerNotActiveException,
+ CustomerNotFoundException,
+ CustomerValidationException,
+ DuplicateCustomerEmailException,
+ InvalidCustomerCredentialsException)
+from app.exceptions.vendor import (VendorNotActiveException,
+ VendorNotFoundException)
+from app.services.auth_service import AuthService
from models.database.customer import Customer, CustomerAddress
from models.database.vendor import Vendor
-from models.schema.customer import CustomerRegister, CustomerUpdate
from models.schema.auth import UserLogin
-from app.exceptions.customer import (
- CustomerNotFoundException,
- CustomerAlreadyExistsException,
- CustomerNotActiveException,
- InvalidCustomerCredentialsException,
- CustomerValidationException,
- DuplicateCustomerEmailException
-)
-from app.exceptions.vendor import VendorNotFoundException, VendorNotActiveException
-from app.services.auth_service import AuthService
+from models.schema.customer import CustomerRegister, CustomerUpdate
logger = logging.getLogger(__name__)
@@ -37,10 +37,7 @@ class CustomerService:
self.auth_service = AuthService()
def register_customer(
- self,
- db: Session,
- vendor_id: int,
- customer_data: CustomerRegister
+ self, db: Session, vendor_id: int, customer_data: CustomerRegister
) -> Customer:
"""
Register a new customer for a specific vendor.
@@ -68,18 +65,26 @@ class CustomerService:
raise VendorNotActiveException(vendor.vendor_code)
# Check if email already exists for this vendor
- existing_customer = db.query(Customer).filter(
- and_(
- Customer.vendor_id == vendor_id,
- Customer.email == customer_data.email.lower()
+ existing_customer = (
+ db.query(Customer)
+ .filter(
+ and_(
+ Customer.vendor_id == vendor_id,
+ Customer.email == customer_data.email.lower(),
+ )
)
- ).first()
+ .first()
+ )
if existing_customer:
- raise DuplicateCustomerEmailException(customer_data.email, vendor.vendor_code)
+ raise DuplicateCustomerEmailException(
+ customer_data.email, vendor.vendor_code
+ )
# Generate unique customer number for this vendor
- customer_number = self._generate_customer_number(db, vendor_id, vendor.vendor_code)
+ customer_number = self._generate_customer_number(
+ db, vendor_id, vendor.vendor_code
+ )
# Hash password
hashed_password = self.auth_service.hash_password(customer_data.password)
@@ -93,8 +98,12 @@ class CustomerService:
last_name=customer_data.last_name,
phone=customer_data.phone,
customer_number=customer_number,
- marketing_consent=customer_data.marketing_consent if hasattr(customer_data, 'marketing_consent') else False,
- is_active=True
+ marketing_consent=(
+ customer_data.marketing_consent
+ if hasattr(customer_data, "marketing_consent")
+ else False
+ ),
+ is_active=True,
)
try:
@@ -114,15 +123,11 @@ class CustomerService:
db.rollback()
logger.error(f"Error registering customer: {str(e)}")
raise CustomerValidationException(
- message="Failed to register customer",
- details={"error": str(e)}
+ message="Failed to register customer", details={"error": str(e)}
)
def login_customer(
- self,
- db: Session,
- vendor_id: int,
- credentials: UserLogin
+ self, db: Session, vendor_id: int, credentials: UserLogin
) -> Dict[str, Any]:
"""
Authenticate customer and generate JWT token.
@@ -146,20 +151,23 @@ class CustomerService:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
# Find customer by email (vendor-scoped)
- customer = db.query(Customer).filter(
- and_(
- Customer.vendor_id == vendor_id,
- Customer.email == credentials.email_or_username.lower()
+ customer = (
+ db.query(Customer)
+ .filter(
+ and_(
+ Customer.vendor_id == vendor_id,
+ Customer.email == credentials.email_or_username.lower(),
+ )
)
- ).first()
+ .first()
+ )
if not customer:
raise InvalidCustomerCredentialsException()
# Verify password using auth_manager directly
if not self.auth_service.auth_manager.verify_password(
- credentials.password,
- customer.hashed_password
+ credentials.password, customer.hashed_password
):
raise InvalidCustomerCredentialsException()
@@ -170,6 +178,7 @@ class CustomerService:
# Generate JWT token with customer context
# Use auth_manager directly since Customer is not a User model
from datetime import datetime, timedelta, timezone
+
from jose import jwt
auth_manager = self.auth_service.auth_manager
@@ -185,7 +194,9 @@ class CustomerService:
"iat": datetime.now(timezone.utc),
}
- token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
+ token = jwt.encode(
+ payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
+ )
token_data = {
"access_token": token,
@@ -198,17 +209,9 @@ class CustomerService:
f"for vendor {vendor.vendor_code}"
)
- return {
- "customer": customer,
- "token_data": token_data
- }
+ return {"customer": customer, "token_data": token_data}
- def get_customer(
- self,
- db: Session,
- vendor_id: int,
- customer_id: int
- ) -> Customer:
+ def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
"""
Get customer by ID with vendor isolation.
@@ -223,12 +226,11 @@ class CustomerService:
Raises:
CustomerNotFoundException: If customer not found
"""
- customer = db.query(Customer).filter(
- and_(
- Customer.id == customer_id,
- Customer.vendor_id == vendor_id
- )
- ).first()
+ customer = (
+ db.query(Customer)
+ .filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
+ .first()
+ )
if not customer:
raise CustomerNotFoundException(str(customer_id))
@@ -236,10 +238,7 @@ class CustomerService:
return customer
def get_customer_by_email(
- self,
- db: Session,
- vendor_id: int,
- email: str
+ self, db: Session, vendor_id: int, email: str
) -> Optional[Customer]:
"""
Get customer by email (vendor-scoped).
@@ -252,19 +251,20 @@ class CustomerService:
Returns:
Optional[Customer]: Customer object or None
"""
- return db.query(Customer).filter(
- and_(
- Customer.vendor_id == vendor_id,
- Customer.email == email.lower()
+ return (
+ db.query(Customer)
+ .filter(
+ and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
)
- ).first()
+ .first()
+ )
def update_customer(
- self,
- db: Session,
- vendor_id: int,
- customer_id: int,
- customer_data: CustomerUpdate
+ self,
+ db: Session,
+ vendor_id: int,
+ customer_id: int,
+ customer_data: CustomerUpdate,
) -> Customer:
"""
Update customer profile.
@@ -290,13 +290,17 @@ class CustomerService:
for field, value in update_data.items():
if field == "email" and value:
# Check if new email already exists for this vendor
- existing = db.query(Customer).filter(
- and_(
- Customer.vendor_id == vendor_id,
- Customer.email == value.lower(),
- Customer.id != customer_id
+ existing = (
+ db.query(Customer)
+ .filter(
+ and_(
+ Customer.vendor_id == vendor_id,
+ Customer.email == value.lower(),
+ Customer.id != customer_id,
+ )
)
- ).first()
+ .first()
+ )
if existing:
raise DuplicateCustomerEmailException(value, "vendor")
@@ -317,15 +321,11 @@ class CustomerService:
db.rollback()
logger.error(f"Error updating customer: {str(e)}")
raise CustomerValidationException(
- message="Failed to update customer",
- details={"error": str(e)}
+ message="Failed to update customer", details={"error": str(e)}
)
def deactivate_customer(
- self,
- db: Session,
- vendor_id: int,
- customer_id: int
+ self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Deactivate customer account.
@@ -352,10 +352,7 @@ class CustomerService:
return customer
def update_customer_stats(
- self,
- db: Session,
- customer_id: int,
- order_total: float
+ self, db: Session, customer_id: int, order_total: float
) -> None:
"""
Update customer statistics after order.
@@ -377,10 +374,7 @@ class CustomerService:
logger.debug(f"Updated stats for customer {customer.email}")
def _generate_customer_number(
- self,
- db: Session,
- vendor_id: int,
- vendor_code: str
+ self, db: Session, vendor_id: int, vendor_code: str
) -> str:
"""
Generate unique customer number for vendor.
@@ -397,21 +391,23 @@ class CustomerService:
str: Unique customer number
"""
# Get count of customers for this vendor
- count = db.query(Customer).filter(
- Customer.vendor_id == vendor_id
- ).count()
+ count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
# Generate number with padding
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
# Ensure uniqueness (in case of deletions)
- while db.query(Customer).filter(
+ while (
+ db.query(Customer)
+ .filter(
and_(
Customer.vendor_id == vendor_id,
- Customer.customer_number == customer_number
+ Customer.customer_number == customer_number,
)
- ).first():
+ )
+ .first()
+ ):
count += 1
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
diff --git a/app/services/inventory_service.py b/app/services/inventory_service.py
index 79d0a673..b7f945d2 100644
--- a/app/services/inventory_service.py
+++ b/app/services/inventory_service.py
@@ -5,27 +5,20 @@ from typing import List, Optional
from sqlalchemy.orm import Session
-from app.exceptions import (
- InventoryNotFoundException,
- InsufficientInventoryException,
- InvalidInventoryOperationException,
- InventoryValidationException,
- NegativeInventoryException,
- InvalidQuantityException,
- ValidationException,
- ProductNotFoundException,
-)
-from models.schema.inventory import (
- InventoryCreate,
- InventoryAdjust,
- InventoryUpdate,
- InventoryReserve,
- InventoryLocationResponse,
- ProductInventorySummary
-)
+from app.exceptions import (InsufficientInventoryException,
+ InvalidInventoryOperationException,
+ InvalidQuantityException,
+ InventoryNotFoundException,
+ InventoryValidationException,
+ NegativeInventoryException,
+ ProductNotFoundException, ValidationException)
from models.database.inventory import Inventory
from models.database.product import Product
from models.database.vendor import Vendor
+from models.schema.inventory import (InventoryAdjust, InventoryCreate,
+ InventoryLocationResponse,
+ InventoryReserve, InventoryUpdate,
+ ProductInventorySummary)
logger = logging.getLogger(__name__)
@@ -34,7 +27,7 @@ class InventoryService:
"""Service for inventory operations with vendor isolation."""
def set_inventory(
- self, db: Session, vendor_id: int, inventory_data: InventoryCreate
+ self, db: Session, vendor_id: int, inventory_data: InventoryCreate
) -> Inventory:
"""
Set exact inventory quantity for a product at a location (replaces existing).
@@ -93,7 +86,11 @@ class InventoryService:
)
return new_inventory
- except (ProductNotFoundException, InvalidQuantityException, InventoryValidationException):
+ except (
+ ProductNotFoundException,
+ InvalidQuantityException,
+ InventoryValidationException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -102,7 +99,7 @@ class InventoryService:
raise ValidationException("Failed to set inventory")
def adjust_inventory(
- self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
+ self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
) -> Inventory:
"""
Adjust inventory by adding or removing quantity.
@@ -124,7 +121,9 @@ class InventoryService:
location = self._validate_location(inventory_data.location)
# Check if inventory exists
- existing = self._get_inventory_entry(db, inventory_data.product_id, location)
+ existing = self._get_inventory_entry(
+ db, inventory_data.product_id, location
+ )
if not existing:
# Create new if adding, error if removing
@@ -173,8 +172,12 @@ class InventoryService:
)
return existing
- except (ProductNotFoundException, InventoryNotFoundException,
- InsufficientInventoryException, InventoryValidationException):
+ except (
+ ProductNotFoundException,
+ InventoryNotFoundException,
+ InsufficientInventoryException,
+ InventoryValidationException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -183,7 +186,7 @@ class InventoryService:
raise ValidationException("Failed to adjust inventory")
def reserve_inventory(
- self, db: Session, vendor_id: int, reserve_data: InventoryReserve
+ self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Reserve inventory for an order (increases reserved_quantity).
@@ -231,8 +234,12 @@ class InventoryService:
)
return inventory
- except (ProductNotFoundException, InventoryNotFoundException,
- InsufficientInventoryException, InvalidQuantityException):
+ except (
+ ProductNotFoundException,
+ InventoryNotFoundException,
+ InsufficientInventoryException,
+ InvalidQuantityException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -241,7 +248,7 @@ class InventoryService:
raise ValidationException("Failed to reserve inventory")
def release_reservation(
- self, db: Session, vendor_id: int, reserve_data: InventoryReserve
+ self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Release reserved inventory (decreases reserved_quantity).
@@ -287,7 +294,11 @@ class InventoryService:
)
return inventory
- except (ProductNotFoundException, InventoryNotFoundException, InvalidQuantityException):
+ except (
+ ProductNotFoundException,
+ InventoryNotFoundException,
+ InvalidQuantityException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -296,7 +307,7 @@ class InventoryService:
raise ValidationException("Failed to release reservation")
def fulfill_reservation(
- self, db: Session, vendor_id: int, reserve_data: InventoryReserve
+ self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Fulfill a reservation (decreases both quantity and reserved_quantity).
@@ -349,8 +360,12 @@ class InventoryService:
)
return inventory
- except (ProductNotFoundException, InventoryNotFoundException,
- InsufficientInventoryException, InvalidQuantityException):
+ except (
+ ProductNotFoundException,
+ InventoryNotFoundException,
+ InsufficientInventoryException,
+ InvalidQuantityException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -359,7 +374,7 @@ class InventoryService:
raise ValidationException("Failed to fulfill reservation")
def get_product_inventory(
- self, db: Session, vendor_id: int, product_id: int
+ self, db: Session, vendor_id: int, product_id: int
) -> ProductInventorySummary:
"""
Get inventory summary for a product across all locations.
@@ -376,9 +391,7 @@ class InventoryService:
product = self._get_vendor_product(db, vendor_id, product_id)
inventory_entries = (
- db.query(Inventory)
- .filter(Inventory.product_id == product_id)
- .all()
+ db.query(Inventory).filter(Inventory.product_id == product_id).all()
)
if not inventory_entries:
@@ -425,8 +438,13 @@ class InventoryService:
raise ValidationException("Failed to retrieve product inventory")
def get_vendor_inventory(
- self, db: Session, vendor_id: int, skip: int = 0, limit: int = 100,
- location: Optional[str] = None, low_stock_threshold: Optional[int] = None
+ self,
+ db: Session,
+ vendor_id: int,
+ skip: int = 0,
+ limit: int = 100,
+ location: Optional[str] = None,
+ low_stock_threshold: Optional[int] = None,
) -> List[Inventory]:
"""
Get all inventory for a vendor with filtering.
@@ -458,8 +476,11 @@ class InventoryService:
raise ValidationException("Failed to retrieve vendor inventory")
def update_inventory(
- self, db: Session, vendor_id: int, inventory_id: int,
- inventory_update: InventoryUpdate
+ self,
+ db: Session,
+ vendor_id: int,
+ inventory_id: int,
+ inventory_update: InventoryUpdate,
) -> Inventory:
"""Update inventory entry."""
try:
@@ -475,7 +496,9 @@ class InventoryService:
inventory.quantity = inventory_update.quantity
if inventory_update.reserved_quantity is not None:
- self._validate_quantity(inventory_update.reserved_quantity, allow_zero=True)
+ self._validate_quantity(
+ inventory_update.reserved_quantity, allow_zero=True
+ )
inventory.reserved_quantity = inventory_update.reserved_quantity
if inventory_update.location:
@@ -488,7 +511,11 @@ class InventoryService:
logger.info(f"Updated inventory {inventory_id}")
return inventory
- except (InventoryNotFoundException, InvalidQuantityException, InventoryValidationException):
+ except (
+ InventoryNotFoundException,
+ InvalidQuantityException,
+ InventoryValidationException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -496,9 +523,7 @@ class InventoryService:
logger.error(f"Error updating inventory: {str(e)}")
raise ValidationException("Failed to update inventory")
- def delete_inventory(
- self, db: Session, vendor_id: int, inventory_id: int
- ) -> bool:
+ def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
"""Delete inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
@@ -521,28 +546,30 @@ class InventoryService:
raise ValidationException("Failed to delete inventory")
# Private helper methods
- def _get_vendor_product(self, db: Session, vendor_id: int, product_id: int) -> Product:
+ def _get_vendor_product(
+ self, db: Session, vendor_id: int, product_id: int
+ ) -> Product:
"""Get product and verify it belongs to vendor."""
- product = db.query(Product).filter(
- Product.id == product_id,
- Product.vendor_id == vendor_id
- ).first()
+ product = (
+ db.query(Product)
+ .filter(Product.id == product_id, Product.vendor_id == vendor_id)
+ .first()
+ )
if not product:
- raise ProductNotFoundException(f"Product {product_id} not found in your catalog")
+ raise ProductNotFoundException(
+ f"Product {product_id} not found in your catalog"
+ )
return product
def _get_inventory_entry(
- self, db: Session, product_id: int, location: str
+ self, db: Session, product_id: int, location: str
) -> Optional[Inventory]:
"""Get inventory entry by product and location."""
return (
db.query(Inventory)
- .filter(
- Inventory.product_id == product_id,
- Inventory.location == location
- )
+ .filter(Inventory.product_id == product_id, Inventory.location == location)
.first()
)
diff --git a/app/services/marketplace_import_job_service.py b/app/services/marketplace_import_job_service.py
index b6a696d1..2007cd2b 100644
--- a/app/services/marketplace_import_job_service.py
+++ b/app/services/marketplace_import_job_service.py
@@ -5,20 +5,15 @@ from typing import List, Optional
from sqlalchemy.orm import Session
-from app.exceptions import (
- ImportJobNotFoundException,
- ImportJobNotOwnedException,
- ImportJobCannotBeCancelledException,
- ImportJobCannotBeDeletedException,
- ValidationException,
-)
-from models.schema.marketplace_import_job import (
- MarketplaceImportJobResponse,
- MarketplaceImportJobRequest
-)
+from app.exceptions import (ImportJobCannotBeCancelledException,
+ ImportJobCannotBeDeletedException,
+ ImportJobNotFoundException,
+ ImportJobNotOwnedException, ValidationException)
from models.database.marketplace_import_job import MarketplaceImportJob
-from models.database.vendor import Vendor
from models.database.user import User
+from models.database.vendor import Vendor
+from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
+ MarketplaceImportJobResponse)
logger = logging.getLogger(__name__)
@@ -31,7 +26,7 @@ class MarketplaceImportJobService:
db: Session,
request: MarketplaceImportJobRequest,
vendor: Vendor, # CHANGED: Vendor object from middleware
- user: User
+ user: User,
) -> MarketplaceImportJob:
"""
Create a new marketplace import job.
@@ -147,7 +142,9 @@ class MarketplaceImportJobService:
marketplace=job.marketplace,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None, # FIXED
- vendor_name=job.vendor.name if job.vendor else None, # FIXED: from relationship
+ vendor_name=(
+ job.vendor.name if job.vendor else None
+ ), # FIXED: from relationship
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
diff --git a/app/services/marketplace_product_service.py b/app/services/marketplace_product_service.py
index 7c220160..fd572b32 100644
--- a/app/services/marketplace_product_service.py
+++ b/app/services/marketplace_product_service.py
@@ -17,19 +17,20 @@ from typing import Generator, List, Optional, Tuple
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
-from app.exceptions import (
- MarketplaceProductNotFoundException,
- MarketplaceProductAlreadyExistsException,
- InvalidMarketplaceProductDataException,
- MarketplaceProductValidationException,
- ValidationException,
-)
-from app.services.marketplace_import_job_service import marketplace_import_job_service
-from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
-from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse
-from models.database.marketplace_product import MarketplaceProduct
-from models.database.inventory import Inventory
+from app.exceptions import (InvalidMarketplaceProductDataException,
+ MarketplaceProductAlreadyExistsException,
+ MarketplaceProductNotFoundException,
+ MarketplaceProductValidationException,
+ ValidationException)
+from app.services.marketplace_import_job_service import \
+ marketplace_import_job_service
from app.utils.data_processing import GTINProcessor, PriceProcessor
+from models.database.inventory import Inventory
+from models.database.marketplace_product import MarketplaceProduct
+from models.schema.inventory import (InventoryLocationResponse,
+ InventorySummaryResponse)
+from models.schema.marketplace_product import (MarketplaceProductCreate,
+ MarketplaceProductUpdate)
logger = logging.getLogger(__name__)
@@ -42,14 +43,18 @@ class MarketplaceProductService:
self.gtin_processor = GTINProcessor()
self.price_processor = PriceProcessor()
- def create_product(self, db: Session, product_data: MarketplaceProductCreate) -> MarketplaceProduct:
+ def create_product(
+ self, db: Session, product_data: MarketplaceProductCreate
+ ) -> MarketplaceProduct:
"""Create a new product with validation."""
try:
# Process and validate GTIN if provided
if product_data.gtin:
normalized_gtin = self.gtin_processor.normalize(product_data.gtin)
if not normalized_gtin:
- raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin")
+ raise InvalidMarketplaceProductDataException(
+ "Invalid GTIN format", field="gtin"
+ )
product_data.gtin = normalized_gtin
# Process price if provided
@@ -70,11 +75,18 @@ class MarketplaceProductService:
product_data.marketplace = "Letzshop"
# Validate required fields
- if not product_data.marketplace_product_id or not product_data.marketplace_product_id.strip():
- raise MarketplaceProductValidationException("MarketplaceProduct ID is required", field="marketplace_product_id")
+ if (
+ not product_data.marketplace_product_id
+ or not product_data.marketplace_product_id.strip()
+ ):
+ raise MarketplaceProductValidationException(
+ "MarketplaceProduct ID is required", field="marketplace_product_id"
+ )
if not product_data.title or not product_data.title.strip():
- raise MarketplaceProductValidationException("MarketplaceProduct title is required", field="title")
+ raise MarketplaceProductValidationException(
+ "MarketplaceProduct title is required", field="title"
+ )
db_product = MarketplaceProduct(**product_data.model_dump())
db.add(db_product)
@@ -84,30 +96,47 @@ class MarketplaceProductService:
logger.info(f"Created product {db_product.marketplace_product_id}")
return db_product
- except (InvalidMarketplaceProductDataException, MarketplaceProductValidationException):
+ except (
+ InvalidMarketplaceProductDataException,
+ MarketplaceProductValidationException,
+ ):
db.rollback()
raise # Re-raise custom exceptions
except IntegrityError as e:
db.rollback()
logger.error(f"Database integrity error: {str(e)}")
if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower():
- raise MarketplaceProductAlreadyExistsException(product_data.marketplace_product_id)
+ raise MarketplaceProductAlreadyExistsException(
+ product_data.marketplace_product_id
+ )
else:
- raise MarketplaceProductValidationException("Data integrity constraint violation")
+ raise MarketplaceProductValidationException(
+ "Data integrity constraint violation"
+ )
except Exception as e:
db.rollback()
logger.error(f"Error creating product: {str(e)}")
raise ValidationException("Failed to create product")
- def get_product_by_id(self, db: Session, marketplace_product_id: str) -> Optional[MarketplaceProduct]:
+ def get_product_by_id(
+ self, db: Session, marketplace_product_id: str
+ ) -> Optional[MarketplaceProduct]:
"""Get a product by its ID."""
try:
- return db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
+ return (
+ db.query(MarketplaceProduct)
+ .filter(
+ MarketplaceProduct.marketplace_product_id == marketplace_product_id
+ )
+ .first()
+ )
except Exception as e:
logger.error(f"Error getting product {marketplace_product_id}: {str(e)}")
return None
- def get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct:
+ def get_product_by_id_or_raise(
+ self, db: Session, marketplace_product_id: str
+ ) -> MarketplaceProduct:
"""
Get a product by its ID or raise exception.
@@ -127,16 +156,16 @@ class MarketplaceProductService:
return product
def get_products_with_filters(
- self,
- db: Session,
- skip: int = 0,
- limit: int = 100,
- brand: Optional[str] = None,
- category: Optional[str] = None,
- availability: Optional[str] = None,
- marketplace: Optional[str] = None,
- vendor_name: Optional[str] = None,
- search: Optional[str] = None,
+ self,
+ db: Session,
+ skip: int = 0,
+ limit: int = 100,
+ brand: Optional[str] = None,
+ category: Optional[str] = None,
+ availability: Optional[str] = None,
+ marketplace: Optional[str] = None,
+ vendor_name: Optional[str] = None,
+ search: Optional[str] = None,
) -> Tuple[List[MarketplaceProduct], int]:
"""
Get products with filtering and pagination.
@@ -162,13 +191,19 @@ class MarketplaceProductService:
if brand:
query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%"))
if category:
- query = query.filter(MarketplaceProduct.google_product_category.ilike(f"%{category}%"))
+ query = query.filter(
+ MarketplaceProduct.google_product_category.ilike(f"%{category}%")
+ )
if availability:
query = query.filter(MarketplaceProduct.availability == availability)
if marketplace:
- query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
+ query = query.filter(
+ MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
+ )
if vendor_name:
- query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
+ query = query.filter(
+ MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
+ )
if search:
# Search in title, description, marketplace, and name
search_term = f"%{search}%"
@@ -188,7 +223,12 @@ class MarketplaceProductService:
logger.error(f"Error getting products with filters: {str(e)}")
raise ValidationException("Failed to retrieve products")
- def update_product(self, db: Session, marketplace_product_id: str, product_update: MarketplaceProductUpdate) -> MarketplaceProduct:
+ def update_product(
+ self,
+ db: Session,
+ marketplace_product_id: str,
+ product_update: MarketplaceProductUpdate,
+ ) -> MarketplaceProduct:
"""Update product with validation."""
try:
product = self.get_product_by_id_or_raise(db, marketplace_product_id)
@@ -200,7 +240,9 @@ class MarketplaceProductService:
if "gtin" in update_data and update_data["gtin"]:
normalized_gtin = self.gtin_processor.normalize(update_data["gtin"])
if not normalized_gtin:
- raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin")
+ raise InvalidMarketplaceProductDataException(
+ "Invalid GTIN format", field="gtin"
+ )
update_data["gtin"] = normalized_gtin
# Process price if being updated
@@ -217,8 +259,12 @@ class MarketplaceProductService:
raise InvalidMarketplaceProductDataException(str(e), field="price")
# Validate required fields if being updated
- if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()):
- raise MarketplaceProductValidationException("MarketplaceProduct title cannot be empty", field="title")
+ if "title" in update_data and (
+ not update_data["title"] or not update_data["title"].strip()
+ ):
+ raise MarketplaceProductValidationException(
+ "MarketplaceProduct title cannot be empty", field="title"
+ )
for key, value in update_data.items():
setattr(product, key, value)
@@ -230,7 +276,11 @@ class MarketplaceProductService:
logger.info(f"Updated product {marketplace_product_id}")
return product
- except (MarketplaceProductNotFoundException, InvalidMarketplaceProductDataException, MarketplaceProductValidationException):
+ except (
+ MarketplaceProductNotFoundException,
+ InvalidMarketplaceProductDataException,
+ MarketplaceProductValidationException,
+ ):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -272,7 +322,9 @@ class MarketplaceProductService:
logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}")
raise ValidationException("Failed to delete product")
- def get_inventory_info(self, db: Session, gtin: str) -> Optional[InventorySummaryResponse]:
+ def get_inventory_info(
+ self, db: Session, gtin: str
+ ) -> Optional[InventorySummaryResponse]:
"""
Get inventory information for a product by GTIN.
@@ -290,7 +342,9 @@ class MarketplaceProductService:
total_quantity = sum(entry.quantity for entry in inventory_entries)
locations = [
- InventoryLocationResponse(location=entry.location, quantity=entry.quantity)
+ InventoryLocationResponse(
+ location=entry.location, quantity=entry.quantity
+ )
for entry in inventory_entries
]
@@ -305,13 +359,14 @@ class MarketplaceProductService:
import csv
from io import StringIO
from typing import Generator, Optional
+
from sqlalchemy.orm import Session
def generate_csv_export(
- self,
- db: Session,
- marketplace: Optional[str] = None,
- vendor_name: Optional[str] = None,
+ self,
+ db: Session,
+ marketplace: Optional[str] = None,
+ vendor_name: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generate CSV export with streaming for memory efficiency and proper CSV escaping.
@@ -331,9 +386,18 @@ class MarketplaceProductService:
# Write header row
headers = [
- "marketplace_product_id", "title", "description", "link", "image_link",
- "availability", "price", "currency", "brand", "gtin",
- "marketplace", "name"
+ "marketplace_product_id",
+ "title",
+ "description",
+ "link",
+ "image_link",
+ "availability",
+ "price",
+ "currency",
+ "brand",
+ "gtin",
+ "marketplace",
+ "name",
]
writer.writerow(headers)
yield output.getvalue()
@@ -350,9 +414,13 @@ class MarketplaceProductService:
# Apply marketplace filters
if marketplace:
- query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
+ query = query.filter(
+ MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
+ )
if vendor_name:
- query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
+ query = query.filter(
+ MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
+ )
products = query.offset(offset).limit(batch_size).all()
if not products:
@@ -392,8 +460,12 @@ class MarketplaceProductService:
"""Check if product exists by ID."""
try:
return (
- db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
- is not None
+ db.query(MarketplaceProduct)
+ .filter(
+ MarketplaceProduct.marketplace_product_id == marketplace_product_id
+ )
+ .first()
+ is not None
)
except Exception as e:
logger.error(f"Error checking if product exists: {str(e)}")
@@ -402,18 +474,27 @@ class MarketplaceProductService:
# Private helper methods
def _validate_product_data(self, product_data: dict) -> None:
"""Validate product data structure."""
- required_fields = ['marketplace_product_id', 'title']
+ required_fields = ["marketplace_product_id", "title"]
for field in required_fields:
if field not in product_data or not product_data[field]:
- raise MarketplaceProductValidationException(f"{field} is required", field=field)
+ raise MarketplaceProductValidationException(
+ f"{field} is required", field=field
+ )
def _normalize_product_data(self, product_data: dict) -> dict:
"""Normalize and clean product data."""
normalized = product_data.copy()
# Trim whitespace from string fields
- string_fields = ['marketplace_product_id', 'title', 'description', 'brand', 'marketplace', 'name']
+ string_fields = [
+ "marketplace_product_id",
+ "title",
+ "description",
+ "brand",
+ "marketplace",
+ "name",
+ ]
for field in string_fields:
if field in normalized and normalized[field]:
normalized[field] = normalized[field].strip()
diff --git a/app/services/media_service.py b/app/services/media_service.py
index 7d008007..0c726681 100644
--- a/app/services/media_service.py
+++ b/app/services/media_service.py
@@ -1 +1 @@
-# File and media management services
+# File and media management services
diff --git a/app/services/monitoring_service.py b/app/services/monitoring_service.py
index f1d92af4..9d5ec783 100644
--- a/app/services/monitoring_service.py
+++ b/app/services/monitoring_service.py
@@ -1 +1 @@
-# Application monitoring services
+# Application monitoring services
diff --git a/app/services/notification_service.py b/app/services/notification_service.py
index c25cb0d9..4eaa423b 100644
--- a/app/services/notification_service.py
+++ b/app/services/notification_service.py
@@ -1 +1 @@
-# Email/notification services
+# Email/notification services
diff --git a/app/services/order_service.py b/app/services/order_service.py
index b4606c43..9e1005fd 100644
--- a/app/services/order_service.py
+++ b/app/services/order_service.py
@@ -9,24 +9,21 @@ This module provides:
"""
import logging
-from datetime import datetime, timezone
-from typing import List, Optional, Tuple
import random
import string
+from datetime import datetime, timezone
+from typing import List, Optional, Tuple
-from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
+from sqlalchemy.orm import Session
-from models.database.order import Order, OrderItem
+from app.exceptions import (CustomerNotFoundException,
+ InsufficientInventoryException,
+ OrderNotFoundException, ValidationException)
from models.database.customer import Customer, CustomerAddress
+from models.database.order import Order, OrderItem
from models.database.product import Product
-from models.schema.order import OrderCreate, OrderUpdate, OrderAddressCreate
-from app.exceptions import (
- OrderNotFoundException,
- ValidationException,
- InsufficientInventoryException,
- CustomerNotFoundException
-)
+from models.schema.order import OrderAddressCreate, OrderCreate, OrderUpdate
logger = logging.getLogger(__name__)
@@ -42,23 +39,27 @@ class OrderService:
Example: ORD-1-20250110-A1B2C3
"""
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
- random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
+ random_suffix = "".join(
+ random.choices(string.ascii_uppercase + string.digits, k=6)
+ )
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
# Ensure uniqueness
while db.query(Order).filter(Order.order_number == order_number).first():
- random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
+ random_suffix = "".join(
+ random.choices(string.ascii_uppercase + string.digits, k=6)
+ )
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
return order_number
def _create_customer_address(
- self,
- db: Session,
- vendor_id: int,
- customer_id: int,
- address_data: OrderAddressCreate,
- address_type: str
+ self,
+ db: Session,
+ vendor_id: int,
+ customer_id: int,
+ address_data: OrderAddressCreate,
+ address_type: str,
) -> CustomerAddress:
"""Create a customer address for order."""
address = CustomerAddress(
@@ -73,17 +74,14 @@ class OrderService:
city=address_data.city,
postal_code=address_data.postal_code,
country=address_data.country,
- is_default=False
+ is_default=False,
)
db.add(address)
db.flush() # Get ID without committing
return address
def create_order(
- self,
- db: Session,
- vendor_id: int,
- order_data: OrderCreate
+ self, db: Session, vendor_id: int, order_data: OrderCreate
) -> Order:
"""
Create a new order.
@@ -104,12 +102,15 @@ class OrderService:
# Validate customer exists if provided
customer_id = order_data.customer_id
if customer_id:
- customer = db.query(Customer).filter(
- and_(
- Customer.id == customer_id,
- Customer.vendor_id == vendor_id
+ customer = (
+ db.query(Customer)
+ .filter(
+ and_(
+ Customer.id == customer_id, Customer.vendor_id == vendor_id
+ )
)
- ).first()
+ .first()
+ )
if not customer:
raise CustomerNotFoundException(str(customer_id))
@@ -124,7 +125,7 @@ class OrderService:
vendor_id=vendor_id,
customer_id=customer_id,
address_data=order_data.shipping_address,
- address_type="shipping"
+ address_type="shipping",
)
# Create billing address (use shipping if not provided)
@@ -134,7 +135,7 @@ class OrderService:
vendor_id=vendor_id,
customer_id=customer_id,
address_data=order_data.billing_address,
- address_type="billing"
+ address_type="billing",
)
else:
billing_address = shipping_address
@@ -145,23 +146,29 @@ class OrderService:
for item_data in order_data.items:
# Get product
- product = db.query(Product).filter(
- and_(
- Product.id == item_data.product_id,
- Product.vendor_id == vendor_id,
- Product.is_active == True
+ product = (
+ db.query(Product)
+ .filter(
+ and_(
+ Product.id == item_data.product_id,
+ Product.vendor_id == vendor_id,
+ Product.is_active == True,
+ )
)
- ).first()
+ .first()
+ )
if not product:
- raise ValidationException(f"Product {item_data.product_id} not found")
+ raise ValidationException(
+ f"Product {item_data.product_id} not found"
+ )
# Check inventory
if product.available_inventory < item_data.quantity:
raise InsufficientInventoryException(
product_id=product.id,
requested=item_data.quantity,
- available=product.available_inventory
+ available=product.available_inventory,
)
# Calculate item total
@@ -172,14 +179,16 @@ class OrderService:
item_total = unit_price * item_data.quantity
subtotal += item_total
- order_items_data.append({
- "product_id": product.id,
- "product_name": product.marketplace_product.title,
- "product_sku": product.product_id,
- "quantity": item_data.quantity,
- "unit_price": unit_price,
- "total_price": item_total
- })
+ order_items_data.append(
+ {
+ "product_id": product.id,
+ "product_name": product.marketplace_product.title,
+ "product_sku": product.product_id,
+ "quantity": item_data.quantity,
+ "unit_price": unit_price,
+ "total_price": item_total,
+ }
+ )
# Calculate tax and shipping (simple implementation)
tax_amount = 0.0 # TODO: Implement tax calculation
@@ -205,7 +214,7 @@ class OrderService:
shipping_address_id=shipping_address.id,
billing_address_id=billing_address.id,
shipping_method=order_data.shipping_method,
- customer_notes=order_data.customer_notes
+ customer_notes=order_data.customer_notes,
)
db.add(order)
@@ -213,10 +222,7 @@ class OrderService:
# Create order items
for item_data in order_items_data:
- order_item = OrderItem(
- order_id=order.id,
- **item_data
- )
+ order_item = OrderItem(order_id=order.id, **item_data)
db.add(order_item)
db.commit()
@@ -229,7 +235,11 @@ class OrderService:
return order
- except (ValidationException, InsufficientInventoryException, CustomerNotFoundException):
+ except (
+ ValidationException,
+ InsufficientInventoryException,
+ CustomerNotFoundException,
+ ):
db.rollback()
raise
except Exception as e:
@@ -237,19 +247,13 @@ class OrderService:
logger.error(f"Error creating order: {str(e)}")
raise ValidationException(f"Failed to create order: {str(e)}")
- def get_order(
- self,
- db: Session,
- vendor_id: int,
- order_id: int
- ) -> Order:
+ def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order:
"""Get order by ID."""
- order = db.query(Order).filter(
- and_(
- Order.id == order_id,
- Order.vendor_id == vendor_id
- )
- ).first()
+ order = (
+ db.query(Order)
+ .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
+ .first()
+ )
if not order:
raise OrderNotFoundException(str(order_id))
@@ -257,13 +261,13 @@ class OrderService:
return order
def get_vendor_orders(
- self,
- db: Session,
- vendor_id: int,
- skip: int = 0,
- limit: int = 100,
- status: Optional[str] = None,
- customer_id: Optional[int] = None
+ self,
+ db: Session,
+ vendor_id: int,
+ skip: int = 0,
+ limit: int = 100,
+ status: Optional[str] = None,
+ customer_id: Optional[int] = None,
) -> Tuple[List[Order], int]:
"""
Get orders for vendor with filtering.
@@ -296,28 +300,20 @@ class OrderService:
return orders, total
def get_customer_orders(
- self,
- db: Session,
- vendor_id: int,
- customer_id: int,
- skip: int = 0,
- limit: int = 100
+ self,
+ db: Session,
+ vendor_id: int,
+ customer_id: int,
+ skip: int = 0,
+ limit: int = 100,
) -> Tuple[List[Order], int]:
"""Get orders for a specific customer."""
return self.get_vendor_orders(
- db=db,
- vendor_id=vendor_id,
- skip=skip,
- limit=limit,
- customer_id=customer_id
+ db=db, vendor_id=vendor_id, skip=skip, limit=limit, customer_id=customer_id
)
def update_order_status(
- self,
- db: Session,
- vendor_id: int,
- order_id: int,
- order_update: OrderUpdate
+ self, db: Session, vendor_id: int, order_id: int, order_update: OrderUpdate
) -> Order:
"""
Update order status and tracking information.
diff --git a/app/services/payment_service.py b/app/services/payment_service.py
index ec9ec61e..5aef72d7 100644
--- a/app/services/payment_service.py
+++ b/app/services/payment_service.py
@@ -1 +1 @@
-# Payment processing services
+# Payment processing services
diff --git a/app/services/product_service.py b/app/services/product_service.py
index 52db4e09..da7f46ba 100644
--- a/app/services/product_service.py
+++ b/app/services/product_service.py
@@ -14,14 +14,11 @@ from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
-from app.exceptions import (
- ProductNotFoundException,
- ProductAlreadyExistsException,
- ValidationException,
-)
-from models.schema.product import ProductCreate, ProductUpdate
-from models.database.product import Product
+from app.exceptions import (ProductAlreadyExistsException,
+ ProductNotFoundException, ValidationException)
from models.database.marketplace_product import MarketplaceProduct
+from models.database.product import Product
+from models.schema.product import ProductCreate, ProductUpdate
logger = logging.getLogger(__name__)
@@ -45,10 +42,11 @@ class ProductService:
ProductNotFoundException: If product not found
"""
try:
- product = db.query(Product).filter(
- Product.id == product_id,
- Product.vendor_id == vendor_id
- ).first()
+ product = (
+ db.query(Product)
+ .filter(Product.id == product_id, Product.vendor_id == vendor_id)
+ .first()
+ )
if not product:
raise ProductNotFoundException(f"Product {product_id} not found")
@@ -62,7 +60,7 @@ class ProductService:
raise ValidationException("Failed to retrieve product")
def create_product(
- self, db: Session, vendor_id: int, product_data: ProductCreate
+ self, db: Session, vendor_id: int, product_data: ProductCreate
) -> Product:
"""
Add a product from marketplace to vendor catalog.
@@ -81,10 +79,14 @@ class ProductService:
"""
try:
# Verify marketplace product exists and belongs to vendor
- marketplace_product = db.query(MarketplaceProduct).filter(
- MarketplaceProduct.id == product_data.marketplace_product_id,
- MarketplaceProduct.vendor_id == vendor_id
- ).first()
+ marketplace_product = (
+ db.query(MarketplaceProduct)
+ .filter(
+ MarketplaceProduct.id == product_data.marketplace_product_id,
+ MarketplaceProduct.vendor_id == vendor_id,
+ )
+ .first()
+ )
if not marketplace_product:
raise ValidationException(
@@ -92,10 +94,15 @@ class ProductService:
)
# Check if already in catalog
- existing = db.query(Product).filter(
- Product.vendor_id == vendor_id,
- Product.marketplace_product_id == product_data.marketplace_product_id
- ).first()
+ existing = (
+ db.query(Product)
+ .filter(
+ Product.vendor_id == vendor_id,
+ Product.marketplace_product_id
+ == product_data.marketplace_product_id,
+ )
+ .first()
+ )
if existing:
raise ProductAlreadyExistsException(
@@ -122,9 +129,7 @@ class ProductService:
db.commit()
db.refresh(product)
- logger.info(
- f"Added product {product.id} to vendor {vendor_id} catalog"
- )
+ logger.info(f"Added product {product.id} to vendor {vendor_id} catalog")
return product
except (ProductAlreadyExistsException, ValidationException):
@@ -136,7 +141,11 @@ class ProductService:
raise ValidationException("Failed to create product")
def update_product(
- self, db: Session, vendor_id: int, product_id: int, product_update: ProductUpdate
+ self,
+ db: Session,
+ vendor_id: int,
+ product_id: int,
+ product_update: ProductUpdate,
) -> Product:
"""
Update product in vendor catalog.
@@ -202,13 +211,13 @@ class ProductService:
raise ValidationException("Failed to delete product")
def get_vendor_products(
- self,
- db: Session,
- vendor_id: int,
- skip: int = 0,
- limit: int = 100,
- is_active: Optional[bool] = None,
- is_featured: Optional[bool] = None,
+ self,
+ db: Session,
+ vendor_id: int,
+ skip: int = 0,
+ limit: int = 100,
+ is_active: Optional[bool] = None,
+ is_featured: Optional[bool] = None,
) -> Tuple[List[Product], int]:
"""
Get products in vendor catalog with filtering.
diff --git a/app/services/search_service.py b/app/services/search_service.py
index f2f192b8..58830348 100644
--- a/app/services/search_service.py
+++ b/app/services/search_service.py
@@ -1 +1 @@
-# Search and indexing services
+# Search and indexing services
diff --git a/app/services/stats_service.py b/app/services/stats_service.py
index beabc8f1..ba8f381b 100644
--- a/app/services/stats_service.py
+++ b/app/services/stats_service.py
@@ -10,25 +10,21 @@ This module provides:
"""
import logging
-from typing import Any, Dict, List
from datetime import datetime, timedelta
+from typing import Any, Dict, List
from sqlalchemy import func
from sqlalchemy.orm import Session
-from app.exceptions import (
- VendorNotFoundException,
- AdminOperationException,
-)
-
-from models.database.marketplace_product import MarketplaceProduct
-from models.database.product import Product
-from models.database.inventory import Inventory
-from models.database.vendor import Vendor
-from models.database.order import Order
-from models.database.user import User
+from app.exceptions import AdminOperationException, VendorNotFoundException
from models.database.customer import Customer
+from models.database.inventory import Inventory
from models.database.marketplace_import_job import MarketplaceImportJob
+from models.database.marketplace_product import MarketplaceProduct
+from models.database.order import Order
+from models.database.product import Product
+from models.database.user import User
+from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
@@ -62,63 +58,77 @@ class StatsService:
try:
# Catalog statistics
- total_catalog_products = db.query(Product).filter(
- Product.vendor_id == vendor_id,
- Product.is_active == True
- ).count()
+ total_catalog_products = (
+ db.query(Product)
+ .filter(Product.vendor_id == vendor_id, Product.is_active == True)
+ .count()
+ )
- featured_products = db.query(Product).filter(
- Product.vendor_id == vendor_id,
- Product.is_featured == True,
- Product.is_active == True
- ).count()
+ featured_products = (
+ db.query(Product)
+ .filter(
+ Product.vendor_id == vendor_id,
+ Product.is_featured == True,
+ Product.is_active == True,
+ )
+ .count()
+ )
# Staging statistics
# TODO: This is fragile - MarketplaceProduct uses vendor_name (string) not vendor_id
# Should add vendor_id foreign key to MarketplaceProduct for robust querying
# For now, matching by vendor name which could fail if names don't match exactly
- staging_products = db.query(MarketplaceProduct).filter(
- MarketplaceProduct.vendor_name == vendor.name
- ).count()
+ staging_products = (
+ db.query(MarketplaceProduct)
+ .filter(MarketplaceProduct.vendor_name == vendor.name)
+ .count()
+ )
# Inventory statistics
- total_inventory = db.query(
- func.sum(Inventory.quantity)
- ).filter(
- Inventory.vendor_id == vendor_id
- ).scalar() or 0
+ total_inventory = (
+ db.query(func.sum(Inventory.quantity))
+ .filter(Inventory.vendor_id == vendor_id)
+ .scalar()
+ or 0
+ )
- reserved_inventory = db.query(
- func.sum(Inventory.reserved_quantity)
- ).filter(
- Inventory.vendor_id == vendor_id
- ).scalar() or 0
+ reserved_inventory = (
+ db.query(func.sum(Inventory.reserved_quantity))
+ .filter(Inventory.vendor_id == vendor_id)
+ .scalar()
+ or 0
+ )
- inventory_locations = db.query(
- func.count(func.distinct(Inventory.location))
- ).filter(
- Inventory.vendor_id == vendor_id
- ).scalar() or 0
+ inventory_locations = (
+ db.query(func.count(func.distinct(Inventory.location)))
+ .filter(Inventory.vendor_id == vendor_id)
+ .scalar()
+ or 0
+ )
# Import statistics
- total_imports = db.query(MarketplaceImportJob).filter(
- MarketplaceImportJob.vendor_id == vendor_id
- ).count()
+ total_imports = (
+ db.query(MarketplaceImportJob)
+ .filter(MarketplaceImportJob.vendor_id == vendor_id)
+ .count()
+ )
- successful_imports = db.query(MarketplaceImportJob).filter(
- MarketplaceImportJob.vendor_id == vendor_id,
- MarketplaceImportJob.status == "completed"
- ).count()
+ successful_imports = (
+ db.query(MarketplaceImportJob)
+ .filter(
+ MarketplaceImportJob.vendor_id == vendor_id,
+ MarketplaceImportJob.status == "completed",
+ )
+ .count()
+ )
# Orders
- total_orders = db.query(Order).filter(
- Order.vendor_id == vendor_id
- ).count()
+ total_orders = db.query(Order).filter(Order.vendor_id == vendor_id).count()
# Customers
- total_customers = db.query(Customer).filter(
- Customer.vendor_id == vendor_id
- ).count()
+ total_customers = (
+ db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
+ )
return {
"catalog": {
@@ -138,7 +148,11 @@ class StatsService:
"imports": {
"total_imports": total_imports,
"successful_imports": successful_imports,
- "success_rate": (successful_imports / total_imports * 100) if total_imports > 0 else 0,
+ "success_rate": (
+ (successful_imports / total_imports * 100)
+ if total_imports > 0
+ else 0
+ ),
},
"orders": {
"total_orders": total_orders,
@@ -151,16 +165,18 @@ class StatsService:
except VendorNotFoundException:
raise
except Exception as e:
- logger.error(f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}")
+ logger.error(
+ f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}"
+ )
raise AdminOperationException(
operation="get_vendor_stats",
reason=f"Database query failed: {str(e)}",
target_type="vendor",
- target_id=str(vendor_id)
+ target_id=str(vendor_id),
)
def get_vendor_analytics(
- self, db: Session, vendor_id: int, period: str = "30d"
+ self, db: Session, vendor_id: int, period: str = "30d"
) -> Dict[str, Any]:
"""
Get a specific vendor analytics for a time period.
@@ -188,21 +204,28 @@ class StatsService:
start_date = datetime.utcnow() - timedelta(days=days)
# Import activity
- recent_imports = db.query(MarketplaceImportJob).filter(
- MarketplaceImportJob.vendor_id == vendor_id,
- MarketplaceImportJob.created_at >= start_date
- ).count()
+ recent_imports = (
+ db.query(MarketplaceImportJob)
+ .filter(
+ MarketplaceImportJob.vendor_id == vendor_id,
+ MarketplaceImportJob.created_at >= start_date,
+ )
+ .count()
+ )
# Products added to catalog
- products_added = db.query(Product).filter(
- Product.vendor_id == vendor_id,
- Product.created_at >= start_date
- ).count()
+ products_added = (
+ db.query(Product)
+ .filter(
+ Product.vendor_id == vendor_id, Product.created_at >= start_date
+ )
+ .count()
+ )
# Inventory changes
- inventory_entries = db.query(Inventory).filter(
- Inventory.vendor_id == vendor_id
- ).count()
+ inventory_entries = (
+ db.query(Inventory).filter(Inventory.vendor_id == vendor_id).count()
+ )
return {
"period": period,
@@ -221,12 +244,14 @@ class StatsService:
except VendorNotFoundException:
raise
except Exception as e:
- logger.error(f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}")
+ logger.error(
+ f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}"
+ )
raise AdminOperationException(
operation="get_vendor_analytics",
reason=f"Database query failed: {str(e)}",
target_type="vendor",
- target_id=str(vendor_id)
+ target_id=str(vendor_id),
)
def get_vendor_statistics(self, db: Session) -> dict:
@@ -234,7 +259,9 @@ class StatsService:
try:
total_vendors = db.query(Vendor).count()
active_vendors = db.query(Vendor).filter(Vendor.is_active == True).count()
- verified_vendors = db.query(Vendor).filter(Vendor.is_verified == True).count()
+ verified_vendors = (
+ db.query(Vendor).filter(Vendor.is_verified == True).count()
+ )
inactive_vendors = total_vendors - active_vendors
return {
@@ -242,13 +269,14 @@ class StatsService:
"active_vendors": active_vendors,
"inactive_vendors": inactive_vendors,
"verified_vendors": verified_vendors,
- "verification_rate": (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
+ "verification_rate": (
+ (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
+ ),
}
except Exception as e:
logger.error(f"Failed to get vendor statistics: {str(e)}")
raise AdminOperationException(
- operation="get_vendor_statistics",
- reason="Database query failed"
+ operation="get_vendor_statistics", reason="Database query failed"
)
# ========================================================================
@@ -302,7 +330,7 @@ class StatsService:
logger.error(f"Failed to retrieve comprehensive statistics: {str(e)}")
raise AdminOperationException(
operation="get_comprehensive_stats",
- reason=f"Database query failed: {str(e)}"
+ reason=f"Database query failed: {str(e)}",
)
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
@@ -323,8 +351,12 @@ class StatsService:
db.query(
MarketplaceProduct.marketplace,
func.count(MarketplaceProduct.id).label("total_products"),
- func.count(func.distinct(MarketplaceProduct.vendor_name)).label("unique_vendors"),
- func.count(func.distinct(MarketplaceProduct.brand)).label("unique_brands"),
+ func.count(func.distinct(MarketplaceProduct.vendor_name)).label(
+ "unique_vendors"
+ ),
+ func.count(func.distinct(MarketplaceProduct.brand)).label(
+ "unique_brands"
+ ),
)
.filter(MarketplaceProduct.marketplace.isnot(None))
.group_by(MarketplaceProduct.marketplace)
@@ -342,10 +374,12 @@ class StatsService:
]
except Exception as e:
- logger.error(f"Failed to retrieve marketplace breakdown statistics: {str(e)}")
+ logger.error(
+ f"Failed to retrieve marketplace breakdown statistics: {str(e)}"
+ )
raise AdminOperationException(
operation="get_marketplace_breakdown_stats",
- reason=f"Database query failed: {str(e)}"
+ reason=f"Database query failed: {str(e)}",
)
def get_user_statistics(self, db: Session) -> Dict[str, Any]:
@@ -372,13 +406,14 @@ class StatsService:
"active_users": active_users,
"inactive_users": inactive_users,
"admin_users": admin_users,
- "activation_rate": (active_users / total_users * 100) if total_users > 0 else 0
+ "activation_rate": (
+ (active_users / total_users * 100) if total_users > 0 else 0
+ ),
}
except Exception as e:
logger.error(f"Failed to get user statistics: {str(e)}")
raise AdminOperationException(
- operation="get_user_statistics",
- reason="Database query failed"
+ operation="get_user_statistics", reason="Database query failed"
)
def get_import_statistics(self, db: Session) -> Dict[str, Any]:
@@ -396,18 +431,22 @@ class StatsService:
"""
try:
total = db.query(MarketplaceImportJob).count()
- completed = db.query(MarketplaceImportJob).filter(
- MarketplaceImportJob.status == "completed"
- ).count()
- failed = db.query(MarketplaceImportJob).filter(
- MarketplaceImportJob.status == "failed"
- ).count()
+ completed = (
+ db.query(MarketplaceImportJob)
+ .filter(MarketplaceImportJob.status == "completed")
+ .count()
+ )
+ failed = (
+ db.query(MarketplaceImportJob)
+ .filter(MarketplaceImportJob.status == "failed")
+ .count()
+ )
return {
"total_imports": total,
"completed_imports": completed,
"failed_imports": failed,
- "success_rate": (completed / total * 100) if total > 0 else 0
+ "success_rate": (completed / total * 100) if total > 0 else 0,
}
except Exception as e:
logger.error(f"Failed to get import statistics: {str(e)}")
@@ -415,7 +454,7 @@ class StatsService:
"total_imports": 0,
"completed_imports": 0,
"failed_imports": 0,
- "success_rate": 0
+ "success_rate": 0,
}
def get_order_statistics(self, db: Session) -> Dict[str, Any]:
@@ -431,11 +470,7 @@ class StatsService:
Note:
TODO: Implement when Order model is fully available
"""
- return {
- "total_orders": 0,
- "pending_orders": 0,
- "completed_orders": 0
- }
+ return {"total_orders": 0, "pending_orders": 0, "completed_orders": 0}
def get_product_statistics(self, db: Session) -> Dict[str, Any]:
"""
@@ -450,11 +485,7 @@ class StatsService:
Note:
TODO: Implement when Product model is fully available
"""
- return {
- "total_products": 0,
- "active_products": 0,
- "out_of_stock": 0
- }
+ return {"total_products": 0, "active_products": 0, "out_of_stock": 0}
# ========================================================================
# PRIVATE HELPER METHODS
@@ -491,8 +522,7 @@ class StatsService:
return (
db.query(MarketplaceProduct.brand)
.filter(
- MarketplaceProduct.brand.isnot(None),
- MarketplaceProduct.brand != ""
+ MarketplaceProduct.brand.isnot(None), MarketplaceProduct.brand != ""
)
.distinct()
.count()
diff --git a/app/services/team_service.py b/app/services/team_service.py
index 543df373..84799d22 100644
--- a/app/services/team_service.py
+++ b/app/services/team_service.py
@@ -9,17 +9,15 @@ This module provides:
"""
import logging
-from typing import List, Dict, Any
from datetime import datetime, timezone
+from typing import Any, Dict, List
from sqlalchemy.orm import Session
-from app.exceptions import (
- ValidationException,
- UnauthorizedVendorAccessException,
-)
-from models.database.vendor import VendorUser, Role
+from app.exceptions import (UnauthorizedVendorAccessException,
+ ValidationException)
from models.database.user import User
+from models.database.vendor import Role, VendorUser
logger = logging.getLogger(__name__)
@@ -28,7 +26,7 @@ class TeamService:
"""Service for team management operations."""
def get_team_members(
- self, db: Session, vendor_id: int, current_user: User
+ self, db: Session, vendor_id: int, current_user: User
) -> List[Dict[str, Any]]:
"""
Get all team members for vendor.
@@ -42,23 +40,26 @@ class TeamService:
List of team members
"""
try:
- vendor_users = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor_id,
- VendorUser.is_active == True
- ).all()
+ vendor_users = (
+ db.query(VendorUser)
+ .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
+ .all()
+ )
members = []
for vu in vendor_users:
- members.append({
- "id": vu.user_id,
- "email": vu.user.email,
- "first_name": vu.user.first_name,
- "last_name": vu.user.last_name,
- "role": vu.role.name,
- "role_id": vu.role_id,
- "is_active": vu.is_active,
- "joined_at": vu.created_at,
- })
+ members.append(
+ {
+ "id": vu.user_id,
+ "email": vu.user.email,
+ "first_name": vu.user.first_name,
+ "last_name": vu.user.last_name,
+ "role": vu.role.name,
+ "role_id": vu.role_id,
+ "is_active": vu.is_active,
+ "joined_at": vu.created_at,
+ }
+ )
return members
@@ -67,7 +68,7 @@ class TeamService:
raise ValidationException("Failed to retrieve team members")
def invite_team_member(
- self, db: Session, vendor_id: int, invitation_data: dict, current_user: User
+ self, db: Session, vendor_id: int, invitation_data: dict, current_user: User
) -> Dict[str, Any]:
"""
Invite a new team member.
@@ -95,12 +96,12 @@ class TeamService:
raise ValidationException("Failed to invite team member")
def update_team_member(
- self,
- db: Session,
- vendor_id: int,
- user_id: int,
- update_data: dict,
- current_user: User
+ self,
+ db: Session,
+ vendor_id: int,
+ user_id: int,
+ update_data: dict,
+ current_user: User,
) -> Dict[str, Any]:
"""
Update team member role or status.
@@ -116,10 +117,13 @@ class TeamService:
Updated member info
"""
try:
- vendor_user = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor_id,
- VendorUser.user_id == user_id
- ).first()
+ vendor_user = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
+ )
+ .first()
+ )
if not vendor_user:
raise ValidationException("Team member not found")
@@ -146,7 +150,7 @@ class TeamService:
raise ValidationException("Failed to update team member")
def remove_team_member(
- self, db: Session, vendor_id: int, user_id: int, current_user: User
+ self, db: Session, vendor_id: int, user_id: int, current_user: User
) -> bool:
"""
Remove team member from vendor.
@@ -161,10 +165,13 @@ class TeamService:
True if removed
"""
try:
- vendor_user = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor_id,
- VendorUser.user_id == user_id
- ).first()
+ vendor_user = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
+ )
+ .first()
+ )
if not vendor_user:
raise ValidationException("Team member not found")
diff --git a/app/services/vendor_domain_service.py b/app/services/vendor_domain_service.py
index 73fd9da3..1d1776fb 100644
--- a/app/services/vendor_domain_service.py
+++ b/app/services/vendor_domain_service.py
@@ -12,30 +12,28 @@ This module provides classes and functions for:
import logging
import secrets
-from typing import List, Tuple, Optional
from datetime import datetime, timezone
+from typing import List, Optional, Tuple
-from sqlalchemy.orm import Session
from sqlalchemy import and_
+from sqlalchemy.orm import Session
-from app.exceptions import (
- VendorNotFoundException,
- VendorDomainNotFoundException,
- VendorDomainAlreadyExistsException,
- InvalidDomainFormatException,
- ReservedDomainException,
- DomainNotVerifiedException,
- DomainVerificationFailedException,
- DomainAlreadyVerifiedException,
- MultiplePrimaryDomainsException,
- DNSVerificationException,
- MaxDomainsReachedException,
- UnauthorizedDomainAccessException,
- ValidationException,
-)
-from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
+from app.exceptions import (DNSVerificationException,
+ DomainAlreadyVerifiedException,
+ DomainNotVerifiedException,
+ DomainVerificationFailedException,
+ InvalidDomainFormatException,
+ MaxDomainsReachedException,
+ MultiplePrimaryDomainsException,
+ ReservedDomainException,
+ UnauthorizedDomainAccessException,
+ ValidationException,
+ VendorDomainAlreadyExistsException,
+ VendorDomainNotFoundException,
+ VendorNotFoundException)
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
+from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
logger = logging.getLogger(__name__)
@@ -45,13 +43,19 @@ class VendorDomainService:
def __init__(self):
self.max_domains_per_vendor = 10 # Configure as needed
- self.reserved_subdomains = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail']
+ self.reserved_subdomains = [
+ "www",
+ "admin",
+ "api",
+ "mail",
+ "smtp",
+ "ftp",
+ "cpanel",
+ "webmail",
+ ]
def add_domain(
- self,
- db: Session,
- vendor_id: int,
- domain_data: VendorDomainCreate
+ self, db: Session, vendor_id: int, domain_data: VendorDomainCreate
) -> VendorDomain:
"""
Add a custom domain to vendor.
@@ -85,12 +89,14 @@ class VendorDomainService:
# Check if domain already exists
if self._domain_exists(db, normalized_domain):
- existing_domain = db.query(VendorDomain).filter(
- VendorDomain.domain == normalized_domain
- ).first()
+ existing_domain = (
+ db.query(VendorDomain)
+ .filter(VendorDomain.domain == normalized_domain)
+ .first()
+ )
raise VendorDomainAlreadyExistsException(
normalized_domain,
- existing_domain.vendor_id if existing_domain else None
+ existing_domain.vendor_id if existing_domain else None,
)
# If setting as primary, unset other primary domains
@@ -104,8 +110,8 @@ class VendorDomainService:
is_primary=domain_data.is_primary,
verification_token=secrets.token_urlsafe(32),
is_verified=False, # Requires DNS verification
- is_active=False, # Cannot be active until verified
- ssl_status="pending"
+ is_active=False, # Cannot be active until verified
+ ssl_status="pending",
)
db.add(new_domain)
@@ -120,7 +126,7 @@ class VendorDomainService:
VendorDomainAlreadyExistsException,
MaxDomainsReachedException,
InvalidDomainFormatException,
- ReservedDomainException
+ ReservedDomainException,
):
db.rollback()
raise
@@ -129,11 +135,7 @@ class VendorDomainService:
logger.error(f"Error adding domain: {str(e)}")
raise ValidationException("Failed to add domain")
- def get_vendor_domains(
- self,
- db: Session,
- vendor_id: int
- ) -> List[VendorDomain]:
+ def get_vendor_domains(self, db: Session, vendor_id: int) -> List[VendorDomain]:
"""
Get all domains for a vendor.
@@ -151,12 +153,14 @@ class VendorDomainService:
# Verify vendor exists
self._get_vendor_by_id_or_raise(db, vendor_id)
- domains = db.query(VendorDomain).filter(
- VendorDomain.vendor_id == vendor_id
- ).order_by(
- VendorDomain.is_primary.desc(),
- VendorDomain.created_at.desc()
- ).all()
+ domains = (
+ db.query(VendorDomain)
+ .filter(VendorDomain.vendor_id == vendor_id)
+ .order_by(
+ VendorDomain.is_primary.desc(), VendorDomain.created_at.desc()
+ )
+ .all()
+ )
return domains
@@ -166,11 +170,7 @@ class VendorDomainService:
logger.error(f"Error getting vendor domains: {str(e)}")
raise ValidationException("Failed to retrieve domains")
- def get_domain_by_id(
- self,
- db: Session,
- domain_id: int
- ) -> VendorDomain:
+ def get_domain_by_id(self, db: Session, domain_id: int) -> VendorDomain:
"""
Get domain by ID.
@@ -190,10 +190,7 @@ class VendorDomainService:
return domain
def update_domain(
- self,
- db: Session,
- domain_id: int,
- domain_update: VendorDomainUpdate
+ self, db: Session, domain_id: int, domain_update: VendorDomainUpdate
) -> VendorDomain:
"""
Update domain settings.
@@ -215,7 +212,9 @@ class VendorDomainService:
# If setting as primary, unset other primary domains
if domain_update.is_primary:
- self._unset_primary_domains(db, domain.vendor_id, exclude_domain_id=domain_id)
+ self._unset_primary_domains(
+ db, domain.vendor_id, exclude_domain_id=domain_id
+ )
domain.is_primary = True
# If activating, check verification
@@ -240,11 +239,7 @@ class VendorDomainService:
logger.error(f"Error updating domain: {str(e)}")
raise ValidationException("Failed to update domain")
- def delete_domain(
- self,
- db: Session,
- domain_id: int
- ) -> str:
+ def delete_domain(self, db: Session, domain_id: int) -> str:
"""
Delete a custom domain.
@@ -277,11 +272,7 @@ class VendorDomainService:
logger.error(f"Error deleting domain: {str(e)}")
raise ValidationException("Failed to delete domain")
- def verify_domain(
- self,
- db: Session,
- domain_id: int
- ) -> Tuple[VendorDomain, str]:
+ def verify_domain(self, db: Session, domain_id: int) -> Tuple[VendorDomain, str]:
"""
Verify domain ownership via DNS TXT record.
@@ -313,8 +304,7 @@ class VendorDomainService:
# Query DNS TXT records
try:
txt_records = dns.resolver.resolve(
- f"_wizamart-verify.{domain.domain}",
- 'TXT'
+ f"_wizamart-verify.{domain.domain}", "TXT"
)
# Check if verification token is present
@@ -332,42 +322,33 @@ class VendorDomainService:
# Token not found
raise DomainVerificationFailedException(
- domain.domain,
- "Verification token not found in DNS records"
+ domain.domain, "Verification token not found in DNS records"
)
except dns.resolver.NXDOMAIN:
raise DomainVerificationFailedException(
domain.domain,
- f"DNS record _wizamart-verify.{domain.domain} not found"
+ f"DNS record _wizamart-verify.{domain.domain} not found",
)
except dns.resolver.NoAnswer:
raise DomainVerificationFailedException(
- domain.domain,
- "No TXT records found for verification"
+ domain.domain, "No TXT records found for verification"
)
except Exception as dns_error:
- raise DNSVerificationException(
- domain.domain,
- str(dns_error)
- )
+ raise DNSVerificationException(domain.domain, str(dns_error))
except (
VendorDomainNotFoundException,
DomainAlreadyVerifiedException,
DomainVerificationFailedException,
- DNSVerificationException
+ DNSVerificationException,
):
raise
except Exception as e:
logger.error(f"Error verifying domain: {str(e)}")
raise ValidationException("Failed to verify domain")
- def get_verification_instructions(
- self,
- db: Session,
- domain_id: int
- ) -> dict:
+ def get_verification_instructions(self, db: Session, domain_id: int) -> dict:
"""
Get DNS verification instructions for domain.
@@ -390,20 +371,20 @@ class VendorDomainService:
"step1": "Go to your domain's DNS settings (at your domain registrar)",
"step2": "Add a new TXT record with the following values:",
"step3": "Wait for DNS propagation (5-15 minutes)",
- "step4": "Click 'Verify Domain' button in admin panel"
+ "step4": "Click 'Verify Domain' button in admin panel",
},
"txt_record": {
"type": "TXT",
"name": "_wizamart-verify",
"value": domain.verification_token,
- "ttl": 3600
+ "ttl": 3600,
},
"common_registrars": {
"Cloudflare": "https://dash.cloudflare.com",
"GoDaddy": "https://dcc.godaddy.com/manage/dns",
"Namecheap": "https://www.namecheap.com/myaccount/domain-list/",
- "Google Domains": "https://domains.google.com"
- }
+ "Google Domains": "https://domains.google.com",
+ },
}
# Private helper methods
@@ -416,36 +397,33 @@ class VendorDomainService:
def _check_domain_limit(self, db: Session, vendor_id: int) -> None:
"""Check if vendor has reached maximum domain limit."""
- domain_count = db.query(VendorDomain).filter(
- VendorDomain.vendor_id == vendor_id
- ).count()
+ domain_count = (
+ db.query(VendorDomain).filter(VendorDomain.vendor_id == vendor_id).count()
+ )
if domain_count >= self.max_domains_per_vendor:
raise MaxDomainsReachedException(vendor_id, self.max_domains_per_vendor)
def _domain_exists(self, db: Session, domain: str) -> bool:
"""Check if domain already exists in system."""
- return db.query(VendorDomain).filter(
- VendorDomain.domain == domain
- ).first() is not None
+ return (
+ db.query(VendorDomain).filter(VendorDomain.domain == domain).first()
+ is not None
+ )
def _validate_domain_format(self, domain: str) -> None:
"""Validate domain format and check for reserved subdomains."""
# Check for reserved subdomains
- first_part = domain.split('.')[0]
+ first_part = domain.split(".")[0]
if first_part in self.reserved_subdomains:
raise ReservedDomainException(domain, first_part)
def _unset_primary_domains(
- self,
- db: Session,
- vendor_id: int,
- exclude_domain_id: Optional[int] = None
+ self, db: Session, vendor_id: int, exclude_domain_id: Optional[int] = None
) -> None:
"""Unset all primary domains for vendor."""
query = db.query(VendorDomain).filter(
- VendorDomain.vendor_id == vendor_id,
- VendorDomain.is_primary == True
+ VendorDomain.vendor_id == vendor_id, VendorDomain.is_primary == True
)
if exclude_domain_id:
diff --git a/app/services/vendor_service.py b/app/services/vendor_service.py
index 3185e56b..d686c201 100644
--- a/app/services/vendor_service.py
+++ b/app/services/vendor_service.py
@@ -15,22 +15,19 @@ from typing import List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
-from app.exceptions import (
- VendorNotFoundException,
- VendorAlreadyExistsException,
- UnauthorizedVendorAccessException,
- InvalidVendorDataException,
- MarketplaceProductNotFoundException,
- ProductAlreadyExistsException,
- MaxVendorsReachedException,
- ValidationException,
-)
-from models.schema.vendor import VendorCreate
-from models.schema.product import ProductCreate
+from app.exceptions import (InvalidVendorDataException,
+ MarketplaceProductNotFoundException,
+ MaxVendorsReachedException,
+ ProductAlreadyExistsException,
+ UnauthorizedVendorAccessException,
+ ValidationException, VendorAlreadyExistsException,
+ VendorNotFoundException)
from models.database.marketplace_product import MarketplaceProduct
-from models.database.vendor import Vendor
from models.database.product import Product
from models.database.user import User
+from models.database.vendor import Vendor
+from models.schema.product import ProductCreate
+from models.schema.vendor import VendorCreate
logger = logging.getLogger(__name__)
@@ -39,7 +36,7 @@ class VendorService:
"""Service class for vendor operations following the application's service pattern."""
def create_vendor(
- self, db: Session, vendor_data: VendorCreate, current_user: User
+ self, db: Session, vendor_data: VendorCreate, current_user: User
) -> Vendor:
"""
Create a new vendor.
@@ -47,7 +44,7 @@ class VendorService:
Args:
db: Database session
vendor_data: Vendor creation data
- current_user: User creating the vendor
+ current_user: User creating the vendor
Returns:
Created vendor object
@@ -91,7 +88,11 @@ class VendorService:
)
return new_vendor
- except (VendorAlreadyExistsException, MaxVendorsReachedException, InvalidVendorDataException):
+ except (
+ VendorAlreadyExistsException,
+ MaxVendorsReachedException,
+ InvalidVendorDataException,
+ ):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -100,13 +101,13 @@ class VendorService:
raise ValidationException("Failed to create vendor ")
def get_vendors(
- self,
- db: Session,
- current_user: User,
- skip: int = 0,
- limit: int = 100,
- active_only: bool = True,
- verified_only: bool = False,
+ self,
+ db: Session,
+ current_user: User,
+ skip: int = 0,
+ limit: int = 100,
+ active_only: bool = True,
+ verified_only: bool = False,
) -> Tuple[List[Vendor], int]:
"""
Get vendors with filtering.
@@ -129,7 +130,10 @@ class VendorService:
if current_user.role != "admin":
query = query.filter(
(Vendor.is_active == True)
- & ((Vendor.is_verified == True) | (Vendor.owner_user_id == current_user.id))
+ & (
+ (Vendor.is_verified == True)
+ | (Vendor.owner_user_id == current_user.id)
+ )
)
else:
# Admin can apply filters
@@ -147,14 +151,16 @@ class VendorService:
logger.error(f"Error getting vendors: {str(e)}")
raise ValidationException("Failed to retrieve vendors")
- def get_vendor_by_code(self, db: Session, vendor_code: str, current_user: User) -> Vendor:
+ def get_vendor_by_code(
+ self, db: Session, vendor_code: str, current_user: User
+ ) -> Vendor:
"""
Get vendor by vendor code with access control.
Args:
db: Database session
vendor_code: Vendor code to find
- current_user: Current user requesting the vendor
+ current_user: Current user requesting the vendor
Returns:
Vendor object
@@ -170,14 +176,14 @@ class VendorService:
.first()
)
- if not vendor :
+ if not vendor:
raise VendorNotFoundException(vendor_code)
# Check access permissions
if not self._can_access_vendor(vendor, current_user):
raise UnauthorizedVendorAccessException(vendor_code, current_user.id)
- return vendor
+ return vendor
except (VendorNotFoundException, UnauthorizedVendorAccessException):
raise # Re-raise custom exceptions
@@ -186,7 +192,7 @@ class VendorService:
raise ValidationException("Failed to retrieve vendor ")
def add_product_to_catalog(
- self, db: Session, vendor : Vendor, product: ProductCreate
+ self, db: Session, vendor: Vendor, product: ProductCreate
) -> Product:
"""
Add existing product to vendor catalog with vendor -specific settings.
@@ -201,15 +207,19 @@ class VendorService:
Raises:
MarketplaceProductNotFoundException: If product not found
- ProductAlreadyExistsException: If product already in vendor
+ ProductAlreadyExistsException: If product already in vendor
"""
try:
# Check if product exists
- marketplace_product = self._get_product_by_id_or_raise(db, product.marketplace_product_id)
+ marketplace_product = self._get_product_by_id_or_raise(
+ db, product.marketplace_product_id
+ )
- # Check if product already in vendor
+ # Check if product already in vendor
if self._product_in_catalog(db, vendor.id, marketplace_product.id):
- raise ProductAlreadyExistsException(vendor.vendor_code, product.marketplace_product_id)
+ raise ProductAlreadyExistsException(
+ vendor.vendor_code, product.marketplace_product_id
+ )
# Create vendor -product association
new_product = Product(
@@ -225,7 +235,9 @@ class VendorService:
# Load the product relationship
db.refresh(new_product)
- logger.info(f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}")
+ logger.info(
+ f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}"
+ )
return new_product
except (MarketplaceProductNotFoundException, ProductAlreadyExistsException):
@@ -237,14 +249,14 @@ class VendorService:
raise ValidationException("Failed to add product to vendor ")
def get_products(
- self,
- db: Session,
- vendor : Vendor,
- current_user: User,
- skip: int = 0,
- limit: int = 100,
- active_only: bool = True,
- featured_only: bool = False,
+ self,
+ db: Session,
+ vendor: Vendor,
+ current_user: User,
+ skip: int = 0,
+ limit: int = 100,
+ active_only: bool = True,
+ featured_only: bool = False,
) -> Tuple[List[Product], int]:
"""
Get products in vendor catalog with filtering.
@@ -267,7 +279,9 @@ class VendorService:
try:
# Check access permissions
if not self._can_access_vendor(vendor, current_user):
- raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id)
+ raise UnauthorizedVendorAccessException(
+ vendor.vendor_code, current_user.id
+ )
# Query vendor products
query = db.query(Product).filter(Product.vendor_id == vendor.id)
@@ -292,17 +306,20 @@ class VendorService:
def _validate_vendor_data(self, vendor_data: VendorCreate) -> None:
"""Validate vendor creation data."""
if not vendor_data.vendor_code or not vendor_data.vendor_code.strip():
- raise InvalidVendorDataException("Vendor code is required", field="vendor_code")
+ raise InvalidVendorDataException(
+ "Vendor code is required", field="vendor_code"
+ )
if not vendor_data.vendor_name or not vendor_data.vendor_name.strip():
raise InvalidVendorDataException("Vendor name is required", field="name")
# Validate vendor code format (alphanumeric, underscores, hyphens)
import re
- if not re.match(r'^[A-Za-z0-9_-]+$', vendor_data.vendor_code):
+
+ if not re.match(r"^[A-Za-z0-9_-]+$", vendor_data.vendor_code):
raise InvalidVendorDataException(
"Vendor code can only contain letters, numbers, underscores, and hyphens",
- field="vendor_code"
+ field="vendor_code",
)
def _check_vendor_limit(self, db: Session, user: User) -> None:
@@ -310,7 +327,9 @@ class VendorService:
if user.role == "admin":
return # Admins have no limit
- user_vendor_count = db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
+ user_vendor_count = (
+ db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
+ )
max_vendors = 5 # Configure this as needed
if user_vendor_count >= max_vendors:
@@ -319,30 +338,40 @@ class VendorService:
def _vendor_code_exists(self, db: Session, vendor_code: str) -> bool:
"""Check if vendor code already exists (case-insensitive)."""
return (
- db.query(Vendor)
- .filter(func.upper(Vendor.vendor_code) == vendor_code.upper())
- .first() is not None
+ db.query(Vendor)
+ .filter(func.upper(Vendor.vendor_code) == vendor_code.upper())
+ .first()
+ is not None
)
- def _get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct:
+ def _get_product_by_id_or_raise(
+ self, db: Session, marketplace_product_id: str
+ ) -> MarketplaceProduct:
"""Get product by ID or raise exception."""
- product = db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
+ product = (
+ db.query(MarketplaceProduct)
+ .filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
+ .first()
+ )
if not product:
raise MarketplaceProductNotFoundException(marketplace_product_id)
return product
- def _product_in_catalog(self, db: Session, vendor_id: int, marketplace_product_id: int) -> bool:
+ def _product_in_catalog(
+ self, db: Session, vendor_id: int, marketplace_product_id: int
+ ) -> bool:
"""Check if product is already in vendor."""
return (
- db.query(Product)
- .filter(
- Product.vendor_id == vendor_id,
- Product.marketplace_product_id == marketplace_product_id
- )
- .first() is not None
+ db.query(Product)
+ .filter(
+ Product.vendor_id == vendor_id,
+ Product.marketplace_product_id == marketplace_product_id,
+ )
+ .first()
+ is not None
)
- def _can_access_vendor(self, vendor : Vendor, user: User) -> bool:
+ def _can_access_vendor(self, vendor: Vendor, user: User) -> bool:
"""Check if user can access vendor."""
# Admins and owners can always access
if user.role == "admin" or vendor.owner_user_id == user.id:
@@ -351,9 +380,10 @@ class VendorService:
# Others can only access active and verified vendors
return vendor.is_active and vendor.is_verified
- def _is_vendor_owner(self, vendor : Vendor, user: User) -> bool:
+ def _is_vendor_owner(self, vendor: Vendor, user: User) -> bool:
"""Check if user is vendor owner."""
return vendor.owner_user_id == user.id
+
# Create service instance following the same pattern as other services
vendor_service = VendorService()
diff --git a/app/services/vendor_team_service.py b/app/services/vendor_team_service.py
index 763ea8f1..2ed95c29 100644
--- a/app/services/vendor_team_service.py
+++ b/app/services/vendor_team_service.py
@@ -11,23 +11,20 @@ Handles:
import logging
import secrets
from datetime import datetime, timedelta
-from typing import Optional, List, Dict, Any
+from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.permissions import get_preset_permissions
-from app.exceptions import (
- TeamMemberAlreadyExistsException,
- InvalidInvitationTokenException,
- TeamInvitationAlreadyAcceptedException,
- MaxTeamMembersReachedException,
- UserNotFoundException,
- VendorNotFoundException,
- CannotRemoveOwnerException,
-)
-from models.database.user import User
-from models.database.vendor import Vendor, VendorUser, VendorUserType, Role
+from app.exceptions import (CannotRemoveOwnerException,
+ InvalidInvitationTokenException,
+ MaxTeamMembersReachedException,
+ TeamInvitationAlreadyAcceptedException,
+ TeamMemberAlreadyExistsException,
+ UserNotFoundException, VendorNotFoundException)
from middleware.auth import AuthManager
+from models.database.user import User
+from models.database.vendor import Role, Vendor, VendorUser, VendorUserType
logger = logging.getLogger(__name__)
@@ -40,13 +37,13 @@ class VendorTeamService:
self.max_team_members = 50 # Configure as needed
def invite_team_member(
- self,
- db: Session,
- vendor: Vendor,
- inviter: User,
- email: str,
- role_name: str,
- custom_permissions: Optional[List[str]] = None,
+ self,
+ db: Session,
+ vendor: Vendor,
+ inviter: User,
+ email: str,
+ role_name: str,
+ custom_permissions: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Invite a new team member to a vendor.
@@ -69,10 +66,14 @@ class VendorTeamService:
"""
try:
# Check team size limit
- current_team_size = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor.id,
- VendorUser.is_active == True,
- ).count()
+ current_team_size = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor.id,
+ VendorUser.is_active == True,
+ )
+ .count()
+ )
if current_team_size >= self.max_team_members:
raise MaxTeamMembersReachedException(
@@ -85,22 +86,34 @@ class VendorTeamService:
if user:
# Check if already a member
- existing_membership = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor.id,
- VendorUser.user_id == user.id,
- ).first()
+ existing_membership = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor.id,
+ VendorUser.user_id == user.id,
+ )
+ .first()
+ )
if existing_membership:
if existing_membership.is_active:
- raise TeamMemberAlreadyExistsException(email, vendor.vendor_code)
+ raise TeamMemberAlreadyExistsException(
+ email, vendor.vendor_code
+ )
# Reactivate old membership
- existing_membership.is_active = False # Will be activated on acceptance
- existing_membership.invitation_token = self._generate_invitation_token()
+ existing_membership.is_active = (
+ False # Will be activated on acceptance
+ )
+ existing_membership.invitation_token = (
+ self._generate_invitation_token()
+ )
existing_membership.invitation_sent_at = datetime.utcnow()
existing_membership.invitation_accepted_at = None
db.commit()
- logger.info(f"Re-invited user {email} to vendor {vendor.vendor_code}")
+ logger.info(
+ f"Re-invited user {email} to vendor {vendor.vendor_code}"
+ )
return {
"invitation_token": existing_membership.invitation_token,
"email": email,
@@ -108,7 +121,7 @@ class VendorTeamService:
}
else:
# Create new user account (inactive until invitation accepted)
- username = email.split('@')[0]
+ username = email.split("@")[0]
# Ensure unique username
base_username = username
counter = 1
@@ -179,12 +192,12 @@ class VendorTeamService:
raise
def accept_invitation(
- self,
- db: Session,
- invitation_token: str,
- password: str,
- first_name: Optional[str] = None,
- last_name: Optional[str] = None,
+ self,
+ db: Session,
+ invitation_token: str,
+ password: str,
+ first_name: Optional[str] = None,
+ last_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
Accept a team invitation and activate account.
@@ -201,9 +214,13 @@ class VendorTeamService:
"""
try:
# Find invitation
- vendor_user = db.query(VendorUser).filter(
- VendorUser.invitation_token == invitation_token,
- ).first()
+ vendor_user = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.invitation_token == invitation_token,
+ )
+ .first()
+ )
if not vendor_user:
raise InvalidInvitationTokenException()
@@ -247,7 +264,10 @@ class VendorTeamService:
"role": vendor_user.role.name if vendor_user.role else "member",
}
- except (InvalidInvitationTokenException, TeamInvitationAlreadyAcceptedException):
+ except (
+ InvalidInvitationTokenException,
+ TeamInvitationAlreadyAcceptedException,
+ ):
raise
except Exception as e:
db.rollback()
@@ -255,10 +275,10 @@ class VendorTeamService:
raise
def remove_team_member(
- self,
- db: Session,
- vendor: Vendor,
- user_id: int,
+ self,
+ db: Session,
+ vendor: Vendor,
+ user_id: int,
) -> bool:
"""
Remove a team member from a vendor.
@@ -274,10 +294,14 @@ class VendorTeamService:
True if removed
"""
try:
- vendor_user = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor.id,
- VendorUser.user_id == user_id,
- ).first()
+ vendor_user = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor.id,
+ VendorUser.user_id == user_id,
+ )
+ .first()
+ )
if not vendor_user:
raise UserNotFoundException(str(user_id))
@@ -301,12 +325,12 @@ class VendorTeamService:
raise
def update_member_role(
- self,
- db: Session,
- vendor: Vendor,
- user_id: int,
- new_role_name: str,
- custom_permissions: Optional[List[str]] = None,
+ self,
+ db: Session,
+ vendor: Vendor,
+ user_id: int,
+ new_role_name: str,
+ custom_permissions: Optional[List[str]] = None,
) -> VendorUser:
"""
Update a team member's role.
@@ -322,10 +346,14 @@ class VendorTeamService:
Updated VendorUser
"""
try:
- vendor_user = db.query(VendorUser).filter(
- VendorUser.vendor_id == vendor.id,
- VendorUser.user_id == user_id,
- ).first()
+ vendor_user = (
+ db.query(VendorUser)
+ .filter(
+ VendorUser.vendor_id == vendor.id,
+ VendorUser.user_id == user_id,
+ )
+ .first()
+ )
if not vendor_user:
raise UserNotFoundException(str(user_id))
@@ -360,10 +388,10 @@ class VendorTeamService:
raise
def get_team_members(
- self,
- db: Session,
- vendor: Vendor,
- include_inactive: bool = False,
+ self,
+ db: Session,
+ vendor: Vendor,
+ include_inactive: bool = False,
) -> List[Dict[str, Any]]:
"""
Get all team members for a vendor.
@@ -387,20 +415,22 @@ class VendorTeamService:
members = []
for vu in vendor_users:
- members.append({
- "id": vu.user.id,
- "email": vu.user.email,
- "username": vu.user.username,
- "full_name": vu.user.full_name,
- "user_type": vu.user_type,
- "role": vu.role.name if vu.role else "owner",
- "permissions": vu.get_all_permissions(),
- "is_active": vu.is_active,
- "is_owner": vu.is_owner,
- "invitation_pending": vu.is_invitation_pending,
- "invited_at": vu.invitation_sent_at,
- "accepted_at": vu.invitation_accepted_at,
- })
+ members.append(
+ {
+ "id": vu.user.id,
+ "email": vu.user.email,
+ "username": vu.user.username,
+ "full_name": vu.user.full_name,
+ "user_type": vu.user_type,
+ "role": vu.role.name if vu.role else "owner",
+ "permissions": vu.get_all_permissions(),
+ "is_active": vu.is_active,
+ "is_owner": vu.is_owner,
+ "invitation_pending": vu.is_invitation_pending,
+ "invited_at": vu.invitation_sent_at,
+ "accepted_at": vu.invitation_accepted_at,
+ }
+ )
return members
@@ -411,18 +441,22 @@ class VendorTeamService:
return secrets.token_urlsafe(32)
def _get_or_create_role(
- self,
- db: Session,
- vendor: Vendor,
- role_name: str,
- custom_permissions: Optional[List[str]] = None,
+ self,
+ db: Session,
+ vendor: Vendor,
+ role_name: str,
+ custom_permissions: Optional[List[str]] = None,
) -> Role:
"""Get existing role or create new one with preset/custom permissions."""
# Try to find existing role with same name
- role = db.query(Role).filter(
- Role.vendor_id == vendor.id,
- Role.name == role_name,
- ).first()
+ role = (
+ db.query(Role)
+ .filter(
+ Role.vendor_id == vendor.id,
+ Role.name == role_name,
+ )
+ .first()
+ )
if role and custom_permissions is None:
# Use existing role
diff --git a/app/services/vendor_theme_service.py b/app/services/vendor_theme_service.py
index 067f022b..46cd3e3b 100644
--- a/app/services/vendor_theme_service.py
+++ b/app/services/vendor_theme_service.py
@@ -8,32 +8,24 @@ Handles theme CRUD operations, preset application, and validation.
import logging
import re
-from typing import Optional, Dict, List
+from typing import Dict, List, Optional
+
from sqlalchemy.orm import Session
+from app.core.theme_presets import (THEME_PRESETS, apply_preset,
+ get_available_presets, get_preset_preview)
+from app.exceptions.vendor import VendorNotFoundException
+from app.exceptions.vendor_theme import (InvalidColorFormatException,
+ InvalidFontFamilyException,
+ InvalidThemeDataException,
+ ThemeOperationException,
+ ThemePresetAlreadyAppliedException,
+ ThemePresetNotFoundException,
+ ThemeValidationException,
+ VendorThemeNotFoundException)
from models.database.vendor import Vendor
from models.database.vendor_theme import VendorTheme
-from models.schema.vendor_theme import (
- VendorThemeUpdate,
- ThemePresetPreview
-)
-from app.exceptions.vendor import VendorNotFoundException
-from app.exceptions.vendor_theme import (
- VendorThemeNotFoundException,
- InvalidThemeDataException,
- ThemePresetNotFoundException,
- ThemeValidationException,
- InvalidColorFormatException,
- InvalidFontFamilyException,
- ThemePresetAlreadyAppliedException,
- ThemeOperationException
-)
-from app.core.theme_presets import (
- apply_preset,
- get_available_presets,
- get_preset_preview,
- THEME_PRESETS
-)
+from models.schema.vendor_theme import ThemePresetPreview, VendorThemeUpdate
logger = logging.getLogger(__name__)
@@ -71,9 +63,9 @@ class VendorThemeService:
Raises:
VendorNotFoundException: If vendor not found
"""
- vendor = db.query(Vendor).filter(
- Vendor.vendor_code == vendor_code.upper()
- ).first()
+ vendor = (
+ db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
+ )
if not vendor:
self.logger.warning(f"Vendor not found: {vendor_code}")
@@ -105,12 +97,12 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme
- theme = db.query(VendorTheme).filter(
- VendorTheme.vendor_id == vendor.id
- ).first()
+ theme = db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
if not theme:
- self.logger.info(f"No custom theme for vendor {vendor_code}, returning default")
+ self.logger.info(
+ f"No custom theme for vendor {vendor_code}, returning default"
+ )
return self._get_default_theme()
return theme.to_dict()
@@ -130,23 +122,16 @@ class VendorThemeService:
"accent": "#ec4899",
"background": "#ffffff",
"text": "#1f2937",
- "border": "#e5e7eb"
- },
- "fonts": {
- "heading": "Inter, sans-serif",
- "body": "Inter, sans-serif"
+ "border": "#e5e7eb",
},
+ "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"branding": {
"logo": None,
"logo_dark": None,
"favicon": None,
- "banner": None
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "modern"
+ "banner": None,
},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"social_links": {},
"custom_css": None,
"css_variables": {
@@ -158,7 +143,7 @@ class VendorThemeService:
"--color-border": "#e5e7eb",
"--font-heading": "Inter, sans-serif",
"--font-body": "Inter, sans-serif",
- }
+ },
}
# ============================================================================
@@ -166,10 +151,7 @@ class VendorThemeService:
# ============================================================================
def update_theme(
- self,
- db: Session,
- vendor_code: str,
- theme_data: VendorThemeUpdate
+ self, db: Session, vendor_code: str, theme_data: VendorThemeUpdate
) -> VendorTheme:
"""
Update or create theme for vendor.
@@ -194,9 +176,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme
- theme = db.query(VendorTheme).filter(
- VendorTheme.vendor_id == vendor.id
- ).first()
+ theme = (
+ db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
+ )
if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -224,15 +206,11 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to update theme for vendor {vendor_code}: {e}")
raise ThemeOperationException(
- operation="update",
- vendor_code=vendor_code,
- reason=str(e)
+ operation="update", vendor_code=vendor_code, reason=str(e)
)
def _apply_theme_updates(
- self,
- theme: VendorTheme,
- theme_data: VendorThemeUpdate
+ self, theme: VendorTheme, theme_data: VendorThemeUpdate
) -> None:
"""
Apply theme updates to theme object.
@@ -251,30 +229,30 @@ class VendorThemeService:
# Update fonts
if theme_data.fonts:
- if theme_data.fonts.get('heading'):
- theme.font_family_heading = theme_data.fonts['heading']
- if theme_data.fonts.get('body'):
- theme.font_family_body = theme_data.fonts['body']
+ if theme_data.fonts.get("heading"):
+ theme.font_family_heading = theme_data.fonts["heading"]
+ if theme_data.fonts.get("body"):
+ theme.font_family_body = theme_data.fonts["body"]
# Update branding
if theme_data.branding:
- if theme_data.branding.get('logo') is not None:
- theme.logo_url = theme_data.branding['logo']
- if theme_data.branding.get('logo_dark') is not None:
- theme.logo_dark_url = theme_data.branding['logo_dark']
- if theme_data.branding.get('favicon') is not None:
- theme.favicon_url = theme_data.branding['favicon']
- if theme_data.branding.get('banner') is not None:
- theme.banner_url = theme_data.branding['banner']
+ if theme_data.branding.get("logo") is not None:
+ theme.logo_url = theme_data.branding["logo"]
+ if theme_data.branding.get("logo_dark") is not None:
+ theme.logo_dark_url = theme_data.branding["logo_dark"]
+ if theme_data.branding.get("favicon") is not None:
+ theme.favicon_url = theme_data.branding["favicon"]
+ if theme_data.branding.get("banner") is not None:
+ theme.banner_url = theme_data.branding["banner"]
# Update layout
if theme_data.layout:
- if theme_data.layout.get('style'):
- theme.layout_style = theme_data.layout['style']
- if theme_data.layout.get('header'):
- theme.header_style = theme_data.layout['header']
- if theme_data.layout.get('product_card'):
- theme.product_card_style = theme_data.layout['product_card']
+ if theme_data.layout.get("style"):
+ theme.layout_style = theme_data.layout["style"]
+ if theme_data.layout.get("header"):
+ theme.header_style = theme_data.layout["header"]
+ if theme_data.layout.get("product_card"):
+ theme.product_card_style = theme_data.layout["product_card"]
# Update custom CSS
if theme_data.custom_css is not None:
@@ -289,10 +267,7 @@ class VendorThemeService:
# ============================================================================
def apply_theme_preset(
- self,
- db: Session,
- vendor_code: str,
- preset_name: str
+ self, db: Session, vendor_code: str, preset_name: str
) -> VendorTheme:
"""
Apply a theme preset to vendor.
@@ -322,9 +297,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme
- theme = db.query(VendorTheme).filter(
- VendorTheme.vendor_id == vendor.id
- ).first()
+ theme = (
+ db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
+ )
if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -338,7 +313,9 @@ class VendorThemeService:
db.commit()
db.refresh(theme)
- self.logger.info(f"Preset '{preset_name}' applied successfully to vendor {vendor_code}")
+ self.logger.info(
+ f"Preset '{preset_name}' applied successfully to vendor {vendor_code}"
+ )
return theme
except (VendorNotFoundException, ThemePresetNotFoundException):
@@ -349,9 +326,7 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to apply preset to vendor {vendor_code}: {e}")
raise ThemeOperationException(
- operation="apply_preset",
- vendor_code=vendor_code,
- reason=str(e)
+ operation="apply_preset", vendor_code=vendor_code, reason=str(e)
)
def get_available_presets(self) -> List[ThemePresetPreview]:
@@ -399,9 +374,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme
- theme = db.query(VendorTheme).filter(
- VendorTheme.vendor_id == vendor.id
- ).first()
+ theme = (
+ db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
+ )
if not theme:
raise VendorThemeNotFoundException(vendor_code)
@@ -423,9 +398,7 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to delete theme for vendor {vendor_code}: {e}")
raise ThemeOperationException(
- operation="delete",
- vendor_code=vendor_code,
- reason=str(e)
+ operation="delete", vendor_code=vendor_code, reason=str(e)
)
# ============================================================================
@@ -459,9 +432,9 @@ class VendorThemeService:
# Validate layout values
if theme_data.layout:
valid_layouts = {
- 'style': ['grid', 'list', 'masonry'],
- 'header': ['fixed', 'static', 'transparent'],
- 'product_card': ['modern', 'classic', 'minimal']
+ "style": ["grid", "list", "masonry"],
+ "header": ["fixed", "static", "transparent"],
+ "product_card": ["modern", "classic", "minimal"],
}
for layout_key, layout_value in theme_data.layout.items():
@@ -472,7 +445,7 @@ class VendorThemeService:
field=layout_key,
validation_errors={
layout_key: f"Must be one of: {', '.join(valid_layouts[layout_key])}"
- }
+ },
)
def _is_valid_color(self, color: str) -> bool:
@@ -489,7 +462,7 @@ class VendorThemeService:
return False
# Check for hex color format (#RGB or #RRGGBB)
- hex_pattern = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$'
+ hex_pattern = r"^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
return bool(re.match(hex_pattern, color))
def _is_valid_font(self, font: str) -> bool:
diff --git a/app/tasks/background_tasks.py b/app/tasks/background_tasks.py
index 7f1ae503..87ca42f4 100644
--- a/app/tasks/background_tasks.py
+++ b/app/tasks/background_tasks.py
@@ -3,9 +3,9 @@ import logging
from datetime import datetime, timezone
from app.core.database import SessionLocal
+from app.utils.csv_processor import CSVProcessor
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
-from app.utils.csv_processor import CSVProcessor
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ async def process_marketplace_import(
url: str,
marketplace: str,
vendor_id: int, # FIXED: Changed from vendor_name to vendor_id
- batch_size: int = 1000
+ batch_size: int = 1000,
):
"""Background task to process marketplace CSV import."""
db = SessionLocal()
@@ -59,7 +59,7 @@ async def process_marketplace_import(
marketplace,
vendor_id, # FIXED: Pass vendor_id instead of vendor_name
batch_size,
- db
+ db,
)
# Update job with results
diff --git a/app/utils/csv_processor.py b/app/utils/csv_processor.py
index 7bfd7300..6fb9b148 100644
--- a/app/utils/csv_processor.py
+++ b/app/utils/csv_processor.py
@@ -195,7 +195,7 @@ class CSVProcessor:
Args:
url: URL to the CSV file
marketplace: Name of the marketplace (e.g., 'Letzshop', 'Amazon')
- vendor_name: Name of the vendor
+ vendor_name: Name of the vendor
batch_size: Number of rows to process in each batch
db: Database session
@@ -267,7 +267,9 @@ class CSVProcessor:
# Validate required fields
if not product_data.get("marketplace_product_id"):
- logger.warning(f"Row {index}: Missing marketplace_product_id, skipping")
+ logger.warning(
+ f"Row {index}: Missing marketplace_product_id, skipping"
+ )
errors += 1
continue
@@ -279,7 +281,10 @@ class CSVProcessor:
# Check if product exists
existing_product = (
db.query(MarketplaceProduct)
- .filter(MarketplaceProduct.marketplace_product_id == literal(product_data["marketplace_product_id"]))
+ .filter(
+ MarketplaceProduct.marketplace_product_id
+ == literal(product_data["marketplace_product_id"])
+ )
.first()
)
diff --git a/app/utils/data_processing.py b/app/utils/data_processing.py
index fe44665b..85168a39 100644
--- a/app/utils/data_processing.py
+++ b/app/utils/data_processing.py
@@ -109,7 +109,9 @@ class PriceProcessor:
r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)),
}
- def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]:
+ def parse_price_currency(
+ self, price_str: any
+ ) -> Tuple[Optional[str], Optional[str]]:
"""
Parse a price string to extract the numeric value and currency.
diff --git a/app/utils/database.py b/app/utils/database.py
index 29c75b22..72fa3e58 100644
--- a/app/utils/database.py
+++ b/app/utils/database.py
@@ -7,6 +7,7 @@ This module provides utility functions and classes to interact with a database u
"""
import logging
+
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
diff --git a/main.py b/main.py
index 4f306a30..b1af97ab 100644
--- a/main.py
+++ b/main.py
@@ -9,13 +9,13 @@ Multi-tenant e-commerce marketplace platform with:
- Middleware stack for context injection
"""
-import sys
import io
+import sys
# Fix Windows console encoding issues (must be at the very top)
-if sys.platform == 'win32':
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
- sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
+if sys.platform == "win32":
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
import logging
from datetime import datetime, timezone
@@ -23,27 +23,25 @@ from pathlib import Path
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import HTMLResponse, RedirectResponse, FileResponse
+from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.api.main import api_router
-
-# Import page routers
-from app.routes import admin_pages, vendor_pages, shop_pages
from app.core.config import settings
from app.core.database import get_db
from app.core.lifespan import lifespan
-from app.exceptions.handler import setup_exception_handlers
from app.exceptions import ServiceUnavailableException
-
+from app.exceptions.handler import setup_exception_handlers
+# Import page routers
+from app.routes import admin_pages, shop_pages, vendor_pages
+from middleware.context import ContextMiddleware
+from middleware.logging import LoggingMiddleware
+from middleware.theme_context import ThemeContextMiddleware
# Import REFACTORED class-based middleware
from middleware.vendor_context import VendorContextMiddleware
-from middleware.context import ContextMiddleware
-from middleware.theme_context import ThemeContextMiddleware
-from middleware.logging import LoggingMiddleware
logger = logging.getLogger(__name__)
@@ -146,6 +144,7 @@ app.include_router(api_router, prefix="/api")
# FAVICON ROUTES (Must be registered BEFORE page routers)
# ============================================================================
+
def serve_favicon() -> Response:
"""
Serve favicon with caching headers.
@@ -164,7 +163,7 @@ def serve_favicon() -> Response:
media_type="image/x-icon",
headers={
"Cache-Control": "public, max-age=86400", # Cache for 1 day
- }
+ },
)
return Response(status_code=204)
@@ -194,10 +193,7 @@ logger.info("=" * 80)
# Admin pages
logger.info("Registering admin page routes: /admin/*")
app.include_router(
- admin_pages.router,
- prefix="/admin",
- tags=["admin-pages"],
- include_in_schema=False
+ admin_pages.router, prefix="/admin", tags=["admin-pages"], include_in_schema=False
)
# Vendor management pages (dashboard, products, orders, etc.)
@@ -206,7 +202,7 @@ app.include_router(
vendor_pages.router,
prefix="/vendor",
tags=["vendor-pages"],
- include_in_schema=False
+ include_in_schema=False,
)
# Customer shop pages - Register at TWO prefixes:
@@ -217,46 +213,42 @@ logger.info(" - /shop/* (subdomain/custom domain mode)")
logger.info(" - /vendors/{code}/shop/* (path-based development mode)")
app.include_router(
- shop_pages.router,
- prefix="/shop",
- tags=["shop-pages"],
- include_in_schema=False
+ shop_pages.router, prefix="/shop", tags=["shop-pages"], include_in_schema=False
)
app.include_router(
shop_pages.router,
prefix="/vendors/{vendor_code}/shop",
tags=["shop-pages"],
- include_in_schema=False
+ include_in_schema=False,
)
+
# Add handler for /vendors/{vendor_code}/ root path
-@app.get("/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False)
-async def vendor_root_path(vendor_code: str, request: Request, db: Session = Depends(get_db)):
+@app.get(
+ "/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False
+)
+async def vendor_root_path(
+ vendor_code: str, request: Request, db: Session = Depends(get_db)
+):
"""Handle vendor root path (e.g., /vendors/wizamart/)"""
# Vendor should already be in request.state from middleware
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(status_code=404, detail=f"Vendor '{vendor_code}' not found")
- from app.services.content_page_service import content_page_service
from app.routes.shop_pages import get_shop_context
+ from app.services.content_page_service import content_page_service
# Try to find landing page
landing_page = content_page_service.get_page_for_vendor(
- db,
- slug="landing",
- vendor_id=vendor.id,
- include_unpublished=False
+ db, slug="landing", vendor_id=vendor.id, include_unpublished=False
)
if not landing_page:
landing_page = content_page_service.get_page_for_vendor(
- db,
- slug="home",
- vendor_id=vendor.id,
- include_unpublished=False
+ db, slug="home", vendor_id=vendor.id, include_unpublished=False
)
if landing_page:
@@ -265,8 +257,7 @@ async def vendor_root_path(vendor_code: str, request: Request, db: Session = Dep
template_path = f"vendor/landing-{template_name}.html"
return templates.TemplateResponse(
- template_path,
- get_shop_context(request, db=db, page=landing_page)
+ template_path, get_shop_context(request, db=db, page=landing_page)
)
else:
# No landing page - redirect to shop
@@ -298,22 +289,16 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
db,
slug="platform_homepage",
vendor_id=None, # Platform-level page
- include_unpublished=False
+ include_unpublished=False,
)
# Load header and footer navigation
header_pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=None,
- header_only=True,
- include_unpublished=False
+ db, vendor_id=None, header_only=True, include_unpublished=False
)
footer_pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=None,
- footer_only=True,
- include_unpublished=False
+ db, vendor_id=None, footer_only=True, include_unpublished=False
)
if homepage:
@@ -330,7 +315,7 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
"page": homepage,
"header_pages": header_pages,
"footer_pages": footer_pages,
- }
+ },
)
else:
# Fallback to default static template
@@ -342,15 +327,13 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
"request": request,
"header_pages": header_pages,
"footer_pages": footer_pages,
- }
+ },
)
@app.get("/{slug}", response_class=HTMLResponse, include_in_schema=False)
async def platform_content_page(
- request: Request,
- slug: str,
- db: Session = Depends(get_db)
+ request: Request, slug: str, db: Session = Depends(get_db)
):
"""
Platform content pages: /about, /faq, /terms, /contact, etc.
@@ -366,10 +349,7 @@ async def platform_content_page(
# Load page from CMS
page = content_page_service.get_page_for_vendor(
- db,
- slug=slug,
- vendor_id=None, # Platform pages only
- include_unpublished=False
+ db, slug=slug, vendor_id=None, include_unpublished=False # Platform pages only
)
if not page:
@@ -378,17 +358,11 @@ async def platform_content_page(
# Load header and footer navigation
header_pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=None,
- header_only=True,
- include_unpublished=False
+ db, vendor_id=None, header_only=True, include_unpublished=False
)
footer_pages = content_page_service.list_pages_for_vendor(
- db,
- vendor_id=None,
- footer_only=True,
- include_unpublished=False
+ db, vendor_id=None, footer_only=True, include_unpublished=False
)
logger.info(f"[PLATFORM] Rendering content page: {page.title} (/{slug})")
@@ -400,7 +374,7 @@ async def platform_content_page(
"page": page,
"header_pages": header_pages,
"footer_pages": footer_pages,
- }
+ },
)
@@ -411,8 +385,8 @@ logger.info("=" * 80)
logger.info("REGISTERED ROUTES SUMMARY")
logger.info("=" * 80)
for route in app.routes:
- if hasattr(route, 'methods') and hasattr(route, 'path'):
- methods = ', '.join(route.methods) if route.methods else 'N/A'
+ if hasattr(route, "methods") and hasattr(route, "path"):
+ methods = ", ".join(route.methods) if route.methods else "N/A"
logger.info(f" {methods:<10} {route.path:<60}")
logger.info("=" * 80)
@@ -420,6 +394,7 @@ logger.info("=" * 80)
# API ROUTES (JSON Responses)
# ============================================================================
+
# Public Routes (no authentication required)
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def root(request: Request, db: Session = Depends(get_db)):
@@ -428,7 +403,7 @@ async def root(request: Request, db: Session = Depends(get_db)):
- If vendor detected (domain/subdomain): Show vendor landing page or redirect to shop
- If no vendor (platform root): Redirect to documentation
"""
- vendor = getattr(request.state, 'vendor', None)
+ vendor = getattr(request.state, "vendor", None)
if vendor:
# Vendor context detected - serve landing page
@@ -436,19 +411,13 @@ async def root(request: Request, db: Session = Depends(get_db)):
# Try to find landing page (slug='landing' or 'home')
landing_page = content_page_service.get_page_for_vendor(
- db,
- slug="landing",
- vendor_id=vendor.id,
- include_unpublished=False
+ db, slug="landing", vendor_id=vendor.id, include_unpublished=False
)
if not landing_page:
# Try 'home' slug as fallback
landing_page = content_page_service.get_page_for_vendor(
- db,
- slug="home",
- vendor_id=vendor.id,
- include_unpublished=False
+ db, slug="home", vendor_id=vendor.id, include_unpublished=False
)
if landing_page:
@@ -459,17 +428,26 @@ async def root(request: Request, db: Session = Depends(get_db)):
template_path = f"vendor/landing-{template_name}.html"
return templates.TemplateResponse(
- template_path,
- get_shop_context(request, db=db, page=landing_page)
+ template_path, get_shop_context(request, db=db, page=landing_page)
)
else:
# No landing page - redirect to shop
- vendor_context = getattr(request.state, 'vendor_context', None)
- access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
+ vendor_context = getattr(request.state, "vendor_context", None)
+ access_method = (
+ vendor_context.get("detection_method", "unknown")
+ if vendor_context
+ else "unknown"
+ )
if access_method == "path":
- full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
- return RedirectResponse(url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302)
+ full_prefix = (
+ vendor_context.get("full_prefix", "/vendor/")
+ if vendor_context
+ else "/vendor/"
+ )
+ return RedirectResponse(
+ url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302
+ )
else:
# Domain/subdomain
return RedirectResponse(url="/shop/", status_code=302)
diff --git a/middleware/auth.py b/middleware/auth.py
index 81110a33..d37207ce 100644
--- a/middleware/auth.py
+++ b/middleware/auth.py
@@ -26,14 +26,10 @@ from jose import jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
-from app.exceptions import (
- AdminRequiredException,
- InvalidTokenException,
- TokenExpiredException,
- UserNotActiveException,
- InvalidCredentialsException,
- InsufficientPermissionsException
-)
+from app.exceptions import (AdminRequiredException,
+ InsufficientPermissionsException,
+ InvalidCredentialsException, InvalidTokenException,
+ TokenExpiredException, UserNotActiveException)
from models.database.user import User
logger = logging.getLogger(__name__)
@@ -99,7 +95,9 @@ class AuthManager:
"""
return pwd_context.verify(plain_password, hashed_password)
- def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
+ def authenticate_user(
+ self, db: Session, username: str, password: str
+ ) -> Optional[User]:
"""Authenticate user credentials against the database.
Supports authentication using either username or email address.
@@ -201,7 +199,9 @@ class AuthManager:
raise InvalidTokenException("Token missing expiration")
# Check if token has expired (additional check beyond jwt.decode)
- if datetime.now(timezone.utc) > datetime.fromtimestamp(exp, tz=timezone.utc):
+ if datetime.now(timezone.utc) > datetime.fromtimestamp(
+ exp, tz=timezone.utc
+ ):
raise TokenExpiredException()
# Validate user identifier claim exists
@@ -214,7 +214,9 @@ class AuthManager:
"user_id": int(user_id),
"username": payload.get("username"),
"email": payload.get("email"),
- "role": payload.get("role", "user"), # Default to "user" role if not specified
+ "role": payload.get(
+ "role", "user"
+ ), # Default to "user" role if not specified
}
except jwt.ExpiredSignatureError:
@@ -232,7 +234,9 @@ class AuthManager:
logger.error(f"Token verification error: {e}")
raise InvalidTokenException("Authentication failed")
- def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User:
+ def get_current_user(
+ self, db: Session, credentials: HTTPAuthorizationCredentials
+ ) -> User:
"""Extract and validate the current authenticated user from request credentials.
Verifies the JWT token from the Authorization header, looks up the user
@@ -286,8 +290,10 @@ class AuthManager:
# This will only execute if user has "admin" role
pass
"""
+
def decorator(func):
"""Decorator that wraps the function with role checking."""
+
def wrapper(current_user: User, *args, **kwargs):
# Check if current user has the required role
if current_user.role != required_role:
@@ -339,8 +345,7 @@ class AuthManager:
# Check if user has vendor or admin role (admins have full access)
if current_user.role not in ["vendor", "admin"]:
raise InsufficientPermissionsException(
- message="Vendor access required",
- required_permission="vendor"
+ message="Vendor access required", required_permission="vendor"
)
return current_user
@@ -363,7 +368,7 @@ class AuthManager:
if current_user.role not in ["customer", "admin"]:
raise InsufficientPermissionsException(
message="Customer account access required",
- required_permission="customer"
+ required_permission="customer",
)
return current_user
diff --git a/middleware/context.py b/middleware/context.py
index 53b6e075..00b88fbd 100644
--- a/middleware/context.py
+++ b/middleware/context.py
@@ -17,14 +17,16 @@ Class-based middleware provides:
import logging
from enum import Enum
-from starlette.middleware.base import BaseHTTPMiddleware
+
from fastapi import Request
+from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class RequestContext(str, Enum):
"""Request context types for the application."""
+
API = "api"
ADMIN = "admin"
VENDOR_DASHBOARD = "vendor"
@@ -59,7 +61,7 @@ class ContextManager:
# Use clean_path if available (extracted by vendor_context_middleware)
# Falls back to original path if clean_path not set
# This is critical for correct context detection with path-based routing
- path = getattr(request.state, 'clean_path', request.url.path)
+ path = getattr(request.state, "clean_path", request.url.path)
host = request.headers.get("host", "")
@@ -71,10 +73,10 @@ class ContextManager:
f"[CONTEXT] Detecting context",
extra={
"original_path": request.url.path,
- "clean_path": getattr(request.state, 'clean_path', 'NOT SET'),
+ "clean_path": getattr(request.state, "clean_path", "NOT SET"),
"path_to_check": path,
"host": host,
- }
+ },
)
# 1. API context (highest priority)
@@ -84,24 +86,30 @@ class ContextManager:
# 2. Admin context
if ContextManager._is_admin_context(request, host, path):
- logger.debug("[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host})
+ logger.debug(
+ "[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host}
+ )
return RequestContext.ADMIN
# 3. Vendor Dashboard context (vendor management area)
# Check both clean_path and original path for vendor dashboard
original_path = request.url.path
- if ContextManager._is_vendor_dashboard_context(path) or \
- ContextManager._is_vendor_dashboard_context(original_path):
- logger.debug("[CONTEXT] Detected as VENDOR_DASHBOARD", extra={"path": path, "original_path": original_path})
+ if ContextManager._is_vendor_dashboard_context(
+ path
+ ) or ContextManager._is_vendor_dashboard_context(original_path):
+ logger.debug(
+ "[CONTEXT] Detected as VENDOR_DASHBOARD",
+ extra={"path": path, "original_path": original_path},
+ )
return RequestContext.VENDOR_DASHBOARD
# 4. Shop context (vendor storefront)
# Check if vendor context exists (set by vendor_context_middleware)
- if hasattr(request.state, 'vendor') and request.state.vendor:
+ if hasattr(request.state, "vendor") and request.state.vendor:
# If we have a vendor and it's not admin or vendor dashboard, it's shop
logger.debug(
"[CONTEXT] Detected as SHOP (has vendor context)",
- extra={"vendor": request.state.vendor.name}
+ extra={"vendor": request.state.vendor.name},
)
return RequestContext.SHOP
@@ -173,11 +181,12 @@ class ContextMiddleware(BaseHTTPMiddleware):
f"[CONTEXT_MIDDLEWARE] Context detected: {context_type.value}",
extra={
"path": request.url.path,
- "clean_path": getattr(request.state, 'clean_path', 'NOT SET'),
+ "clean_path": getattr(request.state, "clean_path", "NOT SET"),
"host": request.headers.get("host", ""),
"context": context_type.value,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- }
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ },
)
# Continue processing
diff --git a/middleware/decorators.py b/middleware/decorators.py
index f068f6fe..2766301d 100644
--- a/middleware/decorators.py
+++ b/middleware/decorators.py
@@ -9,12 +9,14 @@ This module provides classes and functions for:
"""
from functools import wraps
+
from app.exceptions.base import RateLimitException # Add this import
from middleware.rate_limiter import RateLimiter
# Initialize rate limiter instance
rate_limiter = RateLimiter()
+
def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
"""Rate limiting decorator for FastAPI endpoints."""
@@ -26,10 +28,11 @@ def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
if not rate_limiter.allow_request(client_id, max_requests, window_seconds):
# Use custom exception instead of HTTPException
raise RateLimitException(
- message="Rate limit exceeded",
- retry_after=window_seconds
+ message="Rate limit exceeded", retry_after=window_seconds
)
return await func(*args, **kwargs)
+
return wrapper
+
return decorator
diff --git a/middleware/theme_context.py b/middleware/theme_context.py
index a7531eae..e89e5117 100644
--- a/middleware/theme_context.py
+++ b/middleware/theme_context.py
@@ -11,9 +11,10 @@ Class-based middleware provides:
"""
import logging
-from starlette.middleware.base import BaseHTTPMiddleware
+
from fastapi import Request
from sqlalchemy.orm import Session
+from starlette.middleware.base import BaseHTTPMiddleware
from app.core.database import get_db
from models.database.vendor_theme import VendorTheme
@@ -30,10 +31,11 @@ class ThemeContextManager:
Get theme configuration for vendor.
Returns default theme if no custom theme is configured.
"""
- theme = db.query(VendorTheme).filter(
- VendorTheme.vendor_id == vendor_id,
- VendorTheme.is_active == True
- ).first()
+ theme = (
+ db.query(VendorTheme)
+ .filter(VendorTheme.vendor_id == vendor_id, VendorTheme.is_active == True)
+ .first()
+ )
if theme:
return theme.to_dict()
@@ -52,23 +54,16 @@ class ThemeContextManager:
"accent": "#ec4899",
"background": "#ffffff",
"text": "#1f2937",
- "border": "#e5e7eb"
- },
- "fonts": {
- "heading": "Inter, sans-serif",
- "body": "Inter, sans-serif"
+ "border": "#e5e7eb",
},
+ "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"branding": {
"logo": None,
"logo_dark": None,
"favicon": None,
- "banner": None
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "modern"
+ "banner": None,
},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"social_links": {},
"custom_css": None,
"css_variables": {
@@ -80,7 +75,7 @@ class ThemeContextManager:
"--color-border": "#e5e7eb",
"--font-heading": "Inter, sans-serif",
"--font-body": "Inter, sans-serif",
- }
+ },
}
@@ -106,7 +101,7 @@ class ThemeContextMiddleware(BaseHTTPMiddleware):
Load and inject theme context.
"""
# Only inject theme for shop pages (not admin or API)
- if hasattr(request.state, 'vendor') and request.state.vendor:
+ if hasattr(request.state, "vendor") and request.state.vendor:
vendor = request.state.vendor
# Get database session
@@ -123,13 +118,13 @@ class ThemeContextMiddleware(BaseHTTPMiddleware):
extra={
"vendor_id": vendor.id,
"vendor_name": vendor.name,
- "theme_name": theme.get('theme_name', 'default'),
- }
+ "theme_name": theme.get("theme_name", "default"),
+ },
)
except Exception as e:
logger.error(
f"[THEME] Failed to load theme for vendor {vendor.id}: {e}",
- exc_info=True
+ exc_info=True,
)
# Fallback to default theme
request.state.theme = ThemeContextManager.get_default_theme()
@@ -140,7 +135,7 @@ class ThemeContextMiddleware(BaseHTTPMiddleware):
request.state.theme = ThemeContextManager.get_default_theme()
logger.debug(
"[THEME] No vendor context, using default theme",
- extra={"has_vendor": False}
+ extra={"has_vendor": False},
)
# Continue processing
diff --git a/middleware/vendor_context.py b/middleware/vendor_context.py
index 33e3aa70..982afee6 100644
--- a/middleware/vendor_context.py
+++ b/middleware/vendor_context.py
@@ -13,10 +13,11 @@ Also extracts clean_path for nested routing patterns.
import logging
from typing import Optional
-from sqlalchemy.orm import Session
-from sqlalchemy import func
-from starlette.middleware.base import BaseHTTPMiddleware
+
from fastapi import Request
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+from starlette.middleware.base import BaseHTTPMiddleware
from app.core.config import settings
from app.core.database import get_db
@@ -50,14 +51,15 @@ class VendorContextManager:
# Method 1: Custom domain detection (HIGHEST PRIORITY)
# Check if this is a custom domain (not platform.com and not localhost)
- platform_domain = getattr(settings, 'platform_domain', 'platform.com')
+ platform_domain = getattr(settings, "platform_domain", "platform.com")
is_custom_domain = (
- host and
- not host.endswith(f".{platform_domain}") and
- host != platform_domain and
- host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and
- not host.startswith("admin.")
+ host
+ and not host.endswith(f".{platform_domain}")
+ and host != platform_domain
+ and host
+ not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"]
+ and not host.startswith("admin.")
)
if is_custom_domain:
@@ -66,7 +68,7 @@ class VendorContextManager:
"domain": normalized_domain,
"detection_method": "custom_domain",
"host": host,
- "original_host": request.headers.get("host", "")
+ "original_host": request.headers.get("host", ""),
}
# Method 2: Subdomain detection (vendor1.platform.com)
@@ -78,7 +80,7 @@ class VendorContextManager:
return {
"subdomain": subdomain,
"detection_method": "subdomain",
- "host": host
+ "host": host,
}
# Method 3: Path-based detection (/vendor/vendorname/ or /vendors/vendorname/)
@@ -96,9 +98,9 @@ class VendorContextManager:
return {
"subdomain": vendor_code,
"detection_method": "path",
- "path_prefix": path[:prefix_len + len(vendor_code)],
+ "path_prefix": path[: prefix_len + len(vendor_code)],
"full_prefix": path[:prefix_len], # /vendor/ or /vendors/
- "host": host
+ "host": host,
}
return None
@@ -136,10 +138,14 @@ class VendorContextManager:
logger.warning(f"Vendor for domain {domain} is not active")
return None
- logger.info(f"[OK] Vendor found via custom domain: {domain} → {vendor.name}")
+ logger.info(
+ f"[OK] Vendor found via custom domain: {domain} → {vendor.name}"
+ )
return vendor
else:
- logger.warning(f"No active vendor found for custom domain: {domain}")
+ logger.warning(
+ f"No active vendor found for custom domain: {domain}"
+ )
return None
# Method 2 & 3: Subdomain or path-based lookup
@@ -154,7 +160,9 @@ class VendorContextManager:
if vendor:
method = context.get("detection_method", "unknown")
- logger.info(f"[OK] Vendor found via {method}: {subdomain} → {vendor.name}")
+ logger.info(
+ f"[OK] Vendor found via {method}: {subdomain} → {vendor.name}"
+ )
else:
logger.warning(f"No active vendor found for subdomain: {subdomain}")
@@ -176,7 +184,7 @@ class VendorContextManager:
path_prefix = vendor_context.get("path_prefix", "")
if path.startswith(path_prefix):
- clean_path = path[len(path_prefix):]
+ clean_path = path[len(path_prefix) :]
return clean_path if clean_path else "/"
return request.url.path
@@ -232,6 +240,7 @@ class VendorContextManager:
try:
from urllib.parse import urlparse
+
parsed = urlparse(referer)
referer_host = parsed.hostname or ""
referer_path = parsed.path or ""
@@ -246,27 +255,33 @@ class VendorContextManager:
"referer": referer,
"referer_host": referer_host,
"referer_path": referer_path,
- }
+ },
)
# Method 1: Path-based detection from referer path
# /vendors/wizamart/shop/products → wizamart
- if referer_path.startswith("/vendors/") or referer_path.startswith("/vendor/"):
- prefix = "/vendors/" if referer_path.startswith("/vendors/") else "/vendor/"
- path_parts = referer_path[len(prefix):].split("/")
+ if referer_path.startswith("/vendors/") or referer_path.startswith(
+ "/vendor/"
+ ):
+ prefix = (
+ "/vendors/" if referer_path.startswith("/vendors/") else "/vendor/"
+ )
+ path_parts = referer_path[len(prefix) :].split("/")
if len(path_parts) >= 1 and path_parts[0]:
vendor_code = path_parts[0]
prefix_len = len(prefix)
logger.debug(
f"[VENDOR] Extracted vendor from Referer path: {vendor_code}",
- extra={"vendor_code": vendor_code, "method": "referer_path"}
+ extra={"vendor_code": vendor_code, "method": "referer_path"},
)
# Use "path" as detection_method to be consistent with direct path detection
# This allows cookie path logic to work the same way
return {
"subdomain": vendor_code,
"detection_method": "path", # Consistent with direct path detection
- "path_prefix": referer_path[:prefix_len + len(vendor_code)], # /vendor/vendor1
+ "path_prefix": referer_path[
+ : prefix_len + len(vendor_code)
+ ], # /vendor/vendor1
"full_prefix": prefix, # /vendor/ or /vendors/
"host": referer_host,
"referer": referer,
@@ -274,7 +289,7 @@ class VendorContextManager:
# Method 2: Subdomain detection from referer host
# wizamart.platform.com → wizamart
- platform_domain = getattr(settings, 'platform_domain', 'platform.com')
+ platform_domain = getattr(settings, "platform_domain", "platform.com")
if "." in referer_host:
parts = referer_host.split(".")
if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]:
@@ -283,7 +298,10 @@ class VendorContextManager:
subdomain = parts[0]
logger.debug(
f"[VENDOR] Extracted vendor from Referer subdomain: {subdomain}",
- extra={"subdomain": subdomain, "method": "referer_subdomain"}
+ extra={
+ "subdomain": subdomain,
+ "method": "referer_subdomain",
+ },
)
return {
"subdomain": subdomain,
@@ -295,19 +313,23 @@ class VendorContextManager:
# Method 3: Custom domain detection from referer host
# custom-shop.com → custom-shop.com
is_custom_domain = (
- referer_host and
- not referer_host.endswith(f".{platform_domain}") and
- referer_host != platform_domain and
- referer_host not in ["localhost", "127.0.0.1"] and
- not referer_host.startswith("admin.")
+ referer_host
+ and not referer_host.endswith(f".{platform_domain}")
+ and referer_host != platform_domain
+ and referer_host not in ["localhost", "127.0.0.1"]
+ and not referer_host.startswith("admin.")
)
if is_custom_domain:
from models.database.vendor_domain import VendorDomain
+
normalized_domain = VendorDomain.normalize_domain(referer_host)
logger.debug(
f"[VENDOR] Extracted vendor from Referer custom domain: {normalized_domain}",
- extra={"domain": normalized_domain, "method": "referer_custom_domain"}
+ extra={
+ "domain": normalized_domain,
+ "method": "referer_custom_domain",
+ },
)
return {
"domain": normalized_domain,
@@ -319,7 +341,7 @@ class VendorContextManager:
except Exception as e:
logger.warning(
f"[VENDOR] Failed to extract vendor from Referer: {e}",
- extra={"referer": referer, "error": str(e)}
+ extra={"referer": referer, "error": str(e)},
)
return None
@@ -330,12 +352,28 @@ class VendorContextManager:
path = request.url.path.lower()
static_extensions = (
- '.ico', '.css', '.js', '.png', '.jpg', '.jpeg', '.gif', '.svg',
- '.woff', '.woff2', '.ttf', '.eot', '.webp', '.map', '.json',
- '.xml', '.txt', '.pdf', '.webmanifest'
+ ".ico",
+ ".css",
+ ".js",
+ ".png",
+ ".jpg",
+ ".jpeg",
+ ".gif",
+ ".svg",
+ ".woff",
+ ".woff2",
+ ".ttf",
+ ".eot",
+ ".webp",
+ ".map",
+ ".json",
+ ".xml",
+ ".txt",
+ ".pdf",
+ ".webmanifest",
)
- static_paths = ('/static/', '/media/', '/assets/', '/.well-known/')
+ static_paths = ("/static/", "/media/", "/assets/", "/.well-known/")
if path.endswith(static_extensions):
return True
@@ -343,7 +381,7 @@ class VendorContextManager:
if any(path.startswith(static_path) for static_path in static_paths):
return True
- if 'favicon.ico' in path:
+ if "favicon.ico" in path:
return True
return False
@@ -372,13 +410,13 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
"""
# Skip vendor detection for admin, static files, and system requests
if (
- VendorContextManager.is_admin_request(request) or
- VendorContextManager.is_static_file_request(request) or
- request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]
+ VendorContextManager.is_admin_request(request)
+ or VendorContextManager.is_static_file_request(request)
+ or request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]
):
logger.debug(
f"[VENDOR] Skipping vendor detection: {request.url.path}",
- extra={"path": request.url.path, "reason": "admin/static/system"}
+ extra={"path": request.url.path, "reason": "admin/static/system"},
)
request.state.vendor = None
request.state.vendor_context = None
@@ -389,7 +427,10 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
if VendorContextManager.is_shop_api_request(request):
logger.debug(
f"[VENDOR] Shop API request detected: {request.url.path}",
- extra={"path": request.url.path, "referer": request.headers.get("referer", "")}
+ extra={
+ "path": request.url.path,
+ "referer": request.headers.get("referer", ""),
+ },
)
vendor_context = VendorContextManager.extract_vendor_from_referer(request)
@@ -398,7 +439,9 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
db_gen = get_db()
db = next(db_gen)
try:
- vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
+ vendor = VendorContextManager.get_vendor_from_context(
+ db, vendor_context
+ )
if vendor:
request.state.vendor = vendor
@@ -411,19 +454,23 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
"vendor_id": vendor.id,
"vendor_name": vendor.name,
"vendor_subdomain": vendor.subdomain,
- "detection_method": vendor_context.get("detection_method"),
+ "detection_method": vendor_context.get(
+ "detection_method"
+ ),
"api_path": request.url.path,
"referer": vendor_context.get("referer", ""),
- }
+ },
)
else:
logger.warning(
f"[WARNING] Vendor context from Referer but vendor not found",
extra={
"context": vendor_context,
- "detection_method": vendor_context.get("detection_method"),
+ "detection_method": vendor_context.get(
+ "detection_method"
+ ),
"api_path": request.url.path,
- }
+ },
)
request.state.vendor = None
request.state.vendor_context = vendor_context
@@ -433,7 +480,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
else:
logger.warning(
f"[VENDOR] Shop API request without Referer header",
- extra={"path": request.url.path}
+ extra={"path": request.url.path},
)
request.state.vendor = None
request.state.vendor_context = None
@@ -445,7 +492,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
if VendorContextManager.is_api_request(request):
logger.debug(
f"[VENDOR] Skipping vendor detection for non-shop API: {request.url.path}",
- extra={"path": request.url.path, "reason": "api"}
+ extra={"path": request.url.path, "reason": "api"},
)
request.state.vendor = None
request.state.vendor_context = None
@@ -459,7 +506,9 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
db_gen = get_db()
db = next(db_gen)
try:
- vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
+ vendor = VendorContextManager.get_vendor_from_context(
+ db, vendor_context
+ )
if vendor:
request.state.vendor = vendor
@@ -477,7 +526,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
"detection_method": vendor_context.get("detection_method"),
"original_path": request.url.path,
"clean_path": request.state.clean_path,
- }
+ },
)
else:
logger.warning(
@@ -485,7 +534,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
extra={
"context": vendor_context,
"detection_method": vendor_context.get("detection_method"),
- }
+ },
)
request.state.vendor = None
request.state.vendor_context = vendor_context
@@ -498,7 +547,7 @@ class VendorContextMiddleware(BaseHTTPMiddleware):
extra={
"path": request.url.path,
"host": request.headers.get("host", ""),
- }
+ },
)
request.state.vendor = None
request.state.vendor_context = None
@@ -520,9 +569,9 @@ def require_vendor_context():
vendor = get_current_vendor(request)
if not vendor:
from fastapi import HTTPException
+
raise HTTPException(
- status_code=404,
- detail="Vendor not found or not active"
+ status_code=404, detail="Vendor not found or not active"
)
return vendor
diff --git a/models/__init__.py b/models/__init__.py
index 8d148728..ed80aa1c 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,17 +1,16 @@
# models/__init__.py
"""Models package - Database and API models."""
-# Database models (SQLAlchemy)
-from .database.base import Base
-from .database.user import User
-from .database.marketplace_product import MarketplaceProduct
-from .database.inventory import Inventory
-from .database.vendor import Vendor
-from .database.product import Product
-from .database.marketplace_import_job import MarketplaceImportJob
-
# API models (Pydantic) - import the modules, not all classes
from . import schema
+# Database models (SQLAlchemy)
+from .database.base import Base
+from .database.inventory import Inventory
+from .database.marketplace_import_job import MarketplaceImportJob
+from .database.marketplace_product import MarketplaceProduct
+from .database.product import Product
+from .database.user import User
+from .database.vendor import Vendor
# Export database models for Alembic
__all__ = [
diff --git a/models/database/__init__.py b/models/database/__init__.py
index 24ed1b4e..374c8e70 100644
--- a/models/database/__init__.py
+++ b/models/database/__init__.py
@@ -1,17 +1,18 @@
# models/database/__init__.py
"""Database models package."""
-from .admin import AdminAuditLog, AdminNotification, AdminSetting, PlatformAlert, AdminSession
+from .admin import (AdminAuditLog, AdminNotification, AdminSession,
+ AdminSetting, PlatformAlert)
from .base import Base
from .customer import Customer, CustomerAddress
-from .order import Order, OrderItem
-from .user import User
-from .marketplace_product import MarketplaceProduct
from .inventory import Inventory
-from .vendor import Vendor, Role, VendorUser
+from .marketplace_import_job import MarketplaceImportJob
+from .marketplace_product import MarketplaceProduct
+from .order import Order, OrderItem
+from .product import Product
+from .user import User
+from .vendor import Role, Vendor, VendorUser
from .vendor_domain import VendorDomain
from .vendor_theme import VendorTheme
-from .product import Product
-from .marketplace_import_job import MarketplaceImportJob
__all__ = [
# Admin-specific models
@@ -34,5 +35,5 @@ __all__ = [
"MarketplaceImportJob",
"MarketplaceProduct",
"VendorDomain",
- "VendorTheme"
+ "VendorTheme",
]
diff --git a/models/database/admin.py b/models/database/admin.py
index 32b10ae8..45c04cee 100644
--- a/models/database/admin.py
+++ b/models/database/admin.py
@@ -1,4 +1,4 @@
-# Admin-specific models
+# Admin-specific models
# models/database/admin.py
"""
Admin-specific database models.
@@ -10,9 +10,12 @@ This module provides models for:
- Platform alerts (system-wide issues)
"""
-from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey
+from sqlalchemy import (JSON, Boolean, Column, DateTime, ForeignKey, Integer,
+ String, Text)
from sqlalchemy.orm import relationship
+
from app.core.database import Base
+
from .base import TimestampMixin
@@ -23,12 +26,17 @@ class AdminAuditLog(Base, TimestampMixin):
Separate from regular audit logs - focuses on admin-specific operations
like vendor creation, user management, and system configuration changes.
"""
+
__tablename__ = "admin_audit_logs"
id = Column(Integer, primary_key=True, index=True)
admin_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
- action = Column(String(100), nullable=False, index=True) # create_vendor, delete_vendor, etc.
- target_type = Column(String(50), nullable=False, index=True) # vendor, user, import_job, setting
+ action = Column(
+ String(100), nullable=False, index=True
+ ) # create_vendor, delete_vendor, etc.
+ target_type = Column(
+ String(50), nullable=False, index=True
+ ) # vendor, user, import_job, setting
target_id = Column(String(100), nullable=False, index=True)
details = Column(JSON) # Additional context about the action
ip_address = Column(String(45)) # IPv4 or IPv6
@@ -49,11 +57,16 @@ class AdminNotification(Base, TimestampMixin):
Different from vendor/customer notifications - these are for platform
administrators to track system health and issues requiring attention.
"""
+
__tablename__ = "admin_notifications"
id = Column(Integer, primary_key=True, index=True)
- type = Column(String(50), nullable=False, index=True) # system_alert, vendor_issue, import_failure
- priority = Column(String(20), default="normal", index=True) # low, normal, high, critical
+ type = Column(
+ String(50), nullable=False, index=True
+ ) # system_alert, vendor_issue, import_failure
+ priority = Column(
+ String(20), default="normal", index=True
+ ) # low, normal, high, critical
title = Column(String(200), nullable=False)
message = Column(Text, nullable=False)
is_read = Column(Boolean, default=False, index=True)
@@ -84,13 +97,16 @@ class AdminSetting(Base, TimestampMixin):
- smtp_settings
- stripe_api_keys (encrypted)
"""
+
__tablename__ = "admin_settings"
id = Column(Integer, primary_key=True, index=True)
key = Column(String(100), unique=True, nullable=False, index=True)
value = Column(Text, nullable=False)
value_type = Column(String(20), default="string") # string, integer, boolean, json
- category = Column(String(50), index=True) # system, security, marketplace, notifications
+ category = Column(
+ String(50), index=True
+ ) # system, security, marketplace, notifications
description = Column(Text)
is_encrypted = Column(Boolean, default=False)
is_public = Column(Boolean, default=False) # Can be exposed to frontend?
@@ -110,11 +126,16 @@ class PlatformAlert(Base, TimestampMixin):
Tracks platform issues, performance problems, security incidents,
and other system-level concerns that require admin attention.
"""
+
__tablename__ = "platform_alerts"
id = Column(Integer, primary_key=True, index=True)
- alert_type = Column(String(50), nullable=False, index=True) # security, performance, capacity, integration
- severity = Column(String(20), nullable=False, index=True) # info, warning, error, critical
+ alert_type = Column(
+ String(50), nullable=False, index=True
+ ) # security, performance, capacity, integration
+ severity = Column(
+ String(20), nullable=False, index=True
+ ) # info, warning, error, critical
title = Column(String(200), nullable=False)
description = Column(Text)
affected_vendors = Column(JSON) # List of affected vendor IDs
@@ -142,6 +163,7 @@ class AdminSession(Base, TimestampMixin):
Helps identify suspicious login patterns, track concurrent sessions,
and enforce session policies for admin users.
"""
+
__tablename__ = "admin_sessions"
id = Column(Integer, primary_key=True, index=True)
diff --git a/models/database/audit.py b/models/database/audit.py
index a4093794..05461a88 100644
--- a/models/database/audit.py
+++ b/models/database/audit.py
@@ -1 +1 @@
-# AuditLog, DataExportLog models
+# AuditLog, DataExportLog models
diff --git a/models/database/backup.py b/models/database/backup.py
index 7efe94f4..a9ac9d64 100644
--- a/models/database/backup.py
+++ b/models/database/backup.py
@@ -1 +1 @@
-# BackupLog, RestoreLog models
+# BackupLog, RestoreLog models
diff --git a/models/database/base.py b/models/database/base.py
index 5c675ba2..052c3c31 100644
--- a/models/database/base.py
+++ b/models/database/base.py
@@ -10,5 +10,8 @@ class TimestampMixin:
created_at = Column(DateTime, default=datetime.now(timezone.utc), nullable=False)
updated_at = Column(
- DateTime, default=datetime.now(timezone.utc), onupdate=datetime.now(timezone.utc), nullable=False
+ DateTime,
+ default=datetime.now(timezone.utc),
+ onupdate=datetime.now(timezone.utc),
+ nullable=False,
)
diff --git a/models/database/cart.py b/models/database/cart.py
index 46198d44..c1763eb5 100644
--- a/models/database/cart.py
+++ b/models/database/cart.py
@@ -1,7 +1,9 @@
# models/database/cart.py
"""Cart item database model."""
from datetime import datetime
-from sqlalchemy import Column, Float, ForeignKey, Index, Integer, String, UniqueConstraint
+
+from sqlalchemy import (Column, Float, ForeignKey, Index, Integer, String,
+ UniqueConstraint)
from sqlalchemy.orm import relationship
from app.core.database import Base
@@ -15,6 +17,7 @@ class CartItem(Base, TimestampMixin):
Stores cart items per session, vendor, and product.
Sessions are identified by a session_id string (from browser cookies).
"""
+
__tablename__ = "cart_items"
id = Column(Integer, primary_key=True, index=True)
diff --git a/models/database/configuration.py b/models/database/configuration.py
index a6d610ee..ae38c260 100644
--- a/models/database/configuration.py
+++ b/models/database/configuration.py
@@ -1 +1 @@
-# PlatformConfig, VendorConfig, FeatureFlag models
+# PlatformConfig, VendorConfig, FeatureFlag models
diff --git a/models/database/content_page.py b/models/database/content_page.py
index 687dfc9f..32223bf0 100644
--- a/models/database/content_page.py
+++ b/models/database/content_page.py
@@ -15,7 +15,9 @@ Features:
"""
from datetime import datetime, timezone
-from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, Index
+
+from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
+ String, Text, UniqueConstraint)
from sqlalchemy.orm import relationship
from app.core.database import Base
@@ -34,15 +36,20 @@ class ContentPage(Base):
2. If not found, use platform default (slug only)
3. If neither exists, show 404 or default template
"""
+
__tablename__ = "content_pages"
id = Column(Integer, primary_key=True, index=True)
# Vendor association (NULL = platform default)
- vendor_id = Column(Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=True, index=True)
+ vendor_id = Column(
+ Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=True, index=True
+ )
# Page identification
- slug = Column(String(100), nullable=False, index=True) # about, faq, contact, shipping, returns, etc.
+ slug = Column(
+ String(100), nullable=False, index=True
+ ) # about, faq, contact, shipping, returns, etc.
title = Column(String(200), nullable=False)
# Content
@@ -68,12 +75,25 @@ class ContentPage(Base):
show_in_header = Column(Boolean, default=False)
# Timestamps
- created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
- updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), nullable=False)
+ created_at = Column(
+ DateTime(timezone=True),
+ default=lambda: datetime.now(timezone.utc),
+ nullable=False,
+ )
+ updated_at = Column(
+ DateTime(timezone=True),
+ default=lambda: datetime.now(timezone.utc),
+ onupdate=lambda: datetime.now(timezone.utc),
+ nullable=False,
+ )
# Author tracking (admin or vendor user who created/updated)
- created_by = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
- updated_by = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
+ created_by = Column(
+ Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True
+ )
+ updated_by = Column(
+ Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True
+ )
# Relationships
vendor = relationship("Vendor", back_populates="content_pages")
@@ -84,11 +104,10 @@ class ContentPage(Base):
__table_args__ = (
# Unique combination: vendor can only have one page per slug
# Platform defaults (vendor_id=NULL) can only have one page per slug
- UniqueConstraint('vendor_id', 'slug', name='uq_vendor_slug'),
-
+ UniqueConstraint("vendor_id", "slug", name="uq_vendor_slug"),
# Indexes for performance
- Index('idx_vendor_published', 'vendor_id', 'is_published'),
- Index('idx_slug_published', 'slug', 'is_published'),
+ Index("idx_vendor_published", "vendor_id", "is_published"),
+ Index("idx_slug_published", "slug", "is_published"),
)
def __repr__(self):
@@ -119,7 +138,9 @@ class ContentPage(Base):
"meta_description": self.meta_description,
"meta_keywords": self.meta_keywords,
"is_published": self.is_published,
- "published_at": self.published_at.isoformat() if self.published_at else None,
+ "published_at": (
+ self.published_at.isoformat() if self.published_at else None
+ ),
"display_order": self.display_order,
"show_in_footer": self.show_in_footer,
"show_in_header": self.show_in_header,
diff --git a/models/database/customer.py b/models/database/customer.py
index ec4b67e0..8972f806 100644
--- a/models/database/customer.py
+++ b/models/database/customer.py
@@ -1,8 +1,12 @@
from datetime import datetime
from decimal import Decimal
-from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, JSON, Numeric
+
+from sqlalchemy import (JSON, Boolean, Column, DateTime, ForeignKey, Integer,
+ Numeric, String, Text)
from sqlalchemy.orm import relationship
+
from app.core.database import Base
+
from .base import TimestampMixin
@@ -11,12 +15,16 @@ class Customer(Base, TimestampMixin):
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False)
- email = Column(String(255), nullable=False, index=True) # Unique within vendor scope
+ email = Column(
+ String(255), nullable=False, index=True
+ ) # Unique within vendor scope
hashed_password = Column(String(255), nullable=False)
first_name = Column(String(100))
last_name = Column(String(100))
phone = Column(String(50))
- customer_number = Column(String(100), nullable=False, index=True) # Vendor-specific ID
+ customer_number = Column(
+ String(100), nullable=False, index=True
+ ) # Vendor-specific ID
preferences = Column(JSON, default=dict)
marketing_consent = Column(Boolean, default=False)
last_order_date = Column(DateTime)
diff --git a/models/database/inventory.py b/models/database/inventory.py
index fa22744f..9318c0d8 100644
--- a/models/database/inventory.py
+++ b/models/database/inventory.py
@@ -1,6 +1,8 @@
# models/database/inventory.py
from datetime import datetime
-from sqlalchemy import Column, ForeignKey, Index, Integer, String, UniqueConstraint
+
+from sqlalchemy import (Column, ForeignKey, Index, Integer, String,
+ UniqueConstraint)
from sqlalchemy.orm import relationship
from app.core.database import Base
@@ -27,7 +29,9 @@ class Inventory(Base, TimestampMixin):
# Constraints
__table_args__ = (
- UniqueConstraint("product_id", "location", name="uq_inventory_product_location"),
+ UniqueConstraint(
+ "product_id", "location", name="uq_inventory_product_location"
+ ),
Index("idx_inventory_vendor_product", "vendor_id", "product_id"),
Index("idx_inventory_product_location", "product_id", "location"),
)
diff --git a/models/database/marketplace.py b/models/database/marketplace.py
index 9e5c025c..5cae6517 100644
--- a/models/database/marketplace.py
+++ b/models/database/marketplace.py
@@ -1 +1 @@
-# MarketplaceImportJob model
+# MarketplaceImportJob model
diff --git a/models/database/marketplace_import_job.py b/models/database/marketplace_import_job.py
index 27412cb4..c6587d21 100644
--- a/models/database/marketplace_import_job.py
+++ b/models/database/marketplace_import_job.py
@@ -1,6 +1,7 @@
from datetime import datetime, timezone
-from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text
+from sqlalchemy import (Column, DateTime, ForeignKey, Index, Integer, String,
+ Text)
from sqlalchemy.orm import relationship
from app.core.database import Base
diff --git a/models/database/marketplace_product.py b/models/database/marketplace_product.py
index 389c0112..e248c81d 100644
--- a/models/database/marketplace_product.py
+++ b/models/database/marketplace_product.py
@@ -55,7 +55,9 @@ class MarketplaceProduct(Base, TimestampMixin):
marketplace = Column(
String, index=True, nullable=True, default="Letzshop"
) # Index for marketplace filtering
- vendor_name = Column(String, index=True, nullable=True) # Index for vendor filtering
+ vendor_name = Column(
+ String, index=True, nullable=True
+ ) # Index for vendor filtering
product = relationship("Product", back_populates="marketplace_product")
diff --git a/models/database/media.py b/models/database/media.py
index 9e51d54d..9246810c 100644
--- a/models/database/media.py
+++ b/models/database/media.py
@@ -1 +1 @@
-# MediaFile, ProductMedia models
+# MediaFile, ProductMedia models
diff --git a/models/database/monitoring.py b/models/database/monitoring.py
index ed2edb20..22938930 100644
--- a/models/database/monitoring.py
+++ b/models/database/monitoring.py
@@ -1 +1 @@
-# PerformanceMetric, ErrorLog, SystemAlert models
+# PerformanceMetric, ErrorLog, SystemAlert models
diff --git a/models/database/notification.py b/models/database/notification.py
index 5224c7a9..3eb5495b 100644
--- a/models/database/notification.py
+++ b/models/database/notification.py
@@ -1 +1 @@
-# NotificationTemplate, NotificationQueue, NotificationLog models
+# NotificationTemplate, NotificationQueue, NotificationLog models
diff --git a/models/database/order.py b/models/database/order.py
index 562450a6..a62febdb 100644
--- a/models/database/order.py
+++ b/models/database/order.py
@@ -1,6 +1,8 @@
# models/database/order.py
from datetime import datetime
-from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, String, Text, Boolean
+
+from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Integer,
+ String, Text)
from sqlalchemy.orm import relationship
from app.core.database import Base
@@ -9,11 +11,14 @@ from models.database.base import TimestampMixin
class Order(Base, TimestampMixin):
"""Customer orders."""
+
__tablename__ = "orders"
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
- customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False, index=True)
+ customer_id = Column(
+ Integer, ForeignKey("customers.id"), nullable=False, index=True
+ )
order_number = Column(String, nullable=False, unique=True, index=True)
@@ -30,8 +35,12 @@ class Order(Base, TimestampMixin):
currency = Column(String, default="EUR")
# Addresses (stored as IDs)
- shipping_address_id = Column(Integer, ForeignKey("customer_addresses.id"), nullable=False)
- billing_address_id = Column(Integer, ForeignKey("customer_addresses.id"), nullable=False)
+ shipping_address_id = Column(
+ Integer, ForeignKey("customer_addresses.id"), nullable=False
+ )
+ billing_address_id = Column(
+ Integer, ForeignKey("customer_addresses.id"), nullable=False
+ )
# Shipping
shipping_method = Column(String, nullable=True)
@@ -50,8 +59,12 @@ class Order(Base, TimestampMixin):
# Relationships
vendor = relationship("Vendor")
customer = relationship("Customer", back_populates="orders")
- items = relationship("OrderItem", back_populates="order", cascade="all, delete-orphan")
- shipping_address = relationship("CustomerAddress", foreign_keys=[shipping_address_id])
+ items = relationship(
+ "OrderItem", back_populates="order", cascade="all, delete-orphan"
+ )
+ shipping_address = relationship(
+ "CustomerAddress", foreign_keys=[shipping_address_id]
+ )
billing_address = relationship("CustomerAddress", foreign_keys=[billing_address_id])
def __repr__(self):
@@ -60,6 +73,7 @@ class Order(Base, TimestampMixin):
class OrderItem(Base, TimestampMixin):
"""Individual items in an order."""
+
__tablename__ = "order_items"
id = Column(Integer, primary_key=True, index=True)
diff --git a/models/database/payment.py b/models/database/payment.py
index 4ebd4ebd..cf4bb549 100644
--- a/models/database/payment.py
+++ b/models/database/payment.py
@@ -1 +1 @@
-# Payment, PaymentMethod, VendorPaymentConfig models
+# Payment, PaymentMethod, VendorPaymentConfig models
diff --git a/models/database/product.py b/models/database/product.py
index 8e8262d9..269a6885 100644
--- a/models/database/product.py
+++ b/models/database/product.py
@@ -1,6 +1,8 @@
# models/database/product.py
from datetime import datetime
-from sqlalchemy import Boolean, Column, Float, ForeignKey, Index, Integer, String, UniqueConstraint
+
+from sqlalchemy import (Boolean, Column, Float, ForeignKey, Index, Integer,
+ String, UniqueConstraint)
from sqlalchemy.orm import relationship
from app.core.database import Base
@@ -12,7 +14,9 @@ class Product(Base, TimestampMixin):
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False)
- marketplace_product_id = Column(Integer, ForeignKey("marketplace_products.id"), nullable=False)
+ marketplace_product_id = Column(
+ Integer, ForeignKey("marketplace_products.id"), nullable=False
+ )
# Vendor-specific overrides
product_id = Column(String) # Vendor's internal SKU
@@ -34,7 +38,9 @@ class Product(Base, TimestampMixin):
# Relationships
vendor = relationship("Vendor", back_populates="products")
marketplace_product = relationship("MarketplaceProduct", back_populates="product")
- inventory_entries = relationship("Inventory", back_populates="product", cascade="all, delete-orphan")
+ inventory_entries = relationship(
+ "Inventory", back_populates="product", cascade="all, delete-orphan"
+ )
# Constraints
__table_args__ = (
diff --git a/models/database/search.py b/models/database/search.py
index e13e7439..1d8f6e66 100644
--- a/models/database/search.py
+++ b/models/database/search.py
@@ -1 +1 @@
-# SearchIndex, SearchQuery models
+# SearchIndex, SearchQuery models
diff --git a/models/database/task.py b/models/database/task.py
index 34945e62..ef22426a 100644
--- a/models/database/task.py
+++ b/models/database/task.py
@@ -1 +1 @@
-# TaskLog model
+# TaskLog model
diff --git a/models/database/user.py b/models/database/user.py
index 48650405..dcce9329 100644
--- a/models/database/user.py
+++ b/models/database/user.py
@@ -10,16 +10,18 @@ ROLE CLARIFICATION:
- Vendor-specific roles (manager, staff, etc.) are stored in VendorUser.role
- Customers are NOT in the User table - they use the Customer model
"""
-from sqlalchemy import Boolean, Column, DateTime, Integer, String, Enum
-from sqlalchemy.orm import relationship
import enum
+from sqlalchemy import Boolean, Column, DateTime, Enum, Integer, String
+from sqlalchemy.orm import relationship
+
from app.core.database import Base
from models.database.base import TimestampMixin
class UserRole(str, enum.Enum):
"""Platform-level user roles."""
+
ADMIN = "admin" # Platform administrator
VENDOR = "vendor" # Vendor owner or team member
@@ -44,12 +46,12 @@ class User(Base, TimestampMixin):
last_login = Column(DateTime, nullable=True)
# Relationships
- marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="user")
+ marketplace_import_jobs = relationship(
+ "MarketplaceImportJob", back_populates="user"
+ )
owned_vendors = relationship("Vendor", back_populates="owner")
vendor_memberships = relationship(
- "VendorUser",
- foreign_keys="[VendorUser.user_id]",
- back_populates="user"
+ "VendorUser", foreign_keys="[VendorUser.user_id]", back_populates="user"
)
def __repr__(self):
@@ -84,8 +86,7 @@ class User(Base, TimestampMixin):
return True
# Check if team member
return any(
- vm.vendor_id == vendor_id and vm.is_active
- for vm in self.vendor_memberships
+ vm.vendor_id == vendor_id and vm.is_active for vm in self.vendor_memberships
)
def get_vendor_role(self, vendor_id: int) -> str:
diff --git a/models/database/vendor.py b/models/database/vendor.py
index 813db3ce..c962e852 100644
--- a/models/database/vendor.py
+++ b/models/database/vendor.py
@@ -5,29 +5,37 @@ Vendor model representing entities that sell products or services.
This module defines the Vendor model along with its relationships to
other models such as User (owner), Product, Customer, Order, and MarketplaceImportJob.
"""
-from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, JSON, DateTime
+import enum
+
+from sqlalchemy import (JSON, Boolean, Column, DateTime, ForeignKey, Integer,
+ String, Text)
from sqlalchemy.orm import relationship
+from app.core.config import settings
# Import Base from the central database module instead of creating a new one
from app.core.database import Base
from models.database.base import TimestampMixin
-from app.core.config import settings
-import enum
+
class Vendor(Base, TimestampMixin):
"""Represents a vendor in the system."""
__tablename__ = "vendors" # Name of the table in the database
- id = Column(Integer, primary_key=True, index=True) # Primary key and indexed column for vendor ID
- vendor_code = Column(String, unique=True, index=True,
- nullable=False) # Unique, indexed, non-nullable vendor code column
- subdomain = Column(String(100), unique=True, nullable=False,
- index=True) # Unique, non-nullable subdomain column with indexing
+ id = Column(
+ Integer, primary_key=True, index=True
+ ) # Primary key and indexed column for vendor ID
+ vendor_code = Column(
+ String, unique=True, index=True, nullable=False
+ ) # Unique, indexed, non-nullable vendor code column
+ subdomain = Column(
+ String(100), unique=True, nullable=False, index=True
+ ) # Unique, non-nullable subdomain column with indexing
name = Column(String, nullable=False) # Non-nullable name column for the vendor
description = Column(Text) # Optional text description column for the vendor
- owner_user_id = Column(Integer, ForeignKey("users.id"),
- nullable=False) # Foreign key to user ID of the vendor's owner
+ owner_user_id = Column(
+ Integer, ForeignKey("users.id"), nullable=False
+ ) # Foreign key to user ID of the vendor's owner
# Contact information
contact_email = Column(String) # Optional email column for contact information
@@ -40,34 +48,46 @@ class Vendor(Base, TimestampMixin):
letzshop_csv_url_de = Column(String) # URL for German CSV in Letzshop
# Business information
- business_address = Column(Text) # Optional text address column for business information
+ business_address = Column(
+ Text
+ ) # Optional text address column for business information
tax_number = Column(String) # Optional tax number column for business information
# Status
- is_active = Column(Boolean, default=True) # Boolean to indicate if the vendor is active
- is_verified = Column(Boolean, default=False) # Boolean to indicate if the vendor is verified
-
+ is_active = Column(
+ Boolean, default=True
+ ) # Boolean to indicate if the vendor is active
+ is_verified = Column(
+ Boolean, default=False
+ ) # Boolean to indicate if the vendor is verified
# ========================================================================
# Relationships
# ========================================================================
- owner = relationship("User", back_populates="owned_vendors") # Relationship with User model for the vendor's owner
- vendor_users = relationship("VendorUser",
- back_populates="vendor") # Relationship with VendorUser model for users in this vendor
- products = relationship("Product",
- back_populates="vendor") # Relationship with Product model for products of this vendor
- customers = relationship("Customer",
- back_populates="vendor") # Relationship with Customer model for customers of this vendor
- orders = relationship("Order",
- back_populates="vendor") # Relationship with Order model for orders placed by this vendor
- marketplace_import_jobs = relationship("MarketplaceImportJob",
- back_populates="vendor") # Relationship with MarketplaceImportJob model for import jobs related to this vendor
+ owner = relationship(
+ "User", back_populates="owned_vendors"
+ ) # Relationship with User model for the vendor's owner
+ vendor_users = relationship(
+ "VendorUser", back_populates="vendor"
+ ) # Relationship with VendorUser model for users in this vendor
+ products = relationship(
+ "Product", back_populates="vendor"
+ ) # Relationship with Product model for products of this vendor
+ customers = relationship(
+ "Customer", back_populates="vendor"
+ ) # Relationship with Customer model for customers of this vendor
+ orders = relationship(
+ "Order", back_populates="vendor"
+ ) # Relationship with Order model for orders placed by this vendor
+ marketplace_import_jobs = relationship(
+ "MarketplaceImportJob", back_populates="vendor"
+ ) # Relationship with MarketplaceImportJob model for import jobs related to this vendor
domains = relationship(
"VendorDomain",
back_populates="vendor",
cascade="all, delete-orphan",
- order_by="VendorDomain.is_primary.desc()"
+ order_by="VendorDomain.is_primary.desc()",
) # Relationship with VendorDomain model for custom domains of the vendor
# Single theme relationship (ONE vendor = ONE theme)
@@ -77,14 +97,12 @@ class Vendor(Base, TimestampMixin):
"VendorTheme",
back_populates="vendor",
uselist=False,
- cascade="all, delete-orphan"
+ cascade="all, delete-orphan",
) # Relationship with VendorTheme model for the active theme of the vendor
# Content pages relationship (vendor can override platform default pages)
content_pages = relationship(
- "ContentPage",
- back_populates="vendor",
- cascade="all, delete-orphan"
+ "ContentPage", back_populates="vendor", cascade="all, delete-orphan"
) # Relationship with ContentPage model for vendor-specific content pages
def __repr__(self):
@@ -121,23 +139,16 @@ class Vendor(Base, TimestampMixin):
"accent": "#ec4899",
"background": "#ffffff",
"text": "#1f2937",
- "border": "#e5e7eb"
- },
- "fonts": {
- "heading": "Inter, sans-serif",
- "body": "Inter, sans-serif"
+ "border": "#e5e7eb",
},
+ "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"branding": {
"logo": None,
"logo_dark": None,
"favicon": None,
- "banner": None
- },
- "layout": {
- "style": "grid",
- "header": "fixed",
- "product_card": "modern"
+ "banner": None,
},
+ "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"social_links": {},
"custom_css": None,
"css_variables": {
@@ -149,18 +160,22 @@ class Vendor(Base, TimestampMixin):
"--color-border": "#e5e7eb",
"--font-heading": "Inter, sans-serif",
"--font-body": "Inter, sans-serif",
- }
+ },
}
def get_primary_color(self) -> str:
"""Get primary color from active theme."""
theme = self.get_effective_theme()
- return theme.get("colors", {}).get("primary", "#6366f1") # Default to default theme if not found
+ return theme.get("colors", {}).get(
+ "primary", "#6366f1"
+ ) # Default to default theme if not found
def get_logo_url(self) -> str:
"""Get logo URL from active theme."""
theme = self.get_effective_theme()
- return theme.get("branding", {}).get("logo") # Return None or the logo URL if found
+ return theme.get("branding", {}).get(
+ "logo"
+ ) # Return None or the logo URL if found
# ========================================================================
# Domain Helper Methods
@@ -177,7 +192,9 @@ class Vendor(Base, TimestampMixin):
@property
def all_domains(self):
"""Get all active domains (subdomain + custom domains)."""
- domains = [f"{self.subdomain}.{settings.platform_domain}"] # Start with the main subdomain
+ domains = [
+ f"{self.subdomain}.{settings.platform_domain}"
+ ] # Start with the main subdomain
for domain in self.domains:
if domain.is_active:
domains.append(domain.domain) # Add other active custom domains
@@ -186,6 +203,7 @@ class Vendor(Base, TimestampMixin):
class VendorUserType(str, enum.Enum):
"""Types of vendor users."""
+
OWNER = "owner" # Vendor owner (full access to vendor area)
TEAM_MEMBER = "member" # Team member (role-based access to vendor area)
@@ -222,14 +240,18 @@ class VendorUser(Base, TimestampMixin):
invitation_sent_at = Column(DateTime, nullable=True)
invitation_accepted_at = Column(DateTime, nullable=True)
- is_active = Column(Boolean, default=False, nullable=False) # False until invitation accepted
+ is_active = Column(
+ Boolean, default=False, nullable=False
+ ) # False until invitation accepted
"""Indicates whether the VendorUser role is active."""
# Relationships
vendor = relationship("Vendor", back_populates="vendor_users")
"""Relationship to the Vendor model, representing the associated vendor."""
- user = relationship("User", foreign_keys=[user_id], back_populates="vendor_memberships")
+ user = relationship(
+ "User", foreign_keys=[user_id], back_populates="vendor_memberships"
+ )
"""Relationship to the User model, representing the user who holds this role within the vendor."""
inviter = relationship("User", foreign_keys=[invited_by])
@@ -287,6 +309,7 @@ class VendorUser(Base, TimestampMixin):
if self.is_owner:
# Return all possible permissions
from app.core.permissions import VendorPermissions
+
return list(VendorPermissions.__members__.values())
if self.role and self.role.permissions:
@@ -294,6 +317,7 @@ class VendorUser(Base, TimestampMixin):
return []
+
class Role(Base, TimestampMixin):
"""Represents a role within a vendor's system."""
diff --git a/models/database/vendor_domain.py b/models/database/vendor_domain.py
index 5bafea68..e22512e1 100644
--- a/models/database/vendor_domain.py
+++ b/models/database/vendor_domain.py
@@ -3,11 +3,11 @@
Vendor Domain Model - Maps custom domains to vendors
"""
from datetime import datetime, timezone
-from sqlalchemy import (
- Column, Integer, String, Boolean, DateTime,
- ForeignKey, UniqueConstraint, Index
-)
+
+from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
+ String, UniqueConstraint)
from sqlalchemy.orm import relationship
+
from app.core.database import Base
from models.database.base import TimestampMixin
@@ -21,10 +21,13 @@ class VendorDomain(Base, TimestampMixin):
- shop.mybusiness.com → Vendor 2
- www.customdomain1.com → Vendor 1 (www is stripped)
"""
+
__tablename__ = "vendor_domains"
id = Column(Integer, primary_key=True, index=True)
- vendor_id = Column(Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=False)
+ vendor_id = Column(
+ Integer, ForeignKey("vendors.id", ondelete="CASCADE"), nullable=False
+ )
# Domain configuration
domain = Column(String(255), nullable=False, unique=True, index=True)
@@ -32,7 +35,9 @@ class VendorDomain(Base, TimestampMixin):
is_active = Column(Boolean, default=True, nullable=False)
# SSL/TLS status (for monitoring)
- ssl_status = Column(String(50), default="pending") # pending, active, expired, error
+ ssl_status = Column(
+ String(50), default="pending"
+ ) # pending, active, expired, error
ssl_verified_at = Column(DateTime(timezone=True), nullable=True)
# DNS verification (to confirm domain ownership)
@@ -45,9 +50,9 @@ class VendorDomain(Base, TimestampMixin):
# Constraints
__table_args__ = (
- UniqueConstraint('vendor_id', 'domain', name='uq_vendor_domain'),
- Index('idx_domain_active', 'domain', 'is_active'),
- Index('idx_vendor_primary', 'vendor_id', 'is_primary'),
+ UniqueConstraint("vendor_id", "domain", name="uq_vendor_domain"),
+ Index("idx_domain_active", "domain", "is_active"),
+ Index("idx_vendor_primary", "vendor_id", "is_primary"),
)
def __repr__(self):
diff --git a/models/database/vendor_theme.py b/models/database/vendor_theme.py
index 60db45e5..fcb83cb4 100644
--- a/models/database/vendor_theme.py
+++ b/models/database/vendor_theme.py
@@ -3,8 +3,9 @@
Vendor Theme Configuration Model
Allows each vendor to customize their shop's appearance
"""
-from sqlalchemy import Column, Integer, String, Boolean, Text, JSON, ForeignKey
+from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
+
from app.core.database import Base
from models.database.base import TimestampMixin
@@ -22,6 +23,7 @@ class VendorTheme(Base, TimestampMixin):
Theme presets available: default, modern, classic, minimal, vibrant
"""
+
__tablename__ = "vendor_themes"
id = Column(Integer, primary_key=True, index=True)
@@ -29,22 +31,27 @@ class VendorTheme(Base, TimestampMixin):
Integer,
ForeignKey("vendors.id", ondelete="CASCADE"),
nullable=False,
- unique=True # ONE vendor = ONE theme
+ unique=True, # ONE vendor = ONE theme
)
# Basic Theme Settings
- theme_name = Column(String(100), default="default") # default, modern, classic, minimal, vibrant
+ theme_name = Column(
+ String(100), default="default"
+ ) # default, modern, classic, minimal, vibrant
is_active = Column(Boolean, default=True)
# Color Scheme (JSON for flexibility)
- colors = Column(JSON, default={
- "primary": "#6366f1", # Indigo
- "secondary": "#8b5cf6", # Purple
- "accent": "#ec4899", # Pink
- "background": "#ffffff", # White
- "text": "#1f2937", # Gray-800
- "border": "#e5e7eb" # Gray-200
- })
+ colors = Column(
+ JSON,
+ default={
+ "primary": "#6366f1", # Indigo
+ "secondary": "#8b5cf6", # Purple
+ "accent": "#ec4899", # Pink
+ "background": "#ffffff", # White
+ "text": "#1f2937", # Gray-800
+ "border": "#e5e7eb", # Gray-200
+ },
+ )
# Typography
font_family_heading = Column(String(100), default="Inter, sans-serif")
@@ -59,7 +66,9 @@ class VendorTheme(Base, TimestampMixin):
# Layout Preferences
layout_style = Column(String(50), default="grid") # grid, list, masonry
header_style = Column(String(50), default="fixed") # fixed, static, transparent
- product_card_style = Column(String(50), default="modern") # modern, classic, minimal
+ product_card_style = Column(
+ String(50), default="modern"
+ ) # modern, classic, minimal
# Custom CSS (for advanced customization)
custom_css = Column(Text, nullable=True)
@@ -68,14 +77,18 @@ class VendorTheme(Base, TimestampMixin):
social_links = Column(JSON, default={}) # {facebook: "url", instagram: "url", etc.}
# SEO & Meta
- meta_title_template = Column(String(200), nullable=True) # e.g., "{product_name} - {shop_name}"
+ meta_title_template = Column(
+ String(200), nullable=True
+ ) # e.g., "{product_name} - {shop_name}"
meta_description = Column(Text, nullable=True)
# Relationships - FIXED: back_populates must match the relationship name in Vendor model
vendor = relationship("Vendor", back_populates="vendor_theme")
def __repr__(self):
- return f""
+ return (
+ f""
+ )
@property
def primary_color(self):
diff --git a/models/schema/__init__.py b/models/schema/__init__.py
index aff6c19d..9358039d 100644
--- a/models/schema/__init__.py
+++ b/models/schema/__init__.py
@@ -1,14 +1,9 @@
# models/schema/__init__.py
"""API models package - Pydantic models for request/response validation."""
-from . import auth
# Import API model modules
-from . import base
-from . import marketplace_import_job
-from . import marketplace_product
-from . import stats
-from . import inventory
-from . import vendor
+from . import (auth, base, inventory, marketplace_import_job,
+ marketplace_product, stats, vendor)
# Common imports for convenience
from .base import * # Base Pydantic models
diff --git a/models/schema/admin.py b/models/schema/admin.py
index 22318417..03a533b2 100644
--- a/models/schema/admin.py
+++ b/models/schema/admin.py
@@ -12,16 +12,18 @@ This module provides schemas for:
"""
from datetime import datetime
-from typing import Optional, List, Dict, Any
-from pydantic import BaseModel, Field, field_validator
+from typing import Any, Dict, List, Optional
+from pydantic import BaseModel, Field, field_validator
# ============================================================================
# ADMIN AUDIT LOG SCHEMAS
# ============================================================================
+
class AdminAuditLogResponse(BaseModel):
"""Response model for admin audit logs."""
+
id: int
admin_user_id: int
admin_username: Optional[str] = None
@@ -39,6 +41,7 @@ class AdminAuditLogResponse(BaseModel):
class AdminAuditLogFilters(BaseModel):
"""Filters for querying audit logs."""
+
admin_user_id: Optional[int] = None
action: Optional[str] = None
target_type: Optional[str] = None
@@ -50,6 +53,7 @@ class AdminAuditLogFilters(BaseModel):
class AdminAuditLogListResponse(BaseModel):
"""Paginated list of audit logs."""
+
logs: List[AdminAuditLogResponse]
total: int
skip: int
@@ -60,8 +64,10 @@ class AdminAuditLogListResponse(BaseModel):
# ADMIN NOTIFICATION SCHEMAS
# ============================================================================
+
class AdminNotificationCreate(BaseModel):
"""Create admin notification."""
+
type: str = Field(..., max_length=50, description="Notification type")
priority: str = Field(default="normal", description="Priority level")
title: str = Field(..., max_length=200)
@@ -70,10 +76,10 @@ class AdminNotificationCreate(BaseModel):
action_url: Optional[str] = Field(None, max_length=500)
metadata: Optional[Dict[str, Any]] = None
- @field_validator('priority')
+ @field_validator("priority")
@classmethod
def validate_priority(cls, v):
- allowed = ['low', 'normal', 'high', 'critical']
+ allowed = ["low", "normal", "high", "critical"]
if v not in allowed:
raise ValueError(f"Priority must be one of: {', '.join(allowed)}")
return v
@@ -81,6 +87,7 @@ class AdminNotificationCreate(BaseModel):
class AdminNotificationResponse(BaseModel):
"""Admin notification response."""
+
id: int
type: str
priority: str
@@ -99,11 +106,13 @@ class AdminNotificationResponse(BaseModel):
class AdminNotificationUpdate(BaseModel):
"""Mark notification as read."""
+
is_read: bool = True
class AdminNotificationListResponse(BaseModel):
"""Paginated list of notifications."""
+
notifications: List[AdminNotificationResponse]
total: int
unread_count: int
@@ -115,8 +124,10 @@ class AdminNotificationListResponse(BaseModel):
# ADMIN SETTINGS SCHEMAS
# ============================================================================
+
class AdminSettingCreate(BaseModel):
"""Create or update admin setting."""
+
key: str = Field(..., max_length=100, description="Unique setting key")
value: str = Field(..., description="Setting value")
value_type: str = Field(default="string", description="Data type")
@@ -125,25 +136,28 @@ class AdminSettingCreate(BaseModel):
is_encrypted: bool = Field(default=False)
is_public: bool = Field(default=False, description="Can be exposed to frontend")
- @field_validator('value_type')
+ @field_validator("value_type")
@classmethod
def validate_value_type(cls, v):
- allowed = ['string', 'integer', 'boolean', 'json', 'float']
+ allowed = ["string", "integer", "boolean", "json", "float"]
if v not in allowed:
raise ValueError(f"Value type must be one of: {', '.join(allowed)}")
return v
- @field_validator('key')
+ @field_validator("key")
@classmethod
def validate_key_format(cls, v):
# Setting keys should be lowercase with underscores
- if not v.replace('_', '').isalnum():
- raise ValueError("Setting key must contain only letters, numbers, and underscores")
+ if not v.replace("_", "").isalnum():
+ raise ValueError(
+ "Setting key must contain only letters, numbers, and underscores"
+ )
return v.lower()
class AdminSettingResponse(BaseModel):
"""Admin setting response."""
+
id: int
key: str
value: str
@@ -160,12 +174,14 @@ class AdminSettingResponse(BaseModel):
class AdminSettingUpdate(BaseModel):
"""Update admin setting value."""
+
value: str
description: Optional[str] = None
class AdminSettingListResponse(BaseModel):
"""List of settings by category."""
+
settings: List[AdminSettingResponse]
total: int
category: Optional[str] = None
@@ -175,8 +191,10 @@ class AdminSettingListResponse(BaseModel):
# PLATFORM ALERT SCHEMAS
# ============================================================================
+
class PlatformAlertCreate(BaseModel):
"""Create platform alert."""
+
alert_type: str = Field(..., max_length=50)
severity: str = Field(..., description="Alert severity")
title: str = Field(..., max_length=200)
@@ -185,18 +203,25 @@ class PlatformAlertCreate(BaseModel):
affected_systems: Optional[List[str]] = None
auto_generated: bool = Field(default=True)
- @field_validator('severity')
+ @field_validator("severity")
@classmethod
def validate_severity(cls, v):
- allowed = ['info', 'warning', 'error', 'critical']
+ allowed = ["info", "warning", "error", "critical"]
if v not in allowed:
raise ValueError(f"Severity must be one of: {', '.join(allowed)}")
return v
- @field_validator('alert_type')
+ @field_validator("alert_type")
@classmethod
def validate_alert_type(cls, v):
- allowed = ['security', 'performance', 'capacity', 'integration', 'database', 'system']
+ allowed = [
+ "security",
+ "performance",
+ "capacity",
+ "integration",
+ "database",
+ "system",
+ ]
if v not in allowed:
raise ValueError(f"Alert type must be one of: {', '.join(allowed)}")
return v
@@ -204,6 +229,7 @@ class PlatformAlertCreate(BaseModel):
class PlatformAlertResponse(BaseModel):
"""Platform alert response."""
+
id: int
alert_type: str
severity: str
@@ -226,12 +252,14 @@ class PlatformAlertResponse(BaseModel):
class PlatformAlertResolve(BaseModel):
"""Resolve platform alert."""
+
is_resolved: bool = True
resolution_notes: Optional[str] = None
class PlatformAlertListResponse(BaseModel):
"""Paginated list of platform alerts."""
+
alerts: List[PlatformAlertResponse]
total: int
active_count: int
@@ -244,17 +272,19 @@ class PlatformAlertListResponse(BaseModel):
# BULK OPERATION SCHEMAS
# ============================================================================
+
class BulkVendorAction(BaseModel):
"""Bulk actions on vendors."""
+
vendor_ids: List[int] = Field(..., min_length=1, max_length=100)
action: str = Field(..., description="Action to perform")
confirm: bool = Field(default=False, description="Required for destructive actions")
reason: Optional[str] = Field(None, description="Reason for bulk action")
- @field_validator('action')
+ @field_validator("action")
@classmethod
def validate_action(cls, v):
- allowed = ['activate', 'deactivate', 'verify', 'unverify', 'delete']
+ allowed = ["activate", "deactivate", "verify", "unverify", "delete"]
if v not in allowed:
raise ValueError(f"Action must be one of: {', '.join(allowed)}")
return v
@@ -262,6 +292,7 @@ class BulkVendorAction(BaseModel):
class BulkVendorActionResponse(BaseModel):
"""Response for bulk vendor actions."""
+
successful: List[int]
failed: Dict[int, str] # vendor_id -> error_message
total_processed: int
@@ -271,15 +302,16 @@ class BulkVendorActionResponse(BaseModel):
class BulkUserAction(BaseModel):
"""Bulk actions on users."""
+
user_ids: List[int] = Field(..., min_length=1, max_length=100)
action: str = Field(..., description="Action to perform")
confirm: bool = Field(default=False)
reason: Optional[str] = None
- @field_validator('action')
+ @field_validator("action")
@classmethod
def validate_action(cls, v):
- allowed = ['activate', 'deactivate', 'delete']
+ allowed = ["activate", "deactivate", "delete"]
if v not in allowed:
raise ValueError(f"Action must be one of: {', '.join(allowed)}")
return v
@@ -287,6 +319,7 @@ class BulkUserAction(BaseModel):
class BulkUserActionResponse(BaseModel):
"""Response for bulk user actions."""
+
successful: List[int]
failed: Dict[int, str]
total_processed: int
@@ -298,8 +331,10 @@ class BulkUserActionResponse(BaseModel):
# ADMIN DASHBOARD SCHEMAS
# ============================================================================
+
class AdminDashboardStats(BaseModel):
"""Comprehensive admin dashboard statistics."""
+
platform: Dict[str, Any]
users: Dict[str, Any]
vendors: Dict[str, Any]
@@ -317,8 +352,10 @@ class AdminDashboardStats(BaseModel):
# SYSTEM HEALTH SCHEMAS
# ============================================================================
+
class ComponentHealthStatus(BaseModel):
"""Health status for a system component."""
+
status: str # healthy, degraded, unhealthy
response_time_ms: Optional[float] = None
error_message: Optional[str] = None
@@ -328,6 +365,7 @@ class ComponentHealthStatus(BaseModel):
class SystemHealthResponse(BaseModel):
"""System health check response."""
+
overall_status: str # healthy, degraded, critical
database: ComponentHealthStatus
redis: ComponentHealthStatus
@@ -342,8 +380,10 @@ class SystemHealthResponse(BaseModel):
# ADMIN SESSION SCHEMAS
# ============================================================================
+
class AdminSessionResponse(BaseModel):
"""Admin session information."""
+
id: int
admin_user_id: int
admin_username: Optional[str] = None
@@ -360,6 +400,7 @@ class AdminSessionResponse(BaseModel):
class AdminSessionListResponse(BaseModel):
"""List of admin sessions."""
+
sessions: List[AdminSessionResponse]
total: int
active_count: int
diff --git a/models/schema/auth.py b/models/schema/auth.py
index 95ced43c..5d57aedc 100644
--- a/models/schema/auth.py
+++ b/models/schema/auth.py
@@ -17,7 +17,9 @@ class UserRegister(BaseModel):
@classmethod
def validate_username(cls, v):
if not re.match(r"^[a-zA-Z0-9_]+$", v):
- raise ValueError("Username must contain only letters, numbers, or underscores")
+ raise ValueError(
+ "Username must contain only letters, numbers, or underscores"
+ )
return v.lower().strip()
@field_validator("password")
@@ -31,7 +33,9 @@ class UserRegister(BaseModel):
class UserLogin(BaseModel):
email_or_username: str = Field(..., description="Username or email address")
password: str = Field(..., description="Password")
- vendor_code: Optional[str] = Field(None, description="Optional vendor code for context")
+ vendor_code: Optional[str] = Field(
+ None, description="Optional vendor code for context"
+ )
@field_validator("email_or_username")
@classmethod
diff --git a/models/schema/cart.py b/models/schema/cart.py
index be4511c0..a5c705b7 100644
--- a/models/schema/cart.py
+++ b/models/schema/cart.py
@@ -5,21 +5,24 @@ Pydantic schemas for shopping cart operations.
from datetime import datetime
from typing import List, Optional
-from pydantic import BaseModel, Field, ConfigDict
+from pydantic import BaseModel, ConfigDict, Field
# ============================================================================
# Request Schemas
# ============================================================================
+
class AddToCartRequest(BaseModel):
"""Request model for adding items to cart."""
+
product_id: int = Field(..., description="Product ID to add", gt=0)
quantity: int = Field(1, ge=1, description="Quantity to add")
class UpdateCartItemRequest(BaseModel):
"""Request model for updating cart item quantity."""
+
quantity: int = Field(..., ge=1, description="New quantity (must be >= 1)")
@@ -27,23 +30,30 @@ class UpdateCartItemRequest(BaseModel):
# Response Schemas
# ============================================================================
+
class CartItemResponse(BaseModel):
"""Response model for a single cart item."""
+
model_config = ConfigDict(from_attributes=True)
product_id: int = Field(..., description="Product ID")
product_name: str = Field(..., description="Product name")
quantity: int = Field(..., description="Quantity in cart")
price: float = Field(..., description="Price per unit when added to cart")
- line_total: float = Field(..., description="Total price for this line (price * quantity)")
+ line_total: float = Field(
+ ..., description="Total price for this line (price * quantity)"
+ )
image_url: Optional[str] = Field(None, description="Product image URL")
class CartResponse(BaseModel):
"""Response model for shopping cart."""
+
vendor_id: int = Field(..., description="Vendor ID")
session_id: str = Field(..., description="Shopping session ID")
- items: List[CartItemResponse] = Field(default_factory=list, description="Cart items")
+ items: List[CartItemResponse] = Field(
+ default_factory=list, description="Cart items"
+ )
subtotal: float = Field(..., description="Subtotal of all items")
total: float = Field(..., description="Total amount (currently same as subtotal)")
item_count: int = Field(..., description="Total number of items in cart")
@@ -63,18 +73,22 @@ class CartResponse(BaseModel):
items=items,
subtotal=cart_dict["subtotal"],
total=cart_dict["total"],
- item_count=len(items)
+ item_count=len(items),
)
class CartOperationResponse(BaseModel):
"""Response model for cart operations (add, update, remove)."""
+
message: str = Field(..., description="Operation result message")
product_id: int = Field(..., description="Product ID affected")
- quantity: Optional[int] = Field(None, description="New quantity (for add/update operations)")
+ quantity: Optional[int] = Field(
+ None, description="New quantity (for add/update operations)"
+ )
class ClearCartResponse(BaseModel):
"""Response model for clearing cart."""
+
message: str = Field(..., description="Operation result message")
items_removed: int = Field(..., description="Number of items removed from cart")
diff --git a/models/schema/customer.py b/models/schema/customer.py
index 3be4c89a..9c25ac28 100644
--- a/models/schema/customer.py
+++ b/models/schema/customer.py
@@ -5,35 +5,34 @@ Pydantic schema for customer-related operations.
from datetime import datetime
from decimal import Decimal
-from typing import Optional, Dict, Any, List
-from pydantic import BaseModel, EmailStr, Field, field_validator
+from typing import Any, Dict, List, Optional
+from pydantic import BaseModel, EmailStr, Field, field_validator
# ============================================================================
# Customer Registration & Authentication
# ============================================================================
+
class CustomerRegister(BaseModel):
"""Schema for customer registration."""
email: EmailStr = Field(..., description="Customer email address")
password: str = Field(
- ...,
- min_length=8,
- description="Password (minimum 8 characters)"
+ ..., min_length=8, description="Password (minimum 8 characters)"
)
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
phone: Optional[str] = Field(None, max_length=50)
marketing_consent: bool = Field(default=False)
- @field_validator('email')
+ @field_validator("email")
@classmethod
def email_lowercase(cls, v: str) -> str:
"""Convert email to lowercase."""
return v.lower()
- @field_validator('password')
+ @field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Validate password strength."""
@@ -55,7 +54,7 @@ class CustomerUpdate(BaseModel):
phone: Optional[str] = Field(None, max_length=50)
marketing_consent: Optional[bool] = None
- @field_validator('email')
+ @field_validator("email")
@classmethod
def email_lowercase(cls, v: Optional[str]) -> Optional[str]:
"""Convert email to lowercase."""
@@ -66,6 +65,7 @@ class CustomerUpdate(BaseModel):
# Customer Response
# ============================================================================
+
class CustomerResponse(BaseModel):
"""Schema for customer response (excludes password)."""
@@ -84,9 +84,7 @@ class CustomerResponse(BaseModel):
created_at: datetime
updated_at: datetime
- model_config = {
- "from_attributes": True
- }
+ model_config = {"from_attributes": True}
@property
def full_name(self) -> str:
@@ -110,6 +108,7 @@ class CustomerListResponse(BaseModel):
# Customer Address
# ============================================================================
+
class CustomerAddressCreate(BaseModel):
"""Schema for creating customer address."""
@@ -159,14 +158,14 @@ class CustomerAddressResponse(BaseModel):
created_at: datetime
updated_at: datetime
- model_config = {
- "from_attributes": True
- }
+ model_config = {"from_attributes": True}
+
# ============================================================================
# Customer Preferences
# ============================================================================
+
class CustomerPreferencesUpdate(BaseModel):
"""Schema for updating customer preferences."""
diff --git a/models/schema/inventory.py b/models/schema/inventory.py
index 759383e2..89b3b584 100644
--- a/models/schema/inventory.py
+++ b/models/schema/inventory.py
@@ -1,6 +1,7 @@
# models/schema/inventory.py
from datetime import datetime
from typing import List, Optional
+
from pydantic import BaseModel, ConfigDict, Field
@@ -11,16 +12,21 @@ class InventoryBase(BaseModel):
class InventoryCreate(InventoryBase):
"""Set exact inventory quantity (replaces existing)."""
+
quantity: int = Field(..., description="Exact inventory quantity", ge=0)
class InventoryAdjust(InventoryBase):
"""Add or remove inventory quantity."""
- quantity: int = Field(..., description="Quantity to add (positive) or remove (negative)")
+
+ quantity: int = Field(
+ ..., description="Quantity to add (positive) or remove (negative)"
+ )
class InventoryUpdate(BaseModel):
"""Update inventory fields."""
+
quantity: Optional[int] = Field(None, ge=0)
reserved_quantity: Optional[int] = Field(None, ge=0)
location: Optional[str] = None
@@ -28,6 +34,7 @@ class InventoryUpdate(BaseModel):
class InventoryReserve(BaseModel):
"""Reserve inventory for orders."""
+
product_id: int
location: str
quantity: int = Field(..., gt=0)
@@ -60,6 +67,7 @@ class InventoryLocationResponse(BaseModel):
class ProductInventorySummary(BaseModel):
"""Inventory summary for a product."""
+
product_id: int
vendor_id: int
product_sku: Optional[str]
diff --git a/models/schema/marketplace.py b/models/schema/marketplace.py
index 22a01dd7..4097cbf4 100644
--- a/models/schema/marketplace.py
+++ b/models/schema/marketplace.py
@@ -1 +1 @@
-# Marketplace import job models
+# Marketplace import job models
diff --git a/models/schema/marketplace_import_job.py b/models/schema/marketplace_import_job.py
index 5a72c244..ece440be 100644
--- a/models/schema/marketplace_import_job.py
+++ b/models/schema/marketplace_import_job.py
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Optional
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -8,9 +9,12 @@ class MarketplaceImportJobRequest(BaseModel):
Note: vendor_id is injected by middleware, not from request body.
"""
+
source_url: str = Field(..., description="URL to CSV file from marketplace")
marketplace: str = Field(default="Letzshop", description="Marketplace name")
- batch_size: Optional[int] = Field(1000, description="Processing batch size", ge=100, le=10000)
+ batch_size: Optional[int] = Field(
+ 1000, description="Processing batch size", ge=100, le=10000
+ )
@field_validator("source_url")
@classmethod
@@ -28,6 +32,7 @@ class MarketplaceImportJobRequest(BaseModel):
class MarketplaceImportJobResponse(BaseModel):
"""Response schema for marketplace import job."""
+
model_config = ConfigDict(from_attributes=True)
job_id: int
@@ -55,6 +60,7 @@ class MarketplaceImportJobResponse(BaseModel):
class MarketplaceImportJobListResponse(BaseModel):
"""Response schema for list of import jobs."""
+
jobs: list[MarketplaceImportJobResponse]
total: int
skip: int
@@ -63,6 +69,7 @@ class MarketplaceImportJobListResponse(BaseModel):
class MarketplaceImportJobStatusUpdate(BaseModel):
"""Schema for updating import job status (internal use)."""
+
status: str
imported_count: Optional[int] = None
updated_count: Optional[int] = None
diff --git a/models/schema/marketplace_product.py b/models/schema/marketplace_product.py
index 56e5406a..e459f1a3 100644
--- a/models/schema/marketplace_product.py
+++ b/models/schema/marketplace_product.py
@@ -1,9 +1,12 @@
# models/schema/marketplace_products.py - Simplified validation
from datetime import datetime
from typing import List, Optional
+
from pydantic import BaseModel, ConfigDict, Field
+
from models.schema.inventory import ProductInventorySummary
+
class MarketplaceProductBase(BaseModel):
marketplace_product_id: Optional[str] = None
title: Optional[str] = None
@@ -45,27 +48,34 @@ class MarketplaceProductBase(BaseModel):
marketplace: Optional[str] = None
vendor_name: Optional[str] = None
+
class MarketplaceProductCreate(MarketplaceProductBase):
- marketplace_product_id: str = Field(..., description="MarketplaceProduct identifier")
+ marketplace_product_id: str = Field(
+ ..., description="MarketplaceProduct identifier"
+ )
title: str = Field(..., description="MarketplaceProduct title")
# Removed: min_length constraints and custom validators
# Service will handle empty string validation with proper domain exceptions
+
class MarketplaceProductUpdate(MarketplaceProductBase):
pass
+
class MarketplaceProductResponse(MarketplaceProductBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
+
class MarketplaceProductListResponse(BaseModel):
products: List[MarketplaceProductResponse]
total: int
skip: int
limit: int
+
class MarketplaceProductDetailResponse(BaseModel):
product: MarketplaceProductResponse
inventory_info: Optional[ProductInventorySummary] = None
diff --git a/models/schema/media.py b/models/schema/media.py
index cd0a71f0..cb2935b4 100644
--- a/models/schema/media.py
+++ b/models/schema/media.py
@@ -1 +1 @@
-# Media/file management models
+# Media/file management models
diff --git a/models/schema/monitoring.py b/models/schema/monitoring.py
index 84f2d5a3..c7ab11c1 100644
--- a/models/schema/monitoring.py
+++ b/models/schema/monitoring.py
@@ -1 +1 @@
-# Monitoring models
+# Monitoring models
diff --git a/models/schema/notification.py b/models/schema/notification.py
index cb27dfba..36f54ac4 100644
--- a/models/schema/notification.py
+++ b/models/schema/notification.py
@@ -1 +1 @@
-# Notification models
+# Notification models
diff --git a/models/schema/order.py b/models/schema/order.py
index 6d34e9ef..a5a1d371 100644
--- a/models/schema/order.py
+++ b/models/schema/order.py
@@ -4,23 +4,26 @@ Pydantic schema for order operations.
"""
from datetime import datetime
-from typing import List, Optional
from decimal import Decimal
-from pydantic import BaseModel, Field, ConfigDict
+from typing import List, Optional
+from pydantic import BaseModel, ConfigDict, Field
# ============================================================================
# Order Item Schemas
# ============================================================================
+
class OrderItemCreate(BaseModel):
"""Schema for creating an order item."""
+
product_id: int
quantity: int = Field(..., ge=1)
class OrderItemResponse(BaseModel):
"""Schema for order item response."""
+
model_config = ConfigDict(from_attributes=True)
id: int
@@ -41,8 +44,10 @@ class OrderItemResponse(BaseModel):
# Order Address Schemas
# ============================================================================
+
class OrderAddressCreate(BaseModel):
"""Schema for order address (shipping/billing)."""
+
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
company: Optional[str] = Field(None, max_length=200)
@@ -55,6 +60,7 @@ class OrderAddressCreate(BaseModel):
class OrderAddressResponse(BaseModel):
"""Schema for order address response."""
+
model_config = ConfigDict(from_attributes=True)
id: int
@@ -73,8 +79,10 @@ class OrderAddressResponse(BaseModel):
# Order Create/Update Schemas
# ============================================================================
+
class OrderCreate(BaseModel):
"""Schema for creating an order."""
+
customer_id: Optional[int] = None # Optional for guest checkout
items: List[OrderItemCreate] = Field(..., min_length=1)
@@ -92,9 +100,9 @@ class OrderCreate(BaseModel):
class OrderUpdate(BaseModel):
"""Schema for updating order status."""
+
status: Optional[str] = Field(
- None,
- pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$"
+ None, pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$"
)
tracking_number: Optional[str] = None
internal_notes: Optional[str] = None
@@ -104,8 +112,10 @@ class OrderUpdate(BaseModel):
# Order Response Schemas
# ============================================================================
+
class OrderResponse(BaseModel):
"""Schema for order response."""
+
model_config = ConfigDict(from_attributes=True)
id: int
@@ -141,6 +151,7 @@ class OrderResponse(BaseModel):
class OrderDetailResponse(OrderResponse):
"""Schema for detailed order response with items and addresses."""
+
items: List[OrderItemResponse]
shipping_address: OrderAddressResponse
billing_address: OrderAddressResponse
@@ -148,6 +159,7 @@ class OrderDetailResponse(OrderResponse):
class OrderListResponse(BaseModel):
"""Schema for paginated order list."""
+
orders: List[OrderResponse]
total: int
skip: int
diff --git a/models/schema/payment.py b/models/schema/payment.py
index 1442207e..1ad2d971 100644
--- a/models/schema/payment.py
+++ b/models/schema/payment.py
@@ -1 +1 @@
-# Payment models
+# Payment models
diff --git a/models/schema/product.py b/models/schema/product.py
index 7af82de7..1844e955 100644
--- a/models/schema/product.py
+++ b/models/schema/product.py
@@ -1,14 +1,20 @@
# models/schema/product.py
from datetime import datetime
from typing import List, Optional
+
from pydantic import BaseModel, ConfigDict, Field
-from models.schema.marketplace_product import MarketplaceProductResponse
+
from models.schema.inventory import InventoryLocationResponse
+from models.schema.marketplace_product import MarketplaceProductResponse
class ProductCreate(BaseModel):
- marketplace_product_id: int = Field(..., description="MarketplaceProduct ID to add to vendor catalog")
- product_id: Optional[str] = Field(None, description="Vendor's internal SKU/product ID")
+ marketplace_product_id: int = Field(
+ ..., description="MarketplaceProduct ID to add to vendor catalog"
+ )
+ product_id: Optional[str] = Field(
+ None, description="Vendor's internal SKU/product ID"
+ )
price: Optional[float] = Field(None, ge=0)
sale_price: Optional[float] = Field(None, ge=0)
currency: Optional[str] = None
@@ -59,6 +65,7 @@ class ProductResponse(BaseModel):
class ProductDetailResponse(ProductResponse):
"""Product with full inventory details."""
+
inventory_locations: List[InventoryLocationResponse] = []
diff --git a/models/schema/search.py b/models/schema/search.py
index 69abe85d..2e8b1772 100644
--- a/models/schema/search.py
+++ b/models/schema/search.py
@@ -1 +1 @@
-# Search models
+# Search models
diff --git a/models/schema/stats.py b/models/schema/stats.py
index 0bdf9b65..4ef278fd 100644
--- a/models/schema/stats.py
+++ b/models/schema/stats.py
@@ -1,6 +1,6 @@
import re
-from decimal import Decimal
from datetime import datetime
+from decimal import Decimal
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
@@ -22,10 +22,12 @@ class MarketplaceStatsResponse(BaseModel):
unique_vendors: int
unique_brands: int
+
# ============================================================================
# Customer Statistics
# ============================================================================
+
class CustomerStatsResponse(BaseModel):
"""Schema for customer statistics."""
@@ -42,8 +44,10 @@ class CustomerStatsResponse(BaseModel):
# Order Statistics
# ============================================================================
+
class OrderStatsResponse(BaseModel):
"""Schema for order statistics."""
+
total_orders: int
pending_orders: int
processing_orders: int
@@ -53,13 +57,16 @@ class OrderStatsResponse(BaseModel):
total_revenue: Decimal
average_order_value: Decimal
+
# ============================================================================
# Vendor Statistics
# ============================================================================
+
class VendorStatsResponse(BaseModel):
"""Vendor statistics response schema."""
+
total: int = Field(..., description="Total number of vendors")
verified: int = Field(..., description="Number of verified vendors")
pending: int = Field(..., description="Number of pending verification vendors")
- inactive: int = Field(..., description="Number of inactive vendors")
\ No newline at end of file
+ inactive: int = Field(..., description="Number of inactive vendors")
diff --git a/models/schema/team.py b/models/schema/team.py
index ec716979..5d3f5ca0 100644
--- a/models/schema/team.py
+++ b/models/schema/team.py
@@ -10,33 +10,40 @@ This module defines request/response schemas for:
"""
from datetime import datetime
-from typing import Optional, List
-from pydantic import BaseModel, EmailStr, Field, field_validator
+from typing import List, Optional
+from pydantic import BaseModel, EmailStr, Field, field_validator
# ============================================================================
# Role Schemas
# ============================================================================
+
class RoleBase(BaseModel):
"""Base role schema."""
+
name: str = Field(..., min_length=1, max_length=100, description="Role name")
- permissions: List[str] = Field(default_factory=list, description="List of permission strings")
+ permissions: List[str] = Field(
+ default_factory=list, description="List of permission strings"
+ )
class RoleCreate(RoleBase):
"""Schema for creating a role."""
+
pass
class RoleUpdate(BaseModel):
"""Schema for updating a role."""
+
name: Optional[str] = Field(None, min_length=1, max_length=100)
permissions: Optional[List[str]] = None
class RoleResponse(RoleBase):
"""Schema for role response."""
+
id: int
vendor_id: int
created_at: datetime
@@ -48,6 +55,7 @@ class RoleResponse(RoleBase):
class RoleListResponse(BaseModel):
"""Schema for role list response."""
+
roles: List[RoleResponse]
total: int
@@ -56,8 +64,10 @@ class RoleListResponse(BaseModel):
# Team Member Schemas
# ============================================================================
+
class TeamMemberBase(BaseModel):
"""Base team member schema."""
+
email: EmailStr = Field(..., description="Team member email address")
first_name: Optional[str] = Field(None, max_length=100)
last_name: Optional[str] = Field(None, max_length=100)
@@ -65,30 +75,34 @@ class TeamMemberBase(BaseModel):
class TeamMemberInvite(TeamMemberBase):
"""Schema for inviting a team member."""
- role_id: Optional[int] = Field(None, description="Role ID to assign (for preset roles)")
- role_name: Optional[str] = Field(None, description="Role name (manager, staff, support, etc.)")
+
+ role_id: Optional[int] = Field(
+ None, description="Role ID to assign (for preset roles)"
+ )
+ role_name: Optional[str] = Field(
+ None, description="Role name (manager, staff, support, etc.)"
+ )
custom_permissions: Optional[List[str]] = Field(
- None,
- description="Custom permissions (overrides role preset)"
+ None, description="Custom permissions (overrides role preset)"
)
- @field_validator('role_name')
+ @field_validator("role_name")
def validate_role_name(cls, v):
"""Validate role name is in allowed presets."""
if v is not None:
- allowed_roles = ['manager', 'staff', 'support', 'viewer', 'marketing']
+ allowed_roles = ["manager", "staff", "support", "viewer", "marketing"]
if v.lower() not in allowed_roles:
raise ValueError(
f"Role name must be one of: {', '.join(allowed_roles)}"
)
return v.lower() if v else v
- @field_validator('custom_permissions')
+ @field_validator("custom_permissions")
def validate_custom_permissions(cls, v, values):
"""Ensure either role_id/role_name OR custom_permissions is provided."""
if v is not None and len(v) > 0:
# If custom permissions provided, role_name should be provided too
- if 'role_name' not in values or not values['role_name']:
+ if "role_name" not in values or not values["role_name"]:
raise ValueError(
"role_name is required when providing custom_permissions"
)
@@ -97,12 +111,14 @@ class TeamMemberInvite(TeamMemberBase):
class TeamMemberUpdate(BaseModel):
"""Schema for updating a team member."""
+
role_id: Optional[int] = Field(None, description="New role ID")
is_active: Optional[bool] = Field(None, description="Active status")
class TeamMemberResponse(BaseModel):
"""Schema for team member response."""
+
id: int = Field(..., description="User ID")
email: EmailStr
username: str
@@ -112,15 +128,18 @@ class TeamMemberResponse(BaseModel):
user_type: str = Field(..., description="'owner' or 'member'")
role_name: str = Field(..., description="Role name")
role_id: Optional[int]
- permissions: List[str] = Field(default_factory=list, description="User's permissions")
+ permissions: List[str] = Field(
+ default_factory=list, description="User's permissions"
+ )
is_active: bool
is_owner: bool
invitation_pending: bool = Field(
- default=False,
- description="True if invitation not yet accepted"
+ default=False, description="True if invitation not yet accepted"
)
invited_at: Optional[datetime] = Field(None, description="When invitation was sent")
- accepted_at: Optional[datetime] = Field(None, description="When invitation was accepted")
+ accepted_at: Optional[datetime] = Field(
+ None, description="When invitation was accepted"
+ )
joined_at: datetime = Field(..., description="When user joined vendor")
class Config:
@@ -129,6 +148,7 @@ class TeamMemberResponse(BaseModel):
class TeamMemberListResponse(BaseModel):
"""Schema for team member list response."""
+
members: List[TeamMemberResponse]
total: int
active_count: int
@@ -139,19 +159,20 @@ class TeamMemberListResponse(BaseModel):
# Invitation Schemas
# ============================================================================
+
class InvitationAccept(BaseModel):
"""Schema for accepting a team invitation."""
- invitation_token: str = Field(..., min_length=32, description="Invitation token from email")
+
+ invitation_token: str = Field(
+ ..., min_length=32, description="Invitation token from email"
+ )
password: str = Field(
- ...,
- min_length=8,
- max_length=128,
- description="Password for new account"
+ ..., min_length=8, max_length=128, description="Password for new account"
)
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
- @field_validator('password')
+ @field_validator("password")
def validate_password_strength(cls, v):
"""Validate password meets minimum requirements."""
if len(v) < 8:
@@ -172,18 +193,19 @@ class InvitationAccept(BaseModel):
class InvitationResponse(BaseModel):
"""Schema for invitation response."""
+
message: str
email: EmailStr
role: str
invitation_token: Optional[str] = Field(
- None,
- description="Token (only returned in dev/test environments)"
+ None, description="Token (only returned in dev/test environments)"
)
invitation_sent: bool = Field(default=True)
class InvitationAcceptResponse(BaseModel):
"""Schema for invitation acceptance response."""
+
message: str
vendor: dict = Field(..., description="Vendor information")
user: dict = Field(..., description="User information")
@@ -194,8 +216,10 @@ class InvitationAcceptResponse(BaseModel):
# Team Statistics Schema
# ============================================================================
+
class TeamStatistics(BaseModel):
"""Schema for team statistics."""
+
total_members: int
active_members: int
inactive_members: int
@@ -203,8 +227,7 @@ class TeamStatistics(BaseModel):
owners: int
team_members: int
roles_breakdown: dict = Field(
- default_factory=dict,
- description="Count of members per role"
+ default_factory=dict, description="Count of members per role"
)
@@ -212,13 +235,18 @@ class TeamStatistics(BaseModel):
# Bulk Operations Schemas
# ============================================================================
+
class BulkRemoveRequest(BaseModel):
"""Schema for bulk removing team members."""
- user_ids: List[int] = Field(..., min_items=1, description="List of user IDs to remove")
+
+ user_ids: List[int] = Field(
+ ..., min_items=1, description="List of user IDs to remove"
+ )
class BulkRemoveResponse(BaseModel):
"""Schema for bulk remove response."""
+
success_count: int
failed_count: int
errors: List[dict] = Field(default_factory=list)
@@ -228,21 +256,27 @@ class BulkRemoveResponse(BaseModel):
# Permission Check Schemas
# ============================================================================
+
class PermissionCheckRequest(BaseModel):
"""Schema for checking permissions."""
+
permissions: List[str] = Field(..., min_items=1, description="Permissions to check")
class PermissionCheckResponse(BaseModel):
"""Schema for permission check response."""
+
has_all: bool = Field(..., description="True if user has all permissions")
has_any: bool = Field(..., description="True if user has any permission")
granted: List[str] = Field(default_factory=list, description="Permissions user has")
- denied: List[str] = Field(default_factory=list, description="Permissions user lacks")
+ denied: List[str] = Field(
+ default_factory=list, description="Permissions user lacks"
+ )
class UserPermissionsResponse(BaseModel):
"""Schema for user's permissions response."""
+
permissions: List[str] = Field(default_factory=list)
permission_count: int
is_owner: bool
@@ -253,8 +287,10 @@ class UserPermissionsResponse(BaseModel):
# Error Response Schema
# ============================================================================
+
class TeamErrorResponse(BaseModel):
"""Schema for team operation errors."""
+
error_code: str
message: str
details: Optional[dict] = None
diff --git a/models/schema/vendor.py b/models/schema/vendor.py
index f6080a69..e622f407 100644
--- a/models/schema/vendor.py
+++ b/models/schema/vendor.py
@@ -15,7 +15,8 @@ Schemas include:
import re
from datetime import datetime
-from typing import Dict, List, Optional, Any
+from typing import Any, Dict, List, Optional
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -27,32 +28,26 @@ class VendorCreate(BaseModel):
...,
description="Unique vendor identifier (e.g., TECHSTORE)",
min_length=2,
- max_length=50
+ max_length=50,
)
subdomain: str = Field(
- ...,
- description="Unique subdomain for the vendor",
- min_length=2,
- max_length=100
+ ..., description="Unique subdomain for the vendor", min_length=2, max_length=100
)
name: str = Field(
- ...,
- description="Display name of the vendor",
- min_length=2,
- max_length=255
+ ..., description="Display name of the vendor", min_length=2, max_length=255
)
description: Optional[str] = Field(None, description="Vendor description")
# Owner Information (Creates User Account)
owner_email: str = Field(
...,
- description="Email for the vendor owner (used for login and authentication)"
+ description="Email for the vendor owner (used for login and authentication)",
)
# Business Contact Information (Vendor Fields)
contact_email: Optional[str] = Field(
None,
- description="Public business contact email (defaults to owner_email if not provided)"
+ description="Public business contact email (defaults to owner_email if not provided)",
)
contact_phone: Optional[str] = Field(None, description="Contact phone number")
website: Optional[str] = Field(None, description="Website URL")
@@ -78,8 +73,10 @@ class VendorCreate(BaseModel):
@classmethod
def validate_subdomain(cls, v):
"""Validate subdomain format: lowercase alphanumeric with hyphens."""
- if v and not re.match(r'^[a-z0-9][a-z0-9-]*[a-z0-9]$', v):
- raise ValueError("Subdomain must contain only lowercase letters, numbers, and hyphens")
+ if v and not re.match(r"^[a-z0-9][a-z0-9-]*[a-z0-9]$", v):
+ raise ValueError(
+ "Subdomain must contain only lowercase letters, numbers, and hyphens"
+ )
return v.lower() if v else v
@field_validator("vendor_code")
@@ -104,8 +101,7 @@ class VendorUpdate(BaseModel):
# Business Contact Information (Vendor Fields)
contact_email: Optional[str] = Field(
- None,
- description="Public business contact email"
+ None, description="Public business contact email"
)
contact_phone: Optional[str] = None
website: Optional[str] = None
@@ -142,6 +138,7 @@ class VendorUpdate(BaseModel):
class VendorResponse(BaseModel):
"""Standard schema for vendor response data."""
+
model_config = ConfigDict(from_attributes=True)
id: int
@@ -184,13 +181,9 @@ class VendorDetailResponse(VendorResponse):
"""
owner_email: str = Field(
- ...,
- description="Email of the vendor owner (for login/authentication)"
- )
- owner_username: str = Field(
- ...,
- description="Username of the vendor owner"
+ ..., description="Email of the vendor owner (for login/authentication)"
)
+ owner_username: str = Field(..., description="Username of the vendor owner")
class VendorCreateResponse(VendorDetailResponse):
@@ -201,17 +194,14 @@ class VendorCreateResponse(VendorDetailResponse):
"""
temporary_password: str = Field(
- ...,
- description="Temporary password for owner (SHOWN ONLY ONCE)"
- )
- login_url: Optional[str] = Field(
- None,
- description="URL for vendor owner to login"
+ ..., description="Temporary password for owner (SHOWN ONLY ONCE)"
)
+ login_url: Optional[str] = Field(None, description="URL for vendor owner to login")
class VendorListResponse(BaseModel):
"""Schema for paginated vendor list."""
+
vendors: List[VendorResponse]
total: int
skip: int
@@ -220,6 +210,7 @@ class VendorListResponse(BaseModel):
class VendorSummary(BaseModel):
"""Lightweight vendor summary for dropdowns and quick references."""
+
model_config = ConfigDict(from_attributes=True)
id: int
@@ -239,20 +230,17 @@ class VendorTransferOwnership(BaseModel):
"""
new_owner_user_id: int = Field(
- ...,
- description="ID of the user who will become the new owner",
- gt=0
+ ..., description="ID of the user who will become the new owner", gt=0
)
confirm_transfer: bool = Field(
- ...,
- description="Must be true to confirm ownership transfer"
+ ..., description="Must be true to confirm ownership transfer"
)
transfer_reason: Optional[str] = Field(
None,
max_length=500,
- description="Reason for ownership transfer (for audit logs)"
+ description="Reason for ownership transfer (for audit logs)",
)
@field_validator("confirm_transfer")
@@ -273,12 +261,10 @@ class VendorTransferOwnershipResponse(BaseModel):
vendor_name: str
old_owner: Dict[str, Any] = Field(
- ...,
- description="Information about the previous owner"
+ ..., description="Information about the previous owner"
)
new_owner: Dict[str, Any] = Field(
- ...,
- description="Information about the new owner"
+ ..., description="Information about the new owner"
)
transferred_at: datetime
diff --git a/models/schema/vendor_domain.py b/models/schema/vendor_domain.py
index aca0eab3..a99b47bb 100644
--- a/models/schema/vendor_domain.py
+++ b/models/schema/vendor_domain.py
@@ -12,7 +12,8 @@ Schemas include:
import re
from datetime import datetime
-from typing import List, Optional, Dict
+from typing import Dict, List, Optional
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -23,14 +24,13 @@ class VendorDomainCreate(BaseModel):
...,
description="Custom domain (e.g., myshop.com or shop.mybrand.com)",
min_length=3,
- max_length=255
+ max_length=255,
)
is_primary: bool = Field(
- default=False,
- description="Set as primary domain for the vendor"
+ default=False, description="Set as primary domain for the vendor"
)
- @field_validator('domain')
+ @field_validator("domain")
@classmethod
def validate_domain(cls, v: str) -> str:
"""Validate and normalize domain."""
@@ -44,20 +44,22 @@ class VendorDomainCreate(BaseModel):
domain = domain.lower().strip()
# Basic validation
- if not domain or '/' in domain:
+ if not domain or "/" in domain:
raise ValueError("Invalid domain format")
- if '.' not in domain:
+ if "." not in domain:
raise ValueError("Domain must have at least one dot")
# Check for reserved subdomains
- reserved = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail']
- first_part = domain.split('.')[0]
+ reserved = ["www", "admin", "api", "mail", "smtp", "ftp", "cpanel", "webmail"]
+ first_part = domain.split(".")[0]
if first_part in reserved:
- raise ValueError(f"Domain cannot start with reserved subdomain: {first_part}")
+ raise ValueError(
+ f"Domain cannot start with reserved subdomain: {first_part}"
+ )
# Validate domain format (basic regex)
- domain_pattern = r'^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$'
+ domain_pattern = r"^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$"
if not re.match(domain_pattern, domain):
raise ValueError("Invalid domain format")
@@ -75,6 +77,7 @@ class VendorDomainUpdate(BaseModel):
class VendorDomainResponse(BaseModel):
"""Standard schema for vendor domain response."""
+
model_config = ConfigDict(from_attributes=True)
id: int
diff --git a/models/schema/vendor_theme.py b/models/schema/vendor_theme.py
index 5697ed18..9ba482fc 100644
--- a/models/schema/vendor_theme.py
+++ b/models/schema/vendor_theme.py
@@ -3,12 +3,14 @@
Pydantic schemas for vendor theme operations.
"""
-from typing import Dict, Optional, List
+from typing import Dict, List, Optional
+
from pydantic import BaseModel, Field
class VendorThemeColors(BaseModel):
"""Color scheme for vendor theme."""
+
primary: Optional[str] = Field(None, description="Primary brand color")
secondary: Optional[str] = Field(None, description="Secondary color")
accent: Optional[str] = Field(None, description="Accent/CTA color")
@@ -19,12 +21,14 @@ class VendorThemeColors(BaseModel):
class VendorThemeFonts(BaseModel):
"""Typography settings for vendor theme."""
+
heading: Optional[str] = Field(None, description="Font for headings")
body: Optional[str] = Field(None, description="Font for body text")
class VendorThemeBranding(BaseModel):
"""Branding assets for vendor theme."""
+
logo: Optional[str] = Field(None, description="Logo URL")
logo_dark: Optional[str] = Field(None, description="Dark mode logo URL")
favicon: Optional[str] = Field(None, description="Favicon URL")
@@ -33,36 +37,54 @@ class VendorThemeBranding(BaseModel):
class VendorThemeLayout(BaseModel):
"""Layout settings for vendor theme."""
- style: Optional[str] = Field(None, description="Product layout style (grid, list, masonry)")
- header: Optional[str] = Field(None, description="Header style (fixed, static, transparent)")
- product_card: Optional[str] = Field(None, description="Product card style (modern, classic, minimal)")
+
+ style: Optional[str] = Field(
+ None, description="Product layout style (grid, list, masonry)"
+ )
+ header: Optional[str] = Field(
+ None, description="Header style (fixed, static, transparent)"
+ )
+ product_card: Optional[str] = Field(
+ None, description="Product card style (modern, classic, minimal)"
+ )
class VendorThemeUpdate(BaseModel):
"""Schema for updating vendor theme (partial updates allowed)."""
+
theme_name: Optional[str] = Field(None, description="Theme preset name")
colors: Optional[Dict[str, str]] = Field(None, description="Color scheme")
fonts: Optional[Dict[str, str]] = Field(None, description="Font settings")
- branding: Optional[Dict[str, Optional[str]]] = Field(None, description="Branding assets")
+ branding: Optional[Dict[str, Optional[str]]] = Field(
+ None, description="Branding assets"
+ )
layout: Optional[Dict[str, str]] = Field(None, description="Layout settings")
custom_css: Optional[str] = Field(None, description="Custom CSS rules")
- social_links: Optional[Dict[str, str]] = Field(None, description="Social media links")
+ social_links: Optional[Dict[str, str]] = Field(
+ None, description="Social media links"
+ )
class VendorThemeResponse(BaseModel):
"""Schema for vendor theme response."""
+
theme_name: str = Field(..., description="Theme name")
colors: Dict[str, str] = Field(..., description="Color scheme")
fonts: Dict[str, str] = Field(..., description="Font settings")
branding: Dict[str, Optional[str]] = Field(..., description="Branding assets")
layout: Dict[str, str] = Field(..., description="Layout settings")
- social_links: Optional[Dict[str, str]] = Field(default_factory=dict, description="Social links")
+ social_links: Optional[Dict[str, str]] = Field(
+ default_factory=dict, description="Social links"
+ )
custom_css: Optional[str] = Field(None, description="Custom CSS")
- css_variables: Optional[Dict[str, str]] = Field(None, description="CSS custom properties")
+ css_variables: Optional[Dict[str, str]] = Field(
+ None, description="CSS custom properties"
+ )
class ThemePresetPreview(BaseModel):
"""Preview information for a theme preset."""
+
name: str = Field(..., description="Preset name")
description: str = Field(..., description="Preset description")
primary_color: str = Field(..., description="Primary color")
@@ -75,10 +97,12 @@ class ThemePresetPreview(BaseModel):
class ThemePresetResponse(BaseModel):
"""Response after applying a preset."""
+
message: str = Field(..., description="Success message")
theme: VendorThemeResponse = Field(..., description="Applied theme")
class ThemePresetListResponse(BaseModel):
"""List of available theme presets."""
+
presets: List[ThemePresetPreview] = Field(..., description="Available presets")
diff --git a/scripts/backup_database.py b/scripts/backup_database.py
index 52fbfcd1..41385cd5 100644
--- a/scripts/backup_database.py
+++ b/scripts/backup_database.py
@@ -2,9 +2,9 @@
"""Database backup utility that uses project configuration."""
import os
-import sys
import shutil
import sqlite3
+import sys
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse
@@ -19,10 +19,10 @@ def get_database_path():
"""Extract database path from DATABASE_URL."""
db_url = settings.database_url
- if db_url.startswith('sqlite:///'):
+ if db_url.startswith("sqlite:///"):
# Remove sqlite:/// prefix and handle relative paths
- db_path = db_url.replace('sqlite:///', '')
- if db_path.startswith('./'):
+ db_path = db_url.replace("sqlite:///", "")
+ if db_path.startswith("./"):
db_path = db_path[2:] # Remove ./ prefix
return db_path
else:
@@ -76,7 +76,7 @@ def backup_database():
print("[BACKUP] Starting database backup...")
print(f"[INFO] Database URL: {settings.database_url}")
- if settings.database_url.startswith('sqlite'):
+ if settings.database_url.startswith("sqlite"):
return backup_sqlite_database()
else:
print("[INFO] For PostgreSQL databases, use pg_dump:")
diff --git a/scripts/create_default_content_pages.py b/scripts/create_default_content_pages.py
index 927559c3..14fc665d 100755
--- a/scripts/create_default_content_pages.py
+++ b/scripts/create_default_content_pages.py
@@ -26,20 +26,19 @@ Usage:
"""
import sys
-from pathlib import Path
from datetime import datetime, timezone
+from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from sqlalchemy.orm import Session
from sqlalchemy import select
+from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from models.database.content_page import ContentPage
-
# ============================================================================
# DEFAULT PAGE CONTENT
# ============================================================================
@@ -458,6 +457,7 @@ DEFAULT_PAGES = [
# SCRIPT FUNCTIONS
# ============================================================================
+
def create_default_pages(db: Session) -> None:
"""
Create default platform content pages.
@@ -475,13 +475,14 @@ def create_default_pages(db: Session) -> None:
# Check if page already exists (platform default with this slug)
existing = db.execute(
select(ContentPage).where(
- ContentPage.vendor_id == None,
- ContentPage.slug == page_data["slug"]
+ ContentPage.vendor_id == None, ContentPage.slug == page_data["slug"]
)
).scalar_one_or_none()
if existing:
- print(f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists")
+ print(
+ f" ⏭️ Skipped: {page_data['title']} (/{page_data['slug']}) - already exists"
+ )
skipped_count += 1
continue
@@ -519,7 +520,9 @@ def create_default_pages(db: Session) -> None:
if created_count > 0:
print("✅ Default platform content pages created successfully!\n")
print("Next steps:")
- print(" 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms")
+ print(
+ " 1. View pages at: /about, /contact, /faq, /shipping, /returns, /privacy, /terms"
+ )
print(" 2. Vendors can override these pages through the vendor dashboard")
print(" 3. Edit platform defaults through the admin panel\n")
else:
@@ -530,6 +533,7 @@ def create_default_pages(db: Session) -> None:
# MAIN EXECUTION
# ============================================================================
+
def main():
"""Main execution function."""
print("\n🚀 Starting Default Content Pages Creation Script...\n")
diff --git a/scripts/create_inventory.py b/scripts/create_inventory.py
index a69a341e..d1bc55a0 100755
--- a/scripts/create_inventory.py
+++ b/scripts/create_inventory.py
@@ -12,17 +12,18 @@ from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
-import sqlite3
-from datetime import datetime, UTC
import os
+import sqlite3
+from datetime import UTC, datetime
+
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Get database URL
-database_url = os.getenv('DATABASE_URL', 'wizamart.db')
-db_path = database_url.replace('sqlite:///', '')
+database_url = os.getenv("DATABASE_URL", "wizamart.db")
+db_path = database_url.replace("sqlite:///", "")
print(f"📦 Creating inventory entries in {db_path}...")
@@ -30,12 +31,14 @@ conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get products without inventory
-cursor.execute("""
+cursor.execute(
+ """
SELECT p.id, p.vendor_id, p.product_id
FROM products p
LEFT JOIN inventory i ON p.id = i.product_id
WHERE i.id IS NULL
-""")
+"""
+)
products_without_inventory = cursor.fetchall()
if not products_without_inventory:
@@ -47,7 +50,8 @@ print(f"📦 Creating inventory for {len(products_without_inventory)} products..
# Create inventory entries
for product_id, vendor_id, sku in products_without_inventory:
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO inventory (
vendor_id,
product_id,
@@ -57,15 +61,17 @@ for product_id, vendor_id, sku in products_without_inventory:
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
- """, (
- vendor_id,
- product_id,
- 'Main Warehouse',
- 100, # Total quantity
- 0, # Reserved quantity
- datetime.now(UTC),
- datetime.now(UTC)
- ))
+ """,
+ (
+ vendor_id,
+ product_id,
+ "Main Warehouse",
+ 100, # Total quantity
+ 0, # Reserved quantity
+ datetime.now(UTC),
+ datetime.now(UTC),
+ ),
+ )
conn.commit()
@@ -76,7 +82,9 @@ cursor.execute("SELECT COUNT(*) FROM inventory")
total_count = cursor.fetchone()[0]
print(f"\n📊 Total inventory entries: {total_count}")
-cursor.execute("SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5")
+cursor.execute(
+ "SELECT product_id, location, quantity, reserved_quantity FROM inventory LIMIT 5"
+)
print("\n📦 Sample inventory:")
for row in cursor.fetchall():
available = row[2] - row[3]
diff --git a/scripts/create_landing_page.py b/scripts/create_landing_page.py
index e5d463fc..203c33f3 100755
--- a/scripts/create_landing_page.py
+++ b/scripts/create_landing_page.py
@@ -12,18 +12,20 @@ from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
-from sqlalchemy.orm import Session
-from app.core.database import SessionLocal
-from models.database.vendor import Vendor
-from models.database.content_page import ContentPage
from datetime import datetime, timezone
+from sqlalchemy.orm import Session
+
+from app.core.database import SessionLocal
+from models.database.content_page import ContentPage
+from models.database.vendor import Vendor
+
def create_landing_page(
vendor_subdomain: str,
template: str = "default",
title: str = None,
- content: str = None
+ content: str = None,
):
"""
Create a landing page for a vendor.
@@ -38,9 +40,7 @@ def create_landing_page(
try:
# Find vendor
- vendor = db.query(Vendor).filter(
- Vendor.subdomain == vendor_subdomain
- ).first()
+ vendor = db.query(Vendor).filter(Vendor.subdomain == vendor_subdomain).first()
if not vendor:
print(f"❌ Vendor '{vendor_subdomain}' not found!")
@@ -49,10 +49,11 @@ def create_landing_page(
print(f"✅ Found vendor: {vendor.name} (ID: {vendor.id})")
# Check if landing page already exists
- existing = db.query(ContentPage).filter(
- ContentPage.vendor_id == vendor.id,
- ContentPage.slug == "landing"
- ).first()
+ existing = (
+ db.query(ContentPage)
+ .filter(ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing")
+ .first()
+ )
if existing:
print(f"⚠️ Landing page already exists (ID: {existing.id})")
@@ -74,7 +75,8 @@ def create_landing_page(
vendor_id=vendor.id,
slug="landing",
title=title or f"Welcome to {vendor.name}",
- content=content or f"""
+ content=content
+ or f"""
About {vendor.name}
{vendor.description or 'Your trusted shopping destination for quality products.'}
@@ -96,7 +98,7 @@ def create_landing_page(
published_at=datetime.now(timezone.utc),
show_in_footer=False,
show_in_header=False,
- display_order=0
+ display_order=0,
)
db.add(landing_page)
@@ -141,10 +143,13 @@ def list_vendors():
print(f" Code: {vendor.vendor_code}")
# Check if has landing page
- landing = db.query(ContentPage).filter(
- ContentPage.vendor_id == vendor.id,
- ContentPage.slug == "landing"
- ).first()
+ landing = (
+ db.query(ContentPage)
+ .filter(
+ ContentPage.vendor_id == vendor.id, ContentPage.slug == "landing"
+ )
+ .first()
+ )
if landing:
print(f" Landing Page: ✅ ({landing.template})")
@@ -165,7 +170,7 @@ def show_templates():
("default", "Clean professional layout with 3-column quick links"),
("minimal", "Ultra-simple centered design with single CTA"),
("modern", "Full-screen hero with animations and features"),
- ("full", "Maximum features with split-screen hero and stats")
+ ("full", "Maximum features with split-screen hero and stats"),
]
for name, desc in templates:
@@ -209,7 +214,7 @@ if __name__ == "__main__":
success = create_landing_page(
vendor_subdomain=vendor_subdomain,
template=template,
- title=title if title else None
+ title=title if title else None,
)
if success:
diff --git a/scripts/create_platform_pages.py b/scripts/create_platform_pages.py
index e757beb4..cd694bad 100755
--- a/scripts/create_platform_pages.py
+++ b/scripts/create_platform_pages.py
@@ -46,16 +46,22 @@ def create_platform_pages():
print("1. Creating Platform Homepage...")
# Check if already exists
- existing = db.query(ContentPage).filter_by(vendor_id=None, slug="platform_homepage").first()
+ existing = (
+ db.query(ContentPage)
+ .filter_by(vendor_id=None, slug="platform_homepage")
+ .first()
+ )
if existing:
- print(f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})")
+ print(
+ f" ⚠️ Skipped: Platform Homepage - already exists (ID: {existing.id})"
+ )
else:
try:
homepage = content_page_service.create_page(
- db,
- slug="platform_homepage",
- title="Welcome to Our Multi-Vendor Marketplace",
- content="""
+ db,
+ slug="platform_homepage",
+ title="Welcome to Our Multi-Vendor Marketplace",
+ content="""
Connect vendors with customers worldwide. Build your online store and reach millions of shoppers.
@@ -64,14 +70,14 @@ def create_platform_pages():
with minimal effort and maximum impact.
""",
- template="modern", # Uses platform/homepage-modern.html
- vendor_id=None, # Platform-level page
- is_published=True,
- show_in_header=False, # Homepage is not in menu (it's the root)
- show_in_footer=False,
- display_order=0,
+ template="modern", # Uses platform/homepage-modern.html
+ vendor_id=None, # Platform-level page
+ is_published=True,
+ show_in_header=False, # Homepage is not in menu (it's the root)
+ show_in_footer=False,
+ display_order=0,
meta_description="Leading multi-vendor marketplace platform. Connect with thousands of vendors and discover millions of products.",
- meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform"
+ meta_keywords="marketplace, multi-vendor, e-commerce, online shopping, platform",
)
print(f" ✅ Created: {homepage.title} (/{homepage.slug})")
except Exception as e:
@@ -88,10 +94,10 @@ def create_platform_pages():
else:
try:
about = content_page_service.create_page(
- db,
- slug="about",
- title="About Us",
- content="""
+ db,
+ slug="about",
+ title="About Us",
+ content="""
Our Mission
We're on a mission to democratize e-commerce by providing powerful,
@@ -121,13 +127,13 @@ def create_platform_pages():
Excellence: We strive for the highest quality in everything we do
""",
- vendor_id=None,
- is_published=True,
- show_in_header=True, # Show in header navigation
- show_in_footer=True, # Show in footer
- display_order=1,
+ vendor_id=None,
+ is_published=True,
+ show_in_header=True, # Show in header navigation
+ show_in_footer=True, # Show in footer
+ display_order=1,
meta_description="Learn about our mission to democratize e-commerce and empower entrepreneurs worldwide.",
- meta_keywords="about us, mission, vision, values, company"
+ meta_keywords="about us, mission, vision, values, company",
)
print(f" ✅ Created: {about.title} (/{about.slug})")
except Exception as e:
@@ -144,10 +150,10 @@ def create_platform_pages():
else:
try:
faq = content_page_service.create_page(
- db,
- slug="faq",
- title="Frequently Asked Questions",
- content="""
+ db,
+ slug="faq",
+ title="Frequently Asked Questions",
+ content="""
Getting Started
How do I create a vendor account?
@@ -204,13 +210,13 @@ def create_platform_pages():
and marketing tools.
""",
- vendor_id=None,
- is_published=True,
- show_in_header=True, # Show in header navigation
- show_in_footer=True,
- display_order=2,
+ vendor_id=None,
+ is_published=True,
+ show_in_header=True, # Show in header navigation
+ show_in_footer=True,
+ display_order=2,
meta_description="Find answers to common questions about our marketplace platform.",
- meta_keywords="faq, frequently asked questions, help, support"
+ meta_keywords="faq, frequently asked questions, help, support",
)
print(f" ✅ Created: {faq.title} (/{faq.slug})")
except Exception as e:
@@ -221,16 +227,18 @@ def create_platform_pages():
# ========================================================================
print("4. Creating Contact Us page...")
- existing = db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first()
+ existing = (
+ db.query(ContentPage).filter_by(vendor_id=None, slug="contact").first()
+ )
if existing:
print(f" ⚠️ Skipped: Contact Us - already exists (ID: {existing.id})")
else:
try:
contact = content_page_service.create_page(
- db,
- slug="contact",
- title="Contact Us",
- content="""
+ db,
+ slug="contact",
+ title="Contact Us",
+ content="""
Get in Touch
We'd love to hear from you! Whether you have questions about our platform,
@@ -271,13 +279,13 @@ def create_platform_pages():
Email: press@marketplace.com
""",
- vendor_id=None,
- is_published=True,
- show_in_header=True, # Show in header navigation
- show_in_footer=True,
- display_order=3,
+ vendor_id=None,
+ is_published=True,
+ show_in_header=True, # Show in header navigation
+ show_in_footer=True,
+ display_order=3,
meta_description="Get in touch with our team. We're here to help you succeed.",
- meta_keywords="contact, support, email, phone, address"
+ meta_keywords="contact, support, email, phone, address",
)
print(f" ✅ Created: {contact.title} (/{contact.slug})")
except Exception as e:
@@ -290,14 +298,16 @@ def create_platform_pages():
existing = db.query(ContentPage).filter_by(vendor_id=None, slug="terms").first()
if existing:
- print(f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})")
+ print(
+ f" ⚠️ Skipped: Terms of Service - already exists (ID: {existing.id})"
+ )
else:
try:
terms = content_page_service.create_page(
- db,
- slug="terms",
- title="Terms of Service",
- content="""
+ db,
+ slug="terms",
+ title="Terms of Service",
+ content="""
Last updated: January 1, 2024
1. Acceptance of Terms
@@ -361,13 +371,13 @@ def create_platform_pages():
legal@marketplace.com.
""",
- vendor_id=None,
- is_published=True,
- show_in_header=False, # Too legal for header
- show_in_footer=True, # Show in footer
- display_order=10,
+ vendor_id=None,
+ is_published=True,
+ show_in_header=False, # Too legal for header
+ show_in_footer=True, # Show in footer
+ display_order=10,
meta_description="Read our terms of service and platform usage policies.",
- meta_keywords="terms of service, terms, legal, policy, agreement"
+ meta_keywords="terms of service, terms, legal, policy, agreement",
)
print(f" ✅ Created: {terms.title} (/{terms.slug})")
except Exception as e:
@@ -378,16 +388,18 @@ def create_platform_pages():
# ========================================================================
print("6. Creating Privacy Policy page...")
- existing = db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first()
+ existing = (
+ db.query(ContentPage).filter_by(vendor_id=None, slug="privacy").first()
+ )
if existing:
print(f" ⚠️ Skipped: Privacy Policy - already exists (ID: {existing.id})")
else:
try:
privacy = content_page_service.create_page(
- db,
- slug="privacy",
- title="Privacy Policy",
- content="""
+ db,
+ slug="privacy",
+ title="Privacy Policy",
+ content="""
Last updated: January 1, 2024
1. Information We Collect
@@ -453,13 +465,13 @@ def create_platform_pages():
privacy@marketplace.com.
""",
- vendor_id=None,
- is_published=True,
- show_in_header=False, # Too legal for header
- show_in_footer=True, # Show in footer
- display_order=11,
+ vendor_id=None,
+ is_published=True,
+ show_in_header=False, # Too legal for header
+ show_in_footer=True, # Show in footer
+ display_order=11,
meta_description="Learn how we collect, use, and protect your personal information.",
- meta_keywords="privacy policy, privacy, data protection, gdpr, personal information"
+ meta_keywords="privacy policy, privacy, data protection, gdpr, personal information",
)
print(f" ✅ Created: {privacy.title} (/{privacy.slug})")
except Exception as e:
diff --git a/scripts/init_production.py b/scripts/init_production.py
index 6a86f412..882d8f70 100644
--- a/scripts/init_production.py
+++ b/scripts/init_production.py
@@ -17,29 +17,30 @@ This script is idempotent - safe to run multiple times.
"""
import sys
-from pathlib import Path
from datetime import datetime, timezone
+from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from sqlalchemy.orm import Session
from sqlalchemy import select
+from sqlalchemy.orm import Session
+from app.core.config import (print_environment_info, settings,
+ validate_production_settings)
from app.core.database import SessionLocal
-from app.core.config import settings, print_environment_info, validate_production_settings
-from app.core.environment import is_production, get_environment
-from models.database.user import User
-from models.database.admin import AdminSetting
-from middleware.auth import AuthManager
+from app.core.environment import get_environment, is_production
from app.core.permissions import PermissionGroups
-
+from middleware.auth import AuthManager
+from models.database.admin import AdminSetting
+from models.database.user import User
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
+
def print_header(text: str):
"""Print formatted header."""
print("\n" + "=" * 70)
@@ -71,6 +72,7 @@ def print_error(text: str):
# INITIALIZATION FUNCTIONS
# =============================================================================
+
def create_admin_user(db: Session, auth_manager: AuthManager) -> User:
"""Create or get the platform admin user."""
@@ -206,9 +208,7 @@ def create_admin_settings(db: Session) -> int:
for setting_data in default_settings:
# Check if setting already exists
existing = db.execute(
- select(AdminSetting).where(
- AdminSetting.key == setting_data["key"]
- )
+ select(AdminSetting).where(AdminSetting.key == setting_data["key"])
).scalar_one_or_none()
if not existing:
@@ -281,6 +281,7 @@ def verify_rbac_schema(db: Session) -> bool:
# MAIN INITIALIZATION
# =============================================================================
+
def initialize_production(db: Session, auth_manager: AuthManager):
"""Initialize production database with essential data."""
@@ -362,6 +363,7 @@ def print_summary(db: Session):
# MAIN ENTRY POINT
# =============================================================================
+
def main():
"""Main entry point."""
@@ -405,6 +407,7 @@ def main():
print_header("❌ INITIALIZATION FAILED")
print(f"\nError: {e}\n")
import traceback
+
traceback.print_exc()
sys.exit(1)
diff --git a/scripts/route_diagnostics.py b/scripts/route_diagnostics.py
index 0616e221..febda516 100644
--- a/scripts/route_diagnostics.py
+++ b/scripts/route_diagnostics.py
@@ -7,7 +7,7 @@ Usage: python route_diagnostics.py
"""
import sys
-from typing import List, Dict
+from typing import Dict, List
def check_route_order():
@@ -29,23 +29,23 @@ def check_route_order():
html_routes = []
for route in routes:
- if hasattr(route, 'path'):
+ if hasattr(route, "path"):
path = route.path
- methods = getattr(route, 'methods', set())
+ methods = getattr(route, "methods", set())
# Determine if JSON or HTML based on common patterns
- if 'login' in path or 'dashboard' in path or 'products' in path:
+ if "login" in path or "dashboard" in path or "products" in path:
# Check response class if available
- if hasattr(route, 'response_class'):
+ if hasattr(route, "response_class"):
response_class = str(route.response_class)
- if 'HTML' in response_class:
+ if "HTML" in response_class:
html_routes.append((path, methods))
else:
json_routes.append((path, methods))
else:
# Check endpoint name/tags
- endpoint = getattr(route, 'endpoint', None)
- if endpoint and 'page' in endpoint.__name__.lower():
+ endpoint = getattr(route, "endpoint", None)
+ if endpoint and "page" in endpoint.__name__.lower():
html_routes.append((path, methods))
else:
json_routes.append((path, methods))
@@ -57,8 +57,10 @@ def check_route_order():
# Check for specific vendor info route
vendor_info_found = False
for route in routes:
- if hasattr(route, 'path'):
- if '/{vendor_code}' == route.path and 'GET' in getattr(route, 'methods', set()):
+ if hasattr(route, "path"):
+ if "/{vendor_code}" == route.path and "GET" in getattr(
+ route, "methods", set()
+ ):
vendor_info_found = True
print("✅ Found vendor info endpoint: GET /{vendor_code}")
break
@@ -100,14 +102,14 @@ def test_vendor_endpoint():
print(f" Content-Type: {response.headers.get('content-type', 'N/A')}")
# Check if response is JSON
- content_type = response.headers.get('content-type', '')
- if 'application/json' in content_type:
+ content_type = response.headers.get("content-type", "")
+ if "application/json" in content_type:
print("✅ Response is JSON")
data = response.json()
print(f" Vendor: {data.get('name', 'N/A')}")
print(f" Code: {data.get('vendor_code', 'N/A')}")
return True
- elif 'text/html' in content_type:
+ elif "text/html" in content_type:
print("❌ ERROR: Response is HTML, not JSON!")
print(" This confirms the route ordering issue")
print(" The HTML page route is catching the API request")
diff --git a/scripts/seed_demo.py b/scripts/seed_demo.py
index 6e224ea7..13cade60 100644
--- a/scripts/seed_demo.py
+++ b/scripts/seed_demo.py
@@ -26,39 +26,39 @@ This script is idempotent when run normally.
"""
import sys
+from datetime import datetime, timedelta, timezone
+from decimal import Decimal
from pathlib import Path
from typing import Dict, List
-from datetime import datetime, timezone, timedelta
-from decimal import Decimal
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from sqlalchemy.orm import Session
-from sqlalchemy import select, delete
-
-from app.core.database import SessionLocal
-from app.core.config import settings
-from app.core.environment import is_production, get_environment
-from models.database.user import User
-from models.database.vendor import Vendor, VendorUser, Role
-from models.database.vendor_domain import VendorDomain
-from models.database.vendor_theme import VendorTheme
-from models.database.customer import Customer, CustomerAddress
-from models.database.product import Product
-from models.database.marketplace_product import MarketplaceProduct
-from models.database.marketplace_import_job import MarketplaceImportJob
-from models.database.order import Order, OrderItem
-from models.database.admin import PlatformAlert
-from middleware.auth import AuthManager
-
# =============================================================================
# MODE DETECTION (from environment variable set by Makefile)
# =============================================================================
import os
-SEED_MODE = os.getenv('SEED_MODE', 'normal') # normal, minimal, reset
+from sqlalchemy import delete, select
+from sqlalchemy.orm import Session
+
+from app.core.config import settings
+from app.core.database import SessionLocal
+from app.core.environment import get_environment, is_production
+from middleware.auth import AuthManager
+from models.database.admin import PlatformAlert
+from models.database.customer import Customer, CustomerAddress
+from models.database.marketplace_import_job import MarketplaceImportJob
+from models.database.marketplace_product import MarketplaceProduct
+from models.database.order import Order, OrderItem
+from models.database.product import Product
+from models.database.user import User
+from models.database.vendor import Role, Vendor, VendorUser
+from models.database.vendor_domain import VendorDomain
+from models.database.vendor_theme import VendorTheme
+
+SEED_MODE = os.getenv("SEED_MODE", "normal") # normal, minimal, reset
# =============================================================================
# DEMO DATA CONFIGURATION
@@ -147,6 +147,7 @@ THEME_PRESETS = {
# HELPER FUNCTIONS
# =============================================================================
+
def print_header(text: str):
"""Print formatted header."""
print("\n" + "=" * 70)
@@ -178,6 +179,7 @@ def print_error(text: str):
# SAFETY CHECKS
# =============================================================================
+
def check_environment():
"""Prevent running demo seed in production."""
@@ -196,9 +198,7 @@ def check_environment():
def check_admin_exists(db: Session) -> bool:
"""Check if admin user exists."""
- admin = db.execute(
- select(User).where(User.role == "admin")
- ).scalar_one_or_none()
+ admin = db.execute(select(User).where(User.role == "admin")).scalar_one_or_none()
if not admin:
print_error("No admin user found!")
@@ -214,6 +214,7 @@ def check_admin_exists(db: Session) -> bool:
# DATA DELETION (for reset mode)
# =============================================================================
+
def reset_all_data(db: Session):
"""Delete ALL data from database (except admin user)."""
@@ -258,13 +259,14 @@ def reset_all_data(db: Session):
# SEEDING FUNCTIONS
# =============================================================================
+
def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
"""Create demo vendors with users."""
vendors = []
# Determine how many vendors to create based on mode
- vendor_count = 1 if SEED_MODE == 'minimal' else settings.seed_demo_vendors
+ vendor_count = 1 if SEED_MODE == "minimal" else settings.seed_demo_vendors
vendors_to_create = DEMO_VENDORS[:vendor_count]
users_to_create = DEMO_VENDOR_USERS[:vendor_count]
@@ -320,7 +322,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
db.add(vendor_user_link)
# Create vendor theme
- theme_colors = THEME_PRESETS.get(vendor_data["theme_preset"], THEME_PRESETS["modern"])
+ theme_colors = THEME_PRESETS.get(
+ vendor_data["theme_preset"], THEME_PRESETS["modern"]
+ )
theme = VendorTheme(
vendor_id=vendor.id,
theme_name=vendor_data["theme_preset"],
@@ -355,7 +359,9 @@ def create_demo_vendors(db: Session, auth_manager: AuthManager) -> List[Vendor]:
return vendors
-def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager, count: int) -> List[Customer]:
+def create_demo_customers(
+ db: Session, vendor: Vendor, auth_manager: AuthManager, count: int
+) -> List[Customer]:
"""Create demo customers for a vendor."""
customers = []
@@ -367,10 +373,11 @@ def create_demo_customers(db: Session, vendor: Vendor, auth_manager: AuthManager
customer_number = f"CUST-{vendor.vendor_code}-{i:04d}"
# Check if customer already exists
- existing_customer = db.query(Customer).filter(
- Customer.vendor_id == vendor.id,
- Customer.email == email
- ).first()
+ existing_customer = (
+ db.query(Customer)
+ .filter(Customer.vendor_id == vendor.id, Customer.email == email)
+ .first()
+ )
if existing_customer:
customers.append(existing_customer)
@@ -412,19 +419,22 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc
product_id = f"{vendor.vendor_code}-PROD-{i:03d}"
# Check if this product already exists
- existing_product = db.query(Product).filter(
- Product.vendor_id == vendor.id,
- Product.product_id == product_id
- ).first()
+ existing_product = (
+ db.query(Product)
+ .filter(Product.vendor_id == vendor.id, Product.product_id == product_id)
+ .first()
+ )
if existing_product:
products.append(existing_product)
continue # Skip creation, product already exists
# Check if marketplace product already exists
- existing_mp = db.query(MarketplaceProduct).filter(
- MarketplaceProduct.marketplace_product_id == marketplace_product_id
- ).first()
+ existing_mp = (
+ db.query(MarketplaceProduct)
+ .filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
+ .first()
+ )
if existing_mp:
marketplace_product = existing_mp
@@ -485,6 +495,7 @@ def create_demo_products(db: Session, vendor: Vendor, count: int) -> List[Produc
# MAIN SEEDING
# =============================================================================
+
def seed_demo_data(db: Session, auth_manager: AuthManager):
"""Seed demo data for development."""
@@ -501,7 +512,7 @@ def seed_demo_data(db: Session, auth_manager: AuthManager):
sys.exit(1)
# Step 3: Reset data if in reset mode
- if SEED_MODE == 'reset':
+ if SEED_MODE == "reset":
print_step(3, "Resetting data...")
reset_all_data(db)
@@ -513,20 +524,13 @@ def seed_demo_data(db: Session, auth_manager: AuthManager):
print_step(5, "Creating demo customers...")
for vendor in vendors:
create_demo_customers(
- db,
- vendor,
- auth_manager,
- count=settings.seed_customers_per_vendor
+ db, vendor, auth_manager, count=settings.seed_customers_per_vendor
)
# Step 6: Create products
print_step(6, "Creating demo products...")
for vendor in vendors:
- create_demo_products(
- db,
- vendor,
- count=settings.seed_products_per_vendor
- )
+ create_demo_products(db, vendor, count=settings.seed_products_per_vendor)
# Commit all changes
db.commit()
@@ -558,17 +562,18 @@ def print_summary(db: Session):
print(f" Subdomain: {vendor.subdomain}.{settings.platform_domain}")
# Query custom domains separately
- custom_domain = db.query(VendorDomain).filter(
- VendorDomain.vendor_id == vendor.id,
- VendorDomain.is_active == True
- ).first()
+ custom_domain = (
+ db.query(VendorDomain)
+ .filter(VendorDomain.vendor_id == vendor.id, VendorDomain.is_active == True)
+ .first()
+ )
if custom_domain:
# Try different possible field names (model field might vary)
domain_value = (
- getattr(custom_domain, 'domain', None) or
- getattr(custom_domain, 'domain_name', None) or
- getattr(custom_domain, 'name', None)
+ getattr(custom_domain, "domain", None)
+ or getattr(custom_domain, "domain_name", None)
+ or getattr(custom_domain, "name", None)
)
if domain_value:
print(f" Custom: {domain_value}")
@@ -584,8 +589,12 @@ def print_summary(db: Session):
print(f" Email: {vendor_data['email']}")
print(f" Password: {vendor_data['password']}")
if vendor:
- print(f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login")
- print(f" or http://{vendor.subdomain}.localhost:8000/vendor/login")
+ print(
+ f" Login: http://localhost:8000/vendor/{vendor.vendor_code}/login"
+ )
+ print(
+ f" or http://{vendor.subdomain}.localhost:8000/vendor/login"
+ )
print()
print(f"\n🛒 Demo Customer Credentials:")
@@ -600,7 +609,9 @@ def print_summary(db: Session):
print("─" * 70)
for vendor in vendors:
print(f" {vendor.name}:")
- print(f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/")
+ print(
+ f" Path-based: http://localhost:8000/vendors/{vendor.vendor_code}/shop/"
+ )
print(f" Subdomain: http://{vendor.subdomain}.localhost:8000/")
print()
@@ -621,6 +632,7 @@ def print_summary(db: Session):
# MAIN ENTRY POINT
# =============================================================================
+
def main():
"""Main entry point."""
@@ -647,6 +659,7 @@ def main():
print_header("❌ SEEDING FAILED")
print(f"\nError: {e}\n")
import traceback
+
traceback.print_exc()
sys.exit(1)
diff --git a/scripts/show_structure.py b/scripts/show_structure.py
index 6d3eeb43..83d382ae 100644
--- a/scripts/show_structure.py
+++ b/scripts/show_structure.py
@@ -11,11 +11,11 @@ Usage:
"""
import os
-import sys
import subprocess
+import sys
from datetime import datetime
from pathlib import Path
-from typing import List, Dict
+from typing import Dict, List
def count_files(directory: str, pattern: str) -> int:
@@ -26,10 +26,14 @@ def count_files(directory: str, pattern: str) -> int:
count = 0
for root, dirs, files in os.walk(directory):
# Skip __pycache__ and other cache directories
- dirs[:] = [d for d in dirs if d not in ['__pycache__', '.pytest_cache', '.git', 'node_modules']]
+ dirs[:] = [
+ d
+ for d in dirs
+ if d not in ["__pycache__", ".pytest_cache", ".git", "node_modules"]
+ ]
for file in files:
- if pattern == '*' or file.endswith(pattern):
+ if pattern == "*" or file.endswith(pattern):
count += 1
return count
@@ -41,26 +45,26 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st
# Try to use system tree command first
try:
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# Windows tree command
result = subprocess.run(
- ['tree', '/F', '/A', directory],
+ ["tree", "/F", "/A", directory],
capture_output=True,
text=True,
- encoding='utf-8',
- errors='replace'
+ encoding="utf-8",
+ errors="replace",
)
return result.stdout
else:
# Linux/Mac tree command with exclusions
exclude_args = []
if exclude_patterns:
- exclude_args = ['-I', '|'.join(exclude_patterns)]
+ exclude_args = ["-I", "|".join(exclude_patterns)]
result = subprocess.run(
- ['tree', '-F', '-a'] + exclude_args + [directory],
+ ["tree", "-F", "-a"] + exclude_args + [directory],
capture_output=True,
- text=True
+ text=True,
)
if result.returncode == 0:
return result.stdout
@@ -71,10 +75,19 @@ def get_tree_structure(directory: str, exclude_patterns: List[str] = None) -> st
return generate_manual_tree(directory, exclude_patterns)
-def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, prefix: str = "") -> str:
+def generate_manual_tree(
+ directory: str, exclude_patterns: List[str] = None, prefix: str = ""
+) -> str:
"""Generate tree structure manually when tree command is not available."""
if exclude_patterns is None:
- exclude_patterns = ['__pycache__', '.pytest_cache', '.git', 'node_modules', '*.pyc', '*.pyo']
+ exclude_patterns = [
+ "__pycache__",
+ ".pytest_cache",
+ ".git",
+ "node_modules",
+ "*.pyc",
+ "*.pyo",
+ ]
output = []
path = Path(directory)
@@ -86,7 +99,7 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre
# Skip excluded patterns
skip = False
for pattern in exclude_patterns:
- if pattern.startswith('*'):
+ if pattern.startswith("*"):
# File extension pattern
if item.name.endswith(pattern[1:]):
skip = True
@@ -106,7 +119,9 @@ def generate_manual_tree(directory: str, exclude_patterns: List[str] = None, pre
if item.is_dir():
output.append(f"{prefix}{current_prefix}{item.name}/")
extension = " " if is_last else "│ "
- subtree = generate_manual_tree(str(item), exclude_patterns, prefix + extension)
+ subtree = generate_manual_tree(
+ str(item), exclude_patterns, prefix + extension
+ )
if subtree:
output.append(subtree)
else:
@@ -127,20 +142,36 @@ def generate_frontend_structure() -> str:
# Templates section
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ JINJA2 TEMPLATES ║")
- output.append("║ Location: app/templates ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ JINJA2 TEMPLATES ║"
+ )
+ output.append(
+ "║ Location: app/templates ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append(get_tree_structure("app/templates"))
# Static assets section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ STATIC ASSETS ║")
- output.append("║ Location: static ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ STATIC ASSETS ║"
+ )
+ output.append(
+ "║ Location: static ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append(get_tree_structure("static"))
@@ -148,11 +179,21 @@ def generate_frontend_structure() -> str:
if os.path.exists("docs"):
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ DOCUMENTATION ║")
- output.append("║ Location: docs ║")
- output.append("║ (also listed in tools structure) ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ DOCUMENTATION ║"
+ )
+ output.append(
+ "║ Location: docs ║"
+ )
+ output.append(
+ "║ (also listed in tools structure) ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append("Note: Documentation is also included in tools structure")
output.append(" for infrastructure/DevOps context.")
@@ -160,9 +201,15 @@ def generate_frontend_structure() -> str:
# Statistics section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ STATISTICS ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ STATISTICS ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append("Templates:")
@@ -174,8 +221,8 @@ def generate_frontend_structure() -> str:
output.append(f" - JavaScript files: {count_files('static', '.js')}")
output.append(f" - CSS files: {count_files('static', '.css')}")
output.append(" - Image files:")
- for ext in ['png', 'jpg', 'jpeg', 'gif', 'svg', 'webp', 'ico']:
- count = count_files('static', f'.{ext}')
+ for ext in ["png", "jpg", "jpeg", "gif", "svg", "webp", "ico"]:
+ count = count_files("static", f".{ext}")
if count > 0:
output.append(f" - .{ext}: {count}")
@@ -200,41 +247,58 @@ def generate_backend_structure() -> str:
output.append("=" * 78)
output.append("")
- exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info', 'templates']
+ exclude = [
+ "__pycache__",
+ "*.pyc",
+ "*.pyo",
+ ".pytest_cache",
+ "*.egg-info",
+ "templates",
+ ]
# Backend directories to include
backend_dirs = [
- ('app', 'Application Code'),
- ('middleware', 'Middleware Components'),
- ('models', 'Database Models'),
- ('storage', 'File Storage'),
- ('tasks', 'Background Tasks'),
- ('logs', 'Application Logs'),
+ ("app", "Application Code"),
+ ("middleware", "Middleware Components"),
+ ("models", "Database Models"),
+ ("storage", "File Storage"),
+ ("tasks", "Background Tasks"),
+ ("logs", "Application Logs"),
]
for directory, title in backend_dirs:
if os.path.exists(directory):
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
output.append(f"║ {title.upper().center(62)} ║")
output.append(f"║ Location: {directory + '/'.ljust(51)} ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append(get_tree_structure(directory, exclude))
# Statistics section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ STATISTICS ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ STATISTICS ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append("Python Files by Directory:")
total_py_files = 0
for directory, title in backend_dirs:
if os.path.exists(directory):
- count = count_files(directory, '.py')
+ count = count_files(directory, ".py")
total_py_files += count
output.append(f" - {directory}/: {count} files")
@@ -242,11 +306,11 @@ def generate_backend_structure() -> str:
output.append("")
output.append("Application Components (if in app/):")
- components = ['routes', 'services', 'schemas', 'exceptions', 'utils']
+ components = ["routes", "services", "schemas", "exceptions", "utils"]
for component in components:
component_path = f"app/{component}"
if os.path.exists(component_path):
- count = count_files(component_path, '.py')
+ count = count_files(component_path, ".py")
output.append(f" - app/{component}: {count} files")
output.append("")
@@ -264,48 +328,58 @@ def generate_tools_structure() -> str:
output.append("=" * 78)
output.append("")
- exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info']
+ exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"]
# Tools directories to include
tools_dirs = [
- ('alembic', 'Database Migrations'),
- ('scripts', 'Utility Scripts'),
- ('docker', 'Docker Configuration'),
- ('docs', 'Documentation'),
+ ("alembic", "Database Migrations"),
+ ("scripts", "Utility Scripts"),
+ ("docker", "Docker Configuration"),
+ ("docs", "Documentation"),
]
for directory, title in tools_dirs:
if os.path.exists(directory):
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
output.append(f"║ {title.upper().center(62)} ║")
output.append(f"║ Location: {directory + '/'.ljust(51)} ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append(get_tree_structure(directory, exclude))
# Configuration files section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ CONFIGURATION FILES ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ CONFIGURATION FILES ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append("Root configuration files:")
config_files = [
- ('Makefile', 'Build automation'),
- ('requirements.txt', 'Python dependencies'),
- ('pyproject.toml', 'Python project config'),
- ('setup.py', 'Python setup script'),
- ('setup.cfg', 'Setup configuration'),
- ('alembic.ini', 'Alembic migrations config'),
- ('mkdocs.yml', 'MkDocs documentation config'),
- ('Dockerfile', 'Docker image definition'),
- ('docker-compose.yml', 'Docker services'),
- ('.dockerignore', 'Docker ignore patterns'),
- ('.gitignore', 'Git ignore patterns'),
- ('.env.example', 'Environment variables template'),
+ ("Makefile", "Build automation"),
+ ("requirements.txt", "Python dependencies"),
+ ("pyproject.toml", "Python project config"),
+ ("setup.py", "Python setup script"),
+ ("setup.cfg", "Setup configuration"),
+ ("alembic.ini", "Alembic migrations config"),
+ ("mkdocs.yml", "MkDocs documentation config"),
+ ("Dockerfile", "Docker image definition"),
+ ("docker-compose.yml", "Docker services"),
+ (".dockerignore", "Docker ignore patterns"),
+ (".gitignore", "Git ignore patterns"),
+ (".env.example", "Environment variables template"),
]
for file, description in config_files:
@@ -315,9 +389,15 @@ def generate_tools_structure() -> str:
# Statistics section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ STATISTICS ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ STATISTICS ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append("Database Migrations:")
@@ -327,7 +407,9 @@ def generate_tools_structure() -> str:
if migration_count > 0:
# Get first and last migration
try:
- migrations = sorted([f for f in os.listdir("alembic/versions") if f.endswith('.py')])
+ migrations = sorted(
+ [f for f in os.listdir("alembic/versions") if f.endswith(".py")]
+ )
if migrations:
output.append(f" - First: {migrations[0][:40]}...")
if len(migrations) > 1:
@@ -341,9 +423,9 @@ def generate_tools_structure() -> str:
output.append("Scripts:")
if os.path.exists("scripts"):
script_types = {
- '.py': 'Python scripts',
- '.sh': 'Shell scripts',
- '.bat': 'Batch scripts',
+ ".py": "Python scripts",
+ ".sh": "Shell scripts",
+ ".bat": "Batch scripts",
}
for ext, desc in script_types.items():
count = count_files("scripts", ext)
@@ -356,8 +438,8 @@ def generate_tools_structure() -> str:
output.append("Documentation:")
if os.path.exists("docs"):
doc_types = {
- '.md': 'Markdown files',
- '.rst': 'reStructuredText files',
+ ".md": "Markdown files",
+ ".rst": "reStructuredText files",
}
for ext, desc in doc_types.items():
count = count_files("docs", ext)
@@ -368,7 +450,7 @@ def generate_tools_structure() -> str:
output.append("")
output.append("Docker:")
- docker_files = ['Dockerfile', 'docker-compose.yml', '.dockerignore']
+ docker_files = ["Dockerfile", "docker-compose.yml", ".dockerignore"]
docker_exists = any(os.path.exists(f) for f in docker_files)
if docker_exists:
output.append(" ✓ Docker configuration present")
@@ -394,25 +476,44 @@ def generate_test_structure() -> str:
# Test files section
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ TEST FILES ║")
- output.append("║ Location: tests/ ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ TEST FILES ║"
+ )
+ output.append(
+ "║ Location: tests/ ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
- exclude = ['__pycache__', '*.pyc', '*.pyo', '.pytest_cache', '*.egg-info']
+ exclude = ["__pycache__", "*.pyc", "*.pyo", ".pytest_cache", "*.egg-info"]
output.append(get_tree_structure("tests", exclude))
# Configuration section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ TEST CONFIGURATION ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ TEST CONFIGURATION ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
output.append("Test configuration files:")
- test_config_files = ['pytest.ini', 'conftest.py', 'tests/conftest.py', '.coveragerc']
+ test_config_files = [
+ "pytest.ini",
+ "conftest.py",
+ "tests/conftest.py",
+ ".coveragerc",
+ ]
for file in test_config_files:
if os.path.exists(file):
output.append(f" ✓ {file}")
@@ -420,16 +521,22 @@ def generate_test_structure() -> str:
# Statistics section
output.append("")
output.append("")
- output.append("╔══════════════════════════════════════════════════════════════════╗")
- output.append("║ STATISTICS ║")
- output.append("╚══════════════════════════════════════════════════════════════════╝")
+ output.append(
+ "╔══════════════════════════════════════════════════════════════════╗"
+ )
+ output.append(
+ "║ STATISTICS ║"
+ )
+ output.append(
+ "╚══════════════════════════════════════════════════════════════════╝"
+ )
output.append("")
# Count test files
test_file_count = 0
if os.path.exists("tests"):
for root, dirs, files in os.walk("tests"):
- dirs[:] = [d for d in dirs if d != '__pycache__']
+ dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
if file.startswith("test_") and file.endswith(".py"):
test_file_count += 1
@@ -439,13 +546,13 @@ def generate_test_structure() -> str:
output.append("")
output.append("By Category:")
- categories = ['unit', 'integration', 'system', 'e2e', 'performance']
+ categories = ["unit", "integration", "system", "e2e", "performance"]
for category in categories:
category_path = f"tests/{category}"
if os.path.exists(category_path):
count = 0
for root, dirs, files in os.walk(category_path):
- dirs[:] = [d for d in dirs if d != '__pycache__']
+ dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
if file.startswith("test_") and file.endswith(".py"):
count += 1
@@ -455,12 +562,12 @@ def generate_test_structure() -> str:
test_function_count = 0
if os.path.exists("tests"):
for root, dirs, files in os.walk("tests"):
- dirs[:] = [d for d in dirs if d != '__pycache__']
+ dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
if file.startswith("test_") and file.endswith(".py"):
filepath = os.path.join(root, file)
try:
- with open(filepath, 'r', encoding='utf-8') as f:
+ with open(filepath, "r", encoding="utf-8") as f:
for line in f:
if line.strip().startswith("def test_"):
test_function_count += 1
@@ -494,19 +601,19 @@ def main():
structure_type = sys.argv[1].lower()
generators = {
- 'frontend': ('frontend-structure.txt', generate_frontend_structure),
- 'backend': ('backend-structure.txt', generate_backend_structure),
- 'tests': ('test-structure.txt', generate_test_structure),
- 'tools': ('tools-structure.txt', generate_tools_structure),
+ "frontend": ("frontend-structure.txt", generate_frontend_structure),
+ "backend": ("backend-structure.txt", generate_backend_structure),
+ "tests": ("test-structure.txt", generate_test_structure),
+ "tools": ("tools-structure.txt", generate_tools_structure),
}
- if structure_type == 'all':
+ if structure_type == "all":
for name, (filename, generator) in generators.items():
print(f"\n{'=' * 60}")
print(f"Generating {name} structure...")
- print('=' * 60)
+ print("=" * 60)
content = generator()
- with open(filename, 'w', encoding='utf-8') as f:
+ with open(filename, "w", encoding="utf-8") as f:
f.write(content)
print(f"✅ {name.capitalize()} structure saved to {filename}")
print(f"\n{content}\n")
@@ -514,7 +621,7 @@ def main():
filename, generator = generators[structure_type]
print(f"Generating {structure_type} structure...")
content = generator()
- with open(filename, 'w', encoding='utf-8') as f:
+ with open(filename, "w", encoding="utf-8") as f:
f.write(content)
print(f"\n✅ Structure saved to {filename}\n")
print(content)
diff --git a/scripts/test_auth_complete.py b/scripts/test_auth_complete.py
index 82e78a14..abc5c18a 100644
--- a/scripts/test_auth_complete.py
+++ b/scripts/test_auth_complete.py
@@ -20,22 +20,26 @@ Requirements:
* Customer: username=customer, password=customer123, vendor_id=1
"""
-import requests
import json
from typing import Dict, Optional
+import requests
+
BASE_URL = "http://localhost:8000"
+
class Color:
"""Terminal colors for pretty output"""
- GREEN = '\033[92m'
- RED = '\033[91m'
- YELLOW = '\033[93m'
- BLUE = '\033[94m'
- MAGENTA = '\033[95m'
- CYAN = '\033[96m'
- BOLD = '\033[1m'
- END = '\033[0m'
+
+ GREEN = "\033[92m"
+ RED = "\033[91m"
+ YELLOW = "\033[93m"
+ BLUE = "\033[94m"
+ MAGENTA = "\033[95m"
+ CYAN = "\033[96m"
+ BOLD = "\033[1m"
+ END = "\033[0m"
+
def print_section(name: str):
"""Print section header"""
@@ -43,66 +47,70 @@ def print_section(name: str):
print(f" {name}")
print(f"{'═' * 60}{Color.END}")
+
def print_test(name: str):
"""Print test name"""
print(f"\n{Color.BOLD}{Color.BLUE}🧪 Test: {name}{Color.END}")
+
def print_success(message: str):
"""Print success message"""
print(f"{Color.GREEN}✅ {message}{Color.END}")
+
def print_error(message: str):
"""Print error message"""
print(f"{Color.RED}❌ {message}{Color.END}")
+
def print_info(message: str):
"""Print info message"""
print(f"{Color.YELLOW}ℹ️ {message}{Color.END}")
+
def print_warning(message: str):
"""Print warning message"""
print(f"{Color.MAGENTA}⚠️ {message}{Color.END}")
+
# ============================================================================
# ADMIN AUTHENTICATION TESTS
# ============================================================================
+
def test_admin_login() -> Optional[Dict]:
"""Test admin login and cookie configuration"""
print_test("Admin Login")
-
+
try:
response = requests.post(
f"{BASE_URL}/api/v1/admin/auth/login",
- json={"username": "admin", "password": "admin123"}
+ json={"username": "admin", "password": "admin123"},
)
-
+
if response.status_code == 200:
data = response.json()
cookies = response.cookies
-
+
if "access_token" in data:
print_success("Admin login successful")
print_success(f"Received access token: {data['access_token'][:20]}...")
else:
print_error("No access token in response")
return None
-
+
if "admin_token" in cookies:
print_success("admin_token cookie set")
print_info("Cookie path should be /admin (verify in browser)")
else:
print_error("admin_token cookie NOT set")
-
- return {
- "token": data["access_token"],
- "user": data.get("user", {})
- }
+
+ return {"token": data["access_token"], "user": data.get("user", {})}
else:
print_error(f"Login failed: {response.status_code}")
print_error(f"Response: {response.text}")
return None
-
+
except Exception as e:
print_error(f"Exception during admin login: {str(e)}")
return None
@@ -111,13 +119,13 @@ def test_admin_login() -> Optional[Dict]:
def test_admin_cannot_access_vendor_api(admin_token: str):
"""Test that admin token cannot access vendor API"""
print_test("Admin Token on Vendor API (Should Block)")
-
+
try:
response = requests.get(
f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products",
- headers={"Authorization": f"Bearer {admin_token}"}
+ headers={"Authorization": f"Bearer {admin_token}"},
)
-
+
if response.status_code in [401, 403]:
data = response.json()
print_success("Admin correctly blocked from vendor API")
@@ -129,7 +137,7 @@ def test_admin_cannot_access_vendor_api(admin_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -138,13 +146,13 @@ def test_admin_cannot_access_vendor_api(admin_token: str):
def test_admin_cannot_access_customer_api(admin_token: str):
"""Test that admin token cannot access customer account pages"""
print_test("Admin Token on Customer API (Should Block)")
-
+
try:
response = requests.get(
f"{BASE_URL}/shop/account/dashboard",
- headers={"Authorization": f"Bearer {admin_token}"}
+ headers={"Authorization": f"Bearer {admin_token}"},
)
-
+
# Customer pages may return HTML or JSON error
if response.status_code in [401, 403]:
print_success("Admin correctly blocked from customer pages")
@@ -155,7 +163,7 @@ def test_admin_cannot_access_customer_api(admin_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -165,46 +173,47 @@ def test_admin_cannot_access_customer_api(admin_token: str):
# VENDOR AUTHENTICATION TESTS
# ============================================================================
+
def test_vendor_login() -> Optional[Dict]:
"""Test vendor login and cookie configuration"""
print_test("Vendor Login")
-
+
try:
response = requests.post(
f"{BASE_URL}/api/v1/vendor/auth/login",
- json={"username": "vendor", "password": "vendor123"}
+ json={"username": "vendor", "password": "vendor123"},
)
-
+
if response.status_code == 200:
data = response.json()
cookies = response.cookies
-
+
if "access_token" in data:
print_success("Vendor login successful")
print_success(f"Received access token: {data['access_token'][:20]}...")
else:
print_error("No access token in response")
return None
-
+
if "vendor_token" in cookies:
print_success("vendor_token cookie set")
print_info("Cookie path should be /vendor (verify in browser)")
else:
print_error("vendor_token cookie NOT set")
-
+
if "vendor" in data:
print_success(f"Vendor: {data['vendor'].get('vendor_code', 'N/A')}")
-
+
return {
"token": data["access_token"],
"user": data.get("user", {}),
- "vendor": data.get("vendor", {})
+ "vendor": data.get("vendor", {}),
}
else:
print_error(f"Login failed: {response.status_code}")
print_error(f"Response: {response.text}")
return None
-
+
except Exception as e:
print_error(f"Exception during vendor login: {str(e)}")
return None
@@ -213,13 +222,13 @@ def test_vendor_login() -> Optional[Dict]:
def test_vendor_cannot_access_admin_api(vendor_token: str):
"""Test that vendor token cannot access admin API"""
print_test("Vendor Token on Admin API (Should Block)")
-
+
try:
response = requests.get(
f"{BASE_URL}/api/v1/admin/vendors",
- headers={"Authorization": f"Bearer {vendor_token}"}
+ headers={"Authorization": f"Bearer {vendor_token}"},
)
-
+
if response.status_code in [401, 403]:
data = response.json()
print_success("Vendor correctly blocked from admin API")
@@ -231,7 +240,7 @@ def test_vendor_cannot_access_admin_api(vendor_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -240,13 +249,13 @@ def test_vendor_cannot_access_admin_api(vendor_token: str):
def test_vendor_cannot_access_customer_api(vendor_token: str):
"""Test that vendor token cannot access customer account pages"""
print_test("Vendor Token on Customer API (Should Block)")
-
+
try:
response = requests.get(
f"{BASE_URL}/shop/account/dashboard",
- headers={"Authorization": f"Bearer {vendor_token}"}
+ headers={"Authorization": f"Bearer {vendor_token}"},
)
-
+
if response.status_code in [401, 403]:
print_success("Vendor correctly blocked from customer pages")
return True
@@ -256,7 +265,7 @@ def test_vendor_cannot_access_customer_api(vendor_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -266,42 +275,40 @@ def test_vendor_cannot_access_customer_api(vendor_token: str):
# CUSTOMER AUTHENTICATION TESTS
# ============================================================================
+
def test_customer_login() -> Optional[Dict]:
"""Test customer login and cookie configuration"""
print_test("Customer Login")
-
+
try:
response = requests.post(
f"{BASE_URL}/api/v1/public/vendors/1/customers/login",
- json={"username": "customer", "password": "customer123"}
+ json={"username": "customer", "password": "customer123"},
)
-
+
if response.status_code == 200:
data = response.json()
cookies = response.cookies
-
+
if "access_token" in data:
print_success("Customer login successful")
print_success(f"Received access token: {data['access_token'][:20]}...")
else:
print_error("No access token in response")
return None
-
+
if "customer_token" in cookies:
print_success("customer_token cookie set")
print_info("Cookie path should be /shop (verify in browser)")
else:
print_error("customer_token cookie NOT set")
-
- return {
- "token": data["access_token"],
- "user": data.get("user", {})
- }
+
+ return {"token": data["access_token"], "user": data.get("user", {})}
else:
print_error(f"Login failed: {response.status_code}")
print_error(f"Response: {response.text}")
return None
-
+
except Exception as e:
print_error(f"Exception during customer login: {str(e)}")
return None
@@ -310,13 +317,13 @@ def test_customer_login() -> Optional[Dict]:
def test_customer_cannot_access_admin_api(customer_token: str):
"""Test that customer token cannot access admin API"""
print_test("Customer Token on Admin API (Should Block)")
-
+
try:
response = requests.get(
f"{BASE_URL}/api/v1/admin/vendors",
- headers={"Authorization": f"Bearer {customer_token}"}
+ headers={"Authorization": f"Bearer {customer_token}"},
)
-
+
if response.status_code in [401, 403]:
data = response.json()
print_success("Customer correctly blocked from admin API")
@@ -328,7 +335,7 @@ def test_customer_cannot_access_admin_api(customer_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -337,13 +344,13 @@ def test_customer_cannot_access_admin_api(customer_token: str):
def test_customer_cannot_access_vendor_api(customer_token: str):
"""Test that customer token cannot access vendor API"""
print_test("Customer Token on Vendor API (Should Block)")
-
+
try:
response = requests.get(
f"{BASE_URL}/api/v1/vendor/TESTVENDOR/products",
- headers={"Authorization": f"Bearer {customer_token}"}
+ headers={"Authorization": f"Bearer {customer_token}"},
)
-
+
if response.status_code in [401, 403]:
data = response.json()
print_success("Customer correctly blocked from vendor API")
@@ -355,7 +362,7 @@ def test_customer_cannot_access_vendor_api(customer_token: str):
else:
print_warning(f"Unexpected status code: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -364,17 +371,17 @@ def test_customer_cannot_access_vendor_api(customer_token: str):
def test_public_shop_access():
"""Test that public shop pages are accessible without authentication"""
print_test("Public Shop Access (No Auth Required)")
-
+
try:
response = requests.get(f"{BASE_URL}/shop/products")
-
+
if response.status_code == 200:
print_success("Public shop pages accessible without auth")
return True
else:
print_error(f"Failed to access public shop: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -383,10 +390,10 @@ def test_public_shop_access():
def test_health_check():
"""Test health check endpoint"""
print_test("Health Check")
-
+
try:
response = requests.get(f"{BASE_URL}/health")
-
+
if response.status_code == 200:
data = response.json()
print_success("Health check passed")
@@ -395,7 +402,7 @@ def test_health_check():
else:
print_error(f"Health check failed: {response.status_code}")
return False
-
+
except Exception as e:
print_error(f"Exception: {str(e)}")
return False
@@ -405,19 +412,16 @@ def test_health_check():
# MAIN TEST RUNNER
# ============================================================================
+
def main():
"""Run all tests"""
print(f"\n{Color.BOLD}{Color.CYAN}{'═' * 60}")
print(f" 🔒 COMPLETE AUTHENTICATION SYSTEM TEST SUITE")
print(f"{'═' * 60}{Color.END}")
print(f"Testing server at: {BASE_URL}")
-
- results = {
- "passed": 0,
- "failed": 0,
- "total": 0
- }
-
+
+ results = {"passed": 0, "failed": 0, "total": 0}
+
# Health check first
print_section("🏥 System Health")
results["total"] += 1
@@ -427,12 +431,12 @@ def main():
results["failed"] += 1
print_error("Server not responding. Is it running?")
return
-
+
# ========================================================================
# ADMIN TESTS
# ========================================================================
print_section("👤 Admin Authentication Tests")
-
+
# Admin login
results["total"] += 1
admin_auth = test_admin_login()
@@ -441,7 +445,7 @@ def main():
else:
results["failed"] += 1
admin_auth = None
-
+
# Admin cross-context tests
if admin_auth:
results["total"] += 1
@@ -449,18 +453,18 @@ def main():
results["passed"] += 1
else:
results["failed"] += 1
-
+
results["total"] += 1
if test_admin_cannot_access_customer_api(admin_auth["token"]):
results["passed"] += 1
else:
results["failed"] += 1
-
+
# ========================================================================
# VENDOR TESTS
# ========================================================================
print_section("🏪 Vendor Authentication Tests")
-
+
# Vendor login
results["total"] += 1
vendor_auth = test_vendor_login()
@@ -469,7 +473,7 @@ def main():
else:
results["failed"] += 1
vendor_auth = None
-
+
# Vendor cross-context tests
if vendor_auth:
results["total"] += 1
@@ -477,25 +481,25 @@ def main():
results["passed"] += 1
else:
results["failed"] += 1
-
+
results["total"] += 1
if test_vendor_cannot_access_customer_api(vendor_auth["token"]):
results["passed"] += 1
else:
results["failed"] += 1
-
+
# ========================================================================
# CUSTOMER TESTS
# ========================================================================
print_section("🛒 Customer Authentication Tests")
-
+
# Public shop access
results["total"] += 1
if test_public_shop_access():
results["passed"] += 1
else:
results["failed"] += 1
-
+
# Customer login
results["total"] += 1
customer_auth = test_customer_login()
@@ -504,7 +508,7 @@ def main():
else:
results["failed"] += 1
customer_auth = None
-
+
# Customer cross-context tests
if customer_auth:
results["total"] += 1
@@ -512,29 +516,31 @@ def main():
results["passed"] += 1
else:
results["failed"] += 1
-
+
results["total"] += 1
if test_customer_cannot_access_vendor_api(customer_auth["token"]):
results["passed"] += 1
else:
results["failed"] += 1
-
+
# ========================================================================
# RESULTS
# ========================================================================
print_section("📊 Test Results")
print(f"\n{Color.BOLD}Total Tests: {results['total']}{Color.END}")
print_success(f"Passed: {results['passed']}")
- if results['failed'] > 0:
+ if results["failed"] > 0:
print_error(f"Failed: {results['failed']}")
-
- if results['failed'] == 0:
+
+ if results["failed"] == 0:
print(f"\n{Color.GREEN}{Color.BOLD}🎉 ALL TESTS PASSED!{Color.END}")
- print(f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}")
+ print(
+ f"{Color.GREEN}Your authentication system is properly isolated!{Color.END}"
+ )
else:
print(f"\n{Color.RED}{Color.BOLD}⚠️ SOME TESTS FAILED{Color.END}")
print(f"{Color.RED}Please review the output above.{Color.END}")
-
+
# Browser tests reminder
print_section("🌐 Manual Browser Tests")
print("Please also verify in browser:")
diff --git a/scripts/test_vendor_management.py b/scripts/test_vendor_management.py
index 751f59c2..71292c8b 100644
--- a/scripts/test_vendor_management.py
+++ b/scripts/test_vendor_management.py
@@ -9,10 +9,11 @@ Tests:
- Transfer ownership
"""
-import requests
import json
from pprint import pprint
+import requests
+
BASE_URL = "http://localhost:8000"
ADMIN_TOKEN = None # Will be set after login
@@ -25,8 +26,8 @@ def login_admin():
f"{BASE_URL}/api/v1/admin/auth/login", # ✅ Changed: added /admin/
json={
"username": "admin", # Replace with your admin username
- "password": "admin123" # Replace with your admin password
- }
+ "password": "admin123", # Replace with your admin password
+ },
)
if response.status_code == 200:
@@ -44,7 +45,7 @@ def get_headers():
"""Get authorization headers."""
return {
"Authorization": f"Bearer {ADMIN_TOKEN}",
- "Content-Type": "application/json"
+ "Content-Type": "application/json",
}
@@ -60,13 +61,11 @@ def test_create_vendor_with_both_emails():
"subdomain": "testdual",
"owner_email": "owner@testdual.com",
"contact_email": "contact@testdual.com",
- "description": "Test vendor with separate emails"
+ "description": "Test vendor with separate emails",
}
response = requests.post(
- f"{BASE_URL}/api/v1/admin/vendors",
- headers=get_headers(),
- json=vendor_data
+ f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data
)
if response.status_code == 200:
@@ -79,7 +78,7 @@ def test_create_vendor_with_both_emails():
print(f" Username: {data['owner_username']}")
print(f" Password: {data['temporary_password']}")
print(f"\n🔗 Login URL: {data['login_url']}")
- return data['id']
+ return data["id"]
else:
print(f"❌ Failed: {response.status_code}")
print(response.text)
@@ -97,13 +96,11 @@ def test_create_vendor_single_email():
"name": "Test Single Email Vendor",
"subdomain": "testsingle",
"owner_email": "owner@testsingle.com",
- "description": "Test vendor with single email"
+ "description": "Test vendor with single email",
}
response = requests.post(
- f"{BASE_URL}/api/v1/admin/vendors",
- headers=get_headers(),
- json=vendor_data
+ f"{BASE_URL}/api/v1/admin/vendors", headers=get_headers(), json=vendor_data
)
if response.status_code == 200:
@@ -113,12 +110,12 @@ def test_create_vendor_single_email():
print(f" Owner Email: {data['owner_email']}")
print(f" Contact Email: {data['contact_email']}")
- if data['owner_email'] == data['contact_email']:
+ if data["owner_email"] == data["contact_email"]:
print("✅ Contact email correctly defaulted to owner email")
else:
print("❌ Contact email should have defaulted to owner email")
- return data['id']
+ return data["id"]
else:
print(f"❌ Failed: {response.status_code}")
print(response.text)
@@ -133,13 +130,13 @@ def test_update_vendor_contact_email(vendor_id):
update_data = {
"contact_email": "newcontact@business.com",
- "name": "Updated Vendor Name"
+ "name": "Updated Vendor Name",
}
response = requests.put(
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
headers=get_headers(),
- json=update_data
+ json=update_data,
)
if response.status_code == 200:
@@ -163,8 +160,7 @@ def test_get_vendor_details(vendor_id):
print("=" * 60)
response = requests.get(
- f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
- headers=get_headers()
+ f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers()
)
if response.status_code == 200:
@@ -196,8 +192,7 @@ def test_transfer_ownership(vendor_id, new_owner_user_id):
# First, get current owner info
response = requests.get(
- f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}",
- headers=get_headers()
+ f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}", headers=get_headers()
)
if response.status_code == 200:
@@ -211,13 +206,13 @@ def test_transfer_ownership(vendor_id, new_owner_user_id):
transfer_data = {
"new_owner_user_id": new_owner_user_id,
"confirm_transfer": True,
- "transfer_reason": "Testing ownership transfer feature"
+ "transfer_reason": "Testing ownership transfer feature",
}
response = requests.post(
f"{BASE_URL}/api/v1/admin/vendors/{vendor_id}/transfer-ownership",
headers=get_headers(),
- json=transfer_data
+ json=transfer_data,
)
if response.status_code == 200:
diff --git a/scripts/validate_architecture.py b/scripts/validate_architecture.py
index dc186de8..6e654957 100755
--- a/scripts/validate_architecture.py
+++ b/scripts/validate_architecture.py
@@ -22,15 +22,17 @@ import argparse
import ast
import re
import sys
-from pathlib import Path
-from typing import List, Dict, Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
import yaml
class Severity(Enum):
"""Validation severity levels"""
+
ERROR = "error"
WARNING = "warning"
INFO = "info"
@@ -39,6 +41,7 @@ class Severity(Enum):
@dataclass
class Violation:
"""Represents an architectural rule violation"""
+
rule_id: str
rule_name: str
severity: Severity
@@ -52,6 +55,7 @@ class Violation:
@dataclass
class ValidationResult:
"""Results of architecture validation"""
+
violations: List[Violation] = field(default_factory=list)
files_checked: int = 0
rules_applied: int = 0
@@ -82,7 +86,7 @@ class ArchitectureValidator:
print(f"❌ Configuration file not found: {self.config_path}")
sys.exit(1)
- with open(self.config_path, 'r') as f:
+ with open(self.config_path, "r") as f:
config = yaml.safe_load(f)
print(f"📋 Loaded architecture rules: {config.get('project', 'unknown')}")
@@ -126,7 +130,7 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
- lines = content.split('\n')
+ lines = content.split("\n")
# API-001: Check for Pydantic model usage
self._check_pydantic_usage(file_path, content, lines)
@@ -147,8 +151,8 @@ class ArchitectureValidator:
return
# Check for response_model in route decorators
- route_pattern = r'@router\.(get|post|put|delete|patch)'
- dict_return_pattern = r'return\s+\{.*\}'
+ route_pattern = r"@router\.(get|post|put|delete|patch)"
+ dict_return_pattern = r"return\s+\{.*\}"
for i, line in enumerate(lines, 1):
# Check for dict returns in endpoints
@@ -166,47 +170,51 @@ class ArchitectureValidator:
if re.search(dict_return_pattern, func_line):
self._add_violation(
rule_id="API-001",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=j + 1,
message="Endpoint returns raw dict instead of Pydantic model",
context=func_line.strip(),
- suggestion="Define a Pydantic response model and use response_model parameter"
+ suggestion="Define a Pydantic response model and use response_model parameter",
)
- def _check_no_business_logic_in_endpoints(self, file_path: Path, content: str, lines: List[str]):
+ def _check_no_business_logic_in_endpoints(
+ self, file_path: Path, content: str, lines: List[str]
+ ):
"""API-002: Ensure no business logic in endpoints"""
rule = self._get_rule("API-002")
if not rule:
return
anti_patterns = [
- (r'db\.add\(', "Database operations should be in service layer"),
- (r'db\.commit\(\)', "Database commits should be in service layer"),
- (r'db\.query\(', "Database queries should be in service layer"),
- (r'db\.execute\(', "Database operations should be in service layer"),
+ (r"db\.add\(", "Database operations should be in service layer"),
+ (r"db\.commit\(\)", "Database commits should be in service layer"),
+ (r"db\.query\(", "Database queries should be in service layer"),
+ (r"db\.execute\(", "Database operations should be in service layer"),
]
for i, line in enumerate(lines, 1):
# Skip service method calls (allowed)
- if '_service.' in line or 'service.' in line:
+ if "_service." in line or "service." in line:
continue
for pattern, message in anti_patterns:
if re.search(pattern, line):
self._add_violation(
rule_id="API-002",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message=message,
context=line.strip(),
- suggestion="Move database operations to service layer"
+ suggestion="Move database operations to service layer",
)
- def _check_endpoint_exception_handling(self, file_path: Path, content: str, lines: List[str]):
+ def _check_endpoint_exception_handling(
+ self, file_path: Path, content: str, lines: List[str]
+ ):
"""API-003: Check proper exception handling in endpoints"""
rule = self._get_rule("API-003")
if not rule:
@@ -222,40 +230,41 @@ class ArchitectureValidator:
if isinstance(node, ast.FunctionDef):
# Check if it's a route handler
has_router_decorator = any(
- isinstance(d, ast.Call) and
- isinstance(d.func, ast.Attribute) and
- getattr(d.func.value, 'id', None) == 'router'
+ isinstance(d, ast.Call)
+ and isinstance(d.func, ast.Attribute)
+ and getattr(d.func.value, "id", None) == "router"
for d in node.decorator_list
)
if has_router_decorator:
# Check if function body has try/except
has_try_except = any(
- isinstance(child, ast.Try)
- for child in ast.walk(node)
+ isinstance(child, ast.Try) for child in ast.walk(node)
)
# Check if function calls service methods
has_service_call = any(
- isinstance(child, ast.Call) and
- isinstance(child.func, ast.Attribute) and
- 'service' in getattr(child.func.value, 'id', '').lower()
+ isinstance(child, ast.Call)
+ and isinstance(child.func, ast.Attribute)
+ and "service" in getattr(child.func.value, "id", "").lower()
for child in ast.walk(node)
)
if has_service_call and not has_try_except:
self._add_violation(
rule_id="API-003",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.WARNING,
file_path=file_path,
line_number=node.lineno,
message=f"Endpoint '{node.name}' calls service but lacks exception handling",
context=f"def {node.name}(...)",
- suggestion="Wrap service calls in try/except and convert to HTTPException"
+ suggestion="Wrap service calls in try/except and convert to HTTPException",
)
- def _check_endpoint_authentication(self, file_path: Path, content: str, lines: List[str]):
+ def _check_endpoint_authentication(
+ self, file_path: Path, content: str, lines: List[str]
+ ):
"""API-004: Check authentication on endpoints"""
rule = self._get_rule("API-004")
if not rule:
@@ -264,24 +273,28 @@ class ArchitectureValidator:
# This is a warning-level check
# Look for endpoints without Depends(get_current_*)
for i, line in enumerate(lines, 1):
- if '@router.' in line and ('post' in line or 'put' in line or 'delete' in line):
+ if "@router." in line and (
+ "post" in line or "put" in line or "delete" in line
+ ):
# Check next 5 lines for auth
has_auth = False
for j in range(i, min(i + 5, len(lines))):
- if 'Depends(get_current_' in lines[j]:
+ if "Depends(get_current_" in lines[j]:
has_auth = True
break
- if not has_auth and 'include_in_schema=False' not in ' '.join(lines[i:i+5]):
+ if not has_auth and "include_in_schema=False" not in " ".join(
+ lines[i : i + 5]
+ ):
self._add_violation(
rule_id="API-004",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.WARNING,
file_path=file_path,
line_number=i,
message="Endpoint may be missing authentication",
context=line.strip(),
- suggestion="Add Depends(get_current_user) or similar if endpoint should be protected"
+ suggestion="Add Depends(get_current_user) or similar if endpoint should be protected",
)
def _validate_service_layer(self, target_path: Path):
@@ -296,7 +309,7 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
- lines = content.split('\n')
+ lines = content.split("\n")
# SVC-001: No HTTPException in services
self._check_no_http_exception_in_services(file_path, content, lines)
@@ -307,38 +320,45 @@ class ArchitectureValidator:
# SVC-003: DB session as parameter
self._check_db_session_parameter(file_path, content, lines)
- def _check_no_http_exception_in_services(self, file_path: Path, content: str, lines: List[str]):
+ def _check_no_http_exception_in_services(
+ self, file_path: Path, content: str, lines: List[str]
+ ):
"""SVC-001: Services must not raise HTTPException"""
rule = self._get_rule("SVC-001")
if not rule:
return
for i, line in enumerate(lines, 1):
- if 'raise HTTPException' in line:
+ if "raise HTTPException" in line:
self._add_violation(
rule_id="SVC-001",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message="Service raises HTTPException - use domain exceptions instead",
context=line.strip(),
- suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that"
+ suggestion="Create custom exception class (e.g., VendorNotFoundError) and raise that",
)
- if 'from fastapi import HTTPException' in line or 'from fastapi.exceptions import HTTPException' in line:
+ if (
+ "from fastapi import HTTPException" in line
+ or "from fastapi.exceptions import HTTPException" in line
+ ):
self._add_violation(
rule_id="SVC-001",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message="Service imports HTTPException - services should not know about HTTP",
context=line.strip(),
- suggestion="Remove HTTPException import and use domain exceptions"
+ suggestion="Remove HTTPException import and use domain exceptions",
)
- def _check_service_exceptions(self, file_path: Path, content: str, lines: List[str]):
+ def _check_service_exceptions(
+ self, file_path: Path, content: str, lines: List[str]
+ ):
"""SVC-002: Check for proper exception handling"""
rule = self._get_rule("SVC-002")
if not rule:
@@ -346,19 +366,21 @@ class ArchitectureValidator:
for i, line in enumerate(lines, 1):
# Check for generic Exception raises
- if re.match(r'\s*raise Exception\(', line):
+ if re.match(r"\s*raise Exception\(", line):
self._add_violation(
rule_id="SVC-002",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.WARNING,
file_path=file_path,
line_number=i,
message="Service raises generic Exception - use specific domain exception",
context=line.strip(),
- suggestion="Create custom exception class for this error case"
+ suggestion="Create custom exception class for this error case",
)
- def _check_db_session_parameter(self, file_path: Path, content: str, lines: List[str]):
+ def _check_db_session_parameter(
+ self, file_path: Path, content: str, lines: List[str]
+ ):
"""SVC-003: Service methods should accept db session as parameter"""
rule = self._get_rule("SVC-003")
if not rule:
@@ -366,16 +388,16 @@ class ArchitectureValidator:
# Check for SessionLocal() creation in service files
for i, line in enumerate(lines, 1):
- if 'SessionLocal()' in line and 'class' not in line:
+ if "SessionLocal()" in line and "class" not in line:
self._add_violation(
rule_id="SVC-003",
- rule_name=rule['name'],
+ rule_name=rule["name"],
severity=Severity.ERROR,
file_path=file_path,
line_number=i,
message="Service creates database session internally",
context=line.strip(),
- suggestion="Accept db: Session as method parameter instead"
+ suggestion="Accept db: Session as method parameter instead",
)
def _validate_models(self, target_path: Path):
@@ -391,11 +413,11 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
- lines = content.split('\n')
+ lines = content.split("\n")
# Check for mixing SQLAlchemy and Pydantic
for i, line in enumerate(lines, 1):
- if re.search(r'class.*\(Base.*,.*BaseModel.*\)', line):
+ if re.search(r"class.*\(Base.*,.*BaseModel.*\)", line):
self._add_violation(
rule_id="MDL-002",
rule_name="Separate SQLAlchemy and Pydantic models",
@@ -404,7 +426,7 @@ class ArchitectureValidator:
line_number=i,
message="Model mixes SQLAlchemy Base and Pydantic BaseModel",
context=line.strip(),
- suggestion="Keep SQLAlchemy models and Pydantic models separate"
+ suggestion="Keep SQLAlchemy models and Pydantic models separate",
)
def _validate_exceptions(self, target_path: Path):
@@ -418,11 +440,11 @@ class ArchitectureValidator:
continue
content = file_path.read_text()
- lines = content.split('\n')
+ lines = content.split("\n")
# EXC-002: Check for bare except
for i, line in enumerate(lines, 1):
- if re.match(r'\s*except\s*:', line):
+ if re.match(r"\s*except\s*:", line):
self._add_violation(
rule_id="EXC-002",
rule_name="No bare except clauses",
@@ -431,7 +453,7 @@ class ArchitectureValidator:
line_number=i,
message="Bare except clause catches all exceptions including system exits",
context=line.strip(),
- suggestion="Specify exception type: except ValueError: or except Exception:"
+ suggestion="Specify exception type: except ValueError: or except Exception:",
)
def _validate_javascript(self, target_path: Path):
@@ -443,11 +465,16 @@ class ArchitectureValidator:
for file_path in js_files:
content = file_path.read_text()
- lines = content.split('\n')
+ lines = content.split("\n")
# JS-001: Check for window.apiClient
for i, line in enumerate(lines, 1):
- if 'window.apiClient' in line and '//' not in line[:line.find('window.apiClient')] if 'window.apiClient' in line else True:
+ if (
+ "window.apiClient" in line
+ and "//" not in line[: line.find("window.apiClient")]
+ if "window.apiClient" in line
+ else True
+ ):
self._add_violation(
rule_id="JS-001",
rule_name="Use apiClient directly",
@@ -456,14 +483,14 @@ class ArchitectureValidator:
line_number=i,
message="Use apiClient directly instead of window.apiClient",
context=line.strip(),
- suggestion="Replace window.apiClient with apiClient"
+ suggestion="Replace window.apiClient with apiClient",
)
# JS-002: Check for console usage
for i, line in enumerate(lines, 1):
- if re.search(r'console\.(log|warn|error)', line):
+ if re.search(r"console\.(log|warn|error)", line):
# Skip if it's a comment or bootstrap message
- if '//' in line or '✅' in line or 'eslint-disable' in line:
+ if "//" in line or "✅" in line or "eslint-disable" in line:
continue
self._add_violation(
@@ -474,7 +501,7 @@ class ArchitectureValidator:
line_number=i,
message="Use centralized logger instead of console",
context=line.strip()[:80],
- suggestion="Use window.LogConfig.createLogger('moduleName')"
+ suggestion="Use window.LogConfig.createLogger('moduleName')",
)
def _validate_templates(self, target_path: Path):
@@ -486,14 +513,16 @@ class ArchitectureValidator:
for file_path in template_files:
# Skip base template and partials
- if 'base.html' in file_path.name or 'partials' in str(file_path):
+ if "base.html" in file_path.name or "partials" in str(file_path):
continue
content = file_path.read_text()
- lines = content.split('\n')
+ lines = content.split("\n")
# TPL-001: Check for extends
- has_extends = any('{% extends' in line and 'admin/base.html' in line for line in lines)
+ has_extends = any(
+ "{% extends" in line and "admin/base.html" in line for line in lines
+ )
if not has_extends:
self._add_violation(
@@ -504,23 +533,29 @@ class ArchitectureValidator:
line_number=1,
message="Admin template does not extend admin/base.html",
context=file_path.name,
- suggestion="Add {% extends 'admin/base.html' %} at the top"
+ suggestion="Add {% extends 'admin/base.html' %} at the top",
)
def _get_rule(self, rule_id: str) -> Dict[str, Any]:
"""Get rule configuration by ID"""
# Look in different rule categories
- for category in ['api_endpoint_rules', 'service_layer_rules', 'model_rules',
- 'exception_rules', 'javascript_rules', 'template_rules']:
+ for category in [
+ "api_endpoint_rules",
+ "service_layer_rules",
+ "model_rules",
+ "exception_rules",
+ "javascript_rules",
+ "template_rules",
+ ]:
rules = self.config.get(category, [])
for rule in rules:
- if rule.get('id') == rule_id:
+ if rule.get("id") == rule_id:
return rule
return None
def _should_ignore_file(self, file_path: Path) -> bool:
"""Check if file should be ignored"""
- ignore_patterns = self.config.get('ignore', {}).get('files', [])
+ ignore_patterns = self.config.get("ignore", {}).get("files", [])
for pattern in ignore_patterns:
if file_path.match(pattern):
@@ -528,9 +563,17 @@ class ArchitectureValidator:
return False
- def _add_violation(self, rule_id: str, rule_name: str, severity: Severity,
- file_path: Path, line_number: int, message: str,
- context: str = "", suggestion: str = ""):
+ def _add_violation(
+ self,
+ rule_id: str,
+ rule_name: str,
+ severity: Severity,
+ file_path: Path,
+ line_number: int,
+ message: str,
+ context: str = "",
+ suggestion: str = "",
+ ):
"""Add a violation to results"""
violation = Violation(
rule_id=rule_id,
@@ -540,7 +583,7 @@ class ArchitectureValidator:
line_number=line_number,
message=message,
context=context,
- suggestion=suggestion
+ suggestion=suggestion,
)
self.result.violations.append(violation)
@@ -590,24 +633,34 @@ class ArchitectureValidator:
violations_json = []
for v in self.result.violations:
- rel_path = str(v.file_path.relative_to(self.project_root)) if self.project_root in v.file_path.parents else str(v.file_path)
- violations_json.append({
- 'rule_id': v.rule_id,
- 'rule_name': v.rule_name,
- 'severity': v.severity.value,
- 'file_path': rel_path,
- 'line_number': v.line_number,
- 'message': v.message,
- 'context': v.context or '',
- 'suggestion': v.suggestion or ''
- })
+ rel_path = (
+ str(v.file_path.relative_to(self.project_root))
+ if self.project_root in v.file_path.parents
+ else str(v.file_path)
+ )
+ violations_json.append(
+ {
+ "rule_id": v.rule_id,
+ "rule_name": v.rule_name,
+ "severity": v.severity.value,
+ "file_path": rel_path,
+ "line_number": v.line_number,
+ "message": v.message,
+ "context": v.context or "",
+ "suggestion": v.suggestion or "",
+ }
+ )
output = {
- 'files_checked': self.result.files_checked,
- 'total_violations': len(self.result.violations),
- 'errors': len([v for v in self.result.violations if v.severity == Severity.ERROR]),
- 'warnings': len([v for v in self.result.violations if v.severity == Severity.WARNING]),
- 'violations': violations_json
+ "files_checked": self.result.files_checked,
+ "total_violations": len(self.result.violations),
+ "errors": len(
+ [v for v in self.result.violations if v.severity == Severity.ERROR]
+ ),
+ "warnings": len(
+ [v for v in self.result.violations if v.severity == Severity.WARNING]
+ ),
+ "violations": violations_json,
}
print(json.dumps(output, indent=2))
@@ -616,7 +669,11 @@ class ArchitectureValidator:
def _print_violation(self, v: Violation):
"""Print a single violation"""
- rel_path = v.file_path.relative_to(self.project_root) if self.project_root in v.file_path.parents else v.file_path
+ rel_path = (
+ v.file_path.relative_to(self.project_root)
+ if self.project_root in v.file_path.parents
+ else v.file_path
+ )
print(f"\n [{v.rule_id}] {v.rule_name}")
print(f" File: {rel_path}:{v.line_number}")
@@ -634,40 +691,40 @@ def main():
parser = argparse.ArgumentParser(
description="Validate architecture patterns in codebase",
formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog=__doc__
+ epilog=__doc__,
)
parser.add_argument(
- 'path',
- nargs='?',
+ "path",
+ nargs="?",
type=Path,
default=Path.cwd(),
- help="Path to validate (default: current directory)"
+ help="Path to validate (default: current directory)",
)
parser.add_argument(
- '-c', '--config',
+ "-c",
+ "--config",
type=Path,
- default=Path.cwd() / '.architecture-rules.yaml',
- help="Path to architecture rules config (default: .architecture-rules.yaml)"
+ default=Path.cwd() / ".architecture-rules.yaml",
+ help="Path to architecture rules config (default: .architecture-rules.yaml)",
)
parser.add_argument(
- '-v', '--verbose',
- action='store_true',
- help="Show detailed output including context"
+ "-v",
+ "--verbose",
+ action="store_true",
+ help="Show detailed output including context",
)
parser.add_argument(
- '--errors-only',
- action='store_true',
- help="Only show errors, suppress warnings"
+ "--errors-only", action="store_true", help="Only show errors, suppress warnings"
)
parser.add_argument(
- '--json',
- action='store_true',
- help="Output results as JSON (for programmatic use)"
+ "--json",
+ action="store_true",
+ help="Output results as JSON (for programmatic use)",
)
args = parser.parse_args()
@@ -687,5 +744,5 @@ def main():
sys.exit(exit_code)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/verify_setup.py b/scripts/verify_setup.py
index f7b5edf6..432550bd 100644
--- a/scripts/verify_setup.py
+++ b/scripts/verify_setup.py
@@ -3,10 +3,12 @@
import os
import sys
+from urllib.parse import urlparse
+
from sqlalchemy import create_engine, text
+
from alembic import command
from alembic.config import Config
-from urllib.parse import urlparse
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
@@ -20,10 +22,10 @@ def get_database_info():
parsed = urlparse(db_url)
db_type = parsed.scheme
- if db_type == 'sqlite':
+ if db_type == "sqlite":
# Extract path from sqlite:///./path or sqlite:///path
- db_path = db_url.replace('sqlite:///', '')
- if db_path.startswith('./'):
+ db_path = db_url.replace("sqlite:///", "")
+ if db_path.startswith("./"):
db_path = db_path[2:]
return db_type, db_path
else:
@@ -40,7 +42,7 @@ def verify_database_setup():
db_type, db_path = get_database_info()
print(f"[INFO] Database type: {db_type}")
- if db_type == 'sqlite':
+ if db_type == "sqlite":
if not os.path.exists(db_path):
print(f"[ERROR] Database file not found: {db_path}")
return False
@@ -54,10 +56,14 @@ def verify_database_setup():
with engine.connect() as conn:
# Get table list (works for both SQLite and PostgreSQL)
- if db_type == 'sqlite':
- result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
+ if db_type == "sqlite":
+ result = conn.execute(
+ text("SELECT name FROM sqlite_master WHERE type='table'")
+ )
else:
- result = conn.execute(text("SELECT tablename FROM pg_tables WHERE schemaname='public'"))
+ result = conn.execute(
+ text("SELECT tablename FROM pg_tables WHERE schemaname='public'")
+ )
tables = [row[0] for row in result.fetchall()]
@@ -65,12 +71,17 @@ def verify_database_setup():
# Expected tables from your models
expected_tables = [
- 'users', 'products', 'inventory', 'vendors', 'products',
- 'marketplace_import_jobs', 'alembic_version'
+ "users",
+ "products",
+ "inventory",
+ "vendors",
+ "products",
+ "marketplace_import_jobs",
+ "alembic_version",
]
for table in sorted(tables):
- if table == 'alembic_version':
+ if table == "alembic_version":
print(f" * {table} (migration tracking)")
elif table in expected_tables:
print(f" * {table}")
@@ -78,12 +89,12 @@ def verify_database_setup():
print(f" ? {table} (unexpected)")
# Check for missing expected tables
- missing_tables = set(expected_tables) - set(tables) - {'alembic_version'}
+ missing_tables = set(expected_tables) - set(tables) - {"alembic_version"}
if missing_tables:
print(f"[WARNING] Missing expected tables: {missing_tables}")
# Check if alembic_version table exists
- if 'alembic_version' in tables:
+ if "alembic_version" in tables:
result = conn.execute(text("SELECT version_num FROM alembic_version"))
version = result.fetchone()
if version:
@@ -126,16 +137,19 @@ def verify_model_structure():
# Test database models
try:
from models.database.base import Base
+
print(f"[OK] Database Base imported")
- print(f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}")
+ print(
+ f"[INFO] Found {len(Base.metadata.tables)} database tables: {list(Base.metadata.tables.keys())}"
+ )
# Import specific models
- from models.database.user import User
- from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
- from models.database.vendor import Vendor
- from models.database.product import Product
from models.database.marketplace_import_job import MarketplaceImportJob
+ from models.database.marketplace_product import MarketplaceProduct
+ from models.database.product import Product
+ from models.database.user import User
+ from models.database.vendor import Vendor
print("[OK] All database models imported successfully")
@@ -146,13 +160,23 @@ def verify_model_structure():
# Test API models
try:
import models.schema
+
print("[OK] API models package imported")
# Test specific API model imports
- api_modules = ['base', 'auth', 'product', 'inventory', 'vendor ', 'marketplace', 'admin', 'stats']
+ api_modules = [
+ "base",
+ "auth",
+ "product",
+ "inventory",
+ "vendor ",
+ "marketplace",
+ "admin",
+ "stats",
+ ]
for module in api_modules:
try:
- __import__(f'models.api.{module}')
+ __import__(f"models.api.{module}")
print(f" * models.api.{module}")
except ImportError:
print(f" ? models.api.{module} (not found, optional)")
@@ -176,7 +200,7 @@ def check_project_structure():
"models/database/inventory.py",
"app/core/config.py",
"alembic/env.py",
- "alembic.ini"
+ "alembic.ini",
]
for path in critical_paths:
@@ -189,7 +213,7 @@ def check_project_structure():
init_files = [
"models/__init__.py",
"models/database/__init__.py",
- "models/api/__init__.py"
+ "models/api/__init__.py",
]
print(f"\n[INIT] Checking __init__.py files...")
@@ -221,7 +245,9 @@ if __name__ == "__main__":
print("Next steps:")
print(" 1. Run 'make dev' to start your FastAPI server")
print(" 2. Visit http://localhost:8000/docs for interactive API docs")
- print(" 3. Use your API endpoints for authentication, products, inventory, etc.")
+ print(
+ " 3. Use your API endpoints for authentication, products, inventory, etc."
+ )
sys.exit(0)
else:
print()
diff --git a/storage/backends.py b/storage/backends.py
index 4eb1b1ea..999a775e 100644
--- a/storage/backends.py
+++ b/storage/backends.py
@@ -1 +1 @@
-# Storage backend implementations
+# Storage backend implementations
diff --git a/storage/utils.py b/storage/utils.py
index 99c27a1d..7d8066b1 100644
--- a/storage/utils.py
+++ b/storage/utils.py
@@ -1 +1 @@
-# Storage utilities
+# Storage utilities
diff --git a/tasks/analytics_tasks.py b/tasks/analytics_tasks.py
index 6adffe9b..7882ffad 100644
--- a/tasks/analytics_tasks.py
+++ b/tasks/analytics_tasks.py
@@ -1 +1 @@
-# Analytics and reporting tasks
+# Analytics and reporting tasks
diff --git a/tasks/backup_tasks.py b/tasks/backup_tasks.py
index ce44c3ec..13c5b701 100644
--- a/tasks/backup_tasks.py
+++ b/tasks/backup_tasks.py
@@ -1 +1 @@
-# Backup and recovery tasks
+# Backup and recovery tasks
diff --git a/tasks/cleanup_tasks.py b/tasks/cleanup_tasks.py
index 822a2e92..da7c9643 100644
--- a/tasks/cleanup_tasks.py
+++ b/tasks/cleanup_tasks.py
@@ -1 +1 @@
-# Data cleanup and maintenance tasks
+# Data cleanup and maintenance tasks
diff --git a/tasks/email_tasks.py b/tasks/email_tasks.py
index 81bf1fd3..3576fcb6 100644
--- a/tasks/email_tasks.py
+++ b/tasks/email_tasks.py
@@ -1 +1 @@
-# Email sending tasks
+# Email sending tasks
diff --git a/tasks/marketplace_import.py b/tasks/marketplace_import.py
index 6c318097..4af01878 100644
--- a/tasks/marketplace_import.py
+++ b/tasks/marketplace_import.py
@@ -1,12 +1,12 @@
-# Marketplace CSV import tasks
+# Marketplace CSV import tasks
# app/tasks/background_tasks.py
import logging
from datetime import datetime, timezone
from app.core.database import SessionLocal
+from app.utils.csv_processor import CSVProcessor
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
-from app.utils.csv_processor import CSVProcessor
logger = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ async def process_marketplace_import(
url: str,
marketplace: str,
vendor_id: int, # FIXED: Changed from vendor_name to vendor_id
- batch_size: int = 1000
+ batch_size: int = 1000,
):
"""Background task to process marketplace CSV import."""
db = SessionLocal()
@@ -60,7 +60,7 @@ async def process_marketplace_import(
marketplace,
vendor_id, # FIXED: Pass vendor_id instead of vendor_name
batch_size,
- db
+ db,
)
# Update job with results
diff --git a/tasks/media_processing.py b/tasks/media_processing.py
index 5a57bed2..f2e05cc6 100644
--- a/tasks/media_processing.py
+++ b/tasks/media_processing.py
@@ -1 +1 @@
-# Image processing and optimization tasks
+# Image processing and optimization tasks
diff --git a/tasks/search_indexing.py b/tasks/search_indexing.py
index 64f776af..c73a2b86 100644
--- a/tasks/search_indexing.py
+++ b/tasks/search_indexing.py
@@ -1 +1 @@
-# Search index maintenance tasks
+# Search index maintenance tasks
diff --git a/tasks/task_manager.py b/tasks/task_manager.py
index d7fe33a0..655b273c 100644
--- a/tasks/task_manager.py
+++ b/tasks/task_manager.py
@@ -1 +1 @@
-# Celery configuration and task management
+# Celery configuration and task management
diff --git a/tests/conftest.py b/tests/conftest.py
index 0a800a28..6f922731 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,13 +7,13 @@ from sqlalchemy.pool import StaticPool
from app.core.database import Base, get_db
from main import app
+from models.database.inventory import Inventory
# Import all models to ensure they're registered with Base metadata
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
-from models.database.vendor import Vendor
from models.database.product import Product
-from models.database.inventory import Inventory
from models.database.user import User
+from models.database.vendor import Vendor
# Use in-memory SQLite database for tests
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"
diff --git a/tests/fixtures/auth_fixtures.py b/tests/fixtures/auth_fixtures.py
index fe712109..c2618c66 100644
--- a/tests/fixtures/auth_fixtures.py
+++ b/tests/fixtures/auth_fixtures.py
@@ -53,6 +53,7 @@ def test_admin(db, auth_manager):
db.expunge(admin)
return admin
+
@pytest.fixture
def another_admin(db, auth_manager):
"""Create another test admin user for testing admin-to-admin interactions"""
@@ -72,6 +73,7 @@ def another_admin(db, auth_manager):
db.expunge(admin)
return admin
+
@pytest.fixture
def other_user(db, auth_manager):
"""Create a different user for testing access controls"""
diff --git a/tests/fixtures/customer_fixtures.py b/tests/fixtures/customer_fixtures.py
index b9235a88..5666834b 100644
--- a/tests/fixtures/customer_fixtures.py
+++ b/tests/fixtures/customer_fixtures.py
@@ -1,5 +1,6 @@
# tests/fixtures/customer_fixtures.py
import pytest
+
from models.database.customer import Customer, CustomerAddress
diff --git a/tests/fixtures/marketplace_import_job_fixtures.py b/tests/fixtures/marketplace_import_job_fixtures.py
index d61e2441..c79decd7 100644
--- a/tests/fixtures/marketplace_import_job_fixtures.py
+++ b/tests/fixtures/marketplace_import_job_fixtures.py
@@ -1,5 +1,6 @@
# tests/fixtures/marketplace_import_job_fixtures.py
import pytest
+
from models.database.marketplace_import_job import MarketplaceImportJob
diff --git a/tests/fixtures/marketplace_product_fixtures.py b/tests/fixtures/marketplace_product_fixtures.py
index cb2b148d..a27f9b87 100644
--- a/tests/fixtures/marketplace_product_fixtures.py
+++ b/tests/fixtures/marketplace_product_fixtures.py
@@ -117,7 +117,9 @@ def marketplace_product_factory():
@pytest.fixture
-def test_marketplace_product_with_inventory(db, test_marketplace_product, test_inventory):
+def test_marketplace_product_with_inventory(
+ db, test_marketplace_product, test_inventory
+):
"""MarketplaceProduct with associated inventory record."""
# Ensure they're linked by GTIN
if test_marketplace_product.gtin != test_inventory.gtin:
@@ -126,6 +128,6 @@ def test_marketplace_product_with_inventory(db, test_marketplace_product, test_i
db.refresh(test_inventory)
return {
- 'marketplace_product': test_marketplace_product,
- 'inventory': test_inventory
+ "marketplace_product": test_marketplace_product,
+ "inventory": test_inventory,
}
diff --git a/tests/fixtures/testing_fixtures.py b/tests/fixtures/testing_fixtures.py
index ffbcb335..5e3f0cc0 100644
--- a/tests/fixtures/testing_fixtures.py
+++ b/tests/fixtures/testing_fixtures.py
@@ -8,8 +8,9 @@ This module provides fixtures for:
- Additional testing utilities
"""
-import pytest
from unittest.mock import Mock
+
+import pytest
from sqlalchemy.exc import SQLAlchemyError
@@ -26,7 +27,7 @@ def empty_db(db):
"inventory", # Fixed: singular not plural
"products", # Referenced by products
"vendors", # Has foreign key to users
- "users" # Base table
+ "users", # Base table
]
for table in tables_to_clear:
diff --git a/tests/fixtures/vendor_fixtures.py b/tests/fixtures/vendor_fixtures.py
index 04cb902c..4b2443e4 100644
--- a/tests/fixtures/vendor_fixtures.py
+++ b/tests/fixtures/vendor_fixtures.py
@@ -1,10 +1,11 @@
# tests/fixtures/vendor_fixtures.py
import uuid
+
import pytest
-from models.database.vendor import Vendor
-from models.database.product import Product
from models.database.inventory import Inventory
+from models.database.product import Product
+from models.database.vendor import Vendor
@pytest.fixture
@@ -185,6 +186,7 @@ def multiple_inventory_entries(db, multiple_products, test_vendor):
def create_unique_vendor_factory():
"""Factory function to create unique vendors in tests"""
+
def _create_vendor(db, owner_user_id, **kwargs):
unique_id = str(uuid.uuid4())[:8]
defaults = {
diff --git a/tests/integration/api/v1/test_admin_endpoints.py b/tests/integration/api/v1/test_admin_endpoints.py
index 61b940a1..0ede38e9 100644
--- a/tests/integration/api/v1/test_admin_endpoints.py
+++ b/tests/integration/api/v1/test_admin_endpoints.py
@@ -59,7 +59,9 @@ class TestAdminAPI:
assert response.status_code == 400 # Business logic error
data = response.json()
assert data["error_code"] == "CANNOT_MODIFY_SELF"
- assert "Cannot perform 'deactivate account' on your own account" in data["message"]
+ assert (
+ "Cannot perform 'deactivate account' on your own account" in data["message"]
+ )
def test_toggle_user_status_cannot_modify_admin(
self, client, admin_headers, test_admin, another_admin
@@ -85,7 +87,9 @@ class TestAdminAPI:
# Check that test_vendor is in the response
vendor_codes = [
- vendor ["vendor_code"] for vendor in data["vendors"] if "vendor_code" in vendor
+ vendor["vendor_code"]
+ for vendor in data["vendors"]
+ if "vendor_code" in vendor
]
assert test_vendor.vendor_code in vendor_codes
@@ -98,7 +102,7 @@ class TestAdminAPI:
assert data["error_code"] == "ADMIN_REQUIRED"
def test_verify_vendor_admin(self, client, admin_headers, test_vendor):
- """Test admin verifying/unverifying vendor """
+ """Test admin verifying/unverifying vendor"""
response = client.put(
f"/api/v1/admin/vendors/{test_vendor.id}/verify", headers=admin_headers
)
@@ -109,8 +113,10 @@ class TestAdminAPI:
assert test_vendor.vendor_code in message
def test_verify_vendor_not_found(self, client, admin_headers):
- """Test admin verifying non-existent vendor """
- response = client.put("/api/v1/admin/vendors/99999/verify", headers=admin_headers)
+ """Test admin verifying non-existent vendor"""
+ response = client.put(
+ "/api/v1/admin/vendors/99999/verify", headers=admin_headers
+ )
assert response.status_code == 404
data = response.json()
@@ -129,8 +135,10 @@ class TestAdminAPI:
assert test_vendor.vendor_code in message
def test_toggle_vendor_status_not_found(self, client, admin_headers):
- """Test admin toggling status for non-existent vendor """
- response = client.put("/api/v1/admin/vendors/99999/status", headers=admin_headers)
+ """Test admin toggling status for non-existent vendor"""
+ response = client.put(
+ "/api/v1/admin/vendors/99999/status", headers=admin_headers
+ )
assert response.status_code == 404
data = response.json()
@@ -166,7 +174,8 @@ class TestAdminAPI:
data = response.json()
assert len(data) >= 1
assert all(
- job["marketplace"] == test_marketplace_import_job.marketplace for job in data
+ job["marketplace"] == test_marketplace_import_job.marketplace
+ for job in data
)
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):
diff --git a/tests/integration/api/v1/test_auth_endpoints.py b/tests/integration/api/v1/test_auth_endpoints.py
index dd41d40d..83e0ca07 100644
--- a/tests/integration/api/v1/test_auth_endpoints.py
+++ b/tests/integration/api/v1/test_auth_endpoints.py
@@ -1,7 +1,8 @@
# tests/integration/api/v1/test_auth_endpoints.py
+from datetime import datetime, timedelta, timezone
+
import pytest
from jose import jwt
-from datetime import datetime, timedelta, timezone
@pytest.mark.integration
@@ -178,8 +179,7 @@ class TestAuthenticationAPI:
def test_get_current_user_invalid_token(self, client):
"""Test getting current user with invalid token"""
response = client.get(
- "/api/v1/auth/me",
- headers={"Authorization": "Bearer invalid_token_here"}
+ "/api/v1/auth/me", headers={"Authorization": "Bearer invalid_token_here"}
)
assert response.status_code == 401
@@ -201,14 +201,11 @@ class TestAuthenticationAPI:
}
expired_token = jwt.encode(
- expired_payload,
- auth_manager.secret_key,
- algorithm=auth_manager.algorithm
+ expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
response = client.get(
- "/api/v1/auth/me",
- headers={"Authorization": f"Bearer {expired_token}"}
+ "/api/v1/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
)
assert response.status_code == 401
@@ -321,4 +318,3 @@ class TestAuthManager:
user = auth_manager.authenticate_user(db, "nonexistent", "password")
assert user is None
-
diff --git a/tests/integration/api/v1/test_filtering.py b/tests/integration/api/v1/test_filtering.py
index b28d3862..8e89fe2e 100644
--- a/tests/integration/api/v1/test_filtering.py
+++ b/tests/integration/api/v1/test_filtering.py
@@ -14,19 +14,34 @@ class TestFiltering:
"""Test filtering products by brand successfully"""
# Create products with different brands using unique IDs
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
products = [
- MarketplaceProduct(marketplace_product_id=f"BRAND1_{unique_suffix}", title="MarketplaceProduct 1", brand="BrandA"),
- MarketplaceProduct(marketplace_product_id=f"BRAND2_{unique_suffix}", title="MarketplaceProduct 2", brand="BrandB"),
- MarketplaceProduct(marketplace_product_id=f"BRAND3_{unique_suffix}", title="MarketplaceProduct 3", brand="BrandA"),
+ MarketplaceProduct(
+ marketplace_product_id=f"BRAND1_{unique_suffix}",
+ title="MarketplaceProduct 1",
+ brand="BrandA",
+ ),
+ MarketplaceProduct(
+ marketplace_product_id=f"BRAND2_{unique_suffix}",
+ title="MarketplaceProduct 2",
+ brand="BrandB",
+ ),
+ MarketplaceProduct(
+ marketplace_product_id=f"BRAND3_{unique_suffix}",
+ title="MarketplaceProduct 3",
+ brand="BrandA",
+ ),
]
db.add_all(products)
db.commit()
# Filter by BrandA
- response = client.get("/api/v1/marketplace/product?brand=BrandA", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?brand=BrandA", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # At least our test products
@@ -37,7 +52,9 @@ class TestFiltering:
assert product["brand"] == "BrandA"
# Filter by BrandB
- response = client.get("/api/v1/marketplace/product?brand=BrandB", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?brand=BrandB", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1 # At least our test product
@@ -45,37 +62,57 @@ class TestFiltering:
def test_product_marketplace_filter_success(self, client, auth_headers, db):
"""Test filtering products by marketplace successfully"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
products = [
- MarketplaceProduct(marketplace_product_id=f"MKT1_{unique_suffix}", title="MarketplaceProduct 1", marketplace="Amazon"),
- MarketplaceProduct(marketplace_product_id=f"MKT2_{unique_suffix}", title="MarketplaceProduct 2", marketplace="eBay"),
- MarketplaceProduct(marketplace_product_id=f"MKT3_{unique_suffix}", title="MarketplaceProduct 3", marketplace="Amazon"),
+ MarketplaceProduct(
+ marketplace_product_id=f"MKT1_{unique_suffix}",
+ title="MarketplaceProduct 1",
+ marketplace="Amazon",
+ ),
+ MarketplaceProduct(
+ marketplace_product_id=f"MKT2_{unique_suffix}",
+ title="MarketplaceProduct 2",
+ marketplace="eBay",
+ ),
+ MarketplaceProduct(
+ marketplace_product_id=f"MKT3_{unique_suffix}",
+ title="MarketplaceProduct 3",
+ marketplace="Amazon",
+ ),
]
db.add_all(products)
db.commit()
- response = client.get("/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # At least our test products
# Verify all returned products have Amazon marketplace
- amazon_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
+ amazon_products = [
+ p
+ for p in data["products"]
+ if p["marketplace_product_id"].endswith(unique_suffix)
+ ]
for product in amazon_products:
assert product["marketplace"] == "Amazon"
def test_product_search_filter_success(self, client, auth_headers, db):
"""Test searching products by text successfully"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
products = [
MarketplaceProduct(
marketplace_product_id=f"SEARCH1_{unique_suffix}",
title=f"Apple iPhone {unique_suffix}",
- description="Smartphone"
+ description="Smartphone",
),
MarketplaceProduct(
marketplace_product_id=f"SEARCH2_{unique_suffix}",
@@ -85,7 +122,7 @@ class TestFiltering:
MarketplaceProduct(
marketplace_product_id=f"SEARCH3_{unique_suffix}",
title=f"iPad Tablet {unique_suffix}",
- description="Apple tablet"
+ description="Apple tablet",
),
]
@@ -93,13 +130,17 @@ class TestFiltering:
db.commit()
# Search for "Apple"
- response = client.get(f"/api/v1/marketplace/product?search=Apple", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/marketplace/product?search=Apple", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # iPhone and iPad
# Search for "phone"
- response = client.get(f"/api/v1/marketplace/product?search=phone", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/marketplace/product?search=phone", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # iPhone and Galaxy
@@ -107,6 +148,7 @@ class TestFiltering:
def test_combined_filters_success(self, client, auth_headers, db):
"""Test combining multiple filters successfully"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
products = [
@@ -135,14 +177,19 @@ class TestFiltering:
# Filter by brand AND marketplace
response = client.get(
- "/api/v1/marketplace/product?brand=Apple&marketplace=Amazon", headers=auth_headers
+ "/api/v1/marketplace/product?brand=Apple&marketplace=Amazon",
+ headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1 # At least iPhone matches both
# Find our specific test product
- matching_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
+ matching_products = [
+ p
+ for p in data["products"]
+ if p["marketplace_product_id"].endswith(unique_suffix)
+ ]
for product in matching_products:
assert product["brand"] == "Apple"
assert product["marketplace"] == "Amazon"
@@ -150,7 +197,8 @@ class TestFiltering:
def test_filter_with_no_results(self, client, auth_headers):
"""Test filtering with criteria that returns no results"""
response = client.get(
- "/api/v1/marketplace/product?brand=NonexistentBrand123456", headers=auth_headers
+ "/api/v1/marketplace/product?brand=NonexistentBrand123456",
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -161,6 +209,7 @@ class TestFiltering:
def test_filter_case_insensitive(self, client, auth_headers, db):
"""Test that filters are case-insensitive"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
product = MarketplaceProduct(
@@ -174,7 +223,10 @@ class TestFiltering:
# Test different case variations
for brand_filter in ["TestBrand", "testbrand", "TESTBRAND"]:
- response = client.get(f"/api/v1/marketplace/product?brand={brand_filter}", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/marketplace/product?brand={brand_filter}",
+ headers=auth_headers,
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
@@ -183,9 +235,14 @@ class TestFiltering:
"""Test behavior with invalid filter parameters"""
# Test with very long filter values
long_brand = "A" * 1000
- response = client.get(f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers
+ )
assert response.status_code == 200 # Should handle gracefully
# Test with special characters
- response = client.get("/api/v1/marketplace/product?brand=", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?brand=",
+ headers=auth_headers,
+ )
assert response.status_code == 200 # Should handle gracefully
diff --git a/tests/integration/api/v1/test_inventory_endpoints.py b/tests/integration/api/v1/test_inventory_endpoints.py
index bf121bba..1e00dc34 100644
--- a/tests/integration/api/v1/test_inventory_endpoints.py
+++ b/tests/integration/api/v1/test_inventory_endpoints.py
@@ -17,7 +17,9 @@ class TestInventoryAPI:
"quantity": 100,
}
- response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
+ response = client.post(
+ "/api/v1/inventory", headers=auth_headers, json=inventory_data
+ )
assert response.status_code == 200
data = response.json()
@@ -38,7 +40,9 @@ class TestInventoryAPI:
"quantity": 75,
}
- response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
+ response = client.post(
+ "/api/v1/inventory", headers=auth_headers, json=inventory_data
+ )
assert response.status_code == 200
data = response.json()
@@ -52,7 +56,9 @@ class TestInventoryAPI:
"quantity": 100,
}
- response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
+ response = client.post(
+ "/api/v1/inventory", headers=auth_headers, json=inventory_data
+ )
assert response.status_code == 422
data = response.json()
@@ -60,7 +66,9 @@ class TestInventoryAPI:
assert data["status_code"] == 422
assert "GTIN is required" in data["message"]
- def test_set_inventory_invalid_quantity_validation_error(self, client, auth_headers):
+ def test_set_inventory_invalid_quantity_validation_error(
+ self, client, auth_headers
+ ):
"""Test setting inventory with invalid quantity returns InvalidQuantityException"""
inventory_data = {
"gtin": "1234567890123",
@@ -68,7 +76,9 @@ class TestInventoryAPI:
"quantity": -10, # Negative quantity
}
- response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
+ response = client.post(
+ "/api/v1/inventory", headers=auth_headers, json=inventory_data
+ )
assert response.status_code in [400, 422]
data = response.json()
@@ -138,7 +148,9 @@ class TestInventoryAPI:
data = response.json()
assert data["quantity"] == 35 # 50 - 15
- def test_remove_inventory_insufficient_returns_business_logic_error(self, client, auth_headers, db):
+ def test_remove_inventory_insufficient_returns_business_logic_error(
+ self, client, auth_headers, db
+ ):
"""Test removing more inventory than available returns InsufficientInventoryException"""
# Create initial inventory
inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=10)
@@ -184,7 +196,9 @@ class TestInventoryAPI:
assert data["error_code"] == "INVENTORY_NOT_FOUND"
assert data["status_code"] == 404
- def test_negative_inventory_not_allowed_business_logic_error(self, client, auth_headers, db):
+ def test_negative_inventory_not_allowed_business_logic_error(
+ self, client, auth_headers, db
+ ):
"""Test operations resulting in negative inventory returns NegativeInventoryException"""
# Create initial inventory
inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=5)
@@ -204,14 +218,21 @@ class TestInventoryAPI:
assert response.status_code == 400
data = response.json()
# This might be caught as INSUFFICIENT_INVENTORY or NEGATIVE_INVENTORY_NOT_ALLOWED
- assert data["error_code"] in ["INSUFFICIENT_INVENTORY", "NEGATIVE_INVENTORY_NOT_ALLOWED"]
+ assert data["error_code"] in [
+ "INSUFFICIENT_INVENTORY",
+ "NEGATIVE_INVENTORY_NOT_ALLOWED",
+ ]
assert data["status_code"] == 400
def test_get_inventory_by_gtin_success(self, client, auth_headers, db):
"""Test getting inventory summary for GTIN successfully"""
# Create inventory in multiple locations
- inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
- inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25)
+ inventory1 = Inventory(
+ gtin="1234567890123", location="WAREHOUSE_A", quantity=50
+ )
+ inventory2 = Inventory(
+ gtin="1234567890123", location="WAREHOUSE_B", quantity=25
+ )
db.add_all([inventory1, inventory2])
db.commit()
@@ -238,12 +259,18 @@ class TestInventoryAPI:
def test_get_total_inventory_success(self, client, auth_headers, db):
"""Test getting total inventory for GTIN successfully"""
# Create inventory in multiple locations
- inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
- inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25)
+ inventory1 = Inventory(
+ gtin="1234567890123", location="WAREHOUSE_A", quantity=50
+ )
+ inventory2 = Inventory(
+ gtin="1234567890123", location="WAREHOUSE_B", quantity=25
+ )
db.add_all([inventory1, inventory2])
db.commit()
- response = client.get("/api/v1/inventory/1234567890123/total", headers=auth_headers)
+ response = client.get(
+ "/api/v1/inventory/1234567890123/total", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
@@ -253,7 +280,9 @@ class TestInventoryAPI:
def test_get_total_inventory_not_found(self, client, auth_headers):
"""Test getting total inventory for nonexistent GTIN returns InventoryNotFoundException"""
- response = client.get("/api/v1/inventory/9999999999999/total", headers=auth_headers)
+ response = client.get(
+ "/api/v1/inventory/9999999999999/total", headers=auth_headers
+ )
assert response.status_code == 404
data = response.json()
@@ -263,8 +292,12 @@ class TestInventoryAPI:
def test_get_all_inventory_success(self, client, auth_headers, db):
"""Test getting all inventory entries successfully"""
# Create some inventory entries
- inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
- inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25)
+ inventory1 = Inventory(
+ gtin="1234567890123", location="WAREHOUSE_A", quantity=50
+ )
+ inventory2 = Inventory(
+ gtin="9876543210987", location="WAREHOUSE_B", quantity=25
+ )
db.add_all([inventory1, inventory2])
db.commit()
@@ -277,20 +310,28 @@ class TestInventoryAPI:
def test_get_all_inventory_with_filters(self, client, auth_headers, db):
"""Test getting inventory entries with filtering"""
# Create inventory entries
- inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
- inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25)
+ inventory1 = Inventory(
+ gtin="1234567890123", location="WAREHOUSE_A", quantity=50
+ )
+ inventory2 = Inventory(
+ gtin="9876543210987", location="WAREHOUSE_B", quantity=25
+ )
db.add_all([inventory1, inventory2])
db.commit()
# Filter by location
- response = client.get("/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers)
+ response = client.get(
+ "/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
for inventory in data:
assert inventory["location"] == "WAREHOUSE_A"
# Filter by GTIN
- response = client.get("/api/v1/inventory?gtin=1234567890123", headers=auth_headers)
+ response = client.get(
+ "/api/v1/inventory?gtin=1234567890123", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
for inventory in data:
@@ -390,7 +431,9 @@ class TestInventoryAPI:
"quantity": 100,
}
- response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
+ response = client.post(
+ "/api/v1/inventory", headers=auth_headers, json=inventory_data
+ )
# This depends on whether your service validates locations
if response.status_code == 404:
diff --git a/tests/integration/api/v1/test_marketplace_import_job_endpoints.py b/tests/integration/api/v1/test_marketplace_import_job_endpoints.py
index 9e1988cc..37b617f1 100644
--- a/tests/integration/api/v1/test_marketplace_import_job_endpoints.py
+++ b/tests/integration/api/v1/test_marketplace_import_job_endpoints.py
@@ -8,9 +8,11 @@ import pytest
@pytest.mark.api
@pytest.mark.marketplace
class TestMarketplaceImportJobAPI:
- def test_import_from_marketplace(self, client, auth_headers, test_vendor, test_user):
+ def test_import_from_marketplace(
+ self, client, auth_headers, test_vendor, test_user
+ ):
"""Test marketplace import endpoint - just test job creation"""
- # Ensure user owns the vendor
+ # Ensure user owns the vendor
test_vendor.owner_user_id = test_user.id
import_data = {
@@ -32,7 +34,7 @@ class TestMarketplaceImportJobAPI:
assert data["vendor_id"] == test_vendor.id
def test_import_from_marketplace_invalid_vendor(self, client, auth_headers):
- """Test marketplace import with invalid vendor """
+ """Test marketplace import with invalid vendor"""
import_data = {
"url": "https://example.com/products.csv",
"marketplace": "TestMarket",
@@ -48,7 +50,9 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "VENDOR_NOT_FOUND"
assert "NONEXISTENT" in data["message"]
- def test_import_from_marketplace_unauthorized_vendor(self, client, auth_headers, test_vendor, other_user):
+ def test_import_from_marketplace_unauthorized_vendor(
+ self, client, auth_headers, test_vendor, other_user
+ ):
"""Test marketplace import with unauthorized vendor access"""
# Set vendor owner to different user
test_vendor.owner_user_id = other_user.id
@@ -85,8 +89,10 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "VALIDATION_ERROR"
assert "Request validation failed" in data["message"]
- def test_import_from_marketplace_admin_access(self, client, admin_headers, test_vendor):
- """Test that admin can import for any vendor """
+ def test_import_from_marketplace_admin_access(
+ self, client, admin_headers, test_vendor
+ ):
+ """Test that admin can import for any vendor"""
import_data = {
"url": "https://example.com/products.csv",
"marketplace": "AdminMarket",
@@ -94,7 +100,9 @@ class TestMarketplaceImportJobAPI:
}
response = client.post(
- "/api/v1/marketplace/import-product", headers=admin_headers, json=import_data
+ "/api/v1/marketplace/import-product",
+ headers=admin_headers,
+ json=import_data,
)
assert response.status_code == 200
@@ -102,11 +110,13 @@ class TestMarketplaceImportJobAPI:
assert data["marketplace"] == "AdminMarket"
assert data["vendor_code"] == test_vendor.vendor_code
- def test_get_marketplace_import_status(self, client, auth_headers, test_marketplace_import_job):
+ def test_get_marketplace_import_status(
+ self, client, auth_headers, test_marketplace_import_job
+ ):
"""Test getting marketplace import status"""
response = client.get(
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -118,8 +128,7 @@ class TestMarketplaceImportJobAPI:
def test_get_marketplace_import_status_not_found(self, client, auth_headers):
"""Test getting status of non-existent import job"""
response = client.get(
- "/api/v1/marketplace/import-status/99999",
- headers=auth_headers
+ "/api/v1/marketplace/import-status/99999", headers=auth_headers
)
assert response.status_code == 404
@@ -127,21 +136,25 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
assert "99999" in data["message"]
- def test_get_marketplace_import_status_unauthorized(self, client, auth_headers, test_marketplace_import_job, other_user):
+ def test_get_marketplace_import_status_unauthorized(
+ self, client, auth_headers, test_marketplace_import_job, other_user
+ ):
"""Test getting status of unauthorized import job"""
# Change job owner to other user
test_marketplace_import_job.user_id = other_user.id
response = client.get(
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 403
data = response.json()
assert data["error_code"] == "IMPORT_JOB_NOT_OWNED"
- def test_get_marketplace_import_jobs(self, client, auth_headers, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs(
+ self, client, auth_headers, test_marketplace_import_job
+ ):
"""Test getting marketplace import jobs"""
response = client.get("/api/v1/marketplace/import-jobs", headers=auth_headers)
@@ -154,11 +167,13 @@ class TestMarketplaceImportJobAPI:
job_ids = [job["job_id"] for job in data]
assert test_marketplace_import_job.id in job_ids
- def test_get_marketplace_import_jobs_with_filters(self, client, auth_headers, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs_with_filters(
+ self, client, auth_headers, test_marketplace_import_job
+ ):
"""Test getting import jobs with filters"""
response = client.get(
f"/api/v1/marketplace/import-jobs?marketplace={test_marketplace_import_job.marketplace}",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -167,13 +182,15 @@ class TestMarketplaceImportJobAPI:
assert len(data) >= 1
for job in data:
- assert test_marketplace_import_job.marketplace.lower() in job["marketplace"].lower()
+ assert (
+ test_marketplace_import_job.marketplace.lower()
+ in job["marketplace"].lower()
+ )
def test_get_marketplace_import_jobs_pagination(self, client, auth_headers):
"""Test import jobs pagination"""
response = client.get(
- "/api/v1/marketplace/import-jobs?skip=0&limit=5",
- headers=auth_headers
+ "/api/v1/marketplace/import-jobs?skip=0&limit=5", headers=auth_headers
)
assert response.status_code == 200
@@ -181,9 +198,13 @@ class TestMarketplaceImportJobAPI:
assert isinstance(data, list)
assert len(data) <= 5
- def test_get_marketplace_import_stats(self, client, auth_headers, test_marketplace_import_job):
+ def test_get_marketplace_import_stats(
+ self, client, auth_headers, test_marketplace_import_job
+ ):
"""Test getting marketplace import statistics"""
- response = client.get("/api/v1/marketplace/marketplace-import-stats", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/marketplace-import-stats", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
@@ -195,12 +216,15 @@ class TestMarketplaceImportJobAPI:
assert isinstance(data["total_jobs"], int)
assert data["total_jobs"] >= 1
- def test_cancel_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db):
+ def test_cancel_marketplace_import_job(
+ self, client, auth_headers, test_user, test_vendor, db
+ ):
"""Test cancelling a marketplace import job"""
# Create a pending job that can be cancelled
- from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
+ from models.database.marketplace_import_job import MarketplaceImportJob
+
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="pending",
@@ -219,8 +243,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.put(
- f"/api/v1/marketplace/import-jobs/{job.id}/cancel",
- headers=auth_headers
+ f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=auth_headers
)
assert response.status_code == 200
@@ -232,15 +255,16 @@ class TestMarketplaceImportJobAPI:
def test_cancel_marketplace_import_job_not_found(self, client, auth_headers):
"""Test cancelling non-existent import job"""
response = client.put(
- "/api/v1/marketplace/import-jobs/99999/cancel",
- headers=auth_headers
+ "/api/v1/marketplace/import-jobs/99999/cancel", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
- def test_cancel_marketplace_import_job_cannot_cancel(self, client, auth_headers, test_marketplace_import_job, db):
+ def test_cancel_marketplace_import_job_cannot_cancel(
+ self, client, auth_headers, test_marketplace_import_job, db
+ ):
"""Test cancelling a job that cannot be cancelled"""
# Set job to completed status
test_marketplace_import_job.status = "completed"
@@ -248,7 +272,7 @@ class TestMarketplaceImportJobAPI:
response = client.put(
f"/api/v1/marketplace/import-jobs/{test_marketplace_import_job.id}/cancel",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 400
@@ -256,12 +280,15 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "IMPORT_JOB_CANNOT_BE_CANCELLED"
assert "completed" in data["message"]
- def test_delete_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db):
+ def test_delete_marketplace_import_job(
+ self, client, auth_headers, test_user, test_vendor, db
+ ):
"""Test deleting a marketplace import job"""
# Create a completed job that can be deleted
- from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
+ from models.database.marketplace_import_job import MarketplaceImportJob
+
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="completed",
@@ -280,8 +307,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.delete(
- f"/api/v1/marketplace/import-jobs/{job.id}",
- headers=auth_headers
+ f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers
)
assert response.status_code == 200
@@ -291,20 +317,22 @@ class TestMarketplaceImportJobAPI:
def test_delete_marketplace_import_job_not_found(self, client, auth_headers):
"""Test deleting non-existent import job"""
response = client.delete(
- "/api/v1/marketplace/import-jobs/99999",
- headers=auth_headers
+ "/api/v1/marketplace/import-jobs/99999", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
- def test_delete_marketplace_import_job_cannot_delete(self, client, auth_headers, test_user, test_vendor, db):
+ def test_delete_marketplace_import_job_cannot_delete(
+ self, client, auth_headers, test_user, test_vendor, db
+ ):
"""Test deleting a job that cannot be deleted"""
# Create a pending job that cannot be deleted
- from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
+ from models.database.marketplace_import_job import MarketplaceImportJob
+
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="pending",
@@ -323,8 +351,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.delete(
- f"/api/v1/marketplace/import-jobs/{job.id}",
- headers=auth_headers
+ f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers
)
assert response.status_code == 400
@@ -352,7 +379,9 @@ class TestMarketplaceImportJobAPI:
data = response.json()
assert data["error_code"] == "INVALID_TOKEN"
- def test_admin_can_access_all_jobs(self, client, admin_headers, test_marketplace_import_job):
+ def test_admin_can_access_all_jobs(
+ self, client, admin_headers, test_marketplace_import_job
+ ):
"""Test that admin can access all import jobs"""
response = client.get("/api/v1/marketplace/import-jobs", headers=admin_headers)
@@ -363,23 +392,28 @@ class TestMarketplaceImportJobAPI:
job_ids = [job["job_id"] for job in data]
assert test_marketplace_import_job.id in job_ids
- def test_admin_can_view_any_job_status(self, client, admin_headers, test_marketplace_import_job):
+ def test_admin_can_view_any_job_status(
+ self, client, admin_headers, test_marketplace_import_job
+ ):
"""Test that admin can view any job status"""
response = client.get(
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
- headers=admin_headers
+ headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["job_id"] == test_marketplace_import_job.id
- def test_admin_can_cancel_any_job(self, client, admin_headers, test_user, test_vendor, db):
+ def test_admin_can_cancel_any_job(
+ self, client, admin_headers, test_user, test_vendor, db
+ ):
"""Test that admin can cancel any job"""
# Create a pending job owned by different user
- from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
+ from models.database.marketplace_import_job import MarketplaceImportJob
+
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="pending",
@@ -398,8 +432,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.put(
- f"/api/v1/marketplace/import-jobs/{job.id}/cancel",
- headers=admin_headers
+ f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=admin_headers
)
assert response.status_code == 200
diff --git a/tests/integration/api/v1/test_marketplace_product_export.py b/tests/integration/api/v1/test_marketplace_product_export.py
index 02438470..272d1a6e 100644
--- a/tests/integration/api/v1/test_marketplace_product_export.py
+++ b/tests/integration/api/v1/test_marketplace_product_export.py
@@ -1,7 +1,7 @@
# tests/integration/api/v1/test_export.py
import csv
-from io import StringIO
import uuid
+from io import StringIO
import pytest
@@ -13,9 +13,13 @@ from models.database.marketplace_product import MarketplaceProduct
@pytest.mark.performance # for the performance test
class TestExportFunctionality:
- def test_csv_export_basic_success(self, client, auth_headers, test_marketplace_product):
+ def test_csv_export_basic_success(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test basic CSV export functionality successfully"""
- response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/export-csv", headers=auth_headers
+ )
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
@@ -26,13 +30,22 @@ class TestExportFunctionality:
# Check header row
header = next(csv_reader)
- expected_fields = ["marketplace_product_id", "title", "description", "price", "marketplace"]
+ expected_fields = [
+ "marketplace_product_id",
+ "title",
+ "description",
+ "price",
+ "marketplace",
+ ]
for field in expected_fields:
assert field in header
# Verify test product appears in export
- csv_lines = csv_content.split('\n')
- test_product_found = any(test_marketplace_product.marketplace_product_id in line for line in csv_lines)
+ csv_lines = csv_content.split("\n")
+ test_product_found = any(
+ test_marketplace_product.marketplace_product_id in line
+ for line in csv_lines
+ )
assert test_product_found, "Test product should appear in CSV export"
def test_csv_export_with_marketplace_filter_success(self, client, auth_headers, db):
@@ -43,12 +56,12 @@ class TestExportFunctionality:
MarketplaceProduct(
marketplace_product_id=f"EXP1_{unique_suffix}",
title=f"Amazon MarketplaceProduct {unique_suffix}",
- marketplace="Amazon"
+ marketplace="Amazon",
),
MarketplaceProduct(
marketplace_product_id=f"EXP2_{unique_suffix}",
title=f"eBay MarketplaceProduct {unique_suffix}",
- marketplace="eBay"
+ marketplace="eBay",
),
]
@@ -56,7 +69,8 @@ class TestExportFunctionality:
db.commit()
response = client.get(
- "/api/v1/marketplace/product/export-csv?marketplace=Amazon", headers=auth_headers
+ "/api/v1/marketplace/product/export-csv?marketplace=Amazon",
+ headers=auth_headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
@@ -72,12 +86,12 @@ class TestExportFunctionality:
MarketplaceProduct(
marketplace_product_id=f"VENDOR1_{unique_suffix}",
title=f"Vendor1 MarketplaceProduct {unique_suffix}",
- vendor_name="TestVendor1"
+ vendor_name="TestVendor1",
),
MarketplaceProduct(
marketplace_product_id=f"VENDOR2_{unique_suffix}",
title=f"Vendor2 MarketplaceProduct {unique_suffix}",
- vendor_name="TestVendor2"
+ vendor_name="TestVendor2",
),
]
@@ -101,19 +115,19 @@ class TestExportFunctionality:
marketplace_product_id=f"COMBO1_{unique_suffix}",
title=f"Combo MarketplaceProduct 1 {unique_suffix}",
marketplace="Amazon",
- vendor_name="TestVendor"
+ vendor_name="TestVendor",
),
MarketplaceProduct(
marketplace_product_id=f"COMBO2_{unique_suffix}",
title=f"Combo MarketplaceProduct 2 {unique_suffix}",
marketplace="eBay",
- vendor_name="TestVendor"
+ vendor_name="TestVendor",
),
MarketplaceProduct(
marketplace_product_id=f"COMBO3_{unique_suffix}",
title=f"Combo MarketplaceProduct 3 {unique_suffix}",
marketplace="Amazon",
- vendor_name="OtherVendor"
+ vendor_name="OtherVendor",
),
]
@@ -122,27 +136,27 @@ class TestExportFunctionality:
response = client.get(
"/api/v1/marketplace/product?marketplace=Amazon&name=TestVendor",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
csv_content = response.content.decode("utf-8")
assert f"COMBO1_{unique_suffix}" in csv_content # Matches both filters
assert f"COMBO2_{unique_suffix}" not in csv_content # Wrong marketplace
- assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor
+ assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor
def test_csv_export_no_results(self, client, auth_headers):
"""Test CSV export with filters that return no results"""
response = client.get(
"/api/v1/marketplace/product/export-csv?marketplace=NonexistentMarketplace12345",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
csv_content = response.content.decode("utf-8")
- csv_lines = csv_content.strip().split('\n')
+ csv_lines = csv_content.strip().split("\n")
# Should have header row even with no data
assert len(csv_lines) >= 1
# First line should be headers
@@ -161,27 +175,32 @@ class TestExportFunctionality:
title=f"Performance MarketplaceProduct {i}",
marketplace="Performance",
description=f"Performance test product {i}",
- price="10.99"
+ price="10.99",
)
products.append(product)
# Add in batches to avoid memory issues
for i in range(0, len(products), 50):
- batch = products[i:i + 50]
+ batch = products[i : i + 50]
db.add_all(batch)
db.commit()
import time
+
start_time = time.time()
- response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/export-csv", headers=auth_headers
+ )
end_time = time.time()
execution_time = end_time - start_time
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
- assert execution_time < 10.0, f"Export took {execution_time:.2f} seconds, should be under 10s"
+ assert (
+ execution_time < 10.0
+ ), f"Export took {execution_time:.2f} seconds, should be under 10s"
# Verify content contains our test data
csv_content = response.content.decode("utf-8")
@@ -197,9 +216,13 @@ class TestExportFunctionality:
assert data["error_code"] == "INVALID_TOKEN"
assert data["status_code"] == 401
- def test_csv_export_streaming_response_format(self, client, auth_headers, test_marketplace_product):
+ def test_csv_export_streaming_response_format(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test that CSV export returns proper streaming response with correct headers"""
- response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/export-csv", headers=auth_headers
+ )
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
@@ -217,23 +240,27 @@ class TestExportFunctionality:
# Create product with special characters that might break CSV
product = MarketplaceProduct(
marketplace_product_id=f"SPECIAL_{unique_suffix}",
- title=f'MarketplaceProduct with quotes and commas {unique_suffix}', # Simplified to avoid CSV escaping issues
+ title=f"MarketplaceProduct with quotes and commas {unique_suffix}", # Simplified to avoid CSV escaping issues
description=f"Description with special chars {unique_suffix}",
marketplace="Test Market",
- price="19.99"
+ price="19.99",
)
db.add(product)
db.commit()
- response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/export-csv", headers=auth_headers
+ )
assert response.status_code == 200
csv_content = response.content.decode("utf-8")
# Verify our test product appears in the CSV content
assert f"SPECIAL_{unique_suffix}" in csv_content
- assert f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content
+ assert (
+ f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content
+ )
assert "Test Market" in csv_content
assert "19.99" in csv_content
@@ -243,7 +270,12 @@ class TestExportFunctionality:
header = next(csv_reader)
# Verify header contains expected fields
- expected_fields = ["marketplace_product_id", "title", "marketplace", "price"]
+ expected_fields = [
+ "marketplace_product_id",
+ "title",
+ "marketplace",
+ "price",
+ ]
for field in expected_fields:
assert field in header
@@ -264,12 +296,15 @@ class TestExportFunctionality:
assert parsed_successfully, "CSV should be parseable despite special characters"
- def test_csv_export_error_handling_service_failure(self, client, auth_headers, monkeypatch):
+ def test_csv_export_error_handling_service_failure(
+ self, client, auth_headers, monkeypatch
+ ):
"""Test CSV export handles service failures gracefully"""
# Mock the service to raise an exception
def mock_generate_csv_export(*args, **kwargs):
from app.exceptions import ValidationException
+
raise ValidationException("Mocked service failure")
# This would require access to your service instance to mock properly
@@ -293,7 +328,9 @@ class TestExportFunctionality:
def test_csv_export_filename_generation(self, client, auth_headers):
"""Test CSV export generates appropriate filenames based on filters"""
# Test basic export filename
- response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/export-csv", headers=auth_headers
+ )
assert response.status_code == 200
content_disposition = response.headers.get("content-disposition", "")
@@ -302,7 +339,7 @@ class TestExportFunctionality:
# Test with marketplace filter
response = client.get(
"/api/v1/marketplace/product/export-csv?marketplace=Amazon",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
diff --git a/tests/integration/api/v1/test_marketplace_products_endpoints.py b/tests/integration/api/v1/test_marketplace_products_endpoints.py
index 3e01c0e8..dbda5e8d 100644
--- a/tests/integration/api/v1/test_marketplace_products_endpoints.py
+++ b/tests/integration/api/v1/test_marketplace_products_endpoints.py
@@ -16,7 +16,9 @@ class TestMarketplaceProductsAPI:
assert data["products"] == []
assert data["total"] == 0
- def test_get_products_with_data(self, client, auth_headers, test_marketplace_product):
+ def test_get_products_with_data(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test getting products with data"""
response = client.get("/api/v1/marketplace/product", headers=auth_headers)
@@ -25,25 +27,39 @@ class TestMarketplaceProductsAPI:
assert len(data["products"]) >= 1
assert data["total"] >= 1
# Find our test product
- test_product_found = any(p["marketplace_product_id"] == test_marketplace_product.marketplace_product_id for p in data["products"])
+ test_product_found = any(
+ p["marketplace_product_id"]
+ == test_marketplace_product.marketplace_product_id
+ for p in data["products"]
+ )
assert test_product_found
- def test_get_products_with_filters(self, client, auth_headers, test_marketplace_product):
+ def test_get_products_with_filters(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test filtering products"""
# Test brand filter
- response = client.get(f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}",
+ headers=auth_headers,
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Test marketplace filter
- response = client.get(f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}",
+ headers=auth_headers,
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Test search
- response = client.get("/api/v1/marketplace/product?search=Test", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?search=Test", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
@@ -71,7 +87,9 @@ class TestMarketplaceProductsAPI:
assert data["title"] == "New MarketplaceProduct"
assert data["marketplace"] == "Amazon"
- def test_create_product_duplicate_id_returns_conflict(self, client, auth_headers, test_marketplace_product):
+ def test_create_product_duplicate_id_returns_conflict(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test creating product with duplicate ID returns MarketplaceProductAlreadyExistsException"""
product_data = {
"marketplace_product_id": test_marketplace_product.marketplace_product_id,
@@ -93,7 +111,10 @@ class TestMarketplaceProductsAPI:
assert data["error_code"] == "PRODUCT_ALREADY_EXISTS"
assert data["status_code"] == 409
assert test_marketplace_product.marketplace_product_id in data["message"]
- assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
+ assert (
+ data["details"]["marketplace_product_id"]
+ == test_marketplace_product.marketplace_product_id
+ )
def test_create_product_missing_title_validation_error(self, client, auth_headers):
"""Test creating product without title returns ValidationException"""
@@ -114,7 +135,9 @@ class TestMarketplaceProductsAPI:
assert "MarketplaceProduct title is required" in data["message"]
assert data["details"]["field"] == "title"
- def test_create_product_missing_product_id_validation_error(self, client, auth_headers):
+ def test_create_product_missing_product_id_validation_error(
+ self, client, auth_headers
+ ):
"""Test creating product without marketplace_product_id returns ValidationException"""
product_data = {
"marketplace_product_id": "", # Empty product ID
@@ -191,20 +214,28 @@ class TestMarketplaceProductsAPI:
assert "Request validation failed" in data["message"]
assert "validation_errors" in data["details"]
- def test_get_product_by_id_success(self, client, auth_headers, test_marketplace_product):
+ def test_get_product_by_id_success(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test getting specific product successfully"""
response = client.get(
- f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers
+ f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}",
+ headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
- assert data["product"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
+ assert (
+ data["product"]["marketplace_product_id"]
+ == test_marketplace_product.marketplace_product_id
+ )
assert data["product"]["title"] == test_marketplace_product.title
def test_get_nonexistent_product_returns_not_found(self, client, auth_headers):
"""Test getting nonexistent product returns MarketplaceProductNotFoundException"""
- response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
+ )
assert response.status_code == 404
data = response.json()
@@ -214,7 +245,9 @@ class TestMarketplaceProductsAPI:
assert data["details"]["resource_type"] == "MarketplaceProduct"
assert data["details"]["identifier"] == "NONEXISTENT"
- def test_update_product_success(self, client, auth_headers, test_marketplace_product):
+ def test_update_product_success(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test updating product successfully"""
update_data = {"title": "Updated MarketplaceProduct Title", "price": "25.99"}
@@ -247,7 +280,9 @@ class TestMarketplaceProductsAPI:
assert data["details"]["resource_type"] == "MarketplaceProduct"
assert data["details"]["identifier"] == "NONEXISTENT"
- def test_update_product_empty_title_validation_error(self, client, auth_headers, test_marketplace_product):
+ def test_update_product_empty_title_validation_error(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test updating product with empty title returns MarketplaceProductValidationException"""
update_data = {"title": ""}
@@ -264,7 +299,9 @@ class TestMarketplaceProductsAPI:
assert "MarketplaceProduct title cannot be empty" in data["message"]
assert data["details"]["field"] == "title"
- def test_update_product_invalid_gtin_data_error(self, client, auth_headers, test_marketplace_product):
+ def test_update_product_invalid_gtin_data_error(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test updating product with invalid GTIN returns InvalidMarketplaceProductDataException"""
update_data = {"gtin": "invalid_gtin"}
@@ -281,7 +318,9 @@ class TestMarketplaceProductsAPI:
assert "Invalid GTIN format" in data["message"]
assert data["details"]["field"] == "gtin"
- def test_update_product_invalid_price_data_error(self, client, auth_headers, test_marketplace_product):
+ def test_update_product_invalid_price_data_error(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test updating product with invalid price returns InvalidMarketplaceProductDataException"""
update_data = {"price": "invalid_price"}
@@ -298,10 +337,13 @@ class TestMarketplaceProductsAPI:
assert "Invalid price format" in data["message"]
assert data["details"]["field"] == "price"
- def test_delete_product_success(self, client, auth_headers, test_marketplace_product):
+ def test_delete_product_success(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test deleting product successfully"""
response = client.delete(
- f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers
+ f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}",
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -309,7 +351,9 @@ class TestMarketplaceProductsAPI:
def test_delete_nonexistent_product_returns_not_found(self, client, auth_headers):
"""Test deleting nonexistent product returns MarketplaceProductNotFoundException"""
- response = client.delete("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
+ response = client.delete(
+ "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
+ )
assert response.status_code == 404
data = response.json()
@@ -331,7 +375,9 @@ class TestMarketplaceProductsAPI:
def test_exception_structure_consistency(self, client, auth_headers):
"""Test that all exceptions follow the consistent WizamartException structure"""
# Test with a known error case
- response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
+ )
assert response.status_code == 404
data = response.json()
diff --git a/tests/integration/api/v1/test_pagination.py b/tests/integration/api/v1/test_pagination.py
index c2440c18..22ead4b0 100644
--- a/tests/integration/api/v1/test_pagination.py
+++ b/tests/integration/api/v1/test_pagination.py
@@ -4,6 +4,7 @@ import pytest
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor
+
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.database
@@ -13,6 +14,7 @@ class TestPagination:
def test_product_pagination_success(self, client, auth_headers, db):
"""Test pagination for product listing successfully"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
# Create multiple products
@@ -29,7 +31,9 @@ class TestPagination:
db.commit()
# Test first page
- response = client.get("/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert len(data["products"]) == 10
@@ -38,21 +42,29 @@ class TestPagination:
assert data["limit"] == 10
# Test second page
- response = client.get("/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert len(data["products"]) == 10
assert data["skip"] == 10
# Test last page (should have remaining products)
- response = client.get("/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
assert len(data["products"]) >= 5 # At least 5 remaining from our test set
- def test_pagination_boundary_negative_skip_validation_error(self, client, auth_headers):
+ def test_pagination_boundary_negative_skip_validation_error(
+ self, client, auth_headers
+ ):
"""Test negative skip parameter returns ValidationException"""
- response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=-1", headers=auth_headers
+ )
assert response.status_code == 422
data = response.json()
@@ -61,9 +73,13 @@ class TestPagination:
assert "Request validation failed" in data["message"]
assert "validation_errors" in data["details"]
- def test_pagination_boundary_zero_limit_validation_error(self, client, auth_headers):
+ def test_pagination_boundary_zero_limit_validation_error(
+ self, client, auth_headers
+ ):
"""Test zero limit parameter returns ValidationException"""
- response = client.get("/api/v1/marketplace/product?limit=0", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=0", headers=auth_headers
+ )
assert response.status_code == 422
data = response.json()
@@ -71,9 +87,13 @@ class TestPagination:
assert data["status_code"] == 422
assert "Request validation failed" in data["message"]
- def test_pagination_boundary_excessive_limit_validation_error(self, client, auth_headers):
+ def test_pagination_boundary_excessive_limit_validation_error(
+ self, client, auth_headers
+ ):
"""Test excessive limit parameter returns ValidationException"""
- response = client.get("/api/v1/marketplace/product?limit=10000", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=10000", headers=auth_headers
+ )
assert response.status_code == 422
data = response.json()
@@ -84,7 +104,9 @@ class TestPagination:
def test_pagination_beyond_available_records(self, client, auth_headers, db):
"""Test pagination beyond available records returns empty results"""
# Test skip beyond available records
- response = client.get("/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers
+ )
assert response.status_code == 200
data = response.json()
@@ -96,6 +118,7 @@ class TestPagination:
def test_pagination_with_filters(self, client, auth_headers, db):
"""Test pagination combined with filtering"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
# Create products with same brand for filtering
@@ -115,7 +138,7 @@ class TestPagination:
# Test first page with filter
response = client.get(
"/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=0",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -124,14 +147,18 @@ class TestPagination:
assert data["total"] >= 15 # At least our test products
# Verify all products have the filtered brand
- test_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
+ test_products = [
+ p
+ for p in data["products"]
+ if p["marketplace_product_id"].endswith(unique_suffix)
+ ]
for product in test_products:
assert product["brand"] == "FilterBrand"
# Test second page with same filter
response = client.get(
"/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=5",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -152,6 +179,7 @@ class TestPagination:
def test_pagination_consistency(self, client, auth_headers, db):
"""Test pagination consistency across multiple requests"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
# Create products with predictable ordering
@@ -168,14 +196,22 @@ class TestPagination:
db.commit()
# Get first page
- response1 = client.get("/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers)
+ response1 = client.get(
+ "/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers
+ )
assert response1.status_code == 200
- first_page_ids = [p["marketplace_product_id"] for p in response1.json()["products"]]
+ first_page_ids = [
+ p["marketplace_product_id"] for p in response1.json()["products"]
+ ]
# Get second page
- response2 = client.get("/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers)
+ response2 = client.get(
+ "/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers
+ )
assert response2.status_code == 200
- second_page_ids = [p["marketplace_product_id"] for p in response2.json()["products"]]
+ second_page_ids = [
+ p["marketplace_product_id"] for p in response2.json()["products"]
+ ]
# Verify no overlap between pages
overlap = set(first_page_ids) & set(second_page_ids)
@@ -184,11 +220,13 @@ class TestPagination:
def test_vendor_pagination_success(self, client, admin_headers, db, test_user):
"""Test pagination for vendor listing successfully"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
# Create multiple vendors for pagination testing
from models.database.vendor import Vendor
- vendors =[]
+
+ vendors = []
for i in range(15):
vendor = Vendor(
vendor_code=f"PAGEVENDOR{i:03d}_{unique_suffix}",
@@ -202,9 +240,7 @@ class TestPagination:
db.commit()
# Test first page (assuming admin endpoint exists)
- response = client.get(
- "/api/v1/vendor?limit=5&skip=0", headers=admin_headers
- )
+ response = client.get("/api/v1/vendor?limit=5&skip=0", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert len(data["vendors"]) == 5
@@ -215,10 +251,12 @@ class TestPagination:
def test_inventory_pagination_success(self, client, auth_headers, db):
"""Test pagination for inventory listing successfully"""
import uuid
+
unique_suffix = str(uuid.uuid4())[:8]
# Create multiple inventory entries
from models.database.inventory import Inventory
+
inventory_entries = []
for i in range(20):
inventory = Inventory(
@@ -249,7 +287,9 @@ class TestPagination:
import time
start_time = time.time()
- response = client.get("/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers
+ )
end_time = time.time()
assert response.status_code == 200
@@ -262,19 +302,25 @@ class TestPagination:
def test_pagination_with_invalid_parameters_types(self, client, auth_headers):
"""Test pagination with invalid parameter types returns ValidationException"""
# Test non-numeric skip
- response = client.get("/api/v1/marketplace/product?skip=invalid", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=invalid", headers=auth_headers
+ )
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "VALIDATION_ERROR"
# Test non-numeric limit
- response = client.get("/api/v1/marketplace/product?limit=invalid", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=invalid", headers=auth_headers
+ )
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "VALIDATION_ERROR"
# Test float values (should be converted or rejected)
- response = client.get("/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers
+ )
assert response.status_code in [200, 422] # Depends on implementation
def test_empty_dataset_pagination(self, client, auth_headers):
@@ -282,7 +328,7 @@ class TestPagination:
# Use a filter that should return no results
response = client.get(
"/api/v1/marketplace/product?brand=NonexistentBrand999&limit=10&skip=0",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
@@ -294,7 +340,9 @@ class TestPagination:
def test_exception_structure_in_pagination_errors(self, client, auth_headers):
"""Test that pagination validation errors follow consistent exception structure"""
- response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=-1", headers=auth_headers
+ )
assert response.status_code == 422
data = response.json()
diff --git a/tests/integration/api/v1/test_stats_endpoints.py b/tests/integration/api/v1/test_stats_endpoints.py
index eade22f2..253615c9 100644
--- a/tests/integration/api/v1/test_stats_endpoints.py
+++ b/tests/integration/api/v1/test_stats_endpoints.py
@@ -1,6 +1,7 @@
# tests/integration/api/v1/test_stats_endpoints.py
import pytest
+
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.stats
@@ -18,7 +19,9 @@ class TestStatsAPI:
assert "unique_vendors" in data
assert data["total_products"] >= 1
- def test_get_marketplace_stats(self, client, auth_headers, test_marketplace_product):
+ def test_get_marketplace_stats(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test getting marketplace statistics"""
response = client.get("/api/v1/stats/marketplace", headers=auth_headers)
diff --git a/tests/integration/api/v1/test_vendor_endpoints.py b/tests/integration/api/v1/test_vendor_endpoints.py
index 663cca06..6b689d8c 100644
--- a/tests/integration/api/v1/test_vendor_endpoints.py
+++ b/tests/integration/api/v1/test_vendor_endpoints.py
@@ -23,7 +23,9 @@ class TestVendorsAPI:
assert data["name"] == "New Vendor"
assert data["is_active"] is True
- def test_create_vendor_duplicate_code_returns_conflict(self, client, auth_headers, test_vendor):
+ def test_create_vendor_duplicate_code_returns_conflict(
+ self, client, auth_headers, test_vendor
+ ):
"""Test creating vendor with duplicate code returns VendorAlreadyExistsException"""
vendor_data = {
"vendor_code": test_vendor.vendor_code,
@@ -40,7 +42,9 @@ class TestVendorsAPI:
assert test_vendor.vendor_code in data["message"]
assert data["details"]["vendor_code"] == test_vendor.vendor_code
- def test_create_vendor_missing_vendor_code_validation_error(self, client, auth_headers):
+ def test_create_vendor_missing_vendor_code_validation_error(
+ self, client, auth_headers
+ ):
"""Test creating vendor without vendor_code returns ValidationException"""
vendor_data = {
"name": "Vendor without Code",
@@ -56,7 +60,9 @@ class TestVendorsAPI:
assert "Request validation failed" in data["message"]
assert "validation_errors" in data["details"]
- def test_create_vendor_empty_vendor_name_validation_error(self, client, auth_headers):
+ def test_create_vendor_empty_vendor_name_validation_error(
+ self, client, auth_headers
+ ):
"""Test creating vendor with empty name returns VendorValidationException"""
vendor_data = {
"vendor_code": "EMPTYNAME",
@@ -73,7 +79,9 @@ class TestVendorsAPI:
assert "Vendor name is required" in data["message"]
assert data["details"]["field"] == "name"
- def test_create_vendor_max_vendors_reached_business_logic_error(self, client, auth_headers, db, test_user):
+ def test_create_vendor_max_vendors_reached_business_logic_error(
+ self, client, auth_headers, db, test_user
+ ):
"""Test creating vendor when max vendors reached returns MaxVendorsReachedException"""
# This test would require creating the maximum allowed vendors first
# The exact implementation depends on your business rules
@@ -94,8 +102,10 @@ class TestVendorsAPI:
assert data["total"] >= 1
assert len(data["vendors"]) >= 1
- # Find our test vendor
- test_vendor_found = any(s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"])
+ # Find our test vendor
+ test_vendor_found = any(
+ s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"]
+ )
assert test_vendor_found
def test_get_vendors_with_filters(self, client, auth_headers, test_vendor):
@@ -105,7 +115,7 @@ class TestVendorsAPI:
assert response.status_code == 200
data = response.json()
for vendor in data["vendors"]:
- assert vendor ["is_active"] is True
+ assert vendor["is_active"] is True
# Test verified_only filter
response = client.get("/api/v1/vendor?verified_only=true", headers=auth_headers)
@@ -135,7 +145,9 @@ class TestVendorsAPI:
assert data["details"]["resource_type"] == "Vendor"
assert data["details"]["identifier"] == "NONEXISTENT"
- def test_get_vendor_unauthorized_access(self, client, auth_headers, test_vendor, other_user, db):
+ def test_get_vendor_unauthorized_access(
+ self, client, auth_headers, test_vendor, other_user, db
+ ):
"""Test accessing vendor owned by another user returns UnauthorizedVendorAccessException"""
# Change vendor owner to other user AND make it unverified/inactive
# so that non-owner users cannot access it
@@ -154,7 +166,9 @@ class TestVendorsAPI:
assert test_vendor.vendor_code in data["message"]
assert data["details"]["vendor_code"] == test_vendor.vendor_code
- def test_get_vendor_unauthorized_access_with_inactive_vendor(self, client, auth_headers, inactive_vendor):
+ def test_get_vendor_unauthorized_access_with_inactive_vendor(
+ self, client, auth_headers, inactive_vendor
+ ):
"""Test accessing inactive vendor owned by another user returns UnauthorizedVendorAccessException"""
# inactive_vendor fixture already creates an unverified, inactive vendor owned by other_user
response = client.get(
@@ -168,7 +182,9 @@ class TestVendorsAPI:
assert inactive_vendor.vendor_code in data["message"]
assert data["details"]["vendor_code"] == inactive_vendor.vendor_code
- def test_get_vendor_public_access_allowed(self, client, auth_headers, verified_vendor):
+ def test_get_vendor_public_access_allowed(
+ self, client, auth_headers, verified_vendor
+ ):
"""Test accessing verified vendor owned by another user is allowed (public access)"""
# verified_vendor fixture creates a verified, active vendor owned by other_user
# This should allow public access per your business logic
@@ -181,7 +197,9 @@ class TestVendorsAPI:
assert data["vendor_code"] == verified_vendor.vendor_code
assert data["name"] == verified_vendor.name
- def test_add_product_to_vendor_success(self, client, auth_headers, test_vendor, unique_product):
+ def test_add_product_to_vendor_success(
+ self, client, auth_headers, test_vendor, unique_product
+ ):
"""Test adding product to vendor successfully"""
product_data = {
"marketplace_product_id": unique_product.marketplace_product_id, # Use string marketplace_product_id, not database id
@@ -193,7 +211,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
- json=product_data
+ json=product_data,
)
assert response.status_code == 200
@@ -207,10 +225,15 @@ class TestVendorsAPI:
# MarketplaceProduct details are nested in the 'marketplace_product' field
assert "marketplace_product" in data
- assert data["marketplace_product"]["marketplace_product_id"] == unique_product.marketplace_product_id
+ assert (
+ data["marketplace_product"]["marketplace_product_id"]
+ == unique_product.marketplace_product_id
+ )
assert data["marketplace_product"]["id"] == unique_product.id
- def test_add_product_to_vendor_already_exists_conflict(self, client, auth_headers, test_vendor, test_product):
+ def test_add_product_to_vendor_already_exists_conflict(
+ self, client, auth_headers, test_vendor, test_product
+ ):
"""Test adding product that already exists in vendor returns ProductAlreadyExistsException"""
# test_product fixture already creates a relationship, get the marketplace_product_id string
existing_product = test_product.marketplace_product
@@ -223,7 +246,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
- json=product_data
+ json=product_data,
)
assert response.status_code == 409
@@ -233,7 +256,9 @@ class TestVendorsAPI:
assert test_vendor.vendor_code in data["message"]
assert existing_product.marketplace_product_id in data["message"]
- def test_add_nonexistent_product_to_vendor_not_found(self, client, auth_headers, test_vendor):
+ def test_add_nonexistent_product_to_vendor_not_found(
+ self, client, auth_headers, test_vendor
+ ):
"""Test adding nonexistent product to vendor returns MarketplaceProductNotFoundException"""
product_data = {
"marketplace_product_id": "NONEXISTENT_PRODUCT", # Use string marketplace_product_id that doesn't exist
@@ -243,7 +268,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
- json=product_data
+ json=product_data,
)
assert response.status_code == 404
@@ -252,11 +277,12 @@ class TestVendorsAPI:
assert data["status_code"] == 404
assert "NONEXISTENT_PRODUCT" in data["message"]
- def test_get_products_success(self, client, auth_headers, test_vendor, test_product):
+ def test_get_products_success(
+ self, client, auth_headers, test_vendor, test_product
+ ):
"""Test getting vendor products successfully"""
response = client.get(
- f"/api/v1/vendor/{test_vendor.vendor_code}/products",
- headers=auth_headers
+ f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers
)
assert response.status_code == 200
@@ -271,22 +297,21 @@ class TestVendorsAPI:
# Test active_only filter
response = client.get(
f"/api/v1/vendor/{test_vendor.vendor_code}/products?active_only=true",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
# Test featured_only filter
response = client.get(
f"/api/v1/vendor/{test_vendor.vendor_code}/products?featured_only=true",
- headers=auth_headers
+ headers=auth_headers,
)
assert response.status_code == 200
def test_get_products_from_nonexistent_vendor_not_found(self, client, auth_headers):
"""Test getting products from nonexistent vendor returns VendorNotFoundException"""
response = client.get(
- "/api/v1/vendor/NONEXISTENT/products",
- headers=auth_headers
+ "/api/v1/vendor/NONEXISTENT/products", headers=auth_headers
)
assert response.status_code == 404
@@ -295,7 +320,9 @@ class TestVendorsAPI:
assert data["status_code"] == 404
assert "NONEXISTENT" in data["message"]
- def test_vendor_not_active_business_logic_error(self, client, auth_headers, test_vendor, db):
+ def test_vendor_not_active_business_logic_error(
+ self, client, auth_headers, test_vendor, db
+ ):
"""Test accessing inactive vendor returns VendorNotActiveException (if enforced)"""
# Set vendor to inactive
test_vendor.is_active = False
@@ -313,7 +340,9 @@ class TestVendorsAPI:
assert data["status_code"] == 400
assert test_vendor.vendor_code in data["message"]
- def test_vendor_not_verified_business_logic_error(self, client, auth_headers, test_vendor, db):
+ def test_vendor_not_verified_business_logic_error(
+ self, client, auth_headers, test_vendor, db
+ ):
"""Test operations requiring verification returns VendorNotVerifiedException (if enforced)"""
# Set vendor to unverified
test_vendor.is_verified = False
@@ -328,7 +357,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
- json=product_data
+ json=product_data,
)
# If your service requires verification for adding products
diff --git a/tests/integration/api/v1/vendor/test_authentication.py b/tests/integration/api/v1/vendor/test_authentication.py
index 8a76417b..a5cf271c 100644
--- a/tests/integration/api/v1/vendor/test_authentication.py
+++ b/tests/integration/api/v1/vendor/test_authentication.py
@@ -10,8 +10,9 @@ These tests verify that:
5. Vendor context middleware works correctly with API authentication
"""
-import pytest
from datetime import datetime, timedelta, timezone
+
+import pytest
from jose import jwt
@@ -26,7 +27,9 @@ class TestVendorAPIAuthentication:
# Authentication Tests - /api/v1/vendor/auth/me
# ========================================================================
- def test_vendor_auth_me_success(self, client, vendor_user_headers, test_vendor_user):
+ def test_vendor_auth_me_success(
+ self, client, vendor_user_headers, test_vendor_user
+ ):
"""Test /auth/me endpoint with valid vendor user token"""
response = client.get("/api/v1/vendor/auth/me", headers=vendor_user_headers)
@@ -50,7 +53,7 @@ class TestVendorAPIAuthentication:
"""Test /auth/me endpoint with invalid token format"""
response = client.get(
"/api/v1/vendor/auth/me",
- headers={"Authorization": "Bearer invalid_token_here"}
+ headers={"Authorization": "Bearer invalid_token_here"},
)
assert response.status_code == 401
@@ -66,7 +69,9 @@ class TestVendorAPIAuthentication:
assert data["error_code"] == "FORBIDDEN"
assert "Admin users cannot access vendor API" in data["message"]
- def test_vendor_auth_me_with_regular_user_token(self, client, auth_headers, test_user):
+ def test_vendor_auth_me_with_regular_user_token(
+ self, client, auth_headers, test_user
+ ):
"""Test /auth/me endpoint rejects regular users"""
response = client.get("/api/v1/vendor/auth/me", headers=auth_headers)
@@ -88,14 +93,12 @@ class TestVendorAPIAuthentication:
}
expired_token = jwt.encode(
- expired_payload,
- auth_manager.secret_key,
- algorithm=auth_manager.algorithm
+ expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
response = client.get(
"/api/v1/vendor/auth/me",
- headers={"Authorization": f"Bearer {expired_token}"}
+ headers={"Authorization": f"Bearer {expired_token}"},
)
assert response.status_code == 401
@@ -111,8 +114,7 @@ class TestVendorAPIAuthentication:
):
"""Test dashboard stats with valid vendor authentication"""
response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor_user_headers
+ "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -131,10 +133,7 @@ class TestVendorAPIAuthentication:
def test_vendor_dashboard_stats_with_admin(self, client, admin_headers):
"""Test dashboard stats rejects admin users"""
- response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=admin_headers
- )
+ response = client.get("/api/v1/vendor/dashboard/stats", headers=admin_headers)
assert response.status_code == 403
data = response.json()
@@ -145,10 +144,7 @@ class TestVendorAPIAuthentication:
# Login to get session cookie
login_response = client.post(
"/api/v1/vendor/auth/login",
- json={
- "username": test_vendor_user.username,
- "password": "vendorpass123"
- }
+ json={"username": test_vendor_user.username, "password": "vendorpass123"},
)
assert login_response.status_code == 200
@@ -169,10 +165,7 @@ class TestVendorAPIAuthentication:
# Get a valid session by logging in
login_response = client.post(
"/api/v1/vendor/auth/login",
- json={
- "username": test_vendor_user.username,
- "password": "vendorpass123"
- }
+ json={"username": test_vendor_user.username, "password": "vendorpass123"},
)
assert login_response.status_code == 200
@@ -191,8 +184,9 @@ class TestVendorAPIAuthentication:
response = client.get(endpoint)
# All should fail with 401 (header required)
- assert response.status_code == 401, \
- f"Endpoint {endpoint} should reject cookie-only auth"
+ assert (
+ response.status_code == 401
+ ), f"Endpoint {endpoint} should reject cookie-only auth"
# ========================================================================
# Role-Based Access Control Tests
@@ -211,13 +205,15 @@ class TestVendorAPIAuthentication:
for endpoint in endpoints:
# Test with regular user token
response = client.get(endpoint, headers=auth_headers)
- assert response.status_code == 403, \
- f"Endpoint {endpoint} should reject regular users"
+ assert (
+ response.status_code == 403
+ ), f"Endpoint {endpoint} should reject regular users"
# Test with admin token
response = client.get(endpoint, headers=admin_headers)
- assert response.status_code == 403, \
- f"Endpoint {endpoint} should reject admin users"
+ assert (
+ response.status_code == 403
+ ), f"Endpoint {endpoint} should reject admin users"
def test_vendor_api_accepts_only_vendor_role(
self, client, vendor_user_headers, test_vendor_user
@@ -229,8 +225,10 @@ class TestVendorAPIAuthentication:
for endpoint in endpoints:
response = client.get(endpoint, headers=vendor_user_headers)
- assert response.status_code in [200, 404], \
- f"Endpoint {endpoint} should accept vendor users (got {response.status_code})"
+ assert response.status_code in [
+ 200,
+ 404,
+ ], f"Endpoint {endpoint} should accept vendor users (got {response.status_code})"
# ========================================================================
# Token Validation Tests
@@ -248,8 +246,9 @@ class TestVendorAPIAuthentication:
for headers in malformed_headers:
response = client.get("/api/v1/vendor/auth/me", headers=headers)
- assert response.status_code == 401, \
- f"Should reject malformed header: {headers}"
+ assert (
+ response.status_code == 401
+ ), f"Should reject malformed header: {headers}"
def test_token_with_missing_claims(self, client, auth_manager):
"""Test token missing required claims"""
@@ -261,14 +260,12 @@ class TestVendorAPIAuthentication:
}
invalid_token = jwt.encode(
- invalid_payload,
- auth_manager.secret_key,
- algorithm=auth_manager.algorithm
+ invalid_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
response = client.get(
"/api/v1/vendor/auth/me",
- headers={"Authorization": f"Bearer {invalid_token}"}
+ headers={"Authorization": f"Bearer {invalid_token}"},
)
assert response.status_code == 401
@@ -298,9 +295,7 @@ class TestVendorAPIAuthentication:
db.add(test_vendor_user)
db.commit()
- def test_concurrent_requests_with_same_token(
- self, client, vendor_user_headers
- ):
+ def test_concurrent_requests_with_same_token(self, client, vendor_user_headers):
"""Test that the same token can be used for multiple concurrent requests"""
# Make multiple requests with the same token
responses = []
@@ -314,10 +309,7 @@ class TestVendorAPIAuthentication:
def test_vendor_api_with_empty_authorization_header(self, client):
"""Test vendor API with empty Authorization header value"""
- response = client.get(
- "/api/v1/vendor/auth/me",
- headers={"Authorization": ""}
- )
+ response = client.get("/api/v1/vendor/auth/me", headers={"Authorization": ""})
assert response.status_code == 401
@@ -328,17 +320,12 @@ class TestVendorAPIAuthentication:
class TestVendorAPIConsistency:
"""Test that all vendor API endpoints use consistent authentication"""
- def test_all_vendor_endpoints_require_header_auth(
- self, client, test_vendor_user
- ):
+ def test_all_vendor_endpoints_require_header_auth(self, client, test_vendor_user):
"""Verify all vendor API endpoints require Authorization header"""
# Login to establish session
client.post(
"/api/v1/vendor/auth/login",
- json={
- "username": test_vendor_user.username,
- "password": "vendorpass123"
- }
+ json={"username": test_vendor_user.username, "password": "vendorpass123"},
)
# All vendor API endpoints (excluding public endpoints like /info)
@@ -361,8 +348,9 @@ class TestVendorAPIConsistency:
response = client.post(endpoint, json={})
# All should reject cookie-only auth with 401
- assert response.status_code == 401, \
- f"Endpoint {endpoint} should require Authorization header (got {response.status_code})"
+ assert (
+ response.status_code == 401
+ ), f"Endpoint {endpoint} should require Authorization header (got {response.status_code})"
def test_vendor_endpoints_accept_vendor_token(
self, client, vendor_user_headers, test_vendor_with_vendor_user
@@ -380,5 +368,7 @@ class TestVendorAPIConsistency:
response = client.get(endpoint, headers=vendor_user_headers)
# Should not be authentication/authorization errors
- assert response.status_code not in [401, 403], \
- f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})"
+ assert response.status_code not in [
+ 401,
+ 403,
+ ], f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})"
diff --git a/tests/integration/api/v1/vendor/test_dashboard.py b/tests/integration/api/v1/vendor/test_dashboard.py
index f2e4d67f..5fe90cc3 100644
--- a/tests/integration/api/v1/vendor/test_dashboard.py
+++ b/tests/integration/api/v1/vendor/test_dashboard.py
@@ -23,8 +23,7 @@ class TestVendorDashboardAPI:
):
"""Test dashboard stats returns correct data structure"""
response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor_user_headers
+ "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -66,9 +65,9 @@ class TestVendorDashboardAPI:
self, client, db, test_vendor_user, auth_manager
):
"""Test that dashboard stats only show data for the authenticated vendor"""
- from models.database.vendor import Vendor, VendorUser
- from models.database.product import Product
from models.database.marketplace_product import MarketplaceProduct
+ from models.database.product import Product
+ from models.database.vendor import Vendor, VendorUser
# Create two separate vendors with different data
vendor1 = Vendor(
@@ -118,10 +117,7 @@ class TestVendorDashboardAPI:
vendor1_headers = {"Authorization": f"Bearer {token_data['access_token']}"}
# Get stats for vendor1
- response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor1_headers
- )
+ response = client.get("/api/v1/vendor/dashboard/stats", headers=vendor1_headers)
assert response.status_code == 200
data = response.json()
@@ -130,9 +126,7 @@ class TestVendorDashboardAPI:
assert data["vendor"]["id"] == vendor1.id
assert data["products"]["total"] == 3
- def test_dashboard_stats_without_vendor_association(
- self, client, db, auth_manager
- ):
+ def test_dashboard_stats_without_vendor_association(self, client, db, auth_manager):
"""Test dashboard stats for user not associated with any vendor"""
from models.database.user import User
@@ -206,8 +200,7 @@ class TestVendorDashboardAPI:
):
"""Test dashboard stats for vendor with no data"""
response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor_user_headers
+ "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -224,8 +217,8 @@ class TestVendorDashboardAPI:
self, client, db, vendor_user_headers, test_vendor_with_vendor_user
):
"""Test dashboard stats accuracy with actual products"""
- from models.database.product import Product
from models.database.marketplace_product import MarketplaceProduct
+ from models.database.product import Product
# Create marketplace products
mp = MarketplaceProduct(
@@ -249,8 +242,7 @@ class TestVendorDashboardAPI:
# Get stats
response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor_user_headers
+ "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -267,8 +259,7 @@ class TestVendorDashboardAPI:
start_time = time.time()
response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor_user_headers
+ "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
end_time = time.time()
@@ -284,8 +275,7 @@ class TestVendorDashboardAPI:
responses = []
for _ in range(3):
response = client.get(
- "/api/v1/vendor/dashboard/stats",
- headers=vendor_user_headers
+ "/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
responses.append(response)
diff --git a/tests/integration/middleware/conftest.py b/tests/integration/middleware/conftest.py
index a5e253b4..663dd723 100644
--- a/tests/integration/middleware/conftest.py
+++ b/tests/integration/middleware/conftest.py
@@ -3,6 +3,7 @@
Fixtures specific to middleware integration tests.
"""
import pytest
+
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
from models.database.vendor_theme import VendorTheme
@@ -12,10 +13,7 @@ from models.database.vendor_theme import VendorTheme
def vendor_with_subdomain(db):
"""Create a vendor with subdomain for testing."""
vendor = Vendor(
- name="Test Vendor",
- code="testvendor",
- subdomain="testvendor",
- is_active=True
+ name="Test Vendor", code="testvendor", subdomain="testvendor", is_active=True
)
db.add(vendor)
db.commit()
@@ -30,7 +28,7 @@ def vendor_with_custom_domain(db):
name="Custom Domain Vendor",
code="customvendor",
subdomain="customvendor",
- is_active=True
+ is_active=True,
)
db.add(vendor)
db.commit()
@@ -38,10 +36,7 @@ def vendor_with_custom_domain(db):
# Add custom domain
domain = VendorDomain(
- vendor_id=vendor.id,
- domain="customdomain.com",
- is_active=True,
- is_primary=True
+ vendor_id=vendor.id, domain="customdomain.com", is_active=True, is_primary=True
)
db.add(domain)
db.commit()
@@ -56,7 +51,7 @@ def vendor_with_theme(db):
name="Themed Vendor",
code="themedvendor",
subdomain="themedvendor",
- is_active=True
+ is_active=True,
)
db.add(vendor)
db.commit()
@@ -69,7 +64,7 @@ def vendor_with_theme(db):
secondary_color="#33FF57",
logo_url="/static/vendors/themedvendor/logo.png",
favicon_url="/static/vendors/themedvendor/favicon.ico",
- custom_css="body { background: #FF5733; }"
+ custom_css="body { background: #FF5733; }",
)
db.add(theme)
db.commit()
@@ -81,10 +76,7 @@ def vendor_with_theme(db):
def inactive_vendor(db):
"""Create an inactive vendor for testing."""
vendor = Vendor(
- name="Inactive Vendor",
- code="inactive",
- subdomain="inactive",
- is_active=False
+ name="Inactive Vendor", code="inactive", subdomain="inactive", is_active=False
)
db.add(vendor)
db.commit()
diff --git a/tests/integration/middleware/test_context_detection_flow.py b/tests/integration/middleware/test_context_detection_flow.py
index 301fbd29..8c06f3f7 100644
--- a/tests/integration/middleware/test_context_detection_flow.py
+++ b/tests/integration/middleware/test_context_detection_flow.py
@@ -5,8 +5,10 @@ Integration tests for request context detection end-to-end flow.
These tests verify that context type (API, ADMIN, VENDOR_DASHBOARD, SHOP, FALLBACK)
is correctly detected through real HTTP requests.
"""
-import pytest
from unittest.mock import patch
+
+import pytest
+
from middleware.context import RequestContext
@@ -23,13 +25,22 @@ class TestContextDetectionFlow:
def test_api_path_detected_as_api_context(self, client):
"""Test that /api/* paths are detected as API context."""
from fastapi import Request
+
from main import app
@app.get("/api/test-api-context")
async def test_api(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "context_enum": (
+ request.state.context_type.name
+ if hasattr(request.state, "context_type")
+ else None
+ ),
}
response = client.get("/api/test-api-context")
@@ -42,12 +53,17 @@ class TestContextDetectionFlow:
def test_nested_api_path_detected_as_api_context(self, client):
"""Test that nested /api/ paths are detected as API context."""
from fastapi import Request
+
from main import app
@app.get("/api/v1/vendor/products")
async def test_nested_api(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
response = client.get("/api/v1/vendor/products")
@@ -63,13 +79,22 @@ class TestContextDetectionFlow:
def test_admin_path_detected_as_admin_context(self, client):
"""Test that /admin/* paths are detected as ADMIN context."""
from fastapi import Request
+
from main import app
@app.get("/admin/test-admin-context")
async def test_admin(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "context_enum": (
+ request.state.context_type.name
+ if hasattr(request.state, "context_type")
+ else None
+ ),
}
response = client.get("/admin/test-admin-context")
@@ -82,19 +107,23 @@ class TestContextDetectionFlow:
def test_admin_subdomain_detected_as_admin_context(self, client):
"""Test that admin.* subdomain is detected as ADMIN context."""
from fastapi import Request
+
from main import app
@app.get("/test-admin-subdomain-context")
async def test_admin_subdomain(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-admin-subdomain-context",
- headers={"host": "admin.platform.com"}
+ "/test-admin-subdomain-context", headers={"host": "admin.platform.com"}
)
assert response.status_code == 200
@@ -104,12 +133,17 @@ class TestContextDetectionFlow:
def test_nested_admin_path_detected_as_admin_context(self, client):
"""Test that nested /admin/ paths are detected as ADMIN context."""
from fastapi import Request
+
from main import app
@app.get("/admin/vendors/123/edit")
async def test_nested_admin(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
response = client.get("/admin/vendors/123/edit")
@@ -125,21 +159,31 @@ class TestContextDetectionFlow:
def test_vendor_dashboard_path_detected(self, client, vendor_with_subdomain):
"""Test that /vendor/* paths are detected as VENDOR_DASHBOARD context."""
from fastapi import Request
+
from main import app
@app.get("/vendor/test-vendor-dashboard")
async def test_vendor_dashboard(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "context_enum": (
+ request.state.context_type.name
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-vendor-dashboard",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -151,19 +195,24 @@ class TestContextDetectionFlow:
def test_nested_vendor_dashboard_path_detected(self, client, vendor_with_subdomain):
"""Test that nested /vendor/ paths are detected as VENDOR_DASHBOARD context."""
from fastapi import Request
+
from main import app
@app.get("/vendor/products/123/edit")
async def test_nested_vendor(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/products/123/edit",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -174,24 +223,36 @@ class TestContextDetectionFlow:
# Shop Context Detection Tests
# ========================================================================
- def test_shop_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
+ def test_shop_path_with_vendor_detected_as_shop(
+ self, client, vendor_with_subdomain
+ ):
"""Test that /shop/* paths with vendor are detected as SHOP context."""
from fastapi import Request
+
from main import app
@app.get("/shop/test-shop-context")
async def test_shop(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "context_enum": (
+ request.state.context_type.name
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-shop-context",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -200,23 +261,31 @@ class TestContextDetectionFlow:
assert data["context_enum"] == "SHOP"
assert data["has_vendor"] is True
- def test_root_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
+ def test_root_path_with_vendor_detected_as_shop(
+ self, client, vendor_with_subdomain
+ ):
"""Test that root path with vendor is detected as SHOP context."""
from fastapi import Request
+
from main import app
@app.get("/test-root-shop")
async def test_root_shop(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-root-shop",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -228,21 +297,27 @@ class TestContextDetectionFlow:
def test_custom_domain_shop_detected(self, client, vendor_with_custom_domain):
"""Test that custom domain shop is detected as SHOP context."""
from fastapi import Request
+
from main import app
@app.get("/products")
async def test_custom_domain_shop(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
- response = client.get(
- "/products",
- headers={"host": "customdomain.com"}
- )
+ response = client.get("/products", headers={"host": "customdomain.com"})
assert response.status_code == 200
data = response.json()
@@ -256,21 +331,30 @@ class TestContextDetectionFlow:
def test_unknown_path_without_vendor_fallback_context(self, client):
"""Test that unknown paths without vendor get FALLBACK context."""
from fastapi import Request
+
from main import app
@app.get("/test-fallback-context")
async def test_fallback(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "context_enum": (
+ request.state.context_type.name
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-fallback-context",
- headers={"host": "platform.com"}
+ "/test-fallback-context", headers={"host": "platform.com"}
)
assert response.status_code == 200
@@ -286,20 +370,26 @@ class TestContextDetectionFlow:
def test_api_path_overrides_vendor_context(self, client, vendor_with_subdomain):
"""Test that /api/* path sets API context even with vendor."""
from fastapi import Request
+
from main import app
@app.get("/api/test-api-priority")
async def test_api_priority(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/api/test-api-priority",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -312,20 +402,26 @@ class TestContextDetectionFlow:
def test_admin_path_overrides_vendor_context(self, client, vendor_with_subdomain):
"""Test that /admin/* path sets ADMIN context even with vendor."""
from fastapi import Request
+
from main import app
@app.get("/admin/test-admin-priority")
async def test_admin_priority(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/admin/test-admin-priority",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -333,22 +429,29 @@ class TestContextDetectionFlow:
# Admin path should override vendor context
assert data["context_type"] == "admin"
- def test_vendor_dashboard_overrides_shop_context(self, client, vendor_with_subdomain):
+ def test_vendor_dashboard_overrides_shop_context(
+ self, client, vendor_with_subdomain
+ ):
"""Test that /vendor/* path sets VENDOR_DASHBOARD, not SHOP."""
from fastapi import Request
+
from main import app
@app.get("/vendor/test-priority")
async def test_vendor_priority(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-priority",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -363,21 +466,30 @@ class TestContextDetectionFlow:
def test_context_uses_clean_path_for_detection(self, client, vendor_with_subdomain):
"""Test that context detection uses clean_path, not original path."""
from fastapi import Request
+
from main import app
@app.get("/vendors/{vendor_code}/shop/products")
async def test_clean_path_context(vendor_code: str, request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
- "original_path": request.url.path
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "clean_path": (
+ request.state.clean_path
+ if hasattr(request.state, "clean_path")
+ else None
+ ),
+ "original_path": request.url.path,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendors/{vendor_with_subdomain.code}/shop/products",
- headers={"host": "localhost:8000"}
+ headers={"host": "localhost:8000"},
)
assert response.status_code == 200
@@ -394,15 +506,20 @@ class TestContextDetectionFlow:
def test_context_type_is_enum_instance(self, client):
"""Test that context_type is a RequestContext enum instance."""
from fastapi import Request
+
from main import app
@app.get("/api/test-enum")
async def test_enum(request: Request):
- context = request.state.context_type if hasattr(request.state, 'context_type') else None
+ context = (
+ request.state.context_type
+ if hasattr(request.state, "context_type")
+ else None
+ )
return {
"is_enum": isinstance(context, RequestContext) if context else False,
"enum_name": context.name if context else None,
- "enum_value": context.value if context else None
+ "enum_value": context.value if context else None,
}
response = client.get("/api/test-enum")
@@ -417,23 +534,30 @@ class TestContextDetectionFlow:
# Edge Cases
# ========================================================================
- def test_empty_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
+ def test_empty_path_with_vendor_detected_as_shop(
+ self, client, vendor_with_subdomain
+ ):
"""Test that empty/root path with vendor is detected as SHOP."""
from fastapi import Request
+
from main import app
@app.get("/")
async def test_root(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ "/", headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
)
assert response.status_code in [200, 404] # Might not have root handler
@@ -445,13 +569,18 @@ class TestContextDetectionFlow:
def test_case_insensitive_context_detection(self, client):
"""Test that context detection is case insensitive for paths."""
from fastapi import Request
+
from main import app
@app.get("/API/test-case")
@app.get("/api/test-case")
async def test_case(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
# Test uppercase
diff --git a/tests/integration/middleware/test_middleware_stack.py b/tests/integration/middleware/test_middleware_stack.py
index f61efea3..4578159f 100644
--- a/tests/integration/middleware/test_middleware_stack.py
+++ b/tests/integration/middleware/test_middleware_stack.py
@@ -5,8 +5,10 @@ Integration tests for the complete middleware stack.
These tests verify that all middleware components work together correctly
through real HTTP requests, ensuring proper execution order and state injection.
"""
-import pytest
from unittest.mock import patch
+
+import pytest
+
from middleware.context import RequestContext
@@ -23,14 +25,20 @@ class TestMiddlewareStackIntegration:
"""Test that /admin/* paths set ADMIN context type."""
# Create a simple endpoint to inspect request state
from fastapi import Request
+
from main import app
@app.get("/admin/test-context")
async def test_admin_context(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "has_theme": hasattr(request.state, 'theme')
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "has_theme": hasattr(request.state, "theme"),
}
response = client.get("/admin/test-context")
@@ -44,20 +52,24 @@ class TestMiddlewareStackIntegration:
def test_admin_subdomain_sets_admin_context(self, client):
"""Test that admin.* subdomain sets ADMIN context type."""
from fastapi import Request
+
from main import app
@app.get("/test-admin-subdomain")
async def test_admin_subdomain(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
# Simulate request with admin subdomain
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-admin-subdomain",
- headers={"host": "admin.platform.com"}
+ "/test-admin-subdomain", headers={"host": "admin.platform.com"}
)
assert response.status_code == 200
@@ -71,12 +83,17 @@ class TestMiddlewareStackIntegration:
def test_api_path_sets_api_context(self, client):
"""Test that /api/* paths set API context type."""
from fastapi import Request
+
from main import app
@app.get("/api/test-context")
async def test_api_context(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ )
}
response = client.get("/api/test-context")
@@ -89,25 +106,40 @@ class TestMiddlewareStackIntegration:
# Vendor Dashboard Context Tests
# ========================================================================
- def test_vendor_dashboard_path_sets_vendor_context(self, client, vendor_with_subdomain):
+ def test_vendor_dashboard_path_sets_vendor_context(
+ self, client, vendor_with_subdomain
+ ):
"""Test that /vendor/* paths with vendor set VENDOR_DASHBOARD context."""
from fastapi import Request
+
from main import app
@app.get("/vendor/test-context")
async def test_vendor_context(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor")
+ else None
+ ),
}
# Request with vendor subdomain
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-context",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -120,24 +152,35 @@ class TestMiddlewareStackIntegration:
# Shop Context Tests
# ========================================================================
- def test_shop_path_with_subdomain_sets_shop_context(self, client, vendor_with_subdomain):
+ def test_shop_path_with_subdomain_sets_shop_context(
+ self, client, vendor_with_subdomain
+ ):
"""Test that /shop/* paths with vendor subdomain set SHOP context."""
from fastapi import Request
+
from main import app
@app.get("/shop/test-context")
async def test_shop_context(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
- "has_theme": hasattr(request.state, 'theme')
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
+ "has_theme": hasattr(request.state, "theme"),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-context",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -146,23 +189,33 @@ class TestMiddlewareStackIntegration:
assert data["vendor_id"] == vendor_with_subdomain.id
assert data["has_theme"] is True
- def test_shop_path_with_custom_domain_sets_shop_context(self, client, vendor_with_custom_domain):
+ def test_shop_path_with_custom_domain_sets_shop_context(
+ self, client, vendor_with_custom_domain
+ ):
"""Test that /shop/* paths with custom domain set SHOP context."""
from fastapi import Request
+
from main import app
@app.get("/shop/test-custom-domain")
async def test_shop_custom_domain(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/shop/test-custom-domain",
- headers={"host": "customdomain.com"}
+ "/shop/test-custom-domain", headers={"host": "customdomain.com"}
)
assert response.status_code == 200
@@ -174,9 +227,12 @@ class TestMiddlewareStackIntegration:
# Middleware Execution Order Tests
# ========================================================================
- def test_vendor_context_runs_before_context_detection(self, client, vendor_with_subdomain):
+ def test_vendor_context_runs_before_context_detection(
+ self, client, vendor_with_subdomain
+ ):
"""Test that VendorContextMiddleware runs before ContextDetectionMiddleware."""
from fastapi import Request
+
from main import app
@app.get("/test-execution-order")
@@ -184,16 +240,16 @@ class TestMiddlewareStackIntegration:
# If vendor context runs first, clean_path should be available
# before context detection uses it
return {
- "has_vendor": hasattr(request.state, 'vendor'),
- "has_clean_path": hasattr(request.state, 'clean_path'),
- "has_context_type": hasattr(request.state, 'context_type')
+ "has_vendor": hasattr(request.state, "vendor"),
+ "has_clean_path": hasattr(request.state, "clean_path"),
+ "has_context_type": hasattr(request.state, "context_type"),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-execution-order",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -206,21 +262,26 @@ class TestMiddlewareStackIntegration:
def test_theme_context_runs_after_vendor_context(self, client, vendor_with_theme):
"""Test that ThemeContextMiddleware runs after VendorContextMiddleware."""
from fastapi import Request
+
from main import app
@app.get("/test-theme-loading")
async def test_theme_loading(request: Request):
return {
- "has_vendor": hasattr(request.state, 'vendor'),
- "has_theme": hasattr(request.state, 'theme'),
- "theme_primary_color": request.state.theme.get('primary_color') if hasattr(request.state, 'theme') else None
+ "has_vendor": hasattr(request.state, "vendor"),
+ "has_theme": hasattr(request.state, "theme"),
+ "theme_primary_color": (
+ request.state.theme.get("primary_color")
+ if hasattr(request.state, "theme")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-loading",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -249,21 +310,27 @@ class TestMiddlewareStackIntegration:
def test_missing_vendor_graceful_handling(self, client):
"""Test that missing vendor is handled gracefully."""
from fastapi import Request
+
from main import app
@app.get("/test-missing-vendor")
async def test_missing_vendor(request: Request):
return {
- "has_vendor": hasattr(request.state, 'vendor'),
- "vendor": request.state.vendor if hasattr(request.state, 'vendor') else None,
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
+ "has_vendor": hasattr(request.state, "vendor"),
+ "vendor": (
+ request.state.vendor if hasattr(request.state, "vendor") else None
+ ),
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-missing-vendor",
- headers={"host": "nonexistent.platform.com"}
+ "/test-missing-vendor", headers={"host": "nonexistent.platform.com"}
)
assert response.status_code == 200
@@ -276,20 +343,23 @@ class TestMiddlewareStackIntegration:
def test_inactive_vendor_not_loaded(self, client, inactive_vendor):
"""Test that inactive vendors are not loaded."""
from fastapi import Request
+
from main import app
@app.get("/test-inactive-vendor")
async def test_inactive_vendor_endpoint(request: Request):
return {
- "has_vendor": hasattr(request.state, 'vendor'),
- "vendor": request.state.vendor if hasattr(request.state, 'vendor') else None
+ "has_vendor": hasattr(request.state, "vendor"),
+ "vendor": (
+ request.state.vendor if hasattr(request.state, "vendor") else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-inactive-vendor",
- headers={"host": f"{inactive_vendor.subdomain}.platform.com"}
+ headers={"host": f"{inactive_vendor.subdomain}.platform.com"},
)
assert response.status_code == 200
diff --git a/tests/integration/middleware/test_theme_loading_flow.py b/tests/integration/middleware/test_theme_loading_flow.py
index 9b19bcfc..00bb55cc 100644
--- a/tests/integration/middleware/test_theme_loading_flow.py
+++ b/tests/integration/middleware/test_theme_loading_flow.py
@@ -5,9 +5,10 @@ Integration tests for theme loading end-to-end flow.
These tests verify that vendor themes are correctly loaded and injected
into request.state through real HTTP requests.
"""
-import pytest
from unittest.mock import patch
+import pytest
+
@pytest.mark.integration
@pytest.mark.middleware
@@ -22,21 +23,22 @@ class TestThemeLoadingFlow:
def test_theme_loaded_for_vendor_with_custom_theme(self, client, vendor_with_theme):
"""Test that custom theme is loaded for vendor with theme."""
from fastapi import Request
+
from main import app
@app.get("/test-theme-loading")
async def test_theme(request: Request):
- theme = request.state.theme if hasattr(request.state, 'theme') else None
+ theme = request.state.theme if hasattr(request.state, "theme") else None
return {
"has_theme": theme is not None,
- "theme_data": theme if theme else None
+ "theme_data": theme if theme else None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-loading",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -46,24 +48,27 @@ class TestThemeLoadingFlow:
assert data["theme_data"]["primary_color"] == "#FF5733"
assert data["theme_data"]["secondary_color"] == "#33FF57"
- def test_default_theme_loaded_for_vendor_without_theme(self, client, vendor_with_subdomain):
+ def test_default_theme_loaded_for_vendor_without_theme(
+ self, client, vendor_with_subdomain
+ ):
"""Test that default theme is loaded for vendor without custom theme."""
from fastapi import Request
+
from main import app
@app.get("/test-default-theme")
async def test_default_theme(request: Request):
- theme = request.state.theme if hasattr(request.state, 'theme') else None
+ theme = request.state.theme if hasattr(request.state, "theme") else None
return {
"has_theme": theme is not None,
- "theme_data": theme if theme else None
+ "theme_data": theme if theme else None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-default-theme",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -76,21 +81,20 @@ class TestThemeLoadingFlow:
def test_no_theme_loaded_without_vendor(self, client):
"""Test that no theme is loaded when there's no vendor."""
from fastapi import Request
+
from main import app
@app.get("/test-no-theme")
async def test_no_theme(request: Request):
return {
- "has_theme": hasattr(request.state, 'theme'),
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "has_theme": hasattr(request.state, "theme"),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
- response = client.get(
- "/test-no-theme",
- headers={"host": "platform.com"}
- )
+ response = client.get("/test-no-theme", headers={"host": "platform.com"})
assert response.status_code == 200
data = response.json()
@@ -105,24 +109,25 @@ class TestThemeLoadingFlow:
def test_custom_theme_contains_all_fields(self, client, vendor_with_theme):
"""Test that custom theme contains all expected fields."""
from fastapi import Request
+
from main import app
@app.get("/test-theme-fields")
async def test_theme_fields(request: Request):
- theme = request.state.theme if hasattr(request.state, 'theme') else {}
+ theme = request.state.theme if hasattr(request.state, "theme") else {}
return {
"primary_color": theme.get("primary_color"),
"secondary_color": theme.get("secondary_color"),
"logo_url": theme.get("logo_url"),
"favicon_url": theme.get("favicon_url"),
- "custom_css": theme.get("custom_css")
+ "custom_css": theme.get("custom_css"),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-fields",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -136,24 +141,25 @@ class TestThemeLoadingFlow:
def test_default_theme_structure(self, client, vendor_with_subdomain):
"""Test that default theme has expected structure."""
from fastapi import Request
+
from main import app
@app.get("/test-default-theme-structure")
async def test_default_structure(request: Request):
- theme = request.state.theme if hasattr(request.state, 'theme') else {}
+ theme = request.state.theme if hasattr(request.state, "theme") else {}
return {
"has_primary_color": "primary_color" in theme,
"has_secondary_color": "secondary_color" in theme,
"has_logo_url": "logo_url" in theme,
"has_favicon_url": "favicon_url" in theme,
- "has_custom_css": "custom_css" in theme
+ "has_custom_css": "custom_css" in theme,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-default-theme-structure",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -169,21 +175,30 @@ class TestThemeLoadingFlow:
def test_theme_loaded_in_shop_context(self, client, vendor_with_theme):
"""Test that theme is loaded in SHOP context."""
from fastapi import Request
+
from main import app
@app.get("/shop/test-shop-theme")
async def test_shop_theme(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_theme": hasattr(request.state, 'theme'),
- "theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_theme": hasattr(request.state, "theme"),
+ "theme_primary": (
+ request.state.theme.get("primary_color")
+ if hasattr(request.state, "theme")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-shop-theme",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -195,21 +210,30 @@ class TestThemeLoadingFlow:
def test_theme_loaded_in_vendor_dashboard_context(self, client, vendor_with_theme):
"""Test that theme is loaded in VENDOR_DASHBOARD context."""
from fastapi import Request
+
from main import app
@app.get("/vendor/test-dashboard-theme")
async def test_dashboard_theme(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_theme": hasattr(request.state, 'theme'),
- "theme_secondary": request.state.theme.get("secondary_color") if hasattr(request.state, 'theme') else None
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_theme": hasattr(request.state, "theme"),
+ "theme_secondary": (
+ request.state.theme.get("secondary_color")
+ if hasattr(request.state, "theme")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-dashboard-theme",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -221,21 +245,27 @@ class TestThemeLoadingFlow:
def test_theme_loaded_in_api_context_with_vendor(self, client, vendor_with_theme):
"""Test that theme is loaded in API context when vendor is present."""
from fastapi import Request
+
from main import app
@app.get("/api/test-api-theme")
async def test_api_theme(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "has_theme": hasattr(request.state, 'theme')
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "has_theme": hasattr(request.state, "theme"),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/api/test-api-theme",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -248,13 +278,18 @@ class TestThemeLoadingFlow:
def test_no_theme_in_admin_context(self, client):
"""Test that theme is not loaded in ADMIN context (no vendor)."""
from fastapi import Request
+
from main import app
@app.get("/admin/test-admin-no-theme")
async def test_admin_no_theme(request: Request):
return {
- "context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
- "has_theme": hasattr(request.state, 'theme')
+ "context_type": (
+ request.state.context_type.value
+ if hasattr(request.state, "context_type")
+ else None
+ ),
+ "has_theme": hasattr(request.state, "theme"),
}
response = client.get("/admin/test-admin-no-theme")
@@ -272,20 +307,29 @@ class TestThemeLoadingFlow:
def test_theme_loaded_with_subdomain_routing(self, client, vendor_with_theme):
"""Test theme loading with subdomain routing."""
from fastapi import Request
+
from main import app
@app.get("/test-subdomain-theme")
async def test_subdomain_theme(request: Request):
return {
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
- "theme_logo": request.state.theme.get("logo_url") if hasattr(request.state, 'theme') else None
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
+ "theme_logo": (
+ request.state.theme.get("logo_url")
+ if hasattr(request.state, "theme")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-subdomain-theme",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -293,33 +337,40 @@ class TestThemeLoadingFlow:
assert data["vendor_code"] == vendor_with_theme.code
assert data["theme_logo"] == "/static/vendors/themedvendor/logo.png"
- def test_theme_loaded_with_custom_domain_routing(self, client, vendor_with_custom_domain, db):
+ def test_theme_loaded_with_custom_domain_routing(
+ self, client, vendor_with_custom_domain, db
+ ):
"""Test theme loading with custom domain routing."""
# Add theme to custom domain vendor
from models.database.vendor_theme import VendorTheme
+
theme = VendorTheme(
vendor_id=vendor_with_custom_domain.id,
primary_color="#123456",
- secondary_color="#654321"
+ secondary_color="#654321",
)
db.add(theme)
db.commit()
from fastapi import Request
+
from main import app
@app.get("/test-custom-domain-theme")
async def test_custom_domain_theme(request: Request):
return {
- "has_theme": hasattr(request.state, 'theme'),
- "theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
+ "has_theme": hasattr(request.state, "theme"),
+ "theme_primary": (
+ request.state.theme.get("primary_color")
+ if hasattr(request.state, "theme")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-custom-domain-theme",
- headers={"host": "customdomain.com"}
+ "/test-custom-domain-theme", headers={"host": "customdomain.com"}
)
assert response.status_code == 200
@@ -331,28 +382,37 @@ class TestThemeLoadingFlow:
# Theme Dependency on Vendor Context Tests
# ========================================================================
- def test_theme_middleware_depends_on_vendor_middleware(self, client, vendor_with_theme):
+ def test_theme_middleware_depends_on_vendor_middleware(
+ self, client, vendor_with_theme
+ ):
"""Test that theme loading depends on vendor being detected first."""
from fastapi import Request
+
from main import app
@app.get("/test-theme-vendor-dependency")
async def test_dependency(request: Request):
return {
- "has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
- "has_theme": hasattr(request.state, 'theme'),
+ "has_vendor": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
+ "has_theme": hasattr(request.state, "theme"),
"vendor_matches_theme": (
request.state.vendor_id == vendor_with_theme.id
- if hasattr(request.state, 'vendor_id') else False
- )
+ if hasattr(request.state, "vendor_id")
+ else False
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-vendor-dependency",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -369,15 +429,20 @@ class TestThemeLoadingFlow:
def test_theme_loaded_consistently_across_requests(self, client, vendor_with_theme):
"""Test that theme is loaded consistently across multiple requests."""
from fastapi import Request
+
from main import app
@app.get("/test-theme-consistency")
async def test_consistency(request: Request):
return {
- "theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
+ "theme_primary": (
+ request.state.theme.get("primary_color")
+ if hasattr(request.state, "theme")
+ else None
+ )
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
# Make multiple requests
@@ -385,7 +450,7 @@ class TestThemeLoadingFlow:
for _ in range(3):
response = client.get(
"/test-theme-consistency",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
responses.append(response.json())
@@ -396,25 +461,28 @@ class TestThemeLoadingFlow:
# Edge Cases and Error Handling Tests
# ========================================================================
- def test_theme_gracefully_handles_missing_theme_fields(self, client, vendor_with_subdomain):
+ def test_theme_gracefully_handles_missing_theme_fields(
+ self, client, vendor_with_subdomain
+ ):
"""Test that missing theme fields are handled gracefully."""
from fastapi import Request
+
from main import app
@app.get("/test-partial-theme")
async def test_partial_theme(request: Request):
- theme = request.state.theme if hasattr(request.state, 'theme') else {}
+ theme = request.state.theme if hasattr(request.state, "theme") else {}
return {
"has_theme": bool(theme),
"primary_color": theme.get("primary_color", "default"),
- "logo_url": theme.get("logo_url", "default")
+ "logo_url": theme.get("logo_url", "default"),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-partial-theme",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -427,23 +495,21 @@ class TestThemeLoadingFlow:
def test_theme_dict_is_mutable(self, client, vendor_with_theme):
"""Test that theme dict can be accessed and read from."""
from fastapi import Request
+
from main import app
@app.get("/test-theme-mutable")
async def test_mutable(request: Request):
- theme = request.state.theme if hasattr(request.state, 'theme') else {}
+ theme = request.state.theme if hasattr(request.state, "theme") else {}
# Try to access theme values
primary = theme.get("primary_color")
- return {
- "can_read": primary is not None,
- "value": primary
- }
+ return {"can_read": primary is not None, "value": primary}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-mutable",
- headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
diff --git a/tests/integration/middleware/test_vendor_context_flow.py b/tests/integration/middleware/test_vendor_context_flow.py
index 25a5bf94..657ab068 100644
--- a/tests/integration/middleware/test_vendor_context_flow.py
+++ b/tests/integration/middleware/test_vendor_context_flow.py
@@ -5,9 +5,10 @@ Integration tests for vendor context detection end-to-end flow.
These tests verify that vendor detection works correctly through real HTTP requests
for all routing modes: subdomain, custom domain, and path-based.
"""
-import pytest
from unittest.mock import patch
+import pytest
+
@pytest.mark.integration
@pytest.mark.middleware
@@ -22,23 +23,37 @@ class TestVendorContextFlow:
def test_subdomain_vendor_detection(self, client, vendor_with_subdomain):
"""Test vendor detection via subdomain routing."""
from fastapi import Request
+
from main import app
@app.get("/test-subdomain-detection")
async def test_subdomain(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
- "vendor_name": request.state.vendor.name if hasattr(request.state, 'vendor') and request.state.vendor else None,
- "detection_method": "subdomain"
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
+ "vendor_name": (
+ request.state.vendor.name
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
+ "detection_method": "subdomain",
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-subdomain-detection",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -51,20 +66,28 @@ class TestVendorContextFlow:
def test_subdomain_with_port_detection(self, client, vendor_with_subdomain):
"""Test vendor detection via subdomain with port number."""
from fastapi import Request
+
from main import app
@app.get("/test-subdomain-port")
async def test_subdomain_port(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-subdomain-port",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"}
+ headers={
+ "host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"
+ },
)
assert response.status_code == 200
@@ -75,20 +98,24 @@ class TestVendorContextFlow:
def test_nonexistent_subdomain_returns_no_vendor(self, client):
"""Test that nonexistent subdomain doesn't crash and returns no vendor."""
from fastapi import Request
+
from main import app
@app.get("/test-nonexistent-subdomain")
async def test_nonexistent(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor": request.state.vendor if hasattr(request.state, 'vendor') else None
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor": (
+ request.state.vendor if hasattr(request.state, "vendor") else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-nonexistent-subdomain",
- headers={"host": "nonexistent.platform.com"}
+ headers={"host": "nonexistent.platform.com"},
)
assert response.status_code == 200
@@ -102,22 +129,31 @@ class TestVendorContextFlow:
def test_custom_domain_vendor_detection(self, client, vendor_with_custom_domain):
"""Test vendor detection via custom domain."""
from fastapi import Request
+
from main import app
@app.get("/test-custom-domain")
async def test_custom_domain(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
- "detection_method": "custom_domain"
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
+ "detection_method": "custom_domain",
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-custom-domain",
- headers={"host": "customdomain.com"}
+ "/test-custom-domain", headers={"host": "customdomain.com"}
)
assert response.status_code == 200
@@ -129,21 +165,26 @@ class TestVendorContextFlow:
def test_custom_domain_with_www_detection(self, client, vendor_with_custom_domain):
"""Test vendor detection via custom domain with www prefix."""
from fastapi import Request
+
from main import app
@app.get("/test-custom-domain-www")
async def test_custom_domain_www(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
# Test with www prefix - should still detect vendor
response = client.get(
- "/test-custom-domain-www",
- headers={"host": "www.customdomain.com"}
+ "/test-custom-domain-www", headers={"host": "www.customdomain.com"}
)
# This might fail if your implementation doesn't strip www
@@ -154,25 +195,37 @@ class TestVendorContextFlow:
# Path-Based Detection Tests (Development Mode)
# ========================================================================
- def test_path_based_vendor_detection_vendors_prefix(self, client, vendor_with_subdomain):
+ def test_path_based_vendor_detection_vendors_prefix(
+ self, client, vendor_with_subdomain
+ ):
"""Test vendor detection via path-based routing with /vendors/ prefix."""
from fastapi import Request
+
from main import app
@app.get("/vendors/{vendor_code}/test-path")
async def test_path_based(vendor_code: str, request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
"vendor_code_param": vendor_code,
- "vendor_code_state": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
- "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None
+ "vendor_code_state": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
+ "clean_path": (
+ request.state.clean_path
+ if hasattr(request.state, "clean_path")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendors/{vendor_with_subdomain.code}/test-path",
- headers={"host": "localhost:8000"}
+ headers={"host": "localhost:8000"},
)
assert response.status_code == 200
@@ -181,23 +234,31 @@ class TestVendorContextFlow:
assert data["vendor_code_param"] == vendor_with_subdomain.code
assert data["vendor_code_state"] == vendor_with_subdomain.code
- def test_path_based_vendor_detection_vendor_prefix(self, client, vendor_with_subdomain):
+ def test_path_based_vendor_detection_vendor_prefix(
+ self, client, vendor_with_subdomain
+ ):
"""Test vendor detection via path-based routing with /vendor/ prefix."""
from fastapi import Request
+
from main import app
@app.get("/vendor/{vendor_code}/test")
async def test_vendor_path(vendor_code: str, request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
- "vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None,
+ "vendor_code": (
+ request.state.vendor.code
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendor/{vendor_with_subdomain.code}/test",
- headers={"host": "localhost:8000"}
+ headers={"host": "localhost:8000"},
)
assert response.status_code == 200
@@ -209,48 +270,63 @@ class TestVendorContextFlow:
# Clean Path Extraction Tests
# ========================================================================
- def test_clean_path_extracted_from_vendor_prefix(self, client, vendor_with_subdomain):
+ def test_clean_path_extracted_from_vendor_prefix(
+ self, client, vendor_with_subdomain
+ ):
"""Test that clean_path is correctly extracted from path-based routing."""
from fastapi import Request
+
from main import app
@app.get("/vendors/{vendor_code}/shop/products")
async def test_clean_path(vendor_code: str, request: Request):
return {
- "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
- "original_path": request.url.path
+ "clean_path": (
+ request.state.clean_path
+ if hasattr(request.state, "clean_path")
+ else None
+ ),
+ "original_path": request.url.path,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendors/{vendor_with_subdomain.code}/shop/products",
- headers={"host": "localhost:8000"}
+ headers={"host": "localhost:8000"},
)
assert response.status_code == 200
data = response.json()
# Clean path should have vendor prefix removed
assert data["clean_path"] == "/shop/products"
- assert f"/vendors/{vendor_with_subdomain.code}/shop/products" in data["original_path"]
+ assert (
+ f"/vendors/{vendor_with_subdomain.code}/shop/products"
+ in data["original_path"]
+ )
def test_clean_path_unchanged_for_subdomain(self, client, vendor_with_subdomain):
"""Test that clean_path equals original path for subdomain routing."""
from fastapi import Request
+
from main import app
@app.get("/shop/test-clean-path")
async def test_subdomain_clean_path(request: Request):
return {
- "clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
- "original_path": request.url.path
+ "clean_path": (
+ request.state.clean_path
+ if hasattr(request.state, "clean_path")
+ else None
+ ),
+ "original_path": request.url.path,
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-clean-path",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -266,21 +342,30 @@ class TestVendorContextFlow:
def test_vendor_id_injected_into_request_state(self, client, vendor_with_subdomain):
"""Test that vendor_id is correctly injected into request.state."""
from fastapi import Request
+
from main import app
@app.get("/test-vendor-id-injection")
async def test_vendor_id(request: Request):
return {
- "has_vendor_id": hasattr(request.state, 'vendor_id'),
- "vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
- "vendor_id_type": type(request.state.vendor_id).__name__ if hasattr(request.state, 'vendor_id') else None
+ "has_vendor_id": hasattr(request.state, "vendor_id"),
+ "vendor_id": (
+ request.state.vendor_id
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
+ "vendor_id_type": (
+ type(request.state.vendor_id).__name__
+ if hasattr(request.state, "vendor_id")
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-vendor-id-injection",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -289,30 +374,41 @@ class TestVendorContextFlow:
assert data["vendor_id"] == vendor_with_subdomain.id
assert data["vendor_id_type"] == "int"
- def test_vendor_object_injected_into_request_state(self, client, vendor_with_subdomain):
+ def test_vendor_object_injected_into_request_state(
+ self, client, vendor_with_subdomain
+ ):
"""Test that full vendor object is injected into request.state."""
from fastapi import Request
+
from main import app
@app.get("/test-vendor-object-injection")
async def test_vendor_object(request: Request):
- vendor = request.state.vendor if hasattr(request.state, 'vendor') and request.state.vendor else None
+ vendor = (
+ request.state.vendor
+ if hasattr(request.state, "vendor") and request.state.vendor
+ else None
+ )
return {
"has_vendor": vendor is not None,
- "vendor_attributes": {
- "id": vendor.id if vendor else None,
- "name": vendor.name if vendor else None,
- "code": vendor.code if vendor else None,
- "subdomain": vendor.subdomain if vendor else None,
- "is_active": vendor.is_active if vendor else None
- } if vendor else None
+ "vendor_attributes": (
+ {
+ "id": vendor.id if vendor else None,
+ "name": vendor.name if vendor else None,
+ "code": vendor.code if vendor else None,
+ "subdomain": vendor.subdomain if vendor else None,
+ "is_active": vendor.is_active if vendor else None,
+ }
+ if vendor
+ else None
+ ),
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-vendor-object-injection",
- headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
+ headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -330,19 +426,21 @@ class TestVendorContextFlow:
def test_inactive_vendor_not_detected(self, client, inactive_vendor):
"""Test that inactive vendors are not detected."""
from fastapi import Request
+
from main import app
@app.get("/test-inactive-vendor-detection")
async def test_inactive(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-inactive-vendor-detection",
- headers={"host": f"{inactive_vendor.subdomain}.platform.com"}
+ headers={"host": f"{inactive_vendor.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -352,19 +450,20 @@ class TestVendorContextFlow:
def test_platform_domain_without_subdomain_no_vendor(self, client):
"""Test that platform domain without subdomain doesn't detect vendor."""
from fastapi import Request
+
from main import app
@app.get("/test-platform-domain")
async def test_platform(request: Request):
return {
- "vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None
+ "vendor_detected": hasattr(request.state, "vendor")
+ and request.state.vendor is not None
}
- with patch('app.core.config.settings') as mock_settings:
+ with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
- "/test-platform-domain",
- headers={"host": "platform.com"}
+ "/test-platform-domain", headers={"host": "platform.com"}
)
assert response.status_code == 200
diff --git a/tests/integration/security/test_input_validation.py b/tests/integration/security/test_input_validation.py
index 1b44136f..17cb6433 100644
--- a/tests/integration/security/test_input_validation.py
+++ b/tests/integration/security/test_input_validation.py
@@ -11,7 +11,8 @@ class TestInputValidation:
malicious_search = "'; DROP TABLE products; --"
response = client.get(
- f"/api/v1/marketplace/product?search={malicious_search}", headers=auth_headers
+ f"/api/v1/marketplace/product?search={malicious_search}",
+ headers=auth_headers,
)
# Should not crash and should return normal response
@@ -40,17 +41,23 @@ class TestInputValidation:
def test_parameter_validation(self, client, auth_headers):
"""Test parameter validation for API endpoints"""
# Test invalid pagination parameters
- response = client.get("/api/v1/marketplace/product?limit=-1", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=-1", headers=auth_headers
+ )
assert response.status_code == 422 # Validation error
- response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?skip=-1", headers=auth_headers
+ )
assert response.status_code == 422 # Validation error
def test_json_validation(self, client, auth_headers):
"""Test JSON validation for POST requests"""
# Test invalid JSON structure
response = client.post(
- "/api/v1/marketplace/product", headers=auth_headers, content="invalid json content"
+ "/api/v1/marketplace/product",
+ headers=auth_headers,
+ content="invalid json content",
)
assert response.status_code == 422 # JSON decode error
@@ -58,6 +65,8 @@ class TestInputValidation:
response = client.post(
"/api/v1/marketplace/product",
headers=auth_headers,
- json={"title": "Test MarketplaceProduct"}, # Missing required marketplace_product_id
+ json={
+ "title": "Test MarketplaceProduct"
+ }, # Missing required marketplace_product_id
)
assert response.status_code == 422 # Validation error
diff --git a/tests/integration/workflows/test_integration.py b/tests/integration/workflows/test_integration.py
index a10cfe01..ba50d374 100644
--- a/tests/integration/workflows/test_integration.py
+++ b/tests/integration/workflows/test_integration.py
@@ -33,12 +33,15 @@ class TestIntegrationFlows:
"quantity": 50,
}
- response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
+ response = client.post(
+ "/api/v1/inventory", headers=auth_headers, json=inventory_data
+ )
assert response.status_code == 200
# 3. Get product with inventory info
response = client.get(
- f"/api/v1/marketplace/product/{product['marketplace_product_id']}", headers=auth_headers
+ f"/api/v1/marketplace/product/{product['marketplace_product_id']}",
+ headers=auth_headers,
)
assert response.status_code == 200
product_detail = response.json()
@@ -55,14 +58,15 @@ class TestIntegrationFlows:
# 5. Search for product
response = client.get(
- "/api/v1/marketplace/product?search=Updated Integration", headers=auth_headers
+ "/api/v1/marketplace/product?search=Updated Integration",
+ headers=auth_headers,
)
assert response.status_code == 200
assert response.json()["total"] == 1
def test_product_workflow(self, client, auth_headers):
"""Test vendor creation and product management workflow"""
- # 1. Create a vendor
+ # 1. Create a vendor
vendor_data = {
"vendor_code": "FLOWVENDOR",
"name": "Integration Flow Vendor",
@@ -91,7 +95,9 @@ class TestIntegrationFlows:
# This would test the vendor -product association
# 4. Get vendor details
- response = client.get(f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers
+ )
assert response.status_code == 200
def test_inventory_operations_workflow(self, client, auth_headers):
diff --git a/tests/performance/test_api_performance.py b/tests/performance/test_api_performance.py
index e60f7888..5d442052 100644
--- a/tests/performance/test_api_performance.py
+++ b/tests/performance/test_api_performance.py
@@ -28,7 +28,9 @@ class TestPerformance:
# Time the request
start_time = time.time()
- response = client.get("/api/v1/marketplace/product?limit=100", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?limit=100", headers=auth_headers
+ )
end_time = time.time()
assert response.status_code == 200
@@ -54,7 +56,9 @@ class TestPerformance:
# Time search request
start_time = time.time()
- response = client.get("/api/v1/marketplace/product?search=Searchable", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product?search=Searchable", headers=auth_headers
+ )
end_time = time.time()
assert response.status_code == 200
@@ -115,7 +119,8 @@ class TestPerformance:
for offset in offsets:
start_time = time.time()
response = client.get(
- f"/api/v1/marketplace/product?skip={offset}&limit=20", headers=auth_headers
+ f"/api/v1/marketplace/product?skip={offset}&limit=20",
+ headers=auth_headers,
)
end_time = time.time()
diff --git a/tests/system/test_error_handling.py b/tests/system/test_error_handling.py
index 9247035f..1cde868b 100644
--- a/tests/system/test_error_handling.py
+++ b/tests/system/test_error_handling.py
@@ -5,9 +5,10 @@ System tests for error handling across the LetzVendor API.
Tests the complete error handling flow from FastAPI through custom exception handlers
to ensure proper HTTP status codes, error structures, and client-friendly responses.
"""
-import pytest
import json
+import pytest
+
@pytest.mark.system
class TestErrorHandling:
@@ -16,9 +17,7 @@ class TestErrorHandling:
def test_invalid_json_request(self, client, auth_headers):
"""Test handling of malformed JSON requests"""
response = client.post(
- "/api/v1/vendor",
- headers=auth_headers,
- content="{ invalid json syntax"
+ "/api/v1/vendor", headers=auth_headers, content="{ invalid json syntax"
)
assert response.status_code == 422
@@ -31,9 +30,7 @@ class TestErrorHandling:
"""Test validation errors for missing required fields"""
# Missing name
response = client.post(
- "/api/v1/vendor",
- headers=auth_headers,
- json={"vendor_code": "TESTVENDOR"}
+ "/api/v1/vendor", headers=auth_headers, json={"vendor_code": "TESTVENDOR"}
)
assert response.status_code == 422
@@ -48,10 +45,7 @@ class TestErrorHandling:
response = client.post(
"/api/v1/vendor",
headers=auth_headers,
- json={
- "vendor_code": "INVALID@VENDOR!",
- "name": "Test Vendor"
- }
+ json={"vendor_code": "INVALID@VENDOR!", "name": "Test Vendor"},
)
assert response.status_code == 422
@@ -92,7 +86,7 @@ class TestErrorHandling:
assert data["status_code"] == 401
def test_vendor_not_found(self, client, auth_headers):
- """Test accessing non-existent vendor """
+ """Test accessing non-existent vendor"""
response = client.get("/api/v1/vendor/NONEXISTENT", headers=auth_headers)
assert response.status_code == 404
@@ -104,7 +98,9 @@ class TestErrorHandling:
def test_product_not_found(self, client, auth_headers):
"""Test accessing non-existent product"""
- response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
+ response = client.get(
+ "/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
+ )
assert response.status_code == 404
data = response.json()
@@ -117,7 +113,7 @@ class TestErrorHandling:
"""Test creating vendor with duplicate vendor code"""
vendor_data = {
"vendor_code": test_vendor.vendor_code,
- "name": "Duplicate Vendor"
+ "name": "Duplicate Vendor",
}
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
@@ -128,25 +124,34 @@ class TestErrorHandling:
assert data["status_code"] == 409
assert data["details"]["vendor_code"] == test_vendor.vendor_code.upper()
- def test_duplicate_product_creation(self, client, auth_headers, test_marketplace_product):
+ def test_duplicate_product_creation(
+ self, client, auth_headers, test_marketplace_product
+ ):
"""Test creating product with duplicate product ID"""
product_data = {
"marketplace_product_id": test_marketplace_product.marketplace_product_id,
"title": "Duplicate MarketplaceProduct",
- "gtin": "1234567890123"
+ "gtin": "1234567890123",
}
- response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data)
+ response = client.post(
+ "/api/v1/marketplace/product", headers=auth_headers, json=product_data
+ )
assert response.status_code == 409
data = response.json()
assert data["error_code"] == "PRODUCT_ALREADY_EXISTS"
assert data["status_code"] == 409
- assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
+ assert (
+ data["details"]["marketplace_product_id"]
+ == test_marketplace_product.marketplace_product_id
+ )
def test_unauthorized_vendor_access(self, client, auth_headers, inactive_vendor):
"""Test accessing vendor without proper permissions"""
- response = client.get(f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers)
+ response = client.get(
+ f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers
+ )
assert response.status_code == 403
data = response.json()
@@ -154,27 +159,33 @@ class TestErrorHandling:
assert data["status_code"] == 403
assert data["details"]["vendor_code"] == inactive_vendor.vendor_code
- def test_insufficient_permissions(self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"):
+ def test_insufficient_permissions(
+ self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"
+ ):
"""Test accessing admin endpoints with regular user"""
response = client.get(admin_only_endpoint, headers=auth_headers)
- assert response.status_code in [403, 404] # 403 for permission denied, 404 if endpoint doesn't exist
+ assert response.status_code in [
+ 403,
+ 404,
+ ] # 403 for permission denied, 404 if endpoint doesn't exist
if response.status_code == 403:
data = response.json()
assert data["error_code"] in ["ADMIN_REQUIRED", "INSUFFICIENT_PERMISSIONS"]
assert data["status_code"] == 403
- def test_business_logic_violation_max_vendors(self, client, auth_headers, monkeypatch):
+ def test_business_logic_violation_max_vendors(
+ self, client, auth_headers, monkeypatch
+ ):
"""Test business logic violation - creating too many vendors"""
# This test would require mocking the vendor limit check
# For now, test the error structure when creating multiple vendors
vendors_created = []
for i in range(6): # Assume limit is 5
- vendor_data = {
- "vendor_code": f"VENDOR{i:03d}",
- "name": f"Test Vendor {i}"
- }
- response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
+ vendor_data = {"vendor_code": f"VENDOR{i:03d}", "name": f"Test Vendor {i}"}
+ response = client.post(
+ "/api/v1/vendor", headers=auth_headers, json=vendor_data
+ )
vendors_created.append(response)
# At least one should succeed, and if limit is enforced, later ones should fail
@@ -193,10 +204,12 @@ class TestErrorHandling:
product_data = {
"marketplace_product_id": "TESTPROD001",
"title": "Test MarketplaceProduct",
- "gtin": "invalid_gtin_format"
+ "gtin": "invalid_gtin_format",
}
- response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data)
+ response = client.post(
+ "/api/v1/marketplace/product", headers=auth_headers, json=product_data
+ )
assert response.status_code == 422
data = response.json()
@@ -204,13 +217,15 @@ class TestErrorHandling:
assert data["status_code"] == 422
assert data["details"]["field"] == "gtin"
- def test_inventory_insufficient_quantity(self, client, auth_headers, test_vendor, test_marketplace_product):
+ def test_inventory_insufficient_quantity(
+ self, client, auth_headers, test_vendor, test_marketplace_product
+ ):
"""Test business logic error for insufficient inventory"""
# First create some inventory
inventory_data = {
"gtin": test_marketplace_product.gtin,
"location": "WAREHOUSE_A",
- "quantity": 5
+ "quantity": 5,
}
client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
@@ -218,9 +233,11 @@ class TestErrorHandling:
remove_data = {
"gtin": test_marketplace_product.gtin,
"location": "WAREHOUSE_A",
- "quantity": 10 # More than the 5 we added
+ "quantity": 10, # More than the 5 we added
}
- response = client.post("/api/v1/inventory/remove", headers=auth_headers, json=remove_data)
+ response = client.post(
+ "/api/v1/inventory/remove", headers=auth_headers, json=remove_data
+ )
# This should ALWAYS fail with insufficient inventory error
assert response.status_code == 400
@@ -257,7 +274,7 @@ class TestErrorHandling:
response = client.post(
"/api/v1/vendor",
headers=headers,
- content="TEST"
+ content="TEST",
)
assert response.status_code in [400, 415, 422]
@@ -268,7 +285,7 @@ class TestErrorHandling:
vendor_data = {
"vendor_code": "LARGEVENDOR",
"name": "Large Vendor",
- "description": large_description
+ "description": large_description,
}
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
@@ -310,10 +327,12 @@ class TestErrorHandling:
# Test invalid marketplace
import_data = {
"marketplace": "INVALID_MARKETPLACE",
- "vendor_code": test_vendor.vendor_code
+ "vendor_code": test_vendor.vendor_code,
}
- response = client.post("/api/v1/imports", headers=auth_headers, json=import_data)
+ response = client.post(
+ "/api/v1/imports", headers=auth_headers, json=import_data
+ )
if response.status_code == 422:
data = response.json()
@@ -329,10 +348,12 @@ class TestErrorHandling:
# Test with potentially problematic external data
import_data = {
"marketplace": "LETZSHOP",
- "external_url": "https://nonexistent-marketplace.com/api"
+ "external_url": "https://nonexistent-marketplace.com/api",
}
- response = client.post("/api/v1/imports", headers=auth_headers, json=import_data)
+ response = client.post(
+ "/api/v1/imports", headers=auth_headers, json=import_data
+ )
# If it's a real external service error, check structure
if response.status_code == 502:
@@ -356,7 +377,9 @@ class TestErrorHandling:
# All error responses should have these fields
required_fields = ["error_code", "message", "status_code"]
for field in required_fields:
- assert field in data, f"Missing {field} in error response for {endpoint}"
+ assert (
+ field in data
+ ), f"Missing {field} in error response for {endpoint}"
# Details field should be present (can be empty dict)
assert "details" in data
@@ -420,5 +443,7 @@ class TestErrorRecovery:
# Check that error was logged (if your app logs 404s as errors)
# Adjust based on your logging configuration
- error_logs = [record for record in caplog.records if record.levelno >= logging.ERROR]
+ error_logs = [
+ record for record in caplog.records if record.levelno >= logging.ERROR
+ ]
# May or may not have logs depending on whether 404s are logged as errors
diff --git a/tests/unit/middleware/test_auth.py b/tests/unit/middleware/test_auth.py
index d998b4cb..d78e5b6d 100644
--- a/tests/unit/middleware/test_auth.py
+++ b/tests/unit/middleware/test_auth.py
@@ -12,21 +12,18 @@ Tests cover:
- Error handling and edge cases
"""
-import pytest
-from unittest.mock import Mock, MagicMock, patch
from datetime import datetime, timedelta, timezone
-from jose import jwt
-from fastapi import HTTPException
+from unittest.mock import MagicMock, Mock, patch
+import pytest
+from fastapi import HTTPException
+from jose import jwt
+
+from app.exceptions import (AdminRequiredException,
+ InsufficientPermissionsException,
+ InvalidCredentialsException, InvalidTokenException,
+ TokenExpiredException, UserNotActiveException)
from middleware.auth import AuthManager
-from app.exceptions import (
- InvalidTokenException,
- TokenExpiredException,
- UserNotActiveException,
- InvalidCredentialsException,
- AdminRequiredException,
- InsufficientPermissionsException,
-)
from models.database.user import User
@@ -124,7 +121,9 @@ class TestUserAuthentication:
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
- result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
+ result = auth_manager.authenticate_user(
+ mock_db, "test@example.com", "password123"
+ )
assert result is mock_user
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
token = token_data["access_token"]
# Decode without verification to check payload
- payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
+ payload = jwt.decode(
+ token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
+ )
assert payload["sub"] == "42"
assert payload["username"] == "testuser"
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
"""Test tokens are different for different users."""
auth_manager = AuthManager()
- user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
- user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
+ user1 = Mock(
+ spec=User, id=1, username="user1", email="user1@test.com", role="customer"
+ )
+ user2 = Mock(
+ spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
+ )
token1 = auth_manager.create_access_token(user1)["access_token"]
token2 = auth_manager.create_access_token(user2)["access_token"]
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
payload = jwt.decode(
token_data["access_token"],
auth_manager.secret_key,
- algorithms=[auth_manager.algorithm]
+ algorithms=[auth_manager.algorithm],
)
assert payload["role"] == "admin"
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
# Create token without 'sub' field
payload = {
"username": "testuser",
- "exp": datetime.now(timezone.utc) + timedelta(minutes=30)
+ "exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
- token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
+ token = jwt.encode(
+ payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
+ )
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
auth_manager = AuthManager()
# Create token without 'exp' field
- payload = {
- "sub": "1",
- "username": "testuser"
- }
- token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
+ payload = {"sub": "1", "username": "testuser"}
+ token = jwt.encode(
+ payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
+ )
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
payload = {
"sub": "1",
"username": "testuser",
- "exp": datetime.now(timezone.utc) + timedelta(minutes=30)
+ "exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
# Create token with different algorithm
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
# Create a token with expiration in the past
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
- payload = {
- "sub": "1",
- "username": "testuser",
- "exp": past_time.timestamp()
- }
- token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
+ payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
+ token = jwt.encode(
+ payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
+ )
# Mock jwt.decode to bypass its expiration check and test line 205
- with patch('middleware.auth.jwt.decode') as mock_decode:
+ with patch("middleware.auth.jwt.decode") as mock_decode:
mock_decode.return_value = payload
with pytest.raises(TokenExpiredException):
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
# Existing admin user
existing_admin = Mock(spec=User)
- mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
+ mock_db.query.return_value.filter.return_value.first.return_value = (
+ existing_admin
+ )
result = auth_manager.create_default_admin_user(mock_db)
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
def test_default_configuration(self):
"""Test AuthManager uses default configuration."""
- with patch.dict('os.environ', {}, clear=True):
+ with patch.dict("os.environ", {}, clear=True):
auth_manager = AuthManager()
assert auth_manager.algorithm == "HS256"
assert auth_manager.token_expire_minutes == 30
- assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
+ assert (
+ auth_manager.secret_key == "your-secret-key-change-in-production-please"
+ )
def test_custom_configuration(self):
"""Test AuthManager uses environment variables."""
- with patch.dict('os.environ', {
- 'JWT_SECRET_KEY': 'custom-secret-key',
- 'JWT_EXPIRE_MINUTES': '60'
- }):
+ with patch.dict(
+ "os.environ",
+ {"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
+ ):
auth_manager = AuthManager()
assert auth_manager.secret_key == "custom-secret-key"
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
def test_partial_custom_configuration(self):
"""Test AuthManager with partial environment configuration."""
- with patch.dict('os.environ', {
- 'JWT_EXPIRE_MINUTES': '120'
- }, clear=False):
+ with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
auth_manager = AuthManager()
assert auth_manager.token_expire_minutes == 120
@@ -656,9 +662,11 @@ class TestEdgeCases:
"sub": "1",
"username": "testuser",
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
- "exp": datetime.now(timezone.utc) + timedelta(hours=2)
+ "exp": datetime.now(timezone.utc) + timedelta(hours=2),
}
- token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
+ token = jwt.encode(
+ payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
+ )
# Should still verify successfully (JWT doesn't validate iat by default)
result = auth_manager.verify_token(token)
@@ -698,7 +706,9 @@ class TestEdgeCases:
token = token_data["access_token"]
# Mock jose.jwt.decode to raise an unexpected exception
- with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
+ with patch(
+ "middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
+ ):
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
diff --git a/tests/unit/middleware/test_context.py b/tests/unit/middleware/test_context.py
index 95786c33..d2bf6f1b 100644
--- a/tests/unit/middleware/test_context.py
+++ b/tests/unit/middleware/test_context.py
@@ -10,16 +10,13 @@ Tests cover:
- Edge cases and error handling
"""
+from unittest.mock import AsyncMock, Mock, patch
+
import pytest
-from unittest.mock import Mock, AsyncMock, patch
from fastapi import Request
-from middleware.context import (
- ContextManager,
- ContextMiddleware,
- RequestContext,
- get_request_context,
-)
+from middleware.context import (ContextManager, ContextMiddleware,
+ RequestContext, get_request_context)
@pytest.mark.unit
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
def test_is_admin_context_from_subdomain(self):
"""Test _is_admin_context with admin subdomain."""
request = Mock()
- assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
+ assert (
+ ContextManager._is_admin_context(
+ request, "admin.platform.com", "/dashboard"
+ )
+ is True
+ )
def test_is_admin_context_from_path(self):
"""Test _is_admin_context with admin path."""
request = Mock()
- assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
+ assert (
+ ContextManager._is_admin_context(request, "localhost", "/admin/users")
+ is True
+ )
def test_is_admin_context_both(self):
"""Test _is_admin_context with both subdomain and path."""
request = Mock()
- assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
+ assert (
+ ContextManager._is_admin_context(
+ request, "admin.platform.com", "/admin/users"
+ )
+ is True
+ )
def test_is_not_admin_context(self):
"""Test _is_admin_context returns False for non-admin."""
request = Mock()
- assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
+ assert (
+ ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
+ is False
+ )
def test_is_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context with /vendor/ path."""
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
def test_is_vendor_dashboard_context_nested(self):
"""Test _is_vendor_dashboard_context with nested vendor path."""
- assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
+ assert (
+ ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
+ )
def test_is_not_vendor_dashboard_context_vendors_plural(self):
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
- assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
+ assert (
+ ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
+ is False
+ )
def test_is_not_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
@@ -373,7 +391,7 @@ class TestContextMiddleware:
await middleware.dispatch(request, call_next)
- assert hasattr(request.state, 'context_type')
+ assert hasattr(request.state, "context_type")
assert request.state.context_type == RequestContext.API
call_next.assert_called_once_with(request)
@@ -565,7 +583,7 @@ class TestEdgeCases:
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
# No state attribute at all
- delattr(request, 'state')
+ delattr(request, "state")
# Should still work, falling back to url.path
with pytest.raises(AttributeError):
diff --git a/tests/unit/middleware/test_decorators.py b/tests/unit/middleware/test_decorators.py
index 1b957a90..19d4ed25 100644
--- a/tests/unit/middleware/test_decorators.py
+++ b/tests/unit/middleware/test_decorators.py
@@ -12,16 +12,18 @@ Tests cover:
- Edge cases and isolation
"""
-import pytest
from unittest.mock import Mock
-from middleware.decorators import rate_limit, rate_limiter
-from app.exceptions.base import RateLimitException
+import pytest
+
+from app.exceptions.base import RateLimitException
+from middleware.decorators import rate_limit, rate_limiter
# =============================================================================
# Fixtures
# =============================================================================
+
@pytest.fixture(autouse=True)
def reset_rate_limiter():
"""Reset rate limiter state before each test to ensure isolation."""
@@ -34,6 +36,7 @@ def reset_rate_limiter():
# Rate Limit Decorator Tests
# =============================================================================
+
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimitDecorator:
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_allows_within_limit(self):
"""Test decorator allows requests within rate limit."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
return {"status": "ok"}
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_blocks_exceeding_limit(self):
"""Test decorator blocks requests exceeding rate limit."""
+
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_blocked():
return {"status": "ok"}
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_preserves_function_metadata(self):
"""Test decorator preserves original function metadata."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
"""Test endpoint docstring."""
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_with_args_and_kwargs(self):
"""Test decorator works with function arguments."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint(arg1, arg2, kwarg1=None):
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
result = await test_endpoint("value1", "value2", kwarg1="value3")
- assert result == {
- "arg1": "value1",
- "arg2": "value2",
- "kwarg1": "value3"
- }
+ assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
@pytest.mark.asyncio
async def test_decorator_default_parameters(self):
"""Test decorator uses default parameters."""
+
@rate_limit() # Use defaults
async def test_endpoint():
return {"status": "ok"}
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_exception_includes_retry_after(self):
"""Test rate limit exception includes retry_after."""
+
@rate_limit(max_requests=1, window_seconds=60)
async def test_endpoint_retry():
return {"status": "ok"}
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_zero_max_requests(self):
"""Test decorator with max_requests=0 blocks all requests."""
+
@rate_limit(max_requests=0, window_seconds=3600)
async def test_endpoint_zero():
return {"status": "ok"}
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_very_short_window(self):
"""Test decorator with very short time window."""
+
@rate_limit(max_requests=1, window_seconds=1)
async def test_endpoint_short():
return {"status": "ok"}
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_multiple_functions_separate_limits(self):
"""Test that different functions have separate rate limits."""
+
@rate_limit(max_requests=1, window_seconds=3600)
async def endpoint1():
return {"endpoint": "1"}
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_exception_in_function(self):
"""Test decorator handles exceptions from wrapped function."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint_error():
raise ValueError("Function error")
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_isolation_between_tests(self):
"""Test that rate limiter state is properly isolated between tests."""
+
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_isolation():
return {"status": "ok"}
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_dict(self):
"""Test decorator correctly returns dictionary."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def return_dict():
return {"key": "value", "number": 42}
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_list(self):
"""Test decorator correctly returns list."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def return_list():
return [1, 2, 3, 4, 5]
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_none(self):
"""Test decorator correctly returns None."""
+
@rate_limit(max_requests=10, window_seconds=3600)
async def return_none():
return None
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_object(self):
"""Test decorator correctly returns custom objects."""
+
class TestObject:
def __init__(self):
self.name = "test_object"
diff --git a/tests/unit/middleware/test_logging.py b/tests/unit/middleware/test_logging.py
index 5bb88434..6fda9238 100644
--- a/tests/unit/middleware/test_logging.py
+++ b/tests/unit/middleware/test_logging.py
@@ -11,9 +11,10 @@ Tests cover:
- Edge cases (missing client info, etc.)
"""
-import pytest
import asyncio
-from unittest.mock import Mock, AsyncMock, patch
+from unittest.mock import AsyncMock, Mock, patch
+
+import pytest
from fastapi import Request
from middleware.logging import LoggingMiddleware
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger') as mock_logger:
+ with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify request was logged
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger') as mock_logger:
+ with patch("middleware.logging.logger") as mock_logger:
result = await middleware.dispatch(request, call_next)
# Verify response was logged
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger'):
+ with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
assert "X-Process-Time" in response.headers
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger') as mock_logger:
+ with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Should log "unknown" for client
- assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
+ assert any(
+ "unknown" in str(call) for call in mock_logger.info.call_args_list
+ )
@pytest.mark.asyncio
async def test_middleware_logs_exceptions(self):
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
call_next = AsyncMock(side_effect=Exception("Test error"))
- with patch('middleware.logging.logger') as mock_logger, \
- pytest.raises(Exception):
+ with patch("middleware.logging.logger") as mock_logger, pytest.raises(
+ Exception
+ ):
await middleware.dispatch(request, call_next)
# Verify error was logged
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
call_next = slow_call_next
- with patch('middleware.logging.logger'):
+ with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
process_time = float(result.headers["X-Process-Time"])
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger'):
+ with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
# Should still have process time, even if very small
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger') as mock_logger:
+ with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify method was logged
- assert any(method in str(call) for call in mock_logger.info.call_args_list)
+ assert any(
+ method in str(call) for call in mock_logger.info.call_args_list
+ )
@pytest.mark.asyncio
async def test_middleware_logs_different_status_codes(self):
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
response = Mock(status_code=status_code, headers={})
call_next = AsyncMock(return_value=response)
- with patch('middleware.logging.logger') as mock_logger:
+ with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify status code was logged
- assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
+ assert any(
+ str(status_code) in str(call)
+ for call in mock_logger.info.call_args_list
+ )
diff --git a/tests/unit/middleware/test_rate_limiter.py b/tests/unit/middleware/test_rate_limiter.py
index 94c5418d..c16e2495 100644
--- a/tests/unit/middleware/test_rate_limiter.py
+++ b/tests/unit/middleware/test_rate_limiter.py
@@ -11,10 +11,11 @@ Tests cover:
- Edge cases and concurrency scenarios
"""
-import pytest
-from unittest.mock import Mock, patch
-from datetime import datetime, timedelta, timezone
from collections import deque
+from datetime import datetime, timedelta, timezone
+from unittest.mock import Mock, patch
+
+import pytest
from middleware.rate_limiter import RateLimiter
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
# Add requests at different times
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
- limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
- limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
+ limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
+ limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
stats = limiter.get_client_stats(client_id)
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
limiter = RateLimiter()
client_id = "long_window_client"
- result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
+ result = limiter.allow_request(
+ client_id, max_requests=10, window_seconds=86400 * 365
+ )
assert result is True
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
client_id = "same_client"
# Allow with one limit
- assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
+ assert (
+ limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
+ is True
+ )
# Check with stricter limit
- assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
+ assert (
+ limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
+ is False
+ )
def test_rate_limiter_unicode_client_id(self):
"""Test rate limiter with unicode client ID."""
diff --git a/tests/unit/middleware/test_theme_context.py b/tests/unit/middleware/test_theme_context.py
index 885d902a..2cb9887b 100644
--- a/tests/unit/middleware/test_theme_context.py
+++ b/tests/unit/middleware/test_theme_context.py
@@ -11,15 +11,14 @@ Tests cover:
- Edge cases and error handling
"""
+from unittest.mock import AsyncMock, MagicMock, Mock, patch
+
import pytest
-from unittest.mock import Mock, AsyncMock, MagicMock, patch
from fastapi import Request
-from middleware.theme_context import (
- ThemeContextManager,
- ThemeContextMiddleware,
- get_current_theme,
-)
+from middleware.theme_context import (ThemeContextManager,
+ ThemeContextMiddleware,
+ get_current_theme)
@pytest.mark.unit
@@ -42,7 +41,14 @@ class TestThemeContextManager:
"""Test default theme has all required colors."""
theme = ThemeContextManager.get_default_theme()
- required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
+ required_colors = [
+ "primary",
+ "secondary",
+ "accent",
+ "background",
+ "text",
+ "border",
+ ]
for color in required_colors:
assert color in theme["colors"]
assert theme["colors"][color].startswith("#")
@@ -79,10 +85,7 @@ class TestThemeContextManager:
mock_theme = Mock()
# Mock to_dict to return actual dictionary
- custom_theme_dict = {
- "theme_name": "custom",
- "colors": {"primary": "#ff0000"}
- }
+ custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
mock_theme.to_dict.return_value = custom_theme_dict
# Correct filter chain: query().filter().first()
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
mock_theme = {"theme_name": "test_theme"}
- with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
- patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
+ with patch(
+ "middleware.theme_context.get_db", return_value=iter([mock_db])
+ ), patch.object(
+ ThemeContextManager, "get_vendor_theme", return_value=mock_theme
+ ):
await middleware.dispatch(request, call_next)
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
await middleware.dispatch(request, call_next)
- assert hasattr(request.state, 'theme')
+ assert hasattr(request.state, "theme")
assert request.state.theme["theme_name"] == "default"
call_next.assert_called_once()
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
- with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
- patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
+ with patch(
+ "middleware.theme_context.get_db", return_value=iter([mock_db])
+ ), patch.object(
+ ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
+ ):
await middleware.dispatch(request, call_next)
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
mock_db = MagicMock()
- with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
+ with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
# Verify database was closed
diff --git a/tests/unit/middleware/test_vendor_context.py b/tests/unit/middleware/test_vendor_context.py
index 3071b092..8a3c44b2 100644
--- a/tests/unit/middleware/test_vendor_context.py
+++ b/tests/unit/middleware/test_vendor_context.py
@@ -11,17 +11,16 @@ Tests cover:
- Edge cases and error handling
"""
+from unittest.mock import AsyncMock, MagicMock, Mock, patch
+
import pytest
-from unittest.mock import Mock, MagicMock, patch, AsyncMock
-from fastapi import Request, HTTPException
+from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
-from middleware.vendor_context import (
- VendorContextManager,
- VendorContextMiddleware,
- get_current_vendor,
- require_vendor_context,
-)
+from middleware.vendor_context import (VendorContextManager,
+ VendorContextMiddleware,
+ get_current_vendor,
+ require_vendor_context)
@pytest.mark.unit
@@ -39,7 +38,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -55,7 +54,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com:8000"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -71,7 +70,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -87,7 +86,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com:8000"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -140,7 +139,7 @@ class TestVendorContextManager:
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -153,7 +152,7 @@ class TestVendorContextManager:
request.headers = {"host": "www.platform.com"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -166,7 +165,7 @@ class TestVendorContextManager:
request.headers = {"host": "api.platform.com"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -179,7 +178,7 @@ class TestVendorContextManager:
request.headers = {"host": "localhost"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -198,12 +197,11 @@ class TestVendorContextManager:
mock_vendor.is_active = True
mock_vendor_domain.vendor = mock_vendor
- mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
+ mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
+ mock_vendor_domain
+ )
- context = {
- "detection_method": "custom_domain",
- "domain": "customdomain1.com"
- }
+ context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -218,12 +216,11 @@ class TestVendorContextManager:
mock_vendor.is_active = False
mock_vendor_domain.vendor = mock_vendor
- mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
+ mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
+ mock_vendor_domain
+ )
- context = {
- "detection_method": "custom_domain",
- "domain": "customdomain1.com"
- }
+ context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -232,12 +229,11 @@ class TestVendorContextManager:
def test_get_vendor_from_custom_domain_not_found(self):
"""Test custom domain not found in database."""
mock_db = Mock(spec=Session)
- mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
+ mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
+ None
+ )
- context = {
- "detection_method": "custom_domain",
- "domain": "nonexistent.com"
- }
+ context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -249,12 +245,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
- mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
+ mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
+ mock_vendor
+ )
- context = {
- "detection_method": "subdomain",
- "subdomain": "vendor1"
- }
+ context = {"detection_method": "subdomain", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -266,12 +261,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
- mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
+ mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
+ mock_vendor
+ )
- context = {
- "detection_method": "path",
- "subdomain": "vendor1"
- }
+ context = {"detection_method": "path", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -291,12 +285,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
- mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
+ mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
+ mock_vendor
+ )
- context = {
- "detection_method": "subdomain",
- "subdomain": "VENDOR1" # Uppercase
- }
+ context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -311,10 +304,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1/shop/products")
- vendor_context = {
- "detection_method": "path",
- "path_prefix": "/vendor/vendor1"
- }
+ vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -325,10 +315,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendors/vendor1/shop/products")
- vendor_context = {
- "detection_method": "path",
- "path_prefix": "/vendors/vendor1"
- }
+ vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -339,10 +326,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1")
- vendor_context = {
- "detection_method": "path",
- "path_prefix": "/vendor/vendor1"
- }
+ vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -353,10 +337,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
- vendor_context = {
- "detection_method": "subdomain",
- "subdomain": "vendor1"
- }
+ vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -425,21 +406,24 @@ class TestVendorContextManager:
# Static File Detection Tests
# ========================================================================
- @pytest.mark.parametrize("path", [
- "/static/css/style.css",
- "/static/js/app.js",
- "/media/images/product.png",
- "/assets/logo.svg",
- "/.well-known/security.txt",
- "/favicon.ico",
- "/image.jpg",
- "/style.css",
- "/app.webmanifest",
- "/static/", # Path starting with /static/ but no extension
- "/media/uploads", # Path starting with /media/ but no extension
- "/subfolder/favicon.ico", # favicon.ico in subfolder
- "/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
- ])
+ @pytest.mark.parametrize(
+ "path",
+ [
+ "/static/css/style.css",
+ "/static/js/app.js",
+ "/media/images/product.png",
+ "/assets/logo.svg",
+ "/.well-known/security.txt",
+ "/favicon.ico",
+ "/image.jpg",
+ "/style.css",
+ "/app.webmanifest",
+ "/static/", # Path starting with /static/ but no extension
+ "/media/uploads", # Path starting with /media/ but no extension
+ "/subfolder/favicon.ico", # favicon.ico in subfolder
+ "/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
+ ],
+ )
def test_is_static_file_request(self, path):
"""Test static file detection for various paths and extensions."""
request = Mock(spec=Request)
@@ -447,12 +431,15 @@ class TestVendorContextManager:
assert VendorContextManager.is_static_file_request(request) is True
- @pytest.mark.parametrize("path", [
- "/shop/products",
- "/admin/dashboard",
- "/api/vendors",
- "/about",
- ])
+ @pytest.mark.parametrize(
+ "path",
+ [
+ "/shop/products",
+ "/admin/dashboard",
+ "/api/vendors",
+ "/about",
+ ],
+ )
def test_is_not_static_file_request(self, path):
"""Test non-static file paths."""
request = Mock(spec=Request)
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
- with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
+ with patch.object(VendorContextManager, "is_admin_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
- with patch.object(VendorContextManager, 'is_api_request', return_value=True):
+ with patch.object(VendorContextManager, "is_api_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
- with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
+ with patch.object(
+ VendorContextManager, "is_static_file_request", return_value=True
+ ):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
mock_vendor.name = "Test Vendor"
mock_vendor.subdomain = "vendor1"
- vendor_context = {
- "detection_method": "subdomain",
- "subdomain": "vendor1"
- }
+ vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
mock_db = MagicMock()
- with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
- patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
- patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
- patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
+ with patch.object(
+ VendorContextManager, "detect_vendor_context", return_value=vendor_context
+ ), patch.object(
+ VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
+ ), patch.object(
+ VendorContextManager, "extract_clean_path", return_value="/shop/products"
+ ), patch(
+ "middleware.vendor_context.get_db", return_value=iter([mock_db])
+ ):
await middleware.dispatch(request, call_next)
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
- vendor_context = {
- "detection_method": "subdomain",
- "subdomain": "nonexistent"
- }
+ vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
mock_db = MagicMock()
- with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
- patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
- patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
+ with patch.object(
+ VendorContextManager, "detect_vendor_context", return_value=vendor_context
+ ), patch.object(
+ VendorContextManager, "get_vendor_from_context", return_value=None
+ ), patch(
+ "middleware.vendor_context.get_db", return_value=iter([mock_db])
+ ):
await middleware.dispatch(request, call_next)
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
- with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
+ with patch.object(
+ VendorContextManager, "detect_vendor_context", return_value=None
+ ):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -714,7 +708,7 @@ class TestEdgeCases:
request.headers = {"host": "shop.vendor1.platform.com"}
request.url = Mock(path="/")
- with patch('middleware.vendor_context.settings') as mock_settings:
+ with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -735,11 +729,14 @@ class TestEdgeCases:
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
- with patch('middleware.vendor_context.logger') as mock_logger:
+ with patch("middleware.vendor_context.logger") as mock_logger:
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is None
# Verify warning was logged
mock_logger.warning.assert_called()
warning_message = str(mock_logger.warning.call_args)
- assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
+ assert (
+ "No active vendor found for subdomain" in warning_message
+ and "nonexistent" in warning_message
+ )
diff --git a/tests/unit/models/test_database_models.py b/tests/unit/models/test_database_models.py
index 8cae8dcf..a25997b4 100644
--- a/tests/unit/models/test_database_models.py
+++ b/tests/unit/models/test_database_models.py
@@ -1,16 +1,17 @@
# tests/unit/models/test_database_models.py
-import pytest
from datetime import datetime, timezone
+
+import pytest
from sqlalchemy.exc import IntegrityError
-from models.database.marketplace_product import MarketplaceProduct
-from models.database.vendor import Vendor, VendorUser, Role
-from models.database.inventory import Inventory
-from models.database.user import User
-from models.database.marketplace_import_job import MarketplaceImportJob
-from models.database.product import Product
from models.database.customer import Customer, CustomerAddress
+from models.database.inventory import Inventory
+from models.database.marketplace_import_job import MarketplaceImportJob
+from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order, OrderItem
+from models.database.product import Product
+from models.database.user import User
+from models.database.vendor import Role, Vendor, VendorUser
@pytest.mark.unit
@@ -277,7 +278,7 @@ class TestMarketplaceProductModel:
vendor_id=test_vendor.id,
marketplace_product_id="UNIQUE_001",
title="Product 1",
- marketplace="Letzshop"
+ marketplace="Letzshop",
)
db.add(product1)
db.commit()
@@ -288,7 +289,7 @@ class TestMarketplaceProductModel:
vendor_id=test_vendor.id,
marketplace_product_id="UNIQUE_001",
title="Product 2",
- marketplace="Letzshop"
+ marketplace="Letzshop",
)
db.add(product2)
db.commit()
@@ -515,7 +516,9 @@ class TestCustomerModel:
class TestOrderModel:
"""Test Order model"""
- def test_order_creation(self, db, test_vendor, test_customer, test_customer_address):
+ def test_order_creation(
+ self, db, test_vendor, test_customer, test_customer_address
+ ):
"""Test Order model with customer relationship"""
order = Order(
vendor_id=test_vendor.id,
@@ -563,7 +566,9 @@ class TestOrderModel:
assert float(order_item.unit_price) == 49.99
assert float(order_item.total_price) == 99.98
- def test_order_number_uniqueness(self, db, test_vendor, test_customer, test_customer_address):
+ def test_order_number_uniqueness(
+ self, db, test_vendor, test_customer, test_customer_address
+ ):
"""Test order_number unique constraint"""
order1 = Order(
vendor_id=test_vendor.id,
diff --git a/tests/unit/services/test_admin_service.py b/tests/unit/services/test_admin_service.py
index 01b9374a..000ec495 100644
--- a/tests/unit/services/test_admin_service.py
+++ b/tests/unit/services/test_admin_service.py
@@ -1,14 +1,10 @@
# tests/unit/services/test_admin_service.py
import pytest
-from app.exceptions import (
- UserNotFoundException,
- UserStatusChangeException,
- CannotModifySelfException,
- VendorNotFoundException,
- VendorVerificationException,
- AdminOperationException,
-)
+from app.exceptions import (AdminOperationException, CannotModifySelfException,
+ UserNotFoundException, UserStatusChangeException,
+ VendorNotFoundException,
+ VendorVerificationException)
from app.services.admin_service import AdminService
from app.services.stats_service import stats_service
from models.database.marketplace_import_job import MarketplaceImportJob
@@ -85,7 +81,9 @@ class TestAdminService:
assert exception.error_code == "CANNOT_MODIFY_SELF"
assert "deactivate account" in exception.message
- def test_toggle_user_status_cannot_modify_admin(self, db, test_admin, another_admin):
+ def test_toggle_user_status_cannot_modify_admin(
+ self, db, test_admin, another_admin
+ ):
"""Test that admin cannot modify another admin"""
with pytest.raises(UserStatusChangeException) as exc_info:
self.service.toggle_user_status(db, another_admin.id, test_admin.id)
@@ -148,7 +146,7 @@ class TestAdminService:
assert "99999" in exception.message
def test_toggle_vendor_status_deactivate(self, db, test_vendor):
- """Test deactivating a vendor """
+ """Test deactivating a vendor"""
original_status = test_vendor.is_active
vendor, message = self.service.toggle_vendor_status(db, test_vendor.id)
@@ -170,21 +168,26 @@ class TestAdminService:
assert exception.error_code == "VENDOR_NOT_FOUND"
# Marketplace Import Jobs Tests
- def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs_no_filters(
+ self, db, test_marketplace_import_job
+ ):
"""Test getting marketplace import jobs without filters"""
result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10)
assert len(result) >= 1
# Find our test job in the results
test_job = next(
- (job for job in result if job.job_id == test_marketplace_import_job.id), None
+ (job for job in result if job.job_id == test_marketplace_import_job.id),
+ None,
)
assert test_job is not None
assert test_job.marketplace == test_marketplace_import_job.marketplace
assert test_job.vendor_name == test_marketplace_import_job.name
assert test_job.status == test_marketplace_import_job.status
- def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs_with_marketplace_filter(
+ self, db, test_marketplace_import_job
+ ):
"""Test filtering marketplace import jobs by marketplace"""
result = self.service.get_marketplace_import_jobs(
db, marketplace=test_marketplace_import_job.marketplace, skip=0, limit=10
@@ -192,9 +195,14 @@ class TestAdminService:
assert len(result) >= 1
for job in result:
- assert test_marketplace_import_job.marketplace.lower() in job.marketplace.lower()
+ assert (
+ test_marketplace_import_job.marketplace.lower()
+ in job.marketplace.lower()
+ )
- def test_get_marketplace_import_jobs_with_vendor_filter(self, db, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs_with_vendor_filter(
+ self, db, test_marketplace_import_job
+ ):
"""Test filtering marketplace import jobs by vendor name"""
result = self.service.get_marketplace_import_jobs(
db, vendor_name=test_marketplace_import_job.name, skip=0, limit=10
@@ -204,7 +212,9 @@ class TestAdminService:
for job in result:
assert test_marketplace_import_job.name.lower() in job.vendor_name.lower()
- def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs_with_status_filter(
+ self, db, test_marketplace_import_job
+ ):
"""Test filtering marketplace import jobs by status"""
result = self.service.get_marketplace_import_jobs(
db, status=test_marketplace_import_job.status, skip=0, limit=10
@@ -214,7 +224,9 @@ class TestAdminService:
for job in result:
assert job.status == test_marketplace_import_job.status
- def test_get_marketplace_import_jobs_pagination(self, db, test_marketplace_import_job):
+ def test_get_marketplace_import_jobs_pagination(
+ self, db, test_marketplace_import_job
+ ):
"""Test marketplace import jobs pagination"""
result_page1 = self.service.get_marketplace_import_jobs(db, skip=0, limit=1)
result_page2 = self.service.get_marketplace_import_jobs(db, skip=1, limit=1)
diff --git a/tests/unit/services/test_auth_service.py b/tests/unit/services/test_auth_service.py
index e46809c9..5198e31a 100644
--- a/tests/unit/services/test_auth_service.py
+++ b/tests/unit/services/test_auth_service.py
@@ -1,11 +1,9 @@
# tests/test_auth_service.py
import pytest
-from app.exceptions.auth import (
- UserAlreadyExistsException,
- InvalidCredentialsException,
- UserNotActiveException,
-)
+from app.exceptions.auth import (InvalidCredentialsException,
+ UserAlreadyExistsException,
+ UserNotActiveException)
from app.exceptions.base import ValidationException
from app.services.auth_service import AuthService
from models.schema.auth import UserLogin, UserRegister
@@ -218,11 +216,14 @@ class TestAuthService:
def test_create_access_token_failure(self, test_user, monkeypatch):
"""Test creating access token handles failures"""
+
# Mock the auth_manager to raise an exception
def mock_create_token(*args, **kwargs):
raise Exception("Token creation failed")
- monkeypatch.setattr(self.service.auth_manager, "create_access_token", mock_create_token)
+ monkeypatch.setattr(
+ self.service.auth_manager, "create_access_token", mock_create_token
+ )
with pytest.raises(ValidationException) as exc_info:
self.service.create_access_token(test_user)
@@ -250,11 +251,14 @@ class TestAuthService:
def test_hash_password_failure(self, monkeypatch):
"""Test password hashing handles failures"""
+
# Mock the auth_manager to raise an exception
def mock_hash_password(*args, **kwargs):
raise Exception("Hashing failed")
- monkeypatch.setattr(self.service.auth_manager, "hash_password", mock_hash_password)
+ monkeypatch.setattr(
+ self.service.auth_manager, "hash_password", mock_hash_password
+ )
with pytest.raises(ValidationException) as exc_info:
self.service.hash_password("testpassword")
@@ -267,9 +271,7 @@ class TestAuthService:
def test_register_user_database_error(self, db_with_error):
"""Test user registration handles database errors"""
user_data = UserRegister(
- email="test@example.com",
- username="testuser",
- password="password123"
+ email="test@example.com", username="testuser", password="password123"
)
with pytest.raises(ValidationException) as exc_info:
diff --git a/tests/unit/services/test_inventory_service.py b/tests/unit/services/test_inventory_service.py
index 78942474..d1b38364 100644
--- a/tests/unit/services/test_inventory_service.py
+++ b/tests/unit/services/test_inventory_service.py
@@ -3,19 +3,17 @@ import uuid
import pytest
+from app.exceptions import (InsufficientInventoryException,
+ InvalidInventoryOperationException,
+ InvalidQuantityException,
+ InventoryNotFoundException,
+ InventoryValidationException,
+ NegativeInventoryException, ValidationException)
from app.services.inventory_service import InventoryService
-from app.exceptions import (
- InventoryNotFoundException,
- InsufficientInventoryException,
- InvalidInventoryOperationException,
- InventoryValidationException,
- NegativeInventoryException,
- InvalidQuantityException,
- ValidationException,
-)
-from models.schema.inventory import InventoryAdd, InventoryCreate, InventoryUpdate
-from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
+from models.database.marketplace_product import MarketplaceProduct
+from models.schema.inventory import (InventoryAdd, InventoryCreate,
+ InventoryUpdate)
@pytest.mark.unit
@@ -40,10 +38,14 @@ class TestInventoryService:
def test_normalize_gtin_valid(self):
"""Test GTIN normalization with valid GTINs."""
# Test various valid GTIN formats - these should remain unchanged
- assert self.service._normalize_gtin("1234567890123") == "1234567890123" # EAN-13
+ assert (
+ self.service._normalize_gtin("1234567890123") == "1234567890123"
+ ) # EAN-13
assert self.service._normalize_gtin("123456789012") == "123456789012" # UPC-A
assert self.service._normalize_gtin("12345678") == "12345678" # EAN-8
- assert self.service._normalize_gtin("12345678901234") == "12345678901234" # GTIN-14
+ assert (
+ self.service._normalize_gtin("12345678901234") == "12345678901234"
+ ) # GTIN-14
# Test with decimal points (should be removed)
assert self.service._normalize_gtin("1234567890123.0") == "1234567890123"
@@ -52,11 +54,17 @@ class TestInventoryService:
assert self.service._normalize_gtin(" 1234567890123 ") == "1234567890123"
# Test short GTINs being padded
- assert self.service._normalize_gtin("123") == "0000000000123" # Padded to EAN-13
- assert self.service._normalize_gtin("12345") == "0000000012345" # Padded to EAN-13
+ assert (
+ self.service._normalize_gtin("123") == "0000000000123"
+ ) # Padded to EAN-13
+ assert (
+ self.service._normalize_gtin("12345") == "0000000012345"
+ ) # Padded to EAN-13
# Test long GTINs being truncated
- assert self.service._normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13
+ assert (
+ self.service._normalize_gtin("123456789012345") == "3456789012345"
+ ) # Truncated to 13
def test_normalize_gtin_edge_cases(self):
"""Test GTIN normalization edge cases."""
@@ -65,9 +73,15 @@ class TestInventoryService:
assert self.service._normalize_gtin(123) == "0000000000123"
# Test mixed valid/invalid characters
- assert self.service._normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed
- assert self.service._normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed
- assert self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed
+ assert (
+ self.service._normalize_gtin("123-456-789-012") == "123456789012"
+ ) # Dashes removed
+ assert (
+ self.service._normalize_gtin("123 456 789 012") == "123456789012"
+ ) # Spaces removed
+ assert (
+ self.service._normalize_gtin("ABC123456789012DEF") == "123456789012"
+ ) # Letters removed
def test_set_inventory_new_entry_success(self, db):
"""Test setting inventory for a new GTIN/location combination successfully."""
@@ -162,7 +176,9 @@ class TestInventoryService:
def test_add_inventory_invalid_gtin_validation_error(self, db):
"""Test adding inventory with invalid GTIN returns InventoryValidationException."""
- inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
+ inventory_data = InventoryAdd(
+ gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50
+ )
with pytest.raises(InventoryValidationException) as exc_info:
self.service.add_inventory(db, inventory_data)
@@ -180,11 +196,12 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVALID_QUANTITY"
assert "Quantity must be positive" in str(exc_info.value)
-
def test_remove_inventory_success(self, db, test_inventory):
"""Test removing inventory successfully."""
original_quantity = test_inventory.quantity
- remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available
+ remove_quantity = min(
+ 10, original_quantity
+ ) # Ensure we don't remove more than available
inventory_data = InventoryAdd(
gtin=test_inventory.gtin,
@@ -212,7 +229,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
assert exc_info.value.details["gtin"] == test_inventory.gtin
assert exc_info.value.details["location"] == test_inventory.location
- assert exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
+ assert (
+ exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
+ )
assert exc_info.value.details["available_quantity"] == test_inventory.quantity
def test_remove_inventory_nonexistent_entry_not_found(self, db):
@@ -231,7 +250,9 @@ class TestInventoryService:
def test_remove_inventory_invalid_gtin_validation_error(self, db):
"""Test removing inventory with invalid GTIN returns InventoryValidationException."""
- inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
+ inventory_data = InventoryAdd(
+ gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10
+ )
with pytest.raises(InventoryValidationException) as exc_info:
self.service.remove_inventory(db, inventory_data)
@@ -254,7 +275,9 @@ class TestInventoryService:
# The service prevents negative inventory through InsufficientInventoryException
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
- def test_get_inventory_by_gtin_success(self, db, test_inventory, test_marketplace_product):
+ def test_get_inventory_by_gtin_success(
+ self, db, test_inventory, test_marketplace_product
+ ):
"""Test getting inventory summary by GTIN successfully."""
result = self.service.get_inventory_by_gtin(db, test_inventory.gtin)
@@ -265,14 +288,20 @@ class TestInventoryService:
assert result.locations[0].quantity == test_inventory.quantity
assert result.product_title == test_marketplace_product.title
- def test_get_inventory_by_gtin_multiple_locations_success(self, db, test_marketplace_product):
+ def test_get_inventory_by_gtin_multiple_locations_success(
+ self, db, test_marketplace_product
+ ):
"""Test getting inventory summary with multiple locations successfully."""
unique_gtin = test_marketplace_product.gtin
unique_id = str(uuid.uuid4())[:8]
# Create multiple inventory entries for the same GTIN with unique locations
- inventory1 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50)
- inventory2 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30)
+ inventory1 = Inventory(
+ gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
+ )
+ inventory2 = Inventory(
+ gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
+ )
db.add(inventory1)
db.add(inventory2)
@@ -301,7 +330,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVENTORY_VALIDATION_FAILED"
assert "Invalid GTIN format" in str(exc_info.value)
- def test_get_total_inventory_success(self, db, test_inventory, test_marketplace_product):
+ def test_get_total_inventory_success(
+ self, db, test_inventory, test_marketplace_product
+ ):
"""Test getting total inventory for a GTIN successfully."""
result = self.service.get_total_inventory(db, test_inventory.gtin)
@@ -364,7 +395,9 @@ class TestInventoryService:
result = self.service.get_all_inventory(db, skip=2, limit=2)
- assert len(result) <= 2 # Should be at most 2, might be less if other records exist
+ assert (
+ len(result) <= 2
+ ) # Should be at most 2, might be less if other records exist
def test_update_inventory_success(self, db, test_inventory):
"""Test updating inventory quantity successfully."""
@@ -404,7 +437,9 @@ class TestInventoryService:
assert result is True
# Verify the inventory is actually deleted
- deleted_inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
+ deleted_inventory = (
+ db.query(Inventory).filter(Inventory.id == inventory_id).first()
+ )
assert deleted_inventory is None
def test_delete_inventory_not_found_error(self, db):
@@ -415,7 +450,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVENTORY_NOT_FOUND"
assert "99999" in str(exc_info.value)
- def test_get_low_inventory_items_success(self, db, test_inventory, test_marketplace_product):
+ def test_get_low_inventory_items_success(
+ self, db, test_inventory, test_marketplace_product
+ ):
"""Test getting low inventory items successfully."""
# Set inventory to a low value
test_inventory.quantity = 5
@@ -424,7 +461,9 @@ class TestInventoryService:
result = self.service.get_low_inventory_items(db, threshold=10)
assert len(result) >= 1
- low_inventory_item = next((item for item in result if item["gtin"] == test_inventory.gtin), None)
+ low_inventory_item = next(
+ (item for item in result if item["gtin"] == test_inventory.gtin), None
+ )
assert low_inventory_item is not None
assert low_inventory_item["current_quantity"] == 5
assert low_inventory_item["location"] == test_inventory.location
@@ -440,9 +479,13 @@ class TestInventoryService:
def test_get_inventory_summary_by_location_success(self, db, test_inventory):
"""Test getting inventory summary by location successfully."""
- result = self.service.get_inventory_summary_by_location(db, test_inventory.location)
+ result = self.service.get_inventory_summary_by_location(
+ db, test_inventory.location
+ )
- assert result["location"] == test_inventory.location.upper() # Service normalizes to uppercase
+ assert (
+ result["location"] == test_inventory.location.upper()
+ ) # Service normalizes to uppercase
assert result["total_items"] >= 1
assert result["total_quantity"] >= test_inventory.quantity
assert result["unique_gtins"] >= 1
@@ -450,7 +493,9 @@ class TestInventoryService:
def test_get_inventory_summary_by_location_empty_result(self, db):
"""Test getting inventory summary for location with no inventory."""
unique_id = str(uuid.uuid4())[:8]
- result = self.service.get_inventory_summary_by_location(db, f"EMPTY_LOCATION_{unique_id}")
+ result = self.service.get_inventory_summary_by_location(
+ db, f"EMPTY_LOCATION_{unique_id}"
+ )
assert result["total_items"] == 0
assert result["total_quantity"] == 0
@@ -459,12 +504,16 @@ class TestInventoryService:
def test_validate_quantity_edge_cases(self, db):
"""Test quantity validation with edge cases."""
# Test zero quantity with allow_zero=True (should succeed)
- inventory_data = InventoryCreate(gtin="1234567890123", location="WAREHOUSE_A", quantity=0)
+ inventory_data = InventoryCreate(
+ gtin="1234567890123", location="WAREHOUSE_A", quantity=0
+ )
result = self.service.set_inventory(db, inventory_data)
assert result.quantity == 0
# Test zero quantity with add_inventory (should fail - doesn't allow zero)
- inventory_data_add = InventoryAdd(gtin="1234567890123", location="WAREHOUSE_B", quantity=0)
+ inventory_data_add = InventoryAdd(
+ gtin="1234567890123", location="WAREHOUSE_B", quantity=0
+ )
with pytest.raises(InvalidQuantityException):
self.service.add_inventory(db, inventory_data_add)
@@ -477,10 +526,10 @@ class TestInventoryService:
exception = exc_info.value
# Verify exception structure matches WizamartException.to_dict()
- assert hasattr(exception, 'error_code')
- assert hasattr(exception, 'message')
- assert hasattr(exception, 'status_code')
- assert hasattr(exception, 'details')
+ assert hasattr(exception, "error_code")
+ assert hasattr(exception, "message")
+ assert hasattr(exception, "status_code")
+ assert hasattr(exception, "details")
assert isinstance(exception.error_code, str)
assert isinstance(exception.message, str)
diff --git a/tests/unit/services/test_marketplace_service.py b/tests/unit/services/test_marketplace_service.py
index 5a2151b4..a43d7add 100644
--- a/tests/unit/services/test_marketplace_service.py
+++ b/tests/unit/services/test_marketplace_service.py
@@ -4,19 +4,18 @@ from datetime import datetime
import pytest
-from app.exceptions.marketplace_import_job import (
- ImportJobNotFoundException,
- ImportJobNotOwnedException,
- ImportJobCannotBeCancelledException,
- ImportJobCannotBeDeletedException,
-)
-from app.exceptions.vendor import VendorNotFoundException, UnauthorizedVendorAccessException
from app.exceptions.base import ValidationException
-from app.services.marketplace_import_job_service import MarketplaceImportJobService
-from models.schema.marketplace_import_job import MarketplaceImportJobRequest
+from app.exceptions.marketplace_import_job import (
+ ImportJobCannotBeCancelledException, ImportJobCannotBeDeletedException,
+ ImportJobNotFoundException, ImportJobNotOwnedException)
+from app.exceptions.vendor import (UnauthorizedVendorAccessException,
+ VendorNotFoundException)
+from app.services.marketplace_import_job_service import \
+ MarketplaceImportJobService
from models.database.marketplace_import_job import MarketplaceImportJob
-from models.database.vendor import Vendor
from models.database.user import User
+from models.database.vendor import Vendor
+from models.schema.marketplace_import_job import MarketplaceImportJobRequest
@pytest.mark.unit
@@ -31,7 +30,9 @@ class TestMarketplaceService:
test_vendor.owner_user_id = test_user.id
db.commit()
- result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_user)
+ result = self.service.validate_vendor_access(
+ db, test_vendor.vendor_code, test_user
+ )
assert result.vendor_code == test_vendor.vendor_code
assert result.owner_user_id == test_user.id
@@ -39,8 +40,10 @@ class TestMarketplaceService:
def test_validate_vendor_access_admin_can_access_any_vendor(
self, db, test_vendor, test_admin
):
- """Test that admin users can access any vendor """
- result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_admin)
+ """Test that admin users can access any vendor"""
+ result = self.service.validate_vendor_access(
+ db, test_vendor.vendor_code, test_admin
+ )
assert result.vendor_code == test_vendor.vendor_code
@@ -57,7 +60,7 @@ class TestMarketplaceService:
def test_validate_vendor_access_permission_denied(
self, db, test_vendor, test_user, other_user
):
- """Test vendor access validation when user doesn't own the vendor """
+ """Test vendor access validation when user doesn't own the vendor"""
# Set the vendor owner to a different user
test_vendor.owner_user_id = other_user.id
db.commit()
@@ -93,7 +96,7 @@ class TestMarketplaceService:
assert result.vendor_name == test_vendor.name
def test_create_import_job_invalid_vendor(self, db, test_user):
- """Test import job creation with invalid vendor """
+ """Test import job creation with invalid vendor"""
request = MarketplaceImportJobRequest(
url="https://example.com/products.csv",
marketplace="Amazon",
@@ -108,7 +111,9 @@ class TestMarketplaceService:
assert exception.error_code == "VENDOR_NOT_FOUND"
assert "INVALID_VENDOR" in exception.message
- def test_create_import_job_unauthorized_access(self, db, test_vendor, test_user, other_user):
+ def test_create_import_job_unauthorized_access(
+ self, db, test_vendor, test_user, other_user
+ ):
"""Test import job creation with unauthorized vendor access"""
# Set the vendor owner to a different user
test_vendor.owner_user_id = other_user.id
@@ -127,7 +132,9 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "UNAUTHORIZED_VENDOR_ACCESS"
- def test_get_import_job_by_id_success(self, db, test_marketplace_import_job, test_user):
+ def test_get_import_job_by_id_success(
+ self, db, test_marketplace_import_job, test_user
+ ):
"""Test getting import job by ID for job owner"""
result = self.service.get_import_job_by_id(
db, test_marketplace_import_job.id, test_user
@@ -161,14 +168,18 @@ class TestMarketplaceService:
):
"""Test access denied when user doesn't own the job"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
- self.service.get_import_job_by_id(db, test_marketplace_import_job.id, other_user)
+ self.service.get_import_job_by_id(
+ db, test_marketplace_import_job.id, other_user
+ )
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
assert exception.status_code == 403
assert str(test_marketplace_import_job.id) in exception.message
- def test_get_import_jobs_user_filter(self, db, test_marketplace_import_job, test_user):
+ def test_get_import_jobs_user_filter(
+ self, db, test_marketplace_import_job, test_user
+ ):
"""Test getting import jobs filtered by user"""
jobs = self.service.get_import_jobs(db, test_user)
@@ -176,7 +187,9 @@ class TestMarketplaceService:
assert any(job.id == test_marketplace_import_job.id for job in jobs)
assert test_marketplace_import_job.user_id == test_user.id
- def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_import_job, test_admin):
+ def test_get_import_jobs_admin_sees_all(
+ self, db, test_marketplace_import_job, test_admin
+ ):
"""Test that admin sees all import jobs"""
jobs = self.service.get_import_jobs(db, test_admin)
@@ -192,7 +205,9 @@ class TestMarketplaceService:
)
assert len(jobs) >= 1
- assert any(job.marketplace == test_marketplace_import_job.marketplace for job in jobs)
+ assert any(
+ job.marketplace == test_marketplace_import_job.marketplace for job in jobs
+ )
def test_get_import_jobs_with_pagination(self, db, test_user, test_vendor):
"""Test getting import jobs with pagination"""
@@ -330,10 +345,14 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
- def test_cancel_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
+ def test_cancel_import_job_access_denied(
+ self, db, test_marketplace_import_job, other_user
+ ):
"""Test cancelling import job without access"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
- self.service.cancel_import_job(db, test_marketplace_import_job.id, other_user)
+ self.service.cancel_import_job(
+ db, test_marketplace_import_job.id, other_user
+ )
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
@@ -347,7 +366,9 @@ class TestMarketplaceService:
db.commit()
with pytest.raises(ImportJobCannotBeCancelledException) as exc_info:
- self.service.cancel_import_job(db, test_marketplace_import_job.id, test_user)
+ self.service.cancel_import_job(
+ db, test_marketplace_import_job.id, test_user
+ )
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_CANNOT_BE_CANCELLED"
@@ -396,10 +417,14 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
- def test_delete_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
+ def test_delete_import_job_access_denied(
+ self, db, test_marketplace_import_job, other_user
+ ):
"""Test deleting import job without access"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
- self.service.delete_import_job(db, test_marketplace_import_job.id, other_user)
+ self.service.delete_import_job(
+ db, test_marketplace_import_job.id, other_user
+ )
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
@@ -440,11 +465,15 @@ class TestMarketplaceService:
db.commit()
# Test with lowercase vendor code
- result = self.service.validate_vendor_access(db, test_vendor.vendor_code.lower(), test_user)
+ result = self.service.validate_vendor_access(
+ db, test_vendor.vendor_code.lower(), test_user
+ )
assert result.vendor_code == test_vendor.vendor_code
# Test with uppercase vendor code
- result = self.service.validate_vendor_access(db, test_vendor.vendor_code.upper(), test_user)
+ result = self.service.validate_vendor_access(
+ db, test_vendor.vendor_code.upper(), test_user
+ )
assert result.vendor_code == test_vendor.vendor_code
def test_create_import_job_database_error(self, db_with_error, test_user):
diff --git a/tests/unit/services/test_product_service.py b/tests/unit/services/test_product_service.py
index 1a8f9367..1028b7a1 100644
--- a/tests/unit/services/test_product_service.py
+++ b/tests/unit/services/test_product_service.py
@@ -1,16 +1,15 @@
# tests/test_product_service.py
import pytest
+from app.exceptions import (InvalidMarketplaceProductDataException,
+ MarketplaceProductAlreadyExistsException,
+ MarketplaceProductNotFoundException,
+ MarketplaceProductValidationException,
+ ValidationException)
from app.services.marketplace_product_service import MarketplaceProductService
-from app.exceptions import (
- MarketplaceProductNotFoundException,
- MarketplaceProductAlreadyExistsException,
- InvalidMarketplaceProductDataException,
- MarketplaceProductValidationException,
- ValidationException,
-)
-from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
from models.database.marketplace_product import MarketplaceProduct
+from models.schema.marketplace_product import (MarketplaceProductCreate,
+ MarketplaceProductUpdate)
@pytest.mark.unit
@@ -98,7 +97,10 @@ class TestProductService:
assert exc_info.value.error_code == "PRODUCT_ALREADY_EXISTS"
assert test_marketplace_product.marketplace_product_id in str(exc_info.value)
assert exc_info.value.status_code == 409
- assert exc_info.value.details.get("marketplace_product_id") == test_marketplace_product.marketplace_product_id
+ assert (
+ exc_info.value.details.get("marketplace_product_id")
+ == test_marketplace_product.marketplace_product_id
+ )
def test_create_product_invalid_price(self, db):
"""Test product creation with invalid price raises InvalidMarketplaceProductDataException"""
@@ -117,9 +119,14 @@ class TestProductService:
def test_get_product_by_id_or_raise_success(self, db, test_marketplace_product):
"""Test successful product retrieval by ID"""
- product = self.service.get_product_by_id_or_raise(db, test_marketplace_product.marketplace_product_id)
+ product = self.service.get_product_by_id_or_raise(
+ db, test_marketplace_product.marketplace_product_id
+ )
- assert product.marketplace_product_id == test_marketplace_product.marketplace_product_id
+ assert (
+ product.marketplace_product_id
+ == test_marketplace_product.marketplace_product_id
+ )
assert product.title == test_marketplace_product.title
def test_get_product_by_id_or_raise_not_found(self, db):
@@ -152,21 +159,35 @@ class TestProductService:
assert total >= 1
assert len(products) >= 1
# Verify search worked by checking that title contains search term
- found_product = next((p for p in products if p.marketplace_product_id == test_marketplace_product.marketplace_product_id), None)
+ found_product = next(
+ (
+ p
+ for p in products
+ if p.marketplace_product_id
+ == test_marketplace_product.marketplace_product_id
+ ),
+ None,
+ )
assert found_product is not None
def test_update_product_success(self, db, test_marketplace_product):
"""Test successful product update"""
update_data = MarketplaceProductUpdate(
- title="Updated MarketplaceProduct Title",
- price="39.99"
+ title="Updated MarketplaceProduct Title", price="39.99"
)
- updated_product = self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
+ updated_product = self.service.update_product(
+ db, test_marketplace_product.marketplace_product_id, update_data
+ )
assert updated_product.title == "Updated MarketplaceProduct Title"
- assert updated_product.price == "39.99" # Price is stored as string after processing
- assert updated_product.marketplace_product_id == test_marketplace_product.marketplace_product_id # ID unchanged
+ assert (
+ updated_product.price == "39.99"
+ ) # Price is stored as string after processing
+ assert (
+ updated_product.marketplace_product_id
+ == test_marketplace_product.marketplace_product_id
+ ) # ID unchanged
def test_update_product_not_found(self, db):
"""Test updating non-existent product raises MarketplaceProductNotFoundException"""
@@ -183,7 +204,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(gtin="invalid_gtin")
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
- self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
+ self.service.update_product(
+ db, test_marketplace_product.marketplace_product_id, update_data
+ )
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
assert "Invalid GTIN format" in str(exc_info.value)
@@ -194,7 +217,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(title="")
with pytest.raises(MarketplaceProductValidationException) as exc_info:
- self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
+ self.service.update_product(
+ db, test_marketplace_product.marketplace_product_id, update_data
+ )
assert exc_info.value.error_code == "PRODUCT_VALIDATION_FAILED"
assert "MarketplaceProduct title cannot be empty" in str(exc_info.value)
@@ -205,7 +230,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(price="invalid_price")
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
- self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
+ self.service.update_product(
+ db, test_marketplace_product.marketplace_product_id, update_data
+ )
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
assert "Invalid price format" in str(exc_info.value)
@@ -213,12 +240,16 @@ class TestProductService:
def test_delete_product_success(self, db, test_marketplace_product):
"""Test successful product deletion"""
- result = self.service.delete_product(db, test_marketplace_product.marketplace_product_id)
+ result = self.service.delete_product(
+ db, test_marketplace_product.marketplace_product_id
+ )
assert result is True
# Verify product is deleted
- deleted_product = self.service.get_product_by_id(db, test_marketplace_product.marketplace_product_id)
+ deleted_product = self.service.get_product_by_id(
+ db, test_marketplace_product.marketplace_product_id
+ )
assert deleted_product is None
def test_delete_product_not_found(self, db):
@@ -229,10 +260,14 @@ class TestProductService:
assert exc_info.value.error_code == "PRODUCT_NOT_FOUND"
assert "NONEXISTENT" in str(exc_info.value)
- def test_get_inventory_info_success(self, db, test_marketplace_product_with_inventory):
+ def test_get_inventory_info_success(
+ self, db, test_marketplace_product_with_inventory
+ ):
"""Test getting inventory info for product with inventory"""
# Extract the product from the dictionary
- marketplace_product = test_marketplace_product_with_inventory['marketplace_product']
+ marketplace_product = test_marketplace_product_with_inventory[
+ "marketplace_product"
+ ]
inventory_info = self.service.get_inventory_info(db, marketplace_product.gtin)
@@ -243,13 +278,17 @@ class TestProductService:
def test_get_inventory_info_no_inventory(self, db, test_marketplace_product):
"""Test getting inventory info for product without inventory"""
- inventory_info = self.service.get_inventory_info(db, test_marketplace_product.gtin or "1234567890123")
+ inventory_info = self.service.get_inventory_info(
+ db, test_marketplace_product.gtin or "1234567890123"
+ )
assert inventory_info is None
def test_product_exists_true(self, db, test_marketplace_product):
"""Test product_exists returns True for existing product"""
- exists = self.service.product_exists(db, test_marketplace_product.marketplace_product_id)
+ exists = self.service.product_exists(
+ db, test_marketplace_product.marketplace_product_id
+ )
assert exists is True
def test_product_exists_false(self, db):
@@ -265,7 +304,9 @@ class TestProductService:
csv_lines = list(csv_generator)
assert len(csv_lines) > 1 # Header + at least one data row
- assert csv_lines[0].startswith("marketplace_product_id,title,description") # Check header
+ assert csv_lines[0].startswith(
+ "marketplace_product_id,title,description"
+ ) # Check header
# Check that test product appears in CSV
csv_content = "".join(csv_lines)
@@ -274,8 +315,7 @@ class TestProductService:
def test_generate_csv_export_with_filters(self, db, test_marketplace_product):
"""Test CSV export with marketplace filter"""
csv_generator = self.service.generate_csv_export(
- db,
- marketplace=test_marketplace_product.marketplace
+ db, marketplace=test_marketplace_product.marketplace
)
csv_lines = list(csv_generator)
diff --git a/tests/unit/services/test_stats_service.py b/tests/unit/services/test_stats_service.py
index 55582672..249aedfa 100644
--- a/tests/unit/services/test_stats_service.py
+++ b/tests/unit/services/test_stats_service.py
@@ -2,8 +2,8 @@
import pytest
from app.services.stats_service import StatsService
-from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
+from models.database.marketplace_product import MarketplaceProduct
@pytest.mark.unit
@@ -15,7 +15,9 @@ class TestStatsService:
"""Setup method following the same pattern as other service tests"""
self.service = StatsService()
- def test_get_comprehensive_stats_basic(self, db, test_marketplace_product, test_inventory):
+ def test_get_comprehensive_stats_basic(
+ self, db, test_marketplace_product, test_inventory
+ ):
"""Test getting comprehensive stats with basic data"""
stats = self.service.get_comprehensive_stats(db)
@@ -31,7 +33,9 @@ class TestStatsService:
assert stats["total_inventory_entries"] >= 1
assert stats["total_inventory_quantity"] >= 10 # test_inventory has quantity 10
- def test_get_comprehensive_stats_multiple_products(self, db, test_marketplace_product):
+ def test_get_comprehensive_stats_multiple_products(
+ self, db, test_marketplace_product
+ ):
"""Test comprehensive stats with multiple products across different dimensions"""
# Create products with different brands, categories, marketplaces
additional_products = [
@@ -87,7 +91,7 @@ class TestStatsService:
brand=None, # Null brand
google_product_category=None, # Null category
marketplace=None, # Null marketplace
- vendor_name=None, # Null vendor
+ vendor_name=None, # Null vendor
price="10.00",
currency="EUR",
),
@@ -97,7 +101,7 @@ class TestStatsService:
brand="", # Empty brand
google_product_category="", # Empty category
marketplace="", # Empty marketplace
- vendor_name="", # Empty vendor
+ vendor_name="", # Empty vendor
price="15.00",
currency="EUR",
),
@@ -124,7 +128,11 @@ class TestStatsService:
# Find our test marketplace in the results
test_marketplace_stat = next(
- (stat for stat in stats if stat["marketplace"] == test_marketplace_product.marketplace),
+ (
+ stat
+ for stat in stats
+ if stat["marketplace"] == test_marketplace_product.marketplace
+ ),
None,
)
assert test_marketplace_stat is not None
@@ -309,7 +317,9 @@ class TestStatsService:
count = self.service._get_unique_marketplaces_count(db)
- assert count >= 2 # At least Amazon and eBay, plus test_marketplace_product marketplace
+ assert (
+ count >= 2
+ ) # At least Amazon and eBay, plus test_marketplace_product marketplace
assert isinstance(count, int)
def test_get_unique_vendors_count(self, db, test_marketplace_product):
@@ -338,7 +348,9 @@ class TestStatsService:
count = self.service._get_unique_vendors_count(db)
- assert count >= 2 # At least VendorA and VendorB, plus test_marketplace_product vendor
+ assert (
+ count >= 2
+ ) # At least VendorA and VendorB, plus test_marketplace_product vendor
assert isinstance(count, int)
def test_get_inventory_statistics(self, db, test_inventory):
@@ -438,7 +450,7 @@ class TestStatsService:
db.add_all(marketplace_products)
db.commit()
- vendors =self.service._get_vendors_by_marketplace(db, "TestMarketplace")
+ vendors = self.service._get_vendors_by_marketplace(db, "TestMarketplace")
assert len(vendors) == 2
assert "TestVendor1" in vendors
@@ -482,7 +494,9 @@ class TestStatsService:
def test_get_products_by_marketplace_not_found(self, db):
"""Test getting product count for non-existent marketplace"""
- count = self.service._get_products_by_marketplace_count(db, "NonExistentMarketplace")
+ count = self.service._get_products_by_marketplace_count(
+ db, "NonExistentMarketplace"
+ )
assert count == 0
diff --git a/tests/unit/services/test_vendor_service.py b/tests/unit/services/test_vendor_service.py
index c0a1bc1a..286ea923 100644
--- a/tests/unit/services/test_vendor_service.py
+++ b/tests/unit/services/test_vendor_service.py
@@ -1,19 +1,16 @@
# tests/test_vendor_service.py (updated to use custom exceptions)
import pytest
+from app.exceptions import (InvalidVendorDataException,
+ MarketplaceProductNotFoundException,
+ MaxVendorsReachedException,
+ ProductAlreadyExistsException,
+ UnauthorizedVendorAccessException,
+ ValidationException, VendorAlreadyExistsException,
+ VendorNotFoundException)
from app.services.vendor_service import VendorService
-from app.exceptions import (
- VendorNotFoundException,
- VendorAlreadyExistsException,
- UnauthorizedVendorAccessException,
- InvalidVendorDataException,
- MarketplaceProductNotFoundException,
- ProductAlreadyExistsException,
- MaxVendorsReachedException,
- ValidationException,
-)
-from models.schema.vendor import VendorCreate
from models.schema.product import ProductCreate
+from models.schema.vendor import VendorCreate
@pytest.mark.unit
@@ -38,15 +35,17 @@ class TestVendorService:
assert vendor is not None
assert vendor.vendor_code == "NEWVENDOR"
assert vendor.owner_user_id == test_user.id
- assert vendor.is_verified is False # Regular user creates unverified vendor
+ assert vendor.is_verified is False # Regular user creates unverified vendor
def test_create_vendor_admin_auto_verify(self, db, test_admin, vendor_factory):
"""Test admin creates verified vendor automatically"""
- vendor_data = VendorCreate(vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor")
+ vendor_data = VendorCreate(
+ vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor"
+ )
vendor = self.service.create_vendor(db, vendor_data, test_admin)
- assert vendor.is_verified is True # Admin creates verified vendor
+ assert vendor.is_verified is True # Admin creates verified vendor
def test_create_vendor_duplicate_code(self, db, test_user, test_vendor):
"""Test vendor creation fails with duplicate vendor code"""
@@ -88,7 +87,9 @@ class TestVendorService:
def test_create_vendor_invalid_code_format(self, db, test_user):
"""Test vendor creation fails with invalid vendor code format"""
- vendor_data = VendorCreate(vendor_code="INVALID@CODE!", vendor_name="Test Vendor")
+ vendor_data = VendorCreate(
+ vendor_code="INVALID@CODE!", vendor_name="Test Vendor"
+ )
with pytest.raises(InvalidVendorDataException) as exc_info:
self.service.create_vendor(db, vendor_data, test_user)
@@ -105,7 +106,9 @@ class TestVendorService:
def mock_check_vendor_limit(self, db, user):
raise MaxVendorsReachedException(max_vendors=5, user_id=user.id)
- monkeypatch.setattr(VendorService, "_check_vendor_limit", mock_check_vendor_limit)
+ monkeypatch.setattr(
+ VendorService, "_check_vendor_limit", mock_check_vendor_limit
+ )
vendor_data = VendorCreate(vendor_code="NEWVENDOR", vendor_name="New Vendor")
@@ -118,7 +121,9 @@ class TestVendorService:
assert exception.details["max_vendors"] == 5
assert exception.details["user_id"] == test_user.id
- def test_get_vendors_regular_user(self, db, test_user, test_vendor, inactive_vendor):
+ def test_get_vendors_regular_user(
+ self, db, test_user, test_vendor, inactive_vendor
+ ):
"""Test regular user can only see active verified vendors and own vendors"""
vendors, total = self.service.get_vendors(db, test_user, skip=0, limit=10)
@@ -127,7 +132,7 @@ class TestVendorService:
assert inactive_vendor.vendor_code not in vendor_codes
def test_get_vendors_admin_user(
- self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
+ self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
):
"""Test admin user can see all vendors with filters"""
vendors, total = self.service.get_vendors(
@@ -140,14 +145,16 @@ class TestVendorService:
assert verified_vendor.vendor_code in vendor_codes
def test_get_vendor_by_code_owner_access(self, db, test_user, test_vendor):
- """Test vendor owner can access their own vendor """
- vendor = self.service.get_vendor_by_code(db, test_vendor.vendor_code.lower(), test_user)
+ """Test vendor owner can access their own vendor"""
+ vendor = self.service.get_vendor_by_code(
+ db, test_vendor.vendor_code.lower(), test_user
+ )
assert vendor is not None
assert vendor.id == test_vendor.id
def test_get_vendor_by_code_admin_access(self, db, test_admin, test_vendor):
- """Test admin can access any vendor """
+ """Test admin can access any vendor"""
vendor = self.service.get_vendor_by_code(
db, test_vendor.vendor_code.lower(), test_admin
)
@@ -178,16 +185,14 @@ class TestVendorService:
assert exception.details["user_id"] == test_user.id
def test_add_product_to_vendor_success(self, db, test_vendor, unique_product):
- """Test successfully adding product to vendor """
+ """Test successfully adding product to vendor"""
product_data = ProductCreate(
marketplace_product_id=unique_product.marketplace_product_id,
price="15.99",
is_featured=True,
)
- product = self.service.add_product_to_catalog(
- db, test_vendor, product_data
- )
+ product = self.service.add_product_to_catalog(db, test_vendor, product_data)
assert product is not None
assert product.vendor_id == test_vendor.id
@@ -195,7 +200,9 @@ class TestVendorService:
def test_add_product_to_vendor_product_not_found(self, db, test_vendor):
"""Test adding non-existent product to vendor fails"""
- product_data = ProductCreate(marketplace_product_id="NONEXISTENT", price="15.99")
+ product_data = ProductCreate(
+ marketplace_product_id="NONEXISTENT", price="15.99"
+ )
with pytest.raises(MarketplaceProductNotFoundException) as exc_info:
self.service.add_product_to_catalog(db, test_vendor, product_data)
@@ -209,7 +216,8 @@ class TestVendorService:
def test_add_product_to_vendor_already_exists(self, db, test_vendor, test_product):
"""Test adding product that's already in vendor fails"""
product_data = ProductCreate(
- marketplace_product_id=test_product.marketplace_product.marketplace_product_id, price="15.99"
+ marketplace_product_id=test_product.marketplace_product.marketplace_product_id,
+ price="15.99",
)
with pytest.raises(ProductAlreadyExistsException) as exc_info:
@@ -219,11 +227,12 @@ class TestVendorService:
assert exception.status_code == 409
assert exception.error_code == "PRODUCT_ALREADY_EXISTS"
assert exception.details["vendor_code"] == test_vendor.vendor_code
- assert exception.details["marketplace_product_id"] == test_product.marketplace_product.marketplace_product_id
+ assert (
+ exception.details["marketplace_product_id"]
+ == test_product.marketplace_product.marketplace_product_id
+ )
- def test_get_products_owner_access(
- self, db, test_user, test_vendor, test_product
- ):
+ def test_get_products_owner_access(self, db, test_user, test_vendor, test_product):
"""Test vendor owner can get vendor products"""
products, total = self.service.get_products(db, test_vendor, test_user)
@@ -291,7 +300,9 @@ class TestVendorService:
assert exception.error_code == "VALIDATION_ERROR"
assert "Failed to retrieve vendors" in exception.message
- def test_add_product_database_error(self, db, test_vendor, unique_product, monkeypatch):
+ def test_add_product_database_error(
+ self, db, test_vendor, unique_product, monkeypatch
+ ):
"""Test add product handles database errors gracefully"""
def mock_commit():
diff --git a/tests/unit/utils/test_csv_processor.py b/tests/unit/utils/test_csv_processor.py
index 4f35ee16..85615a10 100644
--- a/tests/unit/utils/test_csv_processor.py
+++ b/tests/unit/utils/test_csv_processor.py
@@ -18,7 +18,9 @@ class TestCSVProcessor:
def test_download_csv_encoding_fallback(self, mock_get):
"""Test CSV download with encoding fallback"""
# Create content with special characters that would fail UTF-8 if not properly encoded
- special_content = "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
+ special_content = (
+ "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
+ )
mock_response = Mock()
mock_response.status_code = 200
@@ -40,9 +42,7 @@ class TestCSVProcessor:
mock_response = Mock()
mock_response.status_code = 200
# Create bytes that will fail most encodings
- mock_response.content = (
- b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
- )
+ mock_response.content = b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response