style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -176,19 +176,19 @@ test-inventory:
format: format:
@echo "Running black..." @echo "Running black..."
$(PYTHON) -m black . --exclude venv $(PYTHON) -m black . --exclude '/(\.)?venv/'
@echo "Running isort..." @echo "Running isort..."
$(PYTHON) -m isort . --skip venv $(PYTHON) -m isort . --skip venv --skip .venv
lint: lint:
@echo "Running linting..." @echo "Running linting..."
$(PYTHON) -m ruff check . --exclude venv $(PYTHON) -m ruff check . --exclude venv --exclude .venv
$(PYTHON) -m mypy . --ignore-missing-imports --exclude venv $(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*'
lint-flake8: lint-flake8:
@echo "Running linting..." @echo "Running linting..."
$(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,__pycache__,.git $(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,.venv,__pycache__,.git
$(PYTHON) -m mypy . --ignore-missing-imports --exclude venv $(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*'
check: format lint check: format lint

View File

@@ -15,6 +15,7 @@ import sys
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config, pool
from alembic import context from alembic import context
# Add your project directory to the Python path # Add your project directory to the Python path
@@ -39,13 +40,9 @@ print("=" * 70)
# ADMIN MODELS # ADMIN MODELS
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
try: try:
from models.database.admin import ( from models.database.admin import (AdminAuditLog, AdminNotification,
AdminAuditLog, AdminSession, AdminSetting,
AdminNotification, PlatformAlert)
AdminSetting,
PlatformAlert,
AdminSession
)
print(" ✓ Admin models imported (5 models)") print(" ✓ Admin models imported (5 models)")
print(" - AdminAuditLog") print(" - AdminAuditLog")
@@ -70,7 +67,7 @@ except ImportError as e:
# VENDOR MODELS # VENDOR MODELS
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
try: try:
from models.database.vendor import Vendor, VendorUser, Role from models.database.vendor import Role, Vendor, VendorUser
print(" ✓ Vendor models imported (3 models)") print(" ✓ Vendor models imported (3 models)")
print(" - Vendor") print(" - Vendor")
@@ -248,10 +245,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection,
target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

File diff suppressed because it is too large Load Diff

View File

@@ -5,59 +5,72 @@ Revises: fef1d20ce8b4
Create Date: 2025-11-22 15:16:13.213613 Create Date: 2025-11-22 15:16:13.213613
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '72aa309d4007' revision: str = "72aa309d4007"
down_revision: Union[str, None] = 'fef1d20ce8b4' down_revision: Union[str, None] = "fef1d20ce8b4"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('content_pages', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "content_pages",
sa.Column('vendor_id', sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('slug', sa.String(length=100), nullable=False), sa.Column("vendor_id", sa.Integer(), nullable=True),
sa.Column('title', sa.String(length=200), nullable=False), sa.Column("slug", sa.String(length=100), nullable=False),
sa.Column('content', sa.Text(), nullable=False), sa.Column("title", sa.String(length=200), nullable=False),
sa.Column('content_format', sa.String(length=20), nullable=True), sa.Column("content", sa.Text(), nullable=False),
sa.Column('meta_description', sa.String(length=300), nullable=True), sa.Column("content_format", sa.String(length=20), nullable=True),
sa.Column('meta_keywords', sa.String(length=300), nullable=True), sa.Column("meta_description", sa.String(length=300), nullable=True),
sa.Column('is_published', sa.Boolean(), nullable=False), sa.Column("meta_keywords", sa.String(length=300), nullable=True),
sa.Column('published_at', sa.DateTime(timezone=True), nullable=True), sa.Column("is_published", sa.Boolean(), nullable=False),
sa.Column('display_order', sa.Integer(), nullable=True), sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
sa.Column('show_in_footer', sa.Boolean(), nullable=True), sa.Column("display_order", sa.Integer(), nullable=True),
sa.Column('show_in_header', sa.Boolean(), nullable=True), sa.Column("show_in_footer", sa.Boolean(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("show_in_header", sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_by', sa.Integer(), nullable=True), sa.Column("created_by", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='SET NULL'), sa.Column("updated_by", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['updated_by'], ['users.id'], ondelete='SET NULL'), sa.ForeignKeyConstraint(["created_by"], ["users.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(["updated_by"], ["users.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint('id'), sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"),
sa.UniqueConstraint('vendor_id', 'slug', name='uq_vendor_slug') sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("vendor_id", "slug", name="uq_vendor_slug"),
)
op.create_index(
"idx_slug_published", "content_pages", ["slug", "is_published"], unique=False
)
op.create_index(
"idx_vendor_published",
"content_pages",
["vendor_id", "is_published"],
unique=False,
)
op.create_index(op.f("ix_content_pages_id"), "content_pages", ["id"], unique=False)
op.create_index(
op.f("ix_content_pages_slug"), "content_pages", ["slug"], unique=False
)
op.create_index(
op.f("ix_content_pages_vendor_id"), "content_pages", ["vendor_id"], unique=False
) )
op.create_index('idx_slug_published', 'content_pages', ['slug', 'is_published'], unique=False)
op.create_index('idx_vendor_published', 'content_pages', ['vendor_id', 'is_published'], unique=False)
op.create_index(op.f('ix_content_pages_id'), 'content_pages', ['id'], unique=False)
op.create_index(op.f('ix_content_pages_slug'), 'content_pages', ['slug'], unique=False)
op.create_index(op.f('ix_content_pages_vendor_id'), 'content_pages', ['vendor_id'], unique=False)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_content_pages_vendor_id'), table_name='content_pages') op.drop_index(op.f("ix_content_pages_vendor_id"), table_name="content_pages")
op.drop_index(op.f('ix_content_pages_slug'), table_name='content_pages') op.drop_index(op.f("ix_content_pages_slug"), table_name="content_pages")
op.drop_index(op.f('ix_content_pages_id'), table_name='content_pages') op.drop_index(op.f("ix_content_pages_id"), table_name="content_pages")
op.drop_index('idx_vendor_published', table_name='content_pages') op.drop_index("idx_vendor_published", table_name="content_pages")
op.drop_index('idx_slug_published', table_name='content_pages') op.drop_index("idx_slug_published", table_name="content_pages")
op.drop_table('content_pages') op.drop_table("content_pages")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -5,15 +5,17 @@ Revises: a2064e1dfcd4
Create Date: 2025-11-28 09:21:16.545203 Create Date: 2025-11-28 09:21:16.545203
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '7a7ce92593d5' revision: str = "7a7ce92593d5"
down_revision: Union[str, None] = 'a2064e1dfcd4' down_revision: Union[str, None] = "a2064e1dfcd4"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -21,127 +23,269 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Create architecture_scans table # Create architecture_scans table
op.create_table( op.create_table(
'architecture_scans', "architecture_scans",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), sa.Column(
sa.Column('total_files', sa.Integer(), nullable=True), "timestamp",
sa.Column('total_violations', sa.Integer(), nullable=True), sa.DateTime(timezone=True),
sa.Column('errors', sa.Integer(), nullable=True), server_default=sa.text("(datetime('now'))"),
sa.Column('warnings', sa.Integer(), nullable=True), nullable=False,
sa.Column('duration_seconds', sa.Float(), nullable=True), ),
sa.Column('triggered_by', sa.String(length=100), nullable=True), sa.Column("total_files", sa.Integer(), nullable=True),
sa.Column('git_commit_hash', sa.String(length=40), nullable=True), sa.Column("total_violations", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("errors", sa.Integer(), nullable=True),
sa.Column("warnings", sa.Integer(), nullable=True),
sa.Column("duration_seconds", sa.Float(), nullable=True),
sa.Column("triggered_by", sa.String(length=100), nullable=True),
sa.Column("git_commit_hash", sa.String(length=40), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_architecture_scans_id"), "architecture_scans", ["id"], unique=False
)
op.create_index(
op.f("ix_architecture_scans_timestamp"),
"architecture_scans",
["timestamp"],
unique=False,
) )
op.create_index(op.f('ix_architecture_scans_id'), 'architecture_scans', ['id'], unique=False)
op.create_index(op.f('ix_architecture_scans_timestamp'), 'architecture_scans', ['timestamp'], unique=False)
# Create architecture_rules table # Create architecture_rules table
op.create_table( op.create_table(
'architecture_rules', "architecture_rules",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('rule_id', sa.String(length=20), nullable=False), sa.Column("rule_id", sa.String(length=20), nullable=False),
sa.Column('category', sa.String(length=50), nullable=False), sa.Column("category", sa.String(length=50), nullable=False),
sa.Column('name', sa.String(length=200), nullable=False), sa.Column("name", sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True), sa.Column("description", sa.Text(), nullable=True),
sa.Column('severity', sa.String(length=10), nullable=False), sa.Column("severity", sa.String(length=10), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False, server_default='1'), sa.Column("enabled", sa.Boolean(), nullable=False, server_default="1"),
sa.Column('custom_config', sa.JSON(), nullable=True), sa.Column("custom_config", sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), sa.Column(
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), "created_at",
sa.PrimaryKeyConstraint('id'), sa.DateTime(timezone=True),
sa.UniqueConstraint('rule_id') server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("rule_id"),
)
op.create_index(
op.f("ix_architecture_rules_id"), "architecture_rules", ["id"], unique=False
)
op.create_index(
op.f("ix_architecture_rules_rule_id"),
"architecture_rules",
["rule_id"],
unique=True,
) )
op.create_index(op.f('ix_architecture_rules_id'), 'architecture_rules', ['id'], unique=False)
op.create_index(op.f('ix_architecture_rules_rule_id'), 'architecture_rules', ['rule_id'], unique=True)
# Create architecture_violations table # Create architecture_violations table
op.create_table( op.create_table(
'architecture_violations', "architecture_violations",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('scan_id', sa.Integer(), nullable=False), sa.Column("scan_id", sa.Integer(), nullable=False),
sa.Column('rule_id', sa.String(length=20), nullable=False), sa.Column("rule_id", sa.String(length=20), nullable=False),
sa.Column('rule_name', sa.String(length=200), nullable=False), sa.Column("rule_name", sa.String(length=200), nullable=False),
sa.Column('severity', sa.String(length=10), nullable=False), sa.Column("severity", sa.String(length=10), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False), sa.Column("file_path", sa.String(length=500), nullable=False),
sa.Column('line_number', sa.Integer(), nullable=False), sa.Column("line_number", sa.Integer(), nullable=False),
sa.Column('message', sa.Text(), nullable=False), sa.Column("message", sa.Text(), nullable=False),
sa.Column('context', sa.Text(), nullable=True), sa.Column("context", sa.Text(), nullable=True),
sa.Column('suggestion', sa.Text(), nullable=True), sa.Column("suggestion", sa.Text(), nullable=True),
sa.Column('status', sa.String(length=20), server_default='open', nullable=True), sa.Column("status", sa.String(length=20), server_default="open", nullable=True),
sa.Column('assigned_to', sa.Integer(), nullable=True), sa.Column("assigned_to", sa.Integer(), nullable=True),
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True), sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
sa.Column('resolved_by', sa.Integer(), nullable=True), sa.Column("resolved_by", sa.Integer(), nullable=True),
sa.Column('resolution_note', sa.Text(), nullable=True), sa.Column("resolution_note", sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), sa.Column(
sa.ForeignKeyConstraint(['assigned_to'], ['users.id'], ), "created_at",
sa.ForeignKeyConstraint(['resolved_by'], ['users.id'], ), sa.DateTime(timezone=True),
sa.ForeignKeyConstraint(['scan_id'], ['architecture_scans.id'], ), server_default=sa.text("(datetime('now'))"),
sa.PrimaryKeyConstraint('id') nullable=False,
),
sa.ForeignKeyConstraint(
["assigned_to"],
["users.id"],
),
sa.ForeignKeyConstraint(
["resolved_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["scan_id"],
["architecture_scans.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_architecture_violations_file_path"),
"architecture_violations",
["file_path"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_id"),
"architecture_violations",
["id"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_rule_id"),
"architecture_violations",
["rule_id"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_scan_id"),
"architecture_violations",
["scan_id"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_severity"),
"architecture_violations",
["severity"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_status"),
"architecture_violations",
["status"],
unique=False,
) )
op.create_index(op.f('ix_architecture_violations_file_path'), 'architecture_violations', ['file_path'], unique=False)
op.create_index(op.f('ix_architecture_violations_id'), 'architecture_violations', ['id'], unique=False)
op.create_index(op.f('ix_architecture_violations_rule_id'), 'architecture_violations', ['rule_id'], unique=False)
op.create_index(op.f('ix_architecture_violations_scan_id'), 'architecture_violations', ['scan_id'], unique=False)
op.create_index(op.f('ix_architecture_violations_severity'), 'architecture_violations', ['severity'], unique=False)
op.create_index(op.f('ix_architecture_violations_status'), 'architecture_violations', ['status'], unique=False)
# Create violation_assignments table # Create violation_assignments table
op.create_table( op.create_table(
'violation_assignments', "violation_assignments",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('violation_id', sa.Integer(), nullable=False), sa.Column("violation_id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column('assigned_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), sa.Column(
sa.Column('assigned_by', sa.Integer(), nullable=True), "assigned_at",
sa.Column('due_date', sa.DateTime(timezone=True), nullable=True), sa.DateTime(timezone=True),
sa.Column('priority', sa.String(length=10), server_default='medium', nullable=True), server_default=sa.text("(datetime('now'))"),
sa.ForeignKeyConstraint(['assigned_by'], ['users.id'], ), nullable=False,
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), ),
sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ), sa.Column("assigned_by", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("due_date", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"priority", sa.String(length=10), server_default="medium", nullable=True
),
sa.ForeignKeyConstraint(
["assigned_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.ForeignKeyConstraint(
["violation_id"],
["architecture_violations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_violation_assignments_id"),
"violation_assignments",
["id"],
unique=False,
)
op.create_index(
op.f("ix_violation_assignments_violation_id"),
"violation_assignments",
["violation_id"],
unique=False,
) )
op.create_index(op.f('ix_violation_assignments_id'), 'violation_assignments', ['id'], unique=False)
op.create_index(op.f('ix_violation_assignments_violation_id'), 'violation_assignments', ['violation_id'], unique=False)
# Create violation_comments table # Create violation_comments table
op.create_table( op.create_table(
'violation_comments', "violation_comments",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('violation_id', sa.Integer(), nullable=False), sa.Column("violation_id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column('comment', sa.Text(), nullable=False), sa.Column("comment", sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False), sa.Column(
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), "created_at",
sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ), sa.DateTime(timezone=True),
sa.PrimaryKeyConstraint('id') server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.ForeignKeyConstraint(
["violation_id"],
["architecture_violations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_violation_comments_id"), "violation_comments", ["id"], unique=False
)
op.create_index(
op.f("ix_violation_comments_violation_id"),
"violation_comments",
["violation_id"],
unique=False,
) )
op.create_index(op.f('ix_violation_comments_id'), 'violation_comments', ['id'], unique=False)
op.create_index(op.f('ix_violation_comments_violation_id'), 'violation_comments', ['violation_id'], unique=False)
def downgrade() -> None: def downgrade() -> None:
# Drop tables in reverse order (to respect foreign key constraints) # Drop tables in reverse order (to respect foreign key constraints)
op.drop_index(op.f('ix_violation_comments_violation_id'), table_name='violation_comments') op.drop_index(
op.drop_index(op.f('ix_violation_comments_id'), table_name='violation_comments') op.f("ix_violation_comments_violation_id"), table_name="violation_comments"
op.drop_table('violation_comments') )
op.drop_index(op.f("ix_violation_comments_id"), table_name="violation_comments")
op.drop_table("violation_comments")
op.drop_index(op.f('ix_violation_assignments_violation_id'), table_name='violation_assignments') op.drop_index(
op.drop_index(op.f('ix_violation_assignments_id'), table_name='violation_assignments') op.f("ix_violation_assignments_violation_id"),
op.drop_table('violation_assignments') table_name="violation_assignments",
)
op.drop_index(
op.f("ix_violation_assignments_id"), table_name="violation_assignments"
)
op.drop_table("violation_assignments")
op.drop_index(op.f('ix_architecture_violations_status'), table_name='architecture_violations') op.drop_index(
op.drop_index(op.f('ix_architecture_violations_severity'), table_name='architecture_violations') op.f("ix_architecture_violations_status"), table_name="architecture_violations"
op.drop_index(op.f('ix_architecture_violations_scan_id'), table_name='architecture_violations') )
op.drop_index(op.f('ix_architecture_violations_rule_id'), table_name='architecture_violations') op.drop_index(
op.drop_index(op.f('ix_architecture_violations_id'), table_name='architecture_violations') op.f("ix_architecture_violations_severity"),
op.drop_index(op.f('ix_architecture_violations_file_path'), table_name='architecture_violations') table_name="architecture_violations",
op.drop_table('architecture_violations') )
op.drop_index(
op.f("ix_architecture_violations_scan_id"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_rule_id"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_id"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_file_path"),
table_name="architecture_violations",
)
op.drop_table("architecture_violations")
op.drop_index(op.f('ix_architecture_rules_rule_id'), table_name='architecture_rules') op.drop_index(
op.drop_index(op.f('ix_architecture_rules_id'), table_name='architecture_rules') op.f("ix_architecture_rules_rule_id"), table_name="architecture_rules"
op.drop_table('architecture_rules') )
op.drop_index(op.f("ix_architecture_rules_id"), table_name="architecture_rules")
op.drop_table("architecture_rules")
op.drop_index(op.f('ix_architecture_scans_timestamp'), table_name='architecture_scans') op.drop_index(
op.drop_index(op.f('ix_architecture_scans_id'), table_name='architecture_scans') op.f("ix_architecture_scans_timestamp"), table_name="architecture_scans"
op.drop_table('architecture_scans') )
op.drop_index(op.f("ix_architecture_scans_id"), table_name="architecture_scans")
op.drop_table("architecture_scans")

View File

@@ -5,15 +5,16 @@ Revises: f68d8da5315a
Create Date: 2025-11-23 19:52:40.509538 Create Date: 2025-11-23 19:52:40.509538
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'a2064e1dfcd4' revision: str = "a2064e1dfcd4"
down_revision: Union[str, None] = 'f68d8da5315a' down_revision: Union[str, None] = "f68d8da5315a"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -21,34 +22,46 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Create cart_items table # Create cart_items table
op.create_table( op.create_table(
'cart_items', "cart_items",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('vendor_id', sa.Integer(), nullable=False), sa.Column("vendor_id", sa.Integer(), nullable=False),
sa.Column('product_id', sa.Integer(), nullable=False), sa.Column("product_id", sa.Integer(), nullable=False),
sa.Column('session_id', sa.String(length=255), nullable=False), sa.Column("session_id", sa.String(length=255), nullable=False),
sa.Column('quantity', sa.Integer(), nullable=False), sa.Column("quantity", sa.Integer(), nullable=False),
sa.Column('price_at_add', sa.Float(), nullable=False), sa.Column("price_at_add", sa.Float(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True), sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['product_id'], ['products.id'], ), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), ["product_id"],
sa.PrimaryKeyConstraint('id'), ["products.id"],
sa.UniqueConstraint('vendor_id', 'session_id', 'product_id', name='uq_cart_item') ),
sa.ForeignKeyConstraint(
["vendor_id"],
["vendors.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"vendor_id", "session_id", "product_id", name="uq_cart_item"
),
) )
# Create indexes # Create indexes
op.create_index('idx_cart_session', 'cart_items', ['vendor_id', 'session_id'], unique=False) op.create_index(
op.create_index('idx_cart_created', 'cart_items', ['created_at'], unique=False) "idx_cart_session", "cart_items", ["vendor_id", "session_id"], unique=False
op.create_index(op.f('ix_cart_items_id'), 'cart_items', ['id'], unique=False) )
op.create_index(op.f('ix_cart_items_session_id'), 'cart_items', ['session_id'], unique=False) op.create_index("idx_cart_created", "cart_items", ["created_at"], unique=False)
op.create_index(op.f("ix_cart_items_id"), "cart_items", ["id"], unique=False)
op.create_index(
op.f("ix_cart_items_session_id"), "cart_items", ["session_id"], unique=False
)
def downgrade() -> None: def downgrade() -> None:
# Drop indexes # Drop indexes
op.drop_index(op.f('ix_cart_items_session_id'), table_name='cart_items') op.drop_index(op.f("ix_cart_items_session_id"), table_name="cart_items")
op.drop_index(op.f('ix_cart_items_id'), table_name='cart_items') op.drop_index(op.f("ix_cart_items_id"), table_name="cart_items")
op.drop_index('idx_cart_created', table_name='cart_items') op.drop_index("idx_cart_created", table_name="cart_items")
op.drop_index('idx_cart_session', table_name='cart_items') op.drop_index("idx_cart_session", table_name="cart_items")
# Drop table # Drop table
op.drop_table('cart_items') op.drop_table("cart_items")

View File

@@ -5,24 +5,30 @@ Revises: 72aa309d4007
Create Date: 2025-11-22 23:51:40.694983 Create Date: 2025-11-22 23:51:40.694983
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'f68d8da5315a' revision: str = "f68d8da5315a"
down_revision: Union[str, None] = '72aa309d4007' down_revision: Union[str, None] = "72aa309d4007"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Add template column to content_pages table # Add template column to content_pages table
op.add_column('content_pages', sa.Column('template', sa.String(length=50), nullable=False, server_default='default')) op.add_column(
"content_pages",
sa.Column(
"template", sa.String(length=50), nullable=False, server_default="default"
),
)
def downgrade() -> None: def downgrade() -> None:
# Remove template column from content_pages table # Remove template column from content_pages table
op.drop_column('content_pages', 'template') op.drop_column("content_pages", "template")

View File

@@ -6,15 +6,16 @@ Create Date: 2025-11-13 16:51:25.010057
SQLite-compatible version SQLite-compatible version
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'fa7d4d10e358' revision: str = "fa7d4d10e358"
down_revision: Union[str, None] = '4951b2e50581' down_revision: Union[str, None] = "4951b2e50581"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -28,9 +29,14 @@ def upgrade():
# ======================================================================== # ========================================================================
# User table changes # User table changes
# ======================================================================== # ========================================================================
with op.batch_alter_table('users', schema=None) as batch_op: with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column('is_email_verified', sa.Boolean(), nullable=False, server_default='false') sa.Column(
"is_email_verified",
sa.Boolean(),
nullable=False,
server_default="false",
)
) )
# Set existing active users as verified # Set existing active users as verified
@@ -39,68 +45,65 @@ def upgrade():
# ======================================================================== # ========================================================================
# VendorUser table changes (requires table recreation for SQLite) # VendorUser table changes (requires table recreation for SQLite)
# ======================================================================== # ========================================================================
with op.batch_alter_table('vendor_users', schema=None) as batch_op: with op.batch_alter_table("vendor_users", schema=None) as batch_op:
# Add new columns # Add new columns
batch_op.add_column( batch_op.add_column(
sa.Column('user_type', sa.String(length=20), nullable=False, server_default='member') sa.Column(
"user_type",
sa.String(length=20),
nullable=False,
server_default="member",
)
) )
batch_op.add_column( batch_op.add_column(
sa.Column('invitation_token', sa.String(length=100), nullable=True) sa.Column("invitation_token", sa.String(length=100), nullable=True)
) )
batch_op.add_column( batch_op.add_column(
sa.Column('invitation_sent_at', sa.DateTime(), nullable=True) sa.Column("invitation_sent_at", sa.DateTime(), nullable=True)
) )
batch_op.add_column( batch_op.add_column(
sa.Column('invitation_accepted_at', sa.DateTime(), nullable=True) sa.Column("invitation_accepted_at", sa.DateTime(), nullable=True)
) )
# Create index on invitation_token # Create index on invitation_token
batch_op.create_index( batch_op.create_index("idx_vendor_users_invitation_token", ["invitation_token"])
'idx_vendor_users_invitation_token',
['invitation_token']
)
# Modify role_id to be nullable (this recreates the table in SQLite) # Modify role_id to be nullable (this recreates the table in SQLite)
batch_op.alter_column( batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=True)
'role_id',
existing_type=sa.Integer(),
nullable=True
)
# Change is_active default (this recreates the table in SQLite) # Change is_active default (this recreates the table in SQLite)
batch_op.alter_column( batch_op.alter_column(
'is_active', "is_active", existing_type=sa.Boolean(), server_default="false"
existing_type=sa.Boolean(),
server_default='false'
) )
# Set owners correctly (after table modifications) # Set owners correctly (after table modifications)
# SQLite-compatible UPDATE with subquery # SQLite-compatible UPDATE with subquery
op.execute(""" op.execute(
"""
UPDATE vendor_users UPDATE vendor_users
SET user_type = 'owner' SET user_type = 'owner'
WHERE (vendor_id, user_id) IN ( WHERE (vendor_id, user_id) IN (
SELECT id, owner_user_id SELECT id, owner_user_id
FROM vendors FROM vendors
) )
""") """
)
# Set existing owners as active # Set existing owners as active
op.execute(""" op.execute(
"""
UPDATE vendor_users UPDATE vendor_users
SET is_active = TRUE SET is_active = TRUE
WHERE user_type = 'owner' WHERE user_type = 'owner'
""") """
)
# ======================================================================== # ========================================================================
# Role table changes # Role table changes
# ======================================================================== # ========================================================================
with op.batch_alter_table('roles', schema=None) as batch_op: with op.batch_alter_table("roles", schema=None) as batch_op:
# Create index on vendor_id and name # Create index on vendor_id and name
batch_op.create_index( batch_op.create_index("idx_roles_vendor_name", ["vendor_id", "name"])
'idx_roles_vendor_name',
['vendor_id', 'name']
)
# Note: JSONB conversion only for PostgreSQL # Note: JSONB conversion only for PostgreSQL
# SQLite stores JSON as TEXT by default, no conversion needed # SQLite stores JSON as TEXT by default, no conversion needed
@@ -115,37 +118,31 @@ def downgrade():
# ======================================================================== # ========================================================================
# Role table changes # Role table changes
# ======================================================================== # ========================================================================
with op.batch_alter_table('roles', schema=None) as batch_op: with op.batch_alter_table("roles", schema=None) as batch_op:
batch_op.drop_index('idx_roles_vendor_name') batch_op.drop_index("idx_roles_vendor_name")
# ======================================================================== # ========================================================================
# VendorUser table changes # VendorUser table changes
# ======================================================================== # ========================================================================
with op.batch_alter_table('vendor_users', schema=None) as batch_op: with op.batch_alter_table("vendor_users", schema=None) as batch_op:
# Revert is_active default # Revert is_active default
batch_op.alter_column( batch_op.alter_column(
'is_active', "is_active", existing_type=sa.Boolean(), server_default="true"
existing_type=sa.Boolean(),
server_default='true'
) )
# Revert role_id to NOT NULL # Revert role_id to NOT NULL
# Note: This might fail if there are NULL values # Note: This might fail if there are NULL values
batch_op.alter_column( batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=False)
'role_id',
existing_type=sa.Integer(),
nullable=False
)
# Drop indexes and columns # Drop indexes and columns
batch_op.drop_index('idx_vendor_users_invitation_token') batch_op.drop_index("idx_vendor_users_invitation_token")
batch_op.drop_column('invitation_accepted_at') batch_op.drop_column("invitation_accepted_at")
batch_op.drop_column('invitation_sent_at') batch_op.drop_column("invitation_sent_at")
batch_op.drop_column('invitation_token') batch_op.drop_column("invitation_token")
batch_op.drop_column('user_type') batch_op.drop_column("user_type")
# ======================================================================== # ========================================================================
# User table changes # User table changes
# ======================================================================== # ========================================================================
with op.batch_alter_table('users', schema=None) as batch_op: with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column('is_email_verified') batch_op.drop_column("is_email_verified")

View File

@@ -5,30 +5,43 @@ Revises: fa7d4d10e358
Create Date: 2025-11-22 13:41:18.069674 Create Date: 2025-11-22 13:41:18.069674
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'fef1d20ce8b4' revision: str = "fef1d20ce8b4"
down_revision: Union[str, None] = 'fa7d4d10e358' down_revision: Union[str, None] = "fa7d4d10e358"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index('idx_roles_vendor_name', table_name='roles') op.drop_index("idx_roles_vendor_name", table_name="roles")
op.drop_index('idx_vendor_users_invitation_token', table_name='vendor_users') op.drop_index("idx_vendor_users_invitation_token", table_name="vendor_users")
op.create_index(op.f('ix_vendor_users_invitation_token'), 'vendor_users', ['invitation_token'], unique=False) op.create_index(
op.f("ix_vendor_users_invitation_token"),
"vendor_users",
["invitation_token"],
unique=False,
)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_vendor_users_invitation_token'), table_name='vendor_users') op.drop_index(op.f("ix_vendor_users_invitation_token"), table_name="vendor_users")
op.create_index('idx_vendor_users_invitation_token', 'vendor_users', ['invitation_token'], unique=False) op.create_index(
op.create_index('idx_roles_vendor_name', 'roles', ['vendor_id', 'name'], unique=False) "idx_vendor_users_invitation_token",
"vendor_users",
["invitation_token"],
unique=False,
)
op.create_index(
"idx_roles_vendor_name", "roles", ["vendor_id", "name"], unique=False
)
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -34,22 +34,20 @@ The cookie path restrictions prevent cross-context cookie leakage:
import logging import logging
from typing import Optional from typing import Optional
from fastapi import Depends, Request, Cookie from fastapi import Cookie, Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.exceptions import (AdminRequiredException,
InsufficientPermissionsException,
InvalidTokenException,
UnauthorizedVendorAccessException,
VendorNotFoundException)
from middleware.auth import AuthManager from middleware.auth import AuthManager
from middleware.rate_limiter import RateLimiter from middleware.rate_limiter import RateLimiter
from models.database.vendor import Vendor
from models.database.user import User from models.database.user import User
from app.exceptions import ( from models.database.vendor import Vendor
AdminRequiredException,
InvalidTokenException,
InsufficientPermissionsException,
VendorNotFoundException,
UnauthorizedVendorAccessException
)
# Initialize dependencies # Initialize dependencies
security = HTTPBearer(auto_error=False) # auto_error=False prevents automatic 403 security = HTTPBearer(auto_error=False) # auto_error=False prevents automatic 403
@@ -62,11 +60,12 @@ logger = logging.getLogger(__name__)
# HELPER FUNCTIONS # HELPER FUNCTIONS
# ============================================================================ # ============================================================================
def _get_token_from_request( def _get_token_from_request(
credentials: Optional[HTTPAuthorizationCredentials], credentials: Optional[HTTPAuthorizationCredentials],
cookie_value: Optional[str], cookie_value: Optional[str],
cookie_name: str, cookie_name: str,
request_path: str request_path: str,
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
""" """
Extract token from Authorization header or cookie. Extract token from Authorization header or cookie.
@@ -108,10 +107,7 @@ def _validate_user_token(token: str, db: Session) -> User:
Raises: Raises:
InvalidTokenException: If token is invalid InvalidTokenException: If token is invalid
""" """
mock_credentials = HTTPAuthorizationCredentials( mock_credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
scheme="Bearer",
credentials=token
)
return auth_manager.get_current_user(db, mock_credentials) return auth_manager.get_current_user(db, mock_credentials)
@@ -119,6 +115,7 @@ def _validate_user_token(token: str, db: Session) -> User:
# ADMIN AUTHENTICATION # ADMIN AUTHENTICATION
# ============================================================================ # ============================================================================
def get_current_admin_from_cookie_or_header( def get_current_admin_from_cookie_or_header(
request: Request, request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -148,10 +145,7 @@ def get_current_admin_from_cookie_or_header(
AdminRequiredException: If user is not admin AdminRequiredException: If user is not admin
""" """
token, source = _get_token_from_request( token, source = _get_token_from_request(
credentials, credentials, admin_token, "admin_token", str(request.url.path)
admin_token,
"admin_token",
str(request.url.path)
) )
if not token: if not token:
@@ -208,6 +202,7 @@ def get_current_admin_api(
# VENDOR AUTHENTICATION # VENDOR AUTHENTICATION
# ============================================================================ # ============================================================================
def get_current_vendor_from_cookie_or_header( def get_current_vendor_from_cookie_or_header(
request: Request, request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -237,10 +232,7 @@ def get_current_vendor_from_cookie_or_header(
InsufficientPermissionsException: If user is not vendor or is admin InsufficientPermissionsException: If user is not vendor or is admin
""" """
token, source = _get_token_from_request( token, source = _get_token_from_request(
credentials, credentials, vendor_token, "vendor_token", str(request.url.path)
vendor_token,
"vendor_token",
str(request.url.path)
) )
if not token: if not token:
@@ -310,6 +302,7 @@ def get_current_vendor_api(
# CUSTOMER AUTHENTICATION (SHOP) # CUSTOMER AUTHENTICATION (SHOP)
# ============================================================================ # ============================================================================
def get_current_customer_from_cookie_or_header( def get_current_customer_from_cookie_or_header(
request: Request, request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -338,15 +331,14 @@ def get_current_customer_from_cookie_or_header(
Raises: Raises:
InvalidTokenException: If no token or invalid token InvalidTokenException: If no token or invalid token
""" """
from models.database.customer import Customer
from jose import jwt, JWTError
from datetime import datetime, timezone from datetime import datetime, timezone
from jose import JWTError, jwt
from models.database.customer import Customer
token, source = _get_token_from_request( token, source = _get_token_from_request(
credentials, credentials, customer_token, "customer_token", str(request.url.path)
customer_token,
"customer_token",
str(request.url.path)
) )
if not token: if not token:
@@ -356,9 +348,7 @@ def get_current_customer_from_cookie_or_header(
# Decode and validate customer JWT token # Decode and validate customer JWT token
try: try:
payload = jwt.decode( payload = jwt.decode(
token, token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
auth_manager.secret_key,
algorithms=[auth_manager.algorithm]
) )
# Verify this is a customer token # Verify this is a customer token
@@ -375,7 +365,9 @@ def get_current_customer_from_cookie_or_header(
# Verify token hasn't expired # Verify token hasn't expired
exp = payload.get("exp") exp = payload.get("exp")
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc): if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(
timezone.utc
):
logger.warning(f"Expired customer token for customer_id={customer_id}") logger.warning(f"Expired customer token for customer_id={customer_id}")
raise InvalidTokenException("Token has expired") raise InvalidTokenException("Token has expired")
@@ -445,9 +437,10 @@ def get_current_customer_api(
# GENERIC AUTHENTICATION (for mixed-use endpoints) # GENERIC AUTHENTICATION (for mixed-use endpoints)
# ============================================================================ # ============================================================================
def get_current_user( def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db) db: Session = Depends(get_db),
) -> User: ) -> User:
""" """
Get current authenticated user from Authorization header only. Get current authenticated user from Authorization header only.
@@ -475,6 +468,7 @@ def get_current_user(
# VENDOR OWNERSHIP VERIFICATION # VENDOR OWNERSHIP VERIFICATION
# ============================================================================ # ============================================================================
def get_user_vendor( def get_user_vendor(
vendor_code: str, vendor_code: str,
current_user: User = Depends(get_current_vendor_from_cookie_or_header), current_user: User = Depends(get_current_vendor_from_cookie_or_header),
@@ -500,9 +494,7 @@ def get_user_vendor(
VendorNotFoundException: If vendor doesn't exist VendorNotFoundException: If vendor doesn't exist
UnauthorizedVendorAccessException: If user doesn't have access UnauthorizedVendorAccessException: If user doesn't have access
""" """
vendor = db.query(Vendor).filter( vendor = db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
Vendor.vendor_code == vendor_code.upper()
).first()
if not vendor: if not vendor:
raise VendorNotFoundException(vendor_code) raise VendorNotFoundException(vendor_code)
@@ -517,10 +509,12 @@ def get_user_vendor(
# User doesn't have access to this vendor # User doesn't have access to this vendor
raise UnauthorizedVendorAccessException(vendor_code, current_user.id) raise UnauthorizedVendorAccessException(vendor_code, current_user.id)
# ============================================================================ # ============================================================================
# PERMISSIONS CHECKING # PERMISSIONS CHECKING
# ============================================================================ # ============================================================================
def require_vendor_permission(permission: str): def require_vendor_permission(permission: str):
""" """
Dependency factory to require a specific vendor permission. Dependency factory to require a specific vendor permission.
@@ -610,8 +604,7 @@ def require_any_vendor_permission(*permissions: str):
# Check if user has ANY of the required permissions # Check if user has ANY of the required permissions
has_permission = any( has_permission = any(
current_user.has_vendor_permission(vendor.id, perm) current_user.has_vendor_permission(vendor.id, perm) for perm in permissions
for perm in permissions
) )
if not has_permission: if not has_permission:
@@ -651,7 +644,8 @@ def require_all_vendor_permissions(*permissions: str):
# Check if user has ALL required permissions # Check if user has ALL required permissions
missing_permissions = [ missing_permissions = [
perm for perm in permissions perm
for perm in permissions
if not current_user.has_vendor_permission(vendor.id, perm) if not current_user.has_vendor_permission(vendor.id, perm)
] ]
@@ -682,6 +676,7 @@ def get_user_permissions(
# If owner, return all permissions # If owner, return all permissions
if current_user.is_owner_of(vendor.id): if current_user.is_owner_of(vendor.id):
from app.core.permissions import VendorPermissions from app.core.permissions import VendorPermissions
return [p.value for p in VendorPermissions] return [p.value for p in VendorPermissions]
# Get permissions from vendor membership # Get permissions from vendor membership
@@ -696,6 +691,7 @@ def get_user_permissions(
# OPTIONAL AUTHENTICATION (For Login Page Redirects) # OPTIONAL AUTHENTICATION (For Login Page Redirects)
# ============================================================================ # ============================================================================
def get_current_admin_optional( def get_current_admin_optional(
request: Request, request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -723,10 +719,7 @@ def get_current_admin_optional(
None: If no token, invalid token, or user is not admin None: If no token, invalid token, or user is not admin
""" """
token, source = _get_token_from_request( token, source = _get_token_from_request(
credentials, credentials, admin_token, "admin_token", str(request.url.path)
admin_token,
"admin_token",
str(request.url.path)
) )
if not token: if not token:
@@ -773,10 +766,7 @@ def get_current_vendor_optional(
None: If no token, invalid token, or user is not vendor None: If no token, invalid token, or user is not vendor
""" """
token, source = _get_token_from_request( token, source = _get_token_from_request(
credentials, credentials, vendor_token, "vendor_token", str(request.url.path)
vendor_token,
"vendor_token",
str(request.url.path)
) )
if not token: if not token:
@@ -823,10 +813,7 @@ def get_current_customer_optional(
None: If no token, invalid token, or user is not customer None: If no token, invalid token, or user is not customer
""" """
token, source = _get_token_from_request( token, source = _get_token_from_request(
credentials, credentials, customer_token, "customer_token", str(request.url.path)
customer_token,
"customer_token",
str(request.url.path)
) )
if not token: if not token:
@@ -844,5 +831,3 @@ def get_current_customer_optional(
pass pass
return None return None

View File

@@ -9,7 +9,8 @@ This module provides:
""" """
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import admin, vendor, shop
from app.api.v1 import admin, shop, vendor
api_router = APIRouter() api_router = APIRouter()
@@ -18,31 +19,18 @@ api_router = APIRouter()
# Prefix: /api/v1/admin # Prefix: /api/v1/admin
# ============================================================================ # ============================================================================
api_router.include_router( api_router.include_router(admin.router, prefix="/v1/admin", tags=["admin"])
admin.router,
prefix="/v1/admin",
tags=["admin"]
)
# ============================================================================ # ============================================================================
# VENDOR ROUTES (Vendor-scoped operations) # VENDOR ROUTES (Vendor-scoped operations)
# Prefix: /api/v1/vendor # Prefix: /api/v1/vendor
# ============================================================================ # ============================================================================
api_router.include_router( api_router.include_router(vendor.router, prefix="/v1/vendor", tags=["vendor"])
vendor.router,
prefix="/v1/vendor",
tags=["vendor"]
)
# ============================================================================ # ============================================================================
# SHOP ROUTES (Public shop frontend API) # SHOP ROUTES (Public shop frontend API)
# Prefix: /api/v1/shop # Prefix: /api/v1/shop
# ============================================================================ # ============================================================================
api_router.include_router( api_router.include_router(shop.router, prefix="/v1/shop", tags=["shop"])
shop.router,
prefix="/v1/shop",
tags=["shop"]
)

View File

@@ -3,6 +3,6 @@
API Version 1 - All endpoints API Version 1 - All endpoints
""" """
from . import admin, vendor, shop from . import admin, shop, vendor
__all__ = ["admin", "vendor", "shop"] __all__ = ["admin", "vendor", "shop"]

View File

@@ -24,21 +24,9 @@ IMPORTANT:
from fastapi import APIRouter from fastapi import APIRouter
# Import all admin routers # Import all admin routers
from . import ( from . import (audit, auth, code_quality, content_pages, dashboard,
auth, marketplace, monitoring, notifications, settings, users,
vendors, vendor_domains, vendor_themes, vendors)
vendor_domains,
vendor_themes,
users,
dashboard,
marketplace,
monitoring,
audit,
settings,
notifications,
content_pages,
code_quality
)
# Create admin router # Create admin router
router = APIRouter() router = APIRouter()
@@ -66,7 +54,9 @@ router.include_router(vendor_domains.router, tags=["admin-vendor-domains"])
router.include_router(vendor_themes.router, tags=["admin-vendor-themes"]) router.include_router(vendor_themes.router, tags=["admin-vendor-themes"])
# Include content pages management endpoints # Include content pages management endpoints
router.include_router(content_pages.router, prefix="/content-pages", tags=["admin-content-pages"]) router.include_router(
content_pages.router, prefix="/content-pages", tags=["admin-content-pages"]
)
# ============================================================================ # ============================================================================
@@ -115,7 +105,9 @@ router.include_router(notifications.router, tags=["admin-notifications"])
# ============================================================================ # ============================================================================
# Include code quality and architecture validation endpoints # Include code quality and architecture validation endpoints
router.include_router(code_quality.router, prefix="/code-quality", tags=["admin-code-quality"]) router.include_router(
code_quality.router, prefix="/code-quality", tags=["admin-code-quality"]
)
# Export the router # Export the router
__all__ = ["router"] __all__ = ["router"]

View File

@@ -9,8 +9,8 @@ Provides endpoints for:
""" """
import logging import logging
from typing import Optional
from datetime import datetime from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -18,12 +18,10 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.services.admin_audit_service import admin_audit_service from app.services.admin_audit_service import admin_audit_service
from models.schema.admin import (
AdminAuditLogResponse,
AdminAuditLogFilters,
AdminAuditLogListResponse
)
from models.database.user import User from models.database.user import User
from models.schema.admin import (AdminAuditLogFilters,
AdminAuditLogListResponse,
AdminAuditLogResponse)
router = APIRouter(prefix="/audit") router = APIRouter(prefix="/audit")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -54,7 +52,7 @@ def get_audit_logs(
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
logs = admin_audit_service.get_audit_logs(db, filters) logs = admin_audit_service.get_audit_logs(db, filters)
@@ -62,12 +60,7 @@ def get_audit_logs(
logger.info(f"Admin {current_admin.username} retrieved {len(logs)} audit logs") logger.info(f"Admin {current_admin.username} retrieved {len(logs)} audit logs")
return AdminAuditLogListResponse( return AdminAuditLogListResponse(logs=logs, total=total, skip=skip, limit=limit)
logs=logs,
total=total,
skip=skip,
limit=limit
)
@router.get("/logs/recent", response_model=list[AdminAuditLogResponse]) @router.get("/logs/recent", response_model=list[AdminAuditLogResponse])
@@ -89,9 +82,7 @@ def get_my_actions(
): ):
"""Get audit logs for current admin's actions.""" """Get audit logs for current admin's actions."""
return admin_audit_service.get_recent_actions_by_admin( return admin_audit_service.get_recent_actions_by_admin(
db=db, db=db, admin_user_id=current_admin.id, limit=limit
admin_user_id=current_admin.id,
limit=limit
) )
@@ -109,8 +100,5 @@ def get_actions_by_target(
Useful for tracking the history of a specific vendor, user, or entity. Useful for tracking the history of a specific vendor, user, or entity.
""" """
return admin_audit_service.get_actions_by_target( return admin_audit_service.get_actions_by_target(
db=db, db=db, target_type=target_type, target_id=target_id, limit=limit
target_type=target_type,
target_id=target_id,
limit=limit
) )

View File

@@ -10,16 +10,17 @@ This prevents admin cookies from being sent to vendor routes.
""" """
import logging import logging
from fastapi import APIRouter, Depends, Response from fastapi import APIRouter, Depends, Response
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.core.environment import should_use_secure_cookies from app.core.environment import should_use_secure_cookies
from app.services.auth_service import auth_service
from app.exceptions import InvalidCredentialsException from app.exceptions import InvalidCredentialsException
from models.schema.auth import LoginResponse, UserLogin, UserResponse from app.services.auth_service import auth_service
from models.database.user import User from models.database.user import User
from app.api.deps import get_current_admin_api from models.schema.auth import LoginResponse, UserLogin, UserResponse
router = APIRouter(prefix="/auth") router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,9 +28,7 @@ logger = logging.getLogger(__name__)
@router.post("/login", response_model=LoginResponse) @router.post("/login", response_model=LoginResponse)
def admin_login( def admin_login(
user_credentials: UserLogin, user_credentials: UserLogin, response: Response, db: Session = Depends(get_db)
response: Response,
db: Session = Depends(get_db)
): ):
""" """
Admin login endpoint. Admin login endpoint.
@@ -49,7 +48,9 @@ def admin_login(
# Verify user is admin # Verify user is admin
if login_result["user"].role != "admin": if login_result["user"].role != "admin":
logger.warning(f"Non-admin user attempted admin login: {user_credentials.email_or_username}") logger.warning(
f"Non-admin user attempted admin login: {user_credentials.email_or_username}"
)
raise InvalidCredentialsException("Admin access required") raise InvalidCredentialsException("Admin access required")
logger.info(f"Admin login successful: {login_result['user'].username}") logger.info(f"Admin login successful: {login_result['user'].username}")

View File

@@ -3,25 +3,27 @@ Code Quality API Endpoints
RESTful API for architecture validation and violation management RESTful API for architecture validation and violation management
""" """
from typing import Optional
from datetime import datetime from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query from typing import Optional
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.services.code_quality_service import code_quality_service from app.services.code_quality_service import code_quality_service
from app.api.deps import get_current_admin_api
from models.database.user import User from models.database.user import User
router = APIRouter() router = APIRouter()
# Pydantic Models for API # Pydantic Models for API
class ScanResponse(BaseModel): class ScanResponse(BaseModel):
"""Response model for a scan""" """Response model for a scan"""
id: int id: int
timestamp: str timestamp: str
total_files: int total_files: int
@@ -38,6 +40,7 @@ class ScanResponse(BaseModel):
class ViolationResponse(BaseModel): class ViolationResponse(BaseModel):
"""Response model for a violation""" """Response model for a violation"""
id: int id: int
scan_id: int scan_id: int
rule_id: str rule_id: str
@@ -61,6 +64,7 @@ class ViolationResponse(BaseModel):
class ViolationListResponse(BaseModel): class ViolationListResponse(BaseModel):
"""Response model for paginated violations list""" """Response model for paginated violations list"""
violations: list[ViolationResponse] violations: list[ViolationResponse]
total: int total: int
page: int page: int
@@ -70,34 +74,42 @@ class ViolationListResponse(BaseModel):
class ViolationDetailResponse(ViolationResponse): class ViolationDetailResponse(ViolationResponse):
"""Response model for single violation with relationships""" """Response model for single violation with relationships"""
assignments: list = [] assignments: list = []
comments: list = [] comments: list = []
class AssignViolationRequest(BaseModel): class AssignViolationRequest(BaseModel):
"""Request model for assigning a violation""" """Request model for assigning a violation"""
user_id: int = Field(..., description="User ID to assign to") user_id: int = Field(..., description="User ID to assign to")
due_date: Optional[datetime] = Field(None, description="Due date for resolution") due_date: Optional[datetime] = Field(None, description="Due date for resolution")
priority: str = Field("medium", description="Priority level (low, medium, high, critical)") priority: str = Field(
"medium", description="Priority level (low, medium, high, critical)"
)
class ResolveViolationRequest(BaseModel): class ResolveViolationRequest(BaseModel):
"""Request model for resolving a violation""" """Request model for resolving a violation"""
resolution_note: str = Field(..., description="Note about the resolution") resolution_note: str = Field(..., description="Note about the resolution")
class IgnoreViolationRequest(BaseModel): class IgnoreViolationRequest(BaseModel):
"""Request model for ignoring a violation""" """Request model for ignoring a violation"""
reason: str = Field(..., description="Reason for ignoring") reason: str = Field(..., description="Reason for ignoring")
class AddCommentRequest(BaseModel): class AddCommentRequest(BaseModel):
"""Request model for adding a comment""" """Request model for adding a comment"""
comment: str = Field(..., min_length=1, description="Comment text") comment: str = Field(..., min_length=1, description="Comment text")
class DashboardStatsResponse(BaseModel): class DashboardStatsResponse(BaseModel):
"""Response model for dashboard statistics""" """Response model for dashboard statistics"""
total_violations: int total_violations: int
errors: int errors: int
warnings: int warnings: int
@@ -116,10 +128,10 @@ class DashboardStatsResponse(BaseModel):
# API Endpoints # API Endpoints
@router.post("/scan", response_model=ScanResponse) @router.post("/scan", response_model=ScanResponse)
async def trigger_scan( async def trigger_scan(
db: Session = Depends(get_db), db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api)
): ):
""" """
Trigger a new architecture scan Trigger a new architecture scan
@@ -127,7 +139,9 @@ async def trigger_scan(
Requires authentication. Runs the validator script and stores results. Requires authentication. Runs the validator script and stores results.
""" """
try: try:
scan = code_quality_service.run_scan(db, triggered_by=f"manual:{current_user.username}") scan = code_quality_service.run_scan(
db, triggered_by=f"manual:{current_user.username}"
)
return ScanResponse( return ScanResponse(
id=scan.id, id=scan.id,
@@ -138,7 +152,7 @@ async def trigger_scan(
warnings=scan.warnings, warnings=scan.warnings,
duration_seconds=scan.duration_seconds, duration_seconds=scan.duration_seconds,
triggered_by=scan.triggered_by, triggered_by=scan.triggered_by,
git_commit_hash=scan.git_commit_hash git_commit_hash=scan.git_commit_hash,
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}")
@@ -148,7 +162,7 @@ async def trigger_scan(
async def list_scans( async def list_scans(
limit: int = Query(30, ge=1, le=100, description="Number of scans to return"), limit: int = Query(30, ge=1, le=100, description="Number of scans to return"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Get scan history Get scan history
@@ -167,7 +181,7 @@ async def list_scans(
warnings=scan.warnings, warnings=scan.warnings,
duration_seconds=scan.duration_seconds, duration_seconds=scan.duration_seconds,
triggered_by=scan.triggered_by, triggered_by=scan.triggered_by,
git_commit_hash=scan.git_commit_hash git_commit_hash=scan.git_commit_hash,
) )
for scan in scans for scan in scans
] ]
@@ -175,15 +189,23 @@ async def list_scans(
@router.get("/violations", response_model=ViolationListResponse) @router.get("/violations", response_model=ViolationListResponse)
async def list_violations( async def list_violations(
scan_id: Optional[int] = Query(None, description="Filter by scan ID (defaults to latest)"), scan_id: Optional[int] = Query(
severity: Optional[str] = Query(None, description="Filter by severity (error, warning)"), None, description="Filter by scan ID (defaults to latest)"
status: Optional[str] = Query(None, description="Filter by status (open, assigned, resolved, ignored)"), ),
severity: Optional[str] = Query(
None, description="Filter by severity (error, warning)"
),
status: Optional[str] = Query(
None, description="Filter by status (open, assigned, resolved, ignored)"
),
rule_id: Optional[str] = Query(None, description="Filter by rule ID"), rule_id: Optional[str] = Query(None, description="Filter by rule ID"),
file_path: Optional[str] = Query(None, description="Filter by file path (partial match)"), file_path: Optional[str] = Query(
None, description="Filter by file path (partial match)"
),
page: int = Query(1, ge=1, description="Page number"), page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(50, ge=1, le=200, description="Items per page"), page_size: int = Query(50, ge=1, le=200, description="Items per page"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Get violations with filtering and pagination Get violations with filtering and pagination
@@ -200,7 +222,7 @@ async def list_violations(
rule_id=rule_id, rule_id=rule_id,
file_path=file_path, file_path=file_path,
limit=page_size, limit=page_size,
offset=offset offset=offset,
) )
total_pages = (total + page_size - 1) // page_size total_pages = (total + page_size - 1) // page_size
@@ -223,14 +245,14 @@ async def list_violations(
resolved_at=v.resolved_at.isoformat() if v.resolved_at else None, resolved_at=v.resolved_at.isoformat() if v.resolved_at else None,
resolved_by=v.resolved_by, resolved_by=v.resolved_by,
resolution_note=v.resolution_note, resolution_note=v.resolution_note,
created_at=v.created_at.isoformat() created_at=v.created_at.isoformat(),
) )
for v in violations for v in violations
], ],
total=total, total=total,
page=page, page=page,
page_size=page_size, page_size=page_size,
total_pages=total_pages total_pages=total_pages,
) )
@@ -238,7 +260,7 @@ async def list_violations(
async def get_violation( async def get_violation(
violation_id: int, violation_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Get single violation with details Get single violation with details
@@ -253,12 +275,12 @@ async def get_violation(
# Format assignments # Format assignments
assignments = [ assignments = [
{ {
'id': a.id, "id": a.id,
'user_id': a.user_id, "user_id": a.user_id,
'assigned_at': a.assigned_at.isoformat(), "assigned_at": a.assigned_at.isoformat(),
'assigned_by': a.assigned_by, "assigned_by": a.assigned_by,
'due_date': a.due_date.isoformat() if a.due_date else None, "due_date": a.due_date.isoformat() if a.due_date else None,
'priority': a.priority "priority": a.priority,
} }
for a in violation.assignments for a in violation.assignments
] ]
@@ -266,10 +288,10 @@ async def get_violation(
# Format comments # Format comments
comments = [ comments = [
{ {
'id': c.id, "id": c.id,
'user_id': c.user_id, "user_id": c.user_id,
'comment': c.comment, "comment": c.comment,
'created_at': c.created_at.isoformat() "created_at": c.created_at.isoformat(),
} }
for c in violation.comments for c in violation.comments
] ]
@@ -287,12 +309,14 @@ async def get_violation(
suggestion=violation.suggestion, suggestion=violation.suggestion,
status=violation.status, status=violation.status,
assigned_to=violation.assigned_to, assigned_to=violation.assigned_to,
resolved_at=violation.resolved_at.isoformat() if violation.resolved_at else None, resolved_at=(
violation.resolved_at.isoformat() if violation.resolved_at else None
),
resolved_by=violation.resolved_by, resolved_by=violation.resolved_by,
resolution_note=violation.resolution_note, resolution_note=violation.resolution_note,
created_at=violation.created_at.isoformat(), created_at=violation.created_at.isoformat(),
assignments=assignments, assignments=assignments,
comments=comments comments=comments,
) )
@@ -301,7 +325,7 @@ async def assign_violation(
violation_id: int, violation_id: int,
request: AssignViolationRequest, request: AssignViolationRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Assign violation to a developer Assign violation to a developer
@@ -315,17 +339,19 @@ async def assign_violation(
user_id=request.user_id, user_id=request.user_id,
assigned_by=current_user.id, assigned_by=current_user.id,
due_date=request.due_date, due_date=request.due_date,
priority=request.priority priority=request.priority,
) )
return { return {
'id': assignment.id, "id": assignment.id,
'violation_id': assignment.violation_id, "violation_id": assignment.violation_id,
'user_id': assignment.user_id, "user_id": assignment.user_id,
'assigned_at': assignment.assigned_at.isoformat(), "assigned_at": assignment.assigned_at.isoformat(),
'assigned_by': assignment.assigned_by, "assigned_by": assignment.assigned_by,
'due_date': assignment.due_date.isoformat() if assignment.due_date else None, "due_date": (
'priority': assignment.priority assignment.due_date.isoformat() if assignment.due_date else None
),
"priority": assignment.priority,
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@@ -336,7 +362,7 @@ async def resolve_violation(
violation_id: int, violation_id: int,
request: ResolveViolationRequest, request: ResolveViolationRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Mark violation as resolved Mark violation as resolved
@@ -348,15 +374,17 @@ async def resolve_violation(
db, db,
violation_id=violation_id, violation_id=violation_id,
resolved_by=current_user.id, resolved_by=current_user.id,
resolution_note=request.resolution_note resolution_note=request.resolution_note,
) )
return { return {
'id': violation.id, "id": violation.id,
'status': violation.status, "status": violation.status,
'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None, "resolved_at": (
'resolved_by': violation.resolved_by, violation.resolved_at.isoformat() if violation.resolved_at else None
'resolution_note': violation.resolution_note ),
"resolved_by": violation.resolved_by,
"resolution_note": violation.resolution_note,
} }
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@@ -369,7 +397,7 @@ async def ignore_violation(
violation_id: int, violation_id: int,
request: IgnoreViolationRequest, request: IgnoreViolationRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Mark violation as ignored (won't fix) Mark violation as ignored (won't fix)
@@ -381,15 +409,17 @@ async def ignore_violation(
db, db,
violation_id=violation_id, violation_id=violation_id,
ignored_by=current_user.id, ignored_by=current_user.id,
reason=request.reason reason=request.reason,
) )
return { return {
'id': violation.id, "id": violation.id,
'status': violation.status, "status": violation.status,
'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None, "resolved_at": (
'resolved_by': violation.resolved_by, violation.resolved_at.isoformat() if violation.resolved_at else None
'resolution_note': violation.resolution_note ),
"resolved_by": violation.resolved_by,
"resolution_note": violation.resolution_note,
} }
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@@ -402,7 +432,7 @@ async def add_comment(
violation_id: int, violation_id: int,
request: AddCommentRequest, request: AddCommentRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api) current_user: User = Depends(get_current_admin_api),
): ):
""" """
Add comment to violation Add comment to violation
@@ -414,15 +444,15 @@ async def add_comment(
db, db,
violation_id=violation_id, violation_id=violation_id,
user_id=current_user.id, user_id=current_user.id,
comment=request.comment comment=request.comment,
) )
return { return {
'id': comment.id, "id": comment.id,
'violation_id': comment.violation_id, "violation_id": comment.violation_id,
'user_id': comment.user_id, "user_id": comment.user_id,
'comment': comment.comment, "comment": comment.comment,
'created_at': comment.created_at.isoformat() "created_at": comment.created_at.isoformat(),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@@ -430,8 +460,7 @@ async def add_comment(
@router.get("/stats", response_model=DashboardStatsResponse) @router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats( async def get_dashboard_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api)
): ):
""" """
Get dashboard statistics Get dashboard statistics

View File

@@ -10,6 +10,7 @@ Platform administrators can:
import logging import logging
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -26,24 +27,43 @@ logger = logging.getLogger(__name__)
# REQUEST/RESPONSE SCHEMAS # REQUEST/RESPONSE SCHEMAS
# ============================================================================ # ============================================================================
class ContentPageCreate(BaseModel): class ContentPageCreate(BaseModel):
"""Schema for creating a content page.""" """Schema for creating a content page."""
slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)")
slug: str = Field(
...,
max_length=100,
description="URL-safe identifier (about, faq, contact, etc.)",
)
title: str = Field(..., max_length=200, description="Page title") title: str = Field(..., max_length=200, description="Page title")
content: str = Field(..., description="HTML or Markdown content") content: str = Field(..., description="HTML or Markdown content")
content_format: str = Field(default="html", description="Content format: html or markdown") content_format: str = Field(
template: str = Field(default="default", max_length=50, description="Template name (default, minimal, modern)") default="html", description="Content format: html or markdown"
meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description") )
meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords") template: str = Field(
default="default",
max_length=50,
description="Template name (default, minimal, modern)",
)
meta_description: Optional[str] = Field(
None, max_length=300, description="SEO meta description"
)
meta_keywords: Optional[str] = Field(
None, max_length=300, description="SEO keywords"
)
is_published: bool = Field(default=False, description="Publish immediately") is_published: bool = Field(default=False, description="Publish immediately")
show_in_footer: bool = Field(default=True, description="Show in footer navigation") show_in_footer: bool = Field(default=True, description="Show in footer navigation")
show_in_header: bool = Field(default=False, description="Show in header navigation") show_in_header: bool = Field(default=False, description="Show in header navigation")
display_order: int = Field(default=0, description="Display order (lower = first)") display_order: int = Field(default=0, description="Display order (lower = first)")
vendor_id: Optional[int] = Field(None, description="Vendor ID (None for platform default)") vendor_id: Optional[int] = Field(
None, description="Vendor ID (None for platform default)"
)
class ContentPageUpdate(BaseModel): class ContentPageUpdate(BaseModel):
"""Schema for updating a content page.""" """Schema for updating a content page."""
title: Optional[str] = Field(None, max_length=200) title: Optional[str] = Field(None, max_length=200)
content: Optional[str] = None content: Optional[str] = None
content_format: Optional[str] = None content_format: Optional[str] = None
@@ -58,6 +78,7 @@ class ContentPageUpdate(BaseModel):
class ContentPageResponse(BaseModel): class ContentPageResponse(BaseModel):
"""Schema for content page response.""" """Schema for content page response."""
id: int id: int
vendor_id: Optional[int] vendor_id: Optional[int]
vendor_name: Optional[str] vendor_name: Optional[str]
@@ -84,11 +105,12 @@ class ContentPageResponse(BaseModel):
# PLATFORM DEFAULT PAGES (vendor_id=NULL) # PLATFORM DEFAULT PAGES (vendor_id=NULL)
# ============================================================================ # ============================================================================
@router.get("/platform", response_model=List[ContentPageResponse]) @router.get("/platform", response_model=List[ContentPageResponse])
def list_platform_pages( def list_platform_pages(
include_unpublished: bool = Query(False, description="Include draft pages"), include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_admin_api), current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
List all platform default content pages. List all platform default content pages.
@@ -96,8 +118,7 @@ def list_platform_pages(
These are used as fallbacks when vendors haven't created custom pages. These are used as fallbacks when vendors haven't created custom pages.
""" """
pages = content_page_service.list_all_platform_pages( pages = content_page_service.list_all_platform_pages(
db, db, include_unpublished=include_unpublished
include_unpublished=include_unpublished
) )
return [page.to_dict() for page in pages] return [page.to_dict() for page in pages]
@@ -107,7 +128,7 @@ def list_platform_pages(
def create_platform_page( def create_platform_page(
page_data: ContentPageCreate, page_data: ContentPageCreate,
current_user: User = Depends(get_current_admin_api), current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Create a new platform default content page. Create a new platform default content page.
@@ -129,7 +150,7 @@ def create_platform_page(
show_in_footer=page_data.show_in_footer, show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header, show_in_header=page_data.show_in_header,
display_order=page_data.display_order, display_order=page_data.display_order,
created_by=current_user.id created_by=current_user.id,
) )
return page.to_dict() return page.to_dict()
@@ -139,12 +160,13 @@ def create_platform_page(
# ALL CONTENT PAGES (Platform + Vendors) # ALL CONTENT PAGES (Platform + Vendors)
# ============================================================================ # ============================================================================
@router.get("/", response_model=List[ContentPageResponse]) @router.get("/", response_model=List[ContentPageResponse])
def list_all_pages( def list_all_pages(
vendor_id: Optional[int] = Query(None, description="Filter by vendor ID"), vendor_id: Optional[int] = Query(None, description="Filter by vendor ID"),
include_unpublished: bool = Query(False, description="Include draft pages"), include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_admin_api), current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
List all content pages (platform defaults and vendor overrides). List all content pages (platform defaults and vendor overrides).
@@ -153,15 +175,14 @@ def list_all_pages(
""" """
if vendor_id: if vendor_id:
pages = content_page_service.list_all_vendor_pages( pages = content_page_service.list_all_vendor_pages(
db, db, vendor_id=vendor_id, include_unpublished=include_unpublished
vendor_id=vendor_id,
include_unpublished=include_unpublished
) )
else: else:
# Get all pages (both platform and vendor) # Get all pages (both platform and vendor)
from models.database.content_page import ContentPage
from sqlalchemy import and_ from sqlalchemy import and_
from models.database.content_page import ContentPage
filters = [] filters = []
if not include_unpublished: if not include_unpublished:
filters.append(ContentPage.is_published == True) filters.append(ContentPage.is_published == True)
@@ -169,7 +190,9 @@ def list_all_pages(
pages = ( pages = (
db.query(ContentPage) db.query(ContentPage)
.filter(and_(*filters) if filters else True) .filter(and_(*filters) if filters else True)
.order_by(ContentPage.vendor_id, ContentPage.display_order, ContentPage.title) .order_by(
ContentPage.vendor_id, ContentPage.display_order, ContentPage.title
)
.all() .all()
) )
@@ -180,7 +203,7 @@ def list_all_pages(
def get_page( def get_page(
page_id: int, page_id: int,
current_user: User = Depends(get_current_admin_api), current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
"""Get a specific content page by ID.""" """Get a specific content page by ID."""
page = content_page_service.get_page_by_id(db, page_id) page = content_page_service.get_page_by_id(db, page_id)
@@ -196,7 +219,7 @@ def update_page(
page_id: int, page_id: int,
page_data: ContentPageUpdate, page_data: ContentPageUpdate,
current_user: User = Depends(get_current_admin_api), current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
"""Update a content page (platform or vendor).""" """Update a content page (platform or vendor)."""
page = content_page_service.update_page( page = content_page_service.update_page(
@@ -212,7 +235,7 @@ def update_page(
show_in_footer=page_data.show_in_footer, show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header, show_in_header=page_data.show_in_header,
display_order=page_data.display_order, display_order=page_data.display_order,
updated_by=current_user.id updated_by=current_user.id,
) )
if not page: if not page:
@@ -225,7 +248,7 @@ def update_page(
def delete_page( def delete_page(
page_id: int, page_id: int,
current_user: User = Depends(get_current_admin_api), current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
"""Delete a content page.""" """Delete a content page."""
success = content_page_service.delete_page(db, page_id) success = content_page_service.delete_page(db, page_id)

View File

@@ -5,6 +5,7 @@ Admin dashboard and statistics endpoints.
import logging import logging
from typing import List from typing import List
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

View File

@@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.services.admin_service import admin_service from app.services.admin_service import admin_service
from app.services.stats_service import stats_service from app.services.stats_service import stats_service
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.database.user import User from models.database.user import User
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
router = APIRouter(prefix="/marketplace-import-jobs") router = APIRouter(prefix="/marketplace-import-jobs")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -16,16 +16,13 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from models.schema.admin import (
AdminNotificationCreate,
AdminNotificationResponse,
AdminNotificationListResponse,
PlatformAlertCreate,
PlatformAlertResponse,
PlatformAlertListResponse,
PlatformAlertResolve
)
from models.database.user import User from models.database.user import User
from models.schema.admin import (AdminNotificationCreate,
AdminNotificationListResponse,
AdminNotificationResponse,
PlatformAlertCreate,
PlatformAlertListResponse,
PlatformAlertResolve, PlatformAlertResponse)
router = APIRouter(prefix="/notifications") router = APIRouter(prefix="/notifications")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,6 +32,7 @@ logger = logging.getLogger(__name__)
# ADMIN NOTIFICATIONS # ADMIN NOTIFICATIONS
# ============================================================================ # ============================================================================
@router.get("", response_model=AdminNotificationListResponse) @router.get("", response_model=AdminNotificationListResponse)
def get_notifications( def get_notifications(
priority: Optional[str] = Query(None, description="Filter by priority"), priority: Optional[str] = Query(None, description="Filter by priority"),
@@ -47,11 +45,7 @@ def get_notifications(
"""Get admin notifications with filtering.""" """Get admin notifications with filtering."""
# TODO: Implement notification service # TODO: Implement notification service
return AdminNotificationListResponse( return AdminNotificationListResponse(
notifications=[], notifications=[], total=0, unread_count=0, skip=skip, limit=limit
total=0,
unread_count=0,
skip=skip,
limit=limit
) )
@@ -90,10 +84,13 @@ def mark_all_as_read(
# PLATFORM ALERTS # PLATFORM ALERTS
# ============================================================================ # ============================================================================
@router.get("/alerts", response_model=PlatformAlertListResponse) @router.get("/alerts", response_model=PlatformAlertListResponse)
def get_platform_alerts( def get_platform_alerts(
severity: Optional[str] = Query(None, description="Filter by severity"), severity: Optional[str] = Query(None, description="Filter by severity"),
is_resolved: Optional[bool] = Query(None, description="Filter by resolution status"), is_resolved: Optional[bool] = Query(
None, description="Filter by resolution status"
),
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100), limit: int = Query(50, ge=1, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -102,12 +99,7 @@ def get_platform_alerts(
"""Get platform alerts with filtering.""" """Get platform alerts with filtering."""
# TODO: Implement alert service # TODO: Implement alert service
return PlatformAlertListResponse( return PlatformAlertListResponse(
alerts=[], alerts=[], total=0, active_count=0, critical_count=0, skip=skip, limit=limit
total=0,
active_count=0,
critical_count=0,
skip=skip,
limit=limit
) )
@@ -147,5 +139,5 @@ def get_alert_statistics(
"total_alerts": 0, "total_alerts": 0,
"active_alerts": 0, "active_alerts": 0,
"critical_alerts": 0, "critical_alerts": 0,
"resolved_today": 0 "resolved_today": 0,
} }

View File

@@ -16,15 +16,11 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.services.admin_settings_service import admin_settings_service
from app.services.admin_audit_service import admin_audit_service from app.services.admin_audit_service import admin_audit_service
from models.schema.admin import ( from app.services.admin_settings_service import admin_settings_service
AdminSettingCreate,
AdminSettingResponse,
AdminSettingUpdate,
AdminSettingListResponse
)
from models.database.user import User from models.database.user import User
from models.schema.admin import (AdminSettingCreate, AdminSettingListResponse,
AdminSettingResponse, AdminSettingUpdate)
router = APIRouter(prefix="/settings") router = APIRouter(prefix="/settings")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,9 +42,7 @@ def get_all_settings(
settings = admin_settings_service.get_all_settings(db, category, is_public) settings = admin_settings_service.get_all_settings(db, category, is_public)
return AdminSettingListResponse( return AdminSettingListResponse(
settings=settings, settings=settings, total=len(settings), category=category
total=len(settings),
category=category
) )
@@ -66,7 +60,7 @@ def get_setting_categories(
"marketplace", "marketplace",
"notifications", "notifications",
"integrations", "integrations",
"payments" "payments",
] ]
} }
@@ -82,6 +76,7 @@ def get_setting(
if not setting: if not setting:
from fastapi import HTTPException from fastapi import HTTPException
raise HTTPException(status_code=404, detail=f"Setting '{key}' not found") raise HTTPException(status_code=404, detail=f"Setting '{key}' not found")
return AdminSettingResponse.model_validate(setting) return AdminSettingResponse.model_validate(setting)
@@ -99,9 +94,7 @@ def create_setting(
Setting keys should be lowercase with underscores (e.g., max_vendors_allowed). Setting keys should be lowercase with underscores (e.g., max_vendors_allowed).
""" """
result = admin_settings_service.create_setting( result = admin_settings_service.create_setting(
db=db, db=db, setting_data=setting_data, admin_user_id=current_admin.id
setting_data=setting_data,
admin_user_id=current_admin.id
) )
# Log action # Log action
@@ -111,7 +104,10 @@ def create_setting(
action="create_setting", action="create_setting",
target_type="setting", target_type="setting",
target_id=setting_data.key, target_id=setting_data.key,
details={"category": setting_data.category, "value_type": setting_data.value_type} details={
"category": setting_data.category,
"value_type": setting_data.value_type,
},
) )
return result return result
@@ -128,10 +124,7 @@ def update_setting(
old_value = admin_settings_service.get_setting_value(db, key) old_value = admin_settings_service.get_setting_value(db, key)
result = admin_settings_service.update_setting( result = admin_settings_service.update_setting(
db=db, db=db, key=key, update_data=update_data, admin_user_id=current_admin.id
key=key,
update_data=update_data,
admin_user_id=current_admin.id
) )
# Log action # Log action
@@ -141,7 +134,7 @@ def update_setting(
action="update_setting", action="update_setting",
target_type="setting", target_type="setting",
target_id=key, target_id=key,
details={"old_value": str(old_value), "new_value": update_data.value} details={"old_value": str(old_value), "new_value": update_data.value},
) )
return result return result
@@ -159,9 +152,7 @@ def upsert_setting(
If setting exists, updates its value. If not, creates new setting. If setting exists, updates its value. If not, creates new setting.
""" """
result = admin_settings_service.upsert_setting( result = admin_settings_service.upsert_setting(
db=db, db=db, setting_data=setting_data, admin_user_id=current_admin.id
setting_data=setting_data,
admin_user_id=current_admin.id
) )
# Log action # Log action
@@ -171,7 +162,7 @@ def upsert_setting(
action="upsert_setting", action="upsert_setting",
target_type="setting", target_type="setting",
target_id=setting_data.key, target_id=setting_data.key,
details={"category": setting_data.category} details={"category": setting_data.category},
) )
return result return result
@@ -195,13 +186,11 @@ def delete_setting(
if not confirm: if not confirm:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Deletion requires confirmation parameter: confirm=true" detail="Deletion requires confirmation parameter: confirm=true",
) )
message = admin_settings_service.delete_setting( message = admin_settings_service.delete_setting(
db=db, db=db, key=key, admin_user_id=current_admin.id
key=key,
admin_user_id=current_admin.id
) )
# Log action # Log action
@@ -211,7 +200,7 @@ def delete_setting(
action="delete_setting", action="delete_setting",
target_type="setting", target_type="setting",
target_id=key, target_id=key,
details={} details={},
) )
return {"message": message} return {"message": message}

View File

@@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.services.admin_service import admin_service from app.services.admin_service import admin_service
from app.services.stats_service import stats_service from app.services.stats_service import stats_service
from models.schema.auth import UserResponse
from models.database.user import User from models.database.user import User
from models.schema.auth import UserResponse
router = APIRouter(prefix="/users") router = APIRouter(prefix="/users")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -12,24 +12,22 @@ Follows the architecture pattern:
import logging import logging
from typing import List from typing import List
from fastapi import APIRouter, Depends, Path, Body, Query from fastapi import APIRouter, Body, Depends, Path, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.services.vendor_domain_service import vendor_domain_service
from app.exceptions import VendorNotFoundException from app.exceptions import VendorNotFoundException
from models.schema.vendor_domain import ( from app.services.vendor_domain_service import vendor_domain_service
VendorDomainCreate,
VendorDomainUpdate,
VendorDomainResponse,
VendorDomainListResponse,
DomainVerificationInstructions,
DomainVerificationResponse,
DomainDeletionResponse,
)
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.vendor_domain import (DomainDeletionResponse,
DomainVerificationInstructions,
DomainVerificationResponse,
VendorDomainCreate,
VendorDomainListResponse,
VendorDomainResponse,
VendorDomainUpdate)
router = APIRouter(prefix="/vendors") router = APIRouter(prefix="/vendors")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -88,9 +86,7 @@ def add_vendor_domain(
- 422: Invalid domain format or reserved subdomain - 422: Invalid domain format or reserved subdomain
""" """
domain = vendor_domain_service.add_domain( domain = vendor_domain_service.add_domain(
db=db, db=db, vendor_id=vendor_id, domain_data=domain_data
vendor_id=vendor_id,
domain_data=domain_data
) )
return VendorDomainResponse( return VendorDomainResponse(
@@ -148,7 +144,7 @@ def list_vendor_domains(
) )
for d in domains for d in domains
], ],
total=len(domains) total=len(domains),
) )
@@ -174,7 +170,9 @@ def get_domain_details(
is_active=domain.is_active, is_active=domain.is_active,
is_verified=domain.is_verified, is_verified=domain.is_verified,
ssl_status=domain.ssl_status, ssl_status=domain.ssl_status,
verification_token=domain.verification_token if not domain.is_verified else None, verification_token=(
domain.verification_token if not domain.is_verified else None
),
verified_at=domain.verified_at, verified_at=domain.verified_at,
ssl_verified_at=domain.ssl_verified_at, ssl_verified_at=domain.ssl_verified_at,
created_at=domain.created_at, created_at=domain.created_at,
@@ -206,9 +204,7 @@ def update_vendor_domain(
- 400: Cannot activate unverified domain - 400: Cannot activate unverified domain
""" """
domain = vendor_domain_service.update_domain( domain = vendor_domain_service.update_domain(
db=db, db=db, domain_id=domain_id, domain_update=domain_update
domain_id=domain_id,
domain_update=domain_update
) )
return VendorDomainResponse( return VendorDomainResponse(
@@ -250,9 +246,7 @@ def delete_vendor_domain(
message = vendor_domain_service.delete_domain(db, domain_id) message = vendor_domain_service.delete_domain(db, domain_id)
return DomainDeletionResponse( return DomainDeletionResponse(
message=message, message=message, domain=domain_name, vendor_id=vendor_id
domain=domain_name,
vendor_id=vendor_id
) )
@@ -290,11 +284,14 @@ def verify_domain_ownership(
message=message, message=message,
domain=domain.domain, domain=domain.domain,
verified_at=domain.verified_at, verified_at=domain.verified_at,
is_verified=domain.is_verified is_verified=domain.is_verified,
) )
@router.get("/domains/{domain_id}/verification-instructions", response_model=DomainVerificationInstructions) @router.get(
"/domains/{domain_id}/verification-instructions",
response_model=DomainVerificationInstructions,
)
def get_domain_verification_instructions( def get_domain_verification_instructions(
domain_id: int = Path(..., description="Domain ID", gt=0), domain_id: int = Path(..., description="Domain ID", gt=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -324,5 +321,5 @@ def get_domain_verification_instructions(
verification_token=instructions["verification_token"], verification_token=instructions["verification_token"],
instructions=instructions["instructions"], instructions=instructions["instructions"],
txt_record=instructions["txt_record"], txt_record=instructions["txt_record"],
common_registrars=instructions["common_registrars"] common_registrars=instructions["common_registrars"],
) )

View File

@@ -20,11 +20,8 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api, get_db from app.api.deps import get_current_admin_api, get_db
from app.services.vendor_theme_service import vendor_theme_service from app.services.vendor_theme_service import vendor_theme_service
from models.database.user import User from models.database.user import User
from models.schema.vendor_theme import ( from models.schema.vendor_theme import (ThemePresetListResponse,
VendorThemeResponse, VendorThemeResponse, VendorThemeUpdate)
VendorThemeUpdate,
ThemePresetListResponse
)
router = APIRouter(prefix="/vendor-themes") router = APIRouter(prefix="/vendor-themes")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,10 +31,9 @@ logger = logging.getLogger(__name__)
# PRESET ENDPOINTS # PRESET ENDPOINTS
# ============================================================================ # ============================================================================
@router.get("/presets", response_model=ThemePresetListResponse) @router.get("/presets", response_model=ThemePresetListResponse)
async def get_theme_presets( async def get_theme_presets(current_admin: User = Depends(get_current_admin_api)):
current_admin: User = Depends(get_current_admin_api)
):
""" """
Get all available theme presets with preview information. Get all available theme presets with preview information.
@@ -59,11 +55,12 @@ async def get_theme_presets(
# THEME RETRIEVAL # THEME RETRIEVAL
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}", response_model=VendorThemeResponse) @router.get("/{vendor_code}", response_model=VendorThemeResponse)
async def get_vendor_theme( async def get_vendor_theme(
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api) current_admin: User = Depends(get_current_admin_api),
): ):
""" """
Get theme configuration for a vendor. Get theme configuration for a vendor.
@@ -93,12 +90,13 @@ async def get_vendor_theme(
# THEME UPDATE # THEME UPDATE
# ============================================================================ # ============================================================================
@router.put("/{vendor_code}", response_model=VendorThemeResponse) @router.put("/{vendor_code}", response_model=VendorThemeResponse)
async def update_vendor_theme( async def update_vendor_theme(
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
theme_data: VendorThemeUpdate = None, theme_data: VendorThemeUpdate = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api) current_admin: User = Depends(get_current_admin_api),
): ):
""" """
Update or create theme for a vendor. Update or create theme for a vendor.
@@ -140,12 +138,13 @@ async def update_vendor_theme(
# PRESET APPLICATION # PRESET APPLICATION
# ============================================================================ # ============================================================================
@router.post("/{vendor_code}/preset/{preset_name}") @router.post("/{vendor_code}/preset/{preset_name}")
async def apply_theme_preset( async def apply_theme_preset(
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
preset_name: str = Path(..., description="Preset name"), preset_name: str = Path(..., description="Preset name"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api) current_admin: User = Depends(get_current_admin_api),
): ):
""" """
Apply a theme preset to a vendor. Apply a theme preset to a vendor.
@@ -184,7 +183,7 @@ async def apply_theme_preset(
return { return {
"message": f"Applied {preset_name} preset successfully", "message": f"Applied {preset_name} preset successfully",
"theme": theme.to_dict() "theme": theme.to_dict(),
} }
@@ -192,11 +191,12 @@ async def apply_theme_preset(
# THEME DELETION # THEME DELETION
# ============================================================================ # ============================================================================
@router.delete("/{vendor_code}") @router.delete("/{vendor_code}")
async def delete_vendor_theme( async def delete_vendor_theme(
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api) current_admin: User = Depends(get_current_admin_api),
): ):
""" """
Delete custom theme for a vendor. Delete custom theme for a vendor.

View File

@@ -6,28 +6,24 @@ Vendor management endpoints for admin.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, Path, Body from fastapi import APIRouter, Body, Depends, Path, Query
from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api from app.api.deps import get_current_admin_api
from app.core.database import get_db from app.core.database import get_db
from app.exceptions import (ConfirmationRequiredException,
VendorNotFoundException)
from app.services.admin_service import admin_service from app.services.admin_service import admin_service
from app.services.stats_service import stats_service from app.services.stats_service import stats_service
from app.exceptions import VendorNotFoundException, ConfirmationRequiredException
from models.schema.stats import VendorStatsResponse
from models.schema.vendor import (
VendorListResponse,
VendorResponse,
VendorDetailResponse,
VendorCreate,
VendorCreateResponse,
VendorUpdate,
VendorTransferOwnership,
VendorTransferOwnershipResponse,
)
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from sqlalchemy import func from models.schema.stats import VendorStatsResponse
from models.schema.vendor import (VendorCreate, VendorCreateResponse,
VendorDetailResponse, VendorListResponse,
VendorResponse, VendorTransferOwnership,
VendorTransferOwnershipResponse,
VendorUpdate)
router = APIRouter(prefix="/vendors") router = APIRouter(prefix="/vendors")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -60,9 +56,11 @@ def _get_vendor_by_identifier(db: Session, identifier: str) -> Vendor:
pass pass
# Try as vendor_code (case-insensitive) # Try as vendor_code (case-insensitive)
vendor = db.query(Vendor).filter( vendor = (
func.upper(Vendor.vendor_code) == identifier.upper() db.query(Vendor)
).first() .filter(func.upper(Vendor.vendor_code) == identifier.upper())
.first()
)
if not vendor: if not vendor:
raise VendorNotFoundException(identifier, identifier_type="code") raise VendorNotFoundException(identifier, identifier_type="code")
@@ -93,8 +91,7 @@ def create_vendor_with_owner(
Returns vendor details with owner credentials. Returns vendor details with owner credentials.
""" """
vendor, owner_user, temp_password = admin_service.create_vendor_with_owner( vendor, owner_user, temp_password = admin_service.create_vendor_with_owner(
db=db, db=db, vendor_data=vendor_data
vendor_data=vendor_data
) )
return VendorCreateResponse( return VendorCreateResponse(
@@ -121,7 +118,7 @@ def create_vendor_with_owner(
owner_email=owner_user.email, owner_email=owner_user.email,
owner_username=owner_user.username, owner_username=owner_user.username,
temporary_password=temp_password, temporary_password=temp_password,
login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login" login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login",
) )
@@ -142,7 +139,7 @@ def get_all_vendors_admin(
limit=limit, limit=limit,
search=search, search=search,
is_active=is_active, is_active=is_active,
is_verified=is_verified is_verified=is_verified,
) )
return VendorListResponse(vendors=vendors, total=total, skip=skip, limit=limit) return VendorListResponse(vendors=vendors, total=total, skip=skip, limit=limit)
@@ -257,7 +254,10 @@ def update_vendor(
) )
@router.post("/{vendor_identifier}/transfer-ownership", response_model=VendorTransferOwnershipResponse) @router.post(
"/{vendor_identifier}/transfer-ownership",
response_model=VendorTransferOwnershipResponse,
)
def transfer_vendor_ownership( def transfer_vendor_ownership(
vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"), vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
transfer_data: VendorTransferOwnership = Body(...), transfer_data: VendorTransferOwnership = Body(...),
@@ -436,7 +436,7 @@ def delete_vendor(
if not confirm: if not confirm:
raise ConfirmationRequiredException( raise ConfirmationRequiredException(
operation="delete_vendor", operation="delete_vendor",
message="Deletion requires confirmation parameter: confirm=true" message="Deletion requires confirmation parameter: confirm=true",
) )
vendor = _get_vendor_by_identifier(db, vendor_identifier) vendor = _get_vendor_by_identifier(db, vendor_identifier)

View File

@@ -21,7 +21,7 @@ Authentication:
from fastapi import APIRouter from fastapi import APIRouter
# Import shop routers # Import shop routers
from . import products, cart, orders, auth, content_pages from . import auth, cart, content_pages, orders, products
# Create shop router # Create shop router
router = APIRouter() router = APIRouter()
@@ -43,6 +43,8 @@ router.include_router(cart.router, tags=["shop-cart"])
router.include_router(orders.router, tags=["shop-orders"]) router.include_router(orders.router, tags=["shop-orders"])
# Content pages (public) # Content pages (public)
router.include_router(content_pages.router, prefix="/content-pages", tags=["shop-content-pages"]) router.include_router(
content_pages.router, prefix="/content-pages", tags=["shop-content-pages"]
)
__all__ = ["router"] __all__ = ["router"]

View File

@@ -15,15 +15,16 @@ This prevents:
""" """
import logging import logging
from fastapi import APIRouter, Depends, Response, Request, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.core.environment import should_use_secure_cookies
from app.services.customer_service import customer_service from app.services.customer_service import customer_service
from models.schema.auth import UserLogin from models.schema.auth import UserLogin
from models.schema.customer import CustomerRegister, CustomerResponse from models.schema.customer import CustomerRegister, CustomerResponse
from app.core.environment import should_use_secure_cookies
from pydantic import BaseModel
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
# Response model for customer login # Response model for customer login
class CustomerLoginResponse(BaseModel): class CustomerLoginResponse(BaseModel):
"""Customer login response with token and customer data.""" """Customer login response with token and customer data."""
access_token: str access_token: str
token_type: str token_type: str
expires_in: int expires_in: int
@@ -40,9 +42,7 @@ class CustomerLoginResponse(BaseModel):
@router.post("/auth/register", response_model=CustomerResponse) @router.post("/auth/register", response_model=CustomerResponse)
def register_customer( def register_customer(
request: Request, request: Request, customer_data: CustomerRegister, db: Session = Depends(get_db)
customer_data: CustomerRegister,
db: Session = Depends(get_db)
): ):
""" """
Register a new customer for current vendor. Register a new customer for current vendor.
@@ -59,12 +59,12 @@ def register_customer(
- phone: Customer phone number (optional) - phone: Customer phone number (optional)
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -73,14 +73,12 @@ def register_customer(
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"email": customer_data.email, "email": customer_data.email,
} },
) )
# Create customer account # Create customer account
customer = customer_service.register_customer( customer = customer_service.register_customer(
db=db, db=db, vendor_id=vendor.id, customer_data=customer_data
vendor_id=vendor.id,
customer_data=customer_data
) )
logger.info( logger.info(
@@ -89,7 +87,7 @@ def register_customer(
"customer_id": customer.id, "customer_id": customer.id,
"vendor_id": vendor.id, "vendor_id": vendor.id,
"email": customer.email, "email": customer.email,
} },
) )
return CustomerResponse.model_validate(customer) return CustomerResponse.model_validate(customer)
@@ -100,7 +98,7 @@ def customer_login(
request: Request, request: Request,
user_credentials: UserLogin, user_credentials: UserLogin,
response: Response, response: Response,
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Customer login for current vendor. Customer login for current vendor.
@@ -121,12 +119,12 @@ def customer_login(
- password: Customer password - password: Customer password
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -135,33 +133,39 @@ def customer_login(
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"email_or_username": user_credentials.email_or_username, "email_or_username": user_credentials.email_or_username,
} },
) )
# Authenticate customer # Authenticate customer
login_result = customer_service.login_customer( login_result = customer_service.login_customer(
db=db, db=db, vendor_id=vendor.id, credentials=user_credentials
vendor_id=vendor.id,
credentials=user_credentials
) )
logger.info( logger.info(
f"Customer login successful: {login_result['customer'].email} for vendor {vendor.subdomain}", f"Customer login successful: {login_result['customer'].email} for vendor {vendor.subdomain}",
extra={ extra={
"customer_id": login_result['customer'].id, "customer_id": login_result["customer"].id,
"vendor_id": vendor.id, "vendor_id": vendor.id,
"email": login_result['customer'].email, "email": login_result["customer"].email,
} },
) )
# Calculate cookie path based on vendor access method # Calculate cookie path based on vendor access method
vendor_context = getattr(request.state, 'vendor_context', None) vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
cookie_path = "/shop" # Default for domain/subdomain access cookie_path = "/shop" # Default for domain/subdomain access
if access_method == "path": if access_method == "path":
# For path-based access like /vendors/wizamart/shop # For path-based access like /vendors/wizamart/shop
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
cookie_path = f"{full_prefix}{vendor.subdomain}/shop" cookie_path = f"{full_prefix}{vendor.subdomain}/shop"
# Set HTTP-only cookie for browser navigation # Set HTTP-only cookie for browser navigation
@@ -180,10 +184,10 @@ def customer_login(
f"Set customer_token cookie with {login_result['token_data']['expires_in']}s expiry " f"Set customer_token cookie with {login_result['token_data']['expires_in']}s expiry "
f"(path={cookie_path}, httponly=True, secure={should_use_secure_cookies()})", f"(path={cookie_path}, httponly=True, secure={should_use_secure_cookies()})",
extra={ extra={
"expires_in": login_result['token_data']['expires_in'], "expires_in": login_result["token_data"]["expires_in"],
"secure": should_use_secure_cookies(), "secure": should_use_secure_cookies(),
"cookie_path": cookie_path, "cookie_path": cookie_path,
} },
) )
# Return full login response # Return full login response
@@ -196,10 +200,7 @@ def customer_login(
@router.post("/auth/logout") @router.post("/auth/logout")
def customer_logout( def customer_logout(request: Request, response: Response):
request: Request,
response: Response
):
""" """
Customer logout for current vendor. Customer logout for current vendor.
@@ -208,24 +209,32 @@ def customer_logout(
Client should also remove token from localStorage. Client should also remove token from localStorage.
""" """
# Get vendor from middleware (for logging) # Get vendor from middleware (for logging)
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
logger.info( logger.info(
f"Customer logout for vendor {vendor.subdomain if vendor else 'unknown'}", f"Customer logout for vendor {vendor.subdomain if vendor else 'unknown'}",
extra={ extra={
"vendor_id": vendor.id if vendor else None, "vendor_id": vendor.id if vendor else None,
"vendor_code": vendor.subdomain if vendor else None, "vendor_code": vendor.subdomain if vendor else None,
} },
) )
# Calculate cookie path based on vendor access method (must match login) # Calculate cookie path based on vendor access method (must match login)
vendor_context = getattr(request.state, 'vendor_context', None) vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
cookie_path = "/shop" # Default for domain/subdomain access cookie_path = "/shop" # Default for domain/subdomain access
if access_method == "path" and vendor: if access_method == "path" and vendor:
# For path-based access like /vendors/wizamart/shop # For path-based access like /vendors/wizamart/shop
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
cookie_path = f"{full_prefix}{vendor.subdomain}/shop" cookie_path = f"{full_prefix}{vendor.subdomain}/shop"
# Clear the cookie (must match path used when setting) # Clear the cookie (must match path used when setting)
@@ -240,11 +249,7 @@ def customer_logout(
@router.post("/auth/forgot-password") @router.post("/auth/forgot-password")
def forgot_password( def forgot_password(request: Request, email: str, db: Session = Depends(get_db)):
request: Request,
email: str,
db: Session = Depends(get_db)
):
""" """
Request password reset for customer. Request password reset for customer.
@@ -255,12 +260,12 @@ def forgot_password(
- email: Customer email address - email: Customer email address
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -269,7 +274,7 @@ def forgot_password(
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"email": email, "email": email,
} },
) )
# TODO: Implement password reset functionality # TODO: Implement password reset functionality
@@ -278,9 +283,7 @@ def forgot_password(
# - Send reset email to customer # - Send reset email to customer
# - Return success message (don't reveal if email exists) # - Return success message (don't reveal if email exists)
logger.info( logger.info(f"Password reset requested for {email} (vendor: {vendor.subdomain})")
f"Password reset requested for {email} (vendor: {vendor.subdomain})"
)
return { return {
"message": "If an account exists with this email, a password reset link has been sent." "message": "If an account exists with this email, a password reset link has been sent."
@@ -289,10 +292,7 @@ def forgot_password(
@router.post("/auth/reset-password") @router.post("/auth/reset-password")
def reset_password( def reset_password(
request: Request, request: Request, reset_token: str, new_password: str, db: Session = Depends(get_db)
reset_token: str,
new_password: str,
db: Session = Depends(get_db)
): ):
""" """
Reset customer password using reset token. Reset customer password using reset token.
@@ -304,12 +304,12 @@ def reset_password(
- new_password: New password - new_password: New password
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -317,7 +317,7 @@ def reset_password(
extra={ extra={
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
} },
) )
# TODO: Implement password reset # TODO: Implement password reset
@@ -327,9 +327,7 @@ def reset_password(
# - Invalidate reset token # - Invalidate reset token
# - Return success # - Return success
logger.info( logger.info(f"Password reset completed (vendor: {vendor.subdomain})")
f"Password reset completed (vendor: {vendor.subdomain})"
)
return { return {
"message": "Password reset successfully. You can now log in with your new password." "message": "Password reset successfully. You can now log in with your new password."

View File

@@ -8,18 +8,15 @@ No authentication required - uses session ID for cart tracking.
""" """
import logging import logging
from fastapi import APIRouter, Depends, Path, Body, Request, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.services.cart_service import cart_service from app.services.cart_service import cart_service
from models.schema.cart import ( from models.schema.cart import (AddToCartRequest, CartOperationResponse,
AddToCartRequest, CartResponse, ClearCartResponse,
UpdateCartItemRequest, UpdateCartItemRequest)
CartResponse,
CartOperationResponse,
ClearCartResponse,
)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,6 +26,7 @@ logger = logging.getLogger(__name__)
# CART ENDPOINTS # CART ENDPOINTS
# ============================================================================ # ============================================================================
@router.get("/cart/{session_id}", response_model=CartResponse) @router.get("/cart/{session_id}", response_model=CartResponse)
def get_cart( def get_cart(
request: Request, request: Request,
@@ -45,12 +43,12 @@ def get_cart(
- session_id: Unique session identifier for the cart - session_id: Unique session identifier for the cart
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.info( logger.info(
@@ -59,23 +57,19 @@ def get_cart(
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"session_id": session_id, "session_id": session_id,
} },
) )
cart = cart_service.get_cart( cart = cart_service.get_cart(db=db, vendor_id=vendor.id, session_id=session_id)
db=db,
vendor_id=vendor.id,
session_id=session_id
)
logger.info( logger.info(
f"[SHOP_API] get_cart result: {len(cart.get('items', []))} items in cart", f"[SHOP_API] get_cart result: {len(cart.get('items', []))} items in cart",
extra={ extra={
"session_id": session_id, "session_id": session_id,
"vendor_id": vendor.id, "vendor_id": vendor.id,
"item_count": len(cart.get('items', [])), "item_count": len(cart.get("items", [])),
"total": cart.get('total', 0), "total": cart.get("total", 0),
} },
) )
return CartResponse.from_service_dict(cart) return CartResponse.from_service_dict(cart)
@@ -102,12 +96,12 @@ def add_to_cart(
- quantity: Quantity to add (default: 1) - quantity: Quantity to add (default: 1)
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.info( logger.info(
@@ -118,7 +112,7 @@ def add_to_cart(
"session_id": session_id, "session_id": session_id,
"product_id": cart_data.product_id, "product_id": cart_data.product_id,
"quantity": cart_data.quantity, "quantity": cart_data.quantity,
} },
) )
result = cart_service.add_to_cart( result = cart_service.add_to_cart(
@@ -126,7 +120,7 @@ def add_to_cart(
vendor_id=vendor.id, vendor_id=vendor.id,
session_id=session_id, session_id=session_id,
product_id=cart_data.product_id, product_id=cart_data.product_id,
quantity=cart_data.quantity quantity=cart_data.quantity,
) )
logger.info( logger.info(
@@ -134,13 +128,15 @@ def add_to_cart(
extra={ extra={
"session_id": session_id, "session_id": session_id,
"result": result, "result": result,
} },
) )
return CartOperationResponse(**result) return CartOperationResponse(**result)
@router.put("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse) @router.put(
"/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse
)
def update_cart_item( def update_cart_item(
request: Request, request: Request,
session_id: str = Path(..., description="Shopping session ID"), session_id: str = Path(..., description="Shopping session ID"),
@@ -162,12 +158,12 @@ def update_cart_item(
- quantity: New quantity (must be >= 1) - quantity: New quantity (must be >= 1)
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -178,7 +174,7 @@ def update_cart_item(
"session_id": session_id, "session_id": session_id,
"product_id": product_id, "product_id": product_id,
"quantity": cart_data.quantity, "quantity": cart_data.quantity,
} },
) )
result = cart_service.update_cart_item( result = cart_service.update_cart_item(
@@ -186,13 +182,15 @@ def update_cart_item(
vendor_id=vendor.id, vendor_id=vendor.id,
session_id=session_id, session_id=session_id,
product_id=product_id, product_id=product_id,
quantity=cart_data.quantity quantity=cart_data.quantity,
) )
return CartOperationResponse(**result) return CartOperationResponse(**result)
@router.delete("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse) @router.delete(
"/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse
)
def remove_from_cart( def remove_from_cart(
request: Request, request: Request,
session_id: str = Path(..., description="Shopping session ID"), session_id: str = Path(..., description="Shopping session ID"),
@@ -210,12 +208,12 @@ def remove_from_cart(
- product_id: ID of product to remove - product_id: ID of product to remove
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -225,14 +223,11 @@ def remove_from_cart(
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"session_id": session_id, "session_id": session_id,
"product_id": product_id, "product_id": product_id,
} },
) )
result = cart_service.remove_from_cart( result = cart_service.remove_from_cart(
db=db, db=db, vendor_id=vendor.id, session_id=session_id, product_id=product_id
vendor_id=vendor.id,
session_id=session_id,
product_id=product_id
) )
return CartOperationResponse(**result) return CartOperationResponse(**result)
@@ -254,12 +249,12 @@ def clear_cart(
- session_id: Unique session identifier for the cart - session_id: Unique session identifier for the cart
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -268,13 +263,9 @@ def clear_cart(
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"session_id": session_id, "session_id": session_id,
} },
) )
result = cart_service.clear_cart( result = cart_service.clear_cart(db=db, vendor_id=vendor.id, session_id=session_id)
db=db,
vendor_id=vendor.id,
session_id=session_id
)
return ClearCartResponse(**result) return ClearCartResponse(**result)

View File

@@ -8,6 +8,7 @@ No authentication required.
import logging import logging
from typing import List from typing import List
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -23,8 +24,10 @@ logger = logging.getLogger(__name__)
# RESPONSE SCHEMAS # RESPONSE SCHEMAS
# ============================================================================ # ============================================================================
class PublicContentPageResponse(BaseModel): class PublicContentPageResponse(BaseModel):
"""Public content page response (no internal IDs).""" """Public content page response (no internal IDs)."""
slug: str slug: str
title: str title: str
content: str content: str
@@ -36,6 +39,7 @@ class PublicContentPageResponse(BaseModel):
class ContentPageListItem(BaseModel): class ContentPageListItem(BaseModel):
"""Content page list item for navigation.""" """Content page list item for navigation."""
slug: str slug: str
title: str title: str
show_in_footer: bool show_in_footer: bool
@@ -47,25 +51,21 @@ class ContentPageListItem(BaseModel):
# PUBLIC ENDPOINTS # PUBLIC ENDPOINTS
# ============================================================================ # ============================================================================
@router.get("/navigation", response_model=List[ContentPageListItem]) @router.get("/navigation", response_model=List[ContentPageListItem])
def get_navigation_pages( def get_navigation_pages(request: Request, db: Session = Depends(get_db)):
request: Request,
db: Session = Depends(get_db)
):
""" """
Get list of content pages for navigation (footer/header). Get list of content pages for navigation (footer/header).
Uses vendor from request.state (set by middleware). Uses vendor from request.state (set by middleware).
Returns vendor overrides + platform defaults. Returns vendor overrides + platform defaults.
""" """
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None vendor_id = vendor.id if vendor else None
# Get all published pages for this vendor # Get all published pages for this vendor
pages = content_page_service.list_pages_for_vendor( pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=vendor_id, include_unpublished=False
vendor_id=vendor_id,
include_unpublished=False
) )
return [ return [
@@ -81,25 +81,21 @@ def get_navigation_pages(
@router.get("/{slug}", response_model=PublicContentPageResponse) @router.get("/{slug}", response_model=PublicContentPageResponse)
def get_content_page( def get_content_page(slug: str, request: Request, db: Session = Depends(get_db)):
slug: str,
request: Request,
db: Session = Depends(get_db)
):
""" """
Get a specific content page by slug. Get a specific content page by slug.
Uses vendor from request.state (set by middleware). Uses vendor from request.state (set by middleware).
Returns vendor override if exists, otherwise platform default. Returns vendor override if exists, otherwise platform default.
""" """
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None vendor_id = vendor.id if vendor else None
page = content_page_service.get_page_for_vendor( page = content_page_service.get_page_for_vendor(
db, db,
slug=slug, slug=slug,
vendor_id=vendor_id, vendor_id=vendor_id,
include_unpublished=False # Only show published pages include_unpublished=False, # Only show published pages
) )
if not page: if not page:

View File

@@ -10,31 +10,23 @@ Requires customer authentication for most operations.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Path, Query, Request, HTTPException from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_customer_api from app.api.deps import get_current_customer_api
from app.services.order_service import order_service from app.core.database import get_db
from app.services.customer_service import customer_service from app.services.customer_service import customer_service
from models.schema.order import ( from app.services.order_service import order_service
OrderCreate,
OrderResponse,
OrderDetailResponse,
OrderListResponse
)
from models.database.user import User
from models.database.customer import Customer from models.database.customer import Customer
from models.database.user import User
from models.schema.order import (OrderCreate, OrderDetailResponse,
OrderListResponse, OrderResponse)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_customer_from_user( def get_customer_from_user(request: Request, user: User, db: Session) -> Customer:
request: Request,
user: User,
db: Session
) -> Customer:
""" """
Helper to get Customer record from authenticated User. Helper to get Customer record from authenticated User.
@@ -49,25 +41,22 @@ def get_customer_from_user(
Raises: Raises:
HTTPException: If customer not found or vendor mismatch HTTPException: If customer not found or vendor mismatch
""" """
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
# Find customer record for this user and vendor # Find customer record for this user and vendor
customer = customer_service.get_customer_by_user_id( customer = customer_service.get_customer_by_user_id(
db=db, db=db, vendor_id=vendor.id, user_id=user.id
vendor_id=vendor.id,
user_id=user.id
) )
if not customer: if not customer:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail="Customer account not found for current vendor"
detail="Customer account not found for current vendor"
) )
return customer return customer
@@ -91,12 +80,12 @@ def place_order(
- Order data including shipping address, payment method, etc. - Order data including shipping address, payment method, etc.
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
# Get customer record # Get customer record
@@ -109,14 +98,12 @@ def place_order(
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"customer_id": customer.id, "customer_id": customer.id,
"user_id": current_user.id, "user_id": current_user.id,
} },
) )
# Create order # Create order
order = order_service.create_order( order = order_service.create_order(
db=db, db=db, vendor_id=vendor.id, order_data=order_data
vendor_id=vendor.id,
order_data=order_data
) )
logger.info( logger.info(
@@ -127,7 +114,7 @@ def place_order(
"order_number": order.order_number, "order_number": order.order_number,
"customer_id": customer.id, "customer_id": customer.id,
"total_amount": float(order.total_amount), "total_amount": float(order.total_amount),
} },
) )
# TODO: Update customer stats # TODO: Update customer stats
@@ -156,12 +143,12 @@ def get_my_orders(
- limit: Maximum number of orders to return - limit: Maximum number of orders to return
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
# Get customer record # Get customer record
@@ -175,23 +162,19 @@ def get_my_orders(
"customer_id": customer.id, "customer_id": customer.id,
"skip": skip, "skip": skip,
"limit": limit, "limit": limit,
} },
) )
# Get orders # Get orders
orders, total = order_service.get_customer_orders( orders, total = order_service.get_customer_orders(
db=db, db=db, vendor_id=vendor.id, customer_id=customer.id, skip=skip, limit=limit
vendor_id=vendor.id,
customer_id=customer.id,
skip=skip,
limit=limit
) )
return OrderListResponse( return OrderListResponse(
orders=[OrderResponse.model_validate(o) for o in orders], orders=[OrderResponse.model_validate(o) for o in orders],
total=total, total=total,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
@@ -212,12 +195,12 @@ def get_order_details(
- order_id: ID of the order to retrieve - order_id: ID of the order to retrieve
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
# Get customer record # Get customer record
@@ -230,19 +213,16 @@ def get_order_details(
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"customer_id": customer.id, "customer_id": customer.id,
"order_id": order_id, "order_id": order_id,
} },
) )
# Get order # Get order
order = order_service.get_order( order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id)
db=db,
vendor_id=vendor.id,
order_id=order_id
)
# Verify order belongs to customer # Verify order belongs to customer
if order.customer_id != customer.id: if order.customer_id != customer.id:
from app.exceptions import OrderNotFoundException from app.exceptions import OrderNotFoundException
raise OrderNotFoundException(str(order_id)) raise OrderNotFoundException(str(order_id))
return OrderDetailResponse.model_validate(order) return OrderDetailResponse.model_validate(order)

View File

@@ -10,12 +10,13 @@ No authentication required.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, Path, Request, HTTPException from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.services.product_service import product_service from app.services.product_service import product_service
from models.schema.product import ProductResponse, ProductDetailResponse, ProductListResponse from models.schema.product import (ProductDetailResponse, ProductListResponse,
ProductResponse)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,7 +28,9 @@ def get_product_catalog(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None, description="Search products by name"), search: Optional[str] = Query(None, description="Search products by name"),
is_featured: Optional[bool] = Query(None, description="Filter by featured products"), is_featured: Optional[bool] = Query(
None, description="Filter by featured products"
),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
@@ -44,12 +47,12 @@ def get_product_catalog(
- is_featured: Filter by featured products only - is_featured: Filter by featured products only
""" """
# Get vendor from middleware (injected by VendorContextMiddleware) # Get vendor from middleware (injected by VendorContextMiddleware)
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -61,7 +64,7 @@ def get_product_catalog(
"limit": limit, "limit": limit,
"search": search, "search": search,
"is_featured": is_featured, "is_featured": is_featured,
} },
) )
# Get only active products for public view # Get only active products for public view
@@ -71,14 +74,14 @@ def get_product_catalog(
skip=skip, skip=skip,
limit=limit, limit=limit,
is_active=True, # Only show active products to customers is_active=True, # Only show active products to customers
is_featured=is_featured is_featured=is_featured,
) )
return ProductListResponse( return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products], products=[ProductResponse.model_validate(p) for p in products],
total=total, total=total,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
@@ -98,12 +101,12 @@ def get_product_details(
- product_id: ID of the product to retrieve - product_id: ID of the product to retrieve
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -112,18 +115,17 @@ def get_product_details(
"vendor_id": vendor.id, "vendor_id": vendor.id,
"vendor_code": vendor.subdomain, "vendor_code": vendor.subdomain,
"product_id": product_id, "product_id": product_id,
} },
) )
product = product_service.get_product( product = product_service.get_product(
db=db, db=db, vendor_id=vendor.id, product_id=product_id
vendor_id=vendor.id,
product_id=product_id
) )
# Check if product is active # Check if product is active
if not product.is_active: if not product.is_active:
from app.exceptions import ProductNotActiveException from app.exceptions import ProductNotActiveException
raise ProductNotActiveException(str(product_id)) raise ProductNotActiveException(str(product_id))
return ProductDetailResponse.model_validate(product) return ProductDetailResponse.model_validate(product)
@@ -150,12 +152,12 @@ def search_products(
- limit: Maximum number of results to return - limit: Maximum number of results to return
""" """
# Get vendor from middleware # Get vendor from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path." detail="Vendor not found. Please access via vendor domain/subdomain/path.",
) )
logger.debug( logger.debug(
@@ -166,22 +168,18 @@ def search_products(
"query": q, "query": q,
"skip": skip, "skip": skip,
"limit": limit, "limit": limit,
} },
) )
# TODO: Implement full-text search functionality # TODO: Implement full-text search functionality
# For now, return filtered products # For now, return filtered products
products, total = product_service.get_vendor_products( products, total = product_service.get_vendor_products(
db=db, db=db, vendor_id=vendor.id, skip=skip, limit=limit, is_active=True
vendor_id=vendor.id,
skip=skip,
limit=limit,
is_active=True
) )
return ProductListResponse( return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products], products=[ProductResponse.model_validate(p) for p in products],
total=total, total=total,
skip=skip, skip=skip,
limit=limit limit=limit,
) )

View File

@@ -13,25 +13,9 @@ IMPORTANT:
from fastapi import APIRouter from fastapi import APIRouter
# Import all sub-routers (JSON API only) # Import all sub-routers (JSON API only)
from . import ( from . import (analytics, auth, content_pages, customers, dashboard, info,
info, inventory, marketplace, media, notifications, orders, payments,
auth, products, profile, settings, team)
dashboard,
profile,
settings,
products,
orders,
customers,
team,
inventory,
marketplace,
payments,
media,
notifications,
analytics,
content_pages,
)
# Create vendor router # Create vendor router
router = APIRouter() router = APIRouter()
@@ -68,7 +52,11 @@ router.include_router(notifications.router, tags=["vendor-notifications"])
router.include_router(analytics.router, tags=["vendor-analytics"]) router.include_router(analytics.router, tags=["vendor-analytics"])
# Content pages management # Content pages management
router.include_router(content_pages.router, prefix="/{vendor_code}/content-pages", tags=["vendor-content-pages"]) router.include_router(
content_pages.router,
prefix="/{vendor_code}/content-pages",
tags=["vendor-content-pages"],
)
# Vendor info endpoint - MUST BE LAST! Has catch-all GET /{vendor_code} # Vendor info endpoint - MUST BE LAST! Has catch-all GET /{vendor_code}
router.include_router(info.router, tags=["vendor-info"]) router.include_router(info.router, tags=["vendor-info"])

View File

@@ -4,13 +4,14 @@ Vendor analytics and reporting endpoints.
""" """
import logging import logging
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.stats_service import stats_service from app.services.stats_service import stats_service
from middleware.vendor_context import require_vendor_context
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor

View File

@@ -13,19 +13,20 @@ This prevents:
""" """
import logging import logging
from fastapi import APIRouter, Depends, Request, Response from fastapi import APIRouter, Depends, Request, Response
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.services.auth_service import auth_service
from app.exceptions import InvalidCredentialsException
from middleware.vendor_context import get_current_vendor
from models.schema.auth import UserLogin
from models.database.vendor import Vendor, VendorUser, Role
from models.database.user import User
from pydantic import BaseModel
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from app.core.environment import should_use_secure_cookies from app.core.environment import should_use_secure_cookies
from app.exceptions import InvalidCredentialsException
from app.services.auth_service import auth_service
from middleware.vendor_context import get_current_vendor
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
from models.schema.auth import UserLogin
router = APIRouter(prefix="/auth") router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,7 +47,7 @@ def vendor_login(
user_credentials: UserLogin, user_credentials: UserLogin,
request: Request, request: Request,
response: Response, response: Response,
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Vendor team member login. Vendor team member login.
@@ -64,13 +65,16 @@ def vendor_login(
vendor = get_current_vendor(request) vendor = get_current_vendor(request)
# If no vendor from middleware, try to get from request body # If no vendor from middleware, try to get from request body
if not vendor and hasattr(user_credentials, 'vendor_code'): if not vendor and hasattr(user_credentials, "vendor_code"):
vendor_code = getattr(user_credentials, 'vendor_code', None) vendor_code = getattr(user_credentials, "vendor_code", None)
if vendor_code: if vendor_code:
vendor = db.query(Vendor).filter( vendor = (
Vendor.vendor_code == vendor_code.upper(), db.query(Vendor)
Vendor.is_active == True .filter(
).first() Vendor.vendor_code == vendor_code.upper(), Vendor.is_active == True
)
.first()
)
# Authenticate user # Authenticate user
login_result = auth_service.login_user(db=db, user_credentials=user_credentials) login_result = auth_service.login_user(db=db, user_credentials=user_credentials)
@@ -79,7 +83,9 @@ def vendor_login(
# CRITICAL: Prevent admin users from using vendor login # CRITICAL: Prevent admin users from using vendor login
if user.role == "admin": if user.role == "admin":
logger.warning(f"Admin user attempted vendor login: {user.username}") logger.warning(f"Admin user attempted vendor login: {user.username}")
raise InvalidCredentialsException("Admins cannot access vendor portal. Please use admin portal.") raise InvalidCredentialsException(
"Admins cannot access vendor portal. Please use admin portal."
)
# Determine vendor and role # Determine vendor and role
vendor_role = "Member" vendor_role = "Member"
@@ -92,11 +98,16 @@ def vendor_login(
vendor_role = "Owner" vendor_role = "Owner"
else: else:
# Check if user is team member # Check if user is team member
vendor_user = db.query(VendorUser).join(Role).filter( vendor_user = (
db.query(VendorUser)
.join(Role)
.filter(
VendorUser.user_id == user.id, VendorUser.user_id == user.id,
VendorUser.vendor_id == vendor.id, VendorUser.vendor_id == vendor.id,
VendorUser.is_active == True VendorUser.is_active == True,
).first() )
.first()
)
if vendor_user: if vendor_user:
vendor_role = vendor_user.role.name vendor_role = vendor_user.role.name
@@ -117,17 +128,14 @@ def vendor_login(
# Check vendor memberships # Check vendor memberships
elif user.vendor_memberships: elif user.vendor_memberships:
active_membership = next( active_membership = next(
(vm for vm in user.vendor_memberships if vm.is_active), (vm for vm in user.vendor_memberships if vm.is_active), None
None
) )
if active_membership: if active_membership:
vendor = active_membership.vendor vendor = active_membership.vendor
vendor_role = active_membership.role.name vendor_role = active_membership.role.name
if not vendor: if not vendor:
raise InvalidCredentialsException( raise InvalidCredentialsException("User is not associated with any vendor")
"User is not associated with any vendor"
)
logger.info( logger.info(
f"Vendor team login successful: {user.username} " f"Vendor team login successful: {user.username} "
@@ -161,7 +169,7 @@ def vendor_login(
"username": user.username, "username": user.username,
"email": user.email, "email": user.email,
"role": user.role, "role": user.role,
"is_active": user.is_active "is_active": user.is_active,
}, },
vendor={ vendor={
"id": vendor.id, "id": vendor.id,
@@ -169,9 +177,9 @@ def vendor_login(
"subdomain": vendor.subdomain, "subdomain": vendor.subdomain,
"name": vendor.name, "name": vendor.name,
"is_active": vendor.is_active, "is_active": vendor.is_active,
"is_verified": vendor.is_verified "is_verified": vendor.is_verified,
}, },
vendor_role=vendor_role vendor_role=vendor_role,
) )
@@ -198,8 +206,7 @@ def vendor_logout(response: Response):
@router.get("/me") @router.get("/me")
def get_current_vendor_user( def get_current_vendor_user(
user: User = Depends(get_current_vendor_api), user: User = Depends(get_current_vendor_api), db: Session = Depends(get_db)
db: Session = Depends(get_db)
): ):
""" """
Get current authenticated vendor user. Get current authenticated vendor user.
@@ -212,5 +219,5 @@ def get_current_vendor_user(
"username": user.username, "username": user.username,
"email": user.email, "email": user.email,
"role": user.role, "role": user.role,
"is_active": user.is_active "is_active": user.is_active,
} }

View File

@@ -10,6 +10,7 @@ Vendors can:
import logging import logging
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -26,14 +27,26 @@ logger = logging.getLogger(__name__)
# REQUEST/RESPONSE SCHEMAS # REQUEST/RESPONSE SCHEMAS
# ============================================================================ # ============================================================================
class VendorContentPageCreate(BaseModel): class VendorContentPageCreate(BaseModel):
"""Schema for creating a vendor content page.""" """Schema for creating a vendor content page."""
slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)")
slug: str = Field(
...,
max_length=100,
description="URL-safe identifier (about, faq, contact, etc.)",
)
title: str = Field(..., max_length=200, description="Page title") title: str = Field(..., max_length=200, description="Page title")
content: str = Field(..., description="HTML or Markdown content") content: str = Field(..., description="HTML or Markdown content")
content_format: str = Field(default="html", description="Content format: html or markdown") content_format: str = Field(
meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description") default="html", description="Content format: html or markdown"
meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords") )
meta_description: Optional[str] = Field(
None, max_length=300, description="SEO meta description"
)
meta_keywords: Optional[str] = Field(
None, max_length=300, description="SEO keywords"
)
is_published: bool = Field(default=False, description="Publish immediately") is_published: bool = Field(default=False, description="Publish immediately")
show_in_footer: bool = Field(default=True, description="Show in footer navigation") show_in_footer: bool = Field(default=True, description="Show in footer navigation")
show_in_header: bool = Field(default=False, description="Show in header navigation") show_in_header: bool = Field(default=False, description="Show in header navigation")
@@ -42,6 +55,7 @@ class VendorContentPageCreate(BaseModel):
class VendorContentPageUpdate(BaseModel): class VendorContentPageUpdate(BaseModel):
"""Schema for updating a vendor content page.""" """Schema for updating a vendor content page."""
title: Optional[str] = Field(None, max_length=200) title: Optional[str] = Field(None, max_length=200)
content: Optional[str] = None content: Optional[str] = None
content_format: Optional[str] = None content_format: Optional[str] = None
@@ -55,6 +69,7 @@ class VendorContentPageUpdate(BaseModel):
class ContentPageResponse(BaseModel): class ContentPageResponse(BaseModel):
"""Schema for content page response.""" """Schema for content page response."""
id: int id: int
vendor_id: Optional[int] vendor_id: Optional[int]
vendor_name: Optional[str] vendor_name: Optional[str]
@@ -81,11 +96,12 @@ class ContentPageResponse(BaseModel):
# VENDOR CONTENT PAGES # VENDOR CONTENT PAGES
# ============================================================================ # ============================================================================
@router.get("/", response_model=List[ContentPageResponse]) @router.get("/", response_model=List[ContentPageResponse])
def list_vendor_pages( def list_vendor_pages(
include_unpublished: bool = Query(False, description="Include draft pages"), include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api), current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
List all content pages available for this vendor. List all content pages available for this vendor.
@@ -93,12 +109,12 @@ def list_vendor_pages(
Returns vendor-specific overrides + platform defaults (vendor overrides take precedence). Returns vendor-specific overrides + platform defaults (vendor overrides take precedence).
""" """
if not current_user.vendor_id: if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor") raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
pages = content_page_service.list_pages_for_vendor( pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished
vendor_id=current_user.vendor_id,
include_unpublished=include_unpublished
) )
return [page.to_dict() for page in pages] return [page.to_dict() for page in pages]
@@ -108,7 +124,7 @@ def list_vendor_pages(
def list_vendor_overrides( def list_vendor_overrides(
include_unpublished: bool = Query(False, description="Include draft pages"), include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api), current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
List only vendor-specific content pages (no platform defaults). List only vendor-specific content pages (no platform defaults).
@@ -116,12 +132,12 @@ def list_vendor_overrides(
Shows what the vendor has customized. Shows what the vendor has customized.
""" """
if not current_user.vendor_id: if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor") raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
pages = content_page_service.list_all_vendor_pages( pages = content_page_service.list_all_vendor_pages(
db, db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished
vendor_id=current_user.vendor_id,
include_unpublished=include_unpublished
) )
return [page.to_dict() for page in pages] return [page.to_dict() for page in pages]
@@ -132,7 +148,7 @@ def get_page(
slug: str, slug: str,
include_unpublished: bool = Query(False, description="Include draft pages"), include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api), current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Get a specific content page by slug. Get a specific content page by slug.
@@ -140,13 +156,15 @@ def get_page(
Returns vendor override if exists, otherwise platform default. Returns vendor override if exists, otherwise platform default.
""" """
if not current_user.vendor_id: if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor") raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
page = content_page_service.get_page_for_vendor( page = content_page_service.get_page_for_vendor(
db, db,
slug=slug, slug=slug,
vendor_id=current_user.vendor_id, vendor_id=current_user.vendor_id,
include_unpublished=include_unpublished include_unpublished=include_unpublished,
) )
if not page: if not page:
@@ -159,7 +177,7 @@ def get_page(
def create_vendor_page( def create_vendor_page(
page_data: VendorContentPageCreate, page_data: VendorContentPageCreate,
current_user: User = Depends(get_current_vendor_api), current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Create a vendor-specific content page override. Create a vendor-specific content page override.
@@ -167,7 +185,9 @@ def create_vendor_page(
This will be shown instead of the platform default for this vendor. This will be shown instead of the platform default for this vendor.
""" """
if not current_user.vendor_id: if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor") raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
page = content_page_service.create_page( page = content_page_service.create_page(
db, db,
@@ -182,7 +202,7 @@ def create_vendor_page(
show_in_footer=page_data.show_in_footer, show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header, show_in_header=page_data.show_in_header,
display_order=page_data.display_order, display_order=page_data.display_order,
created_by=current_user.id created_by=current_user.id,
) )
return page.to_dict() return page.to_dict()
@@ -193,7 +213,7 @@ def update_vendor_page(
page_id: int, page_id: int,
page_data: VendorContentPageUpdate, page_data: VendorContentPageUpdate,
current_user: User = Depends(get_current_vendor_api), current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Update a vendor-specific content page. Update a vendor-specific content page.
@@ -201,7 +221,9 @@ def update_vendor_page(
Can only update pages owned by this vendor. Can only update pages owned by this vendor.
""" """
if not current_user.vendor_id: if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor") raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
# Verify ownership # Verify ownership
existing_page = content_page_service.get_page_by_id(db, page_id) existing_page = content_page_service.get_page_by_id(db, page_id)
@@ -209,7 +231,9 @@ def update_vendor_page(
raise HTTPException(status_code=404, detail="Content page not found") raise HTTPException(status_code=404, detail="Content page not found")
if existing_page.vendor_id != current_user.vendor_id: if existing_page.vendor_id != current_user.vendor_id:
raise HTTPException(status_code=403, detail="Cannot edit pages from other vendors") raise HTTPException(
status_code=403, detail="Cannot edit pages from other vendors"
)
# Update # Update
page = content_page_service.update_page( page = content_page_service.update_page(
@@ -224,7 +248,7 @@ def update_vendor_page(
show_in_footer=page_data.show_in_footer, show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header, show_in_header=page_data.show_in_header,
display_order=page_data.display_order, display_order=page_data.display_order,
updated_by=current_user.id updated_by=current_user.id,
) )
return page.to_dict() return page.to_dict()
@@ -234,7 +258,7 @@ def update_vendor_page(
def delete_vendor_page( def delete_vendor_page(
page_id: int, page_id: int,
current_user: User = Depends(get_current_vendor_api), current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Delete a vendor-specific content page. Delete a vendor-specific content page.
@@ -243,7 +267,9 @@ def delete_vendor_page(
After deletion, platform default will be shown (if exists). After deletion, platform default will be shown (if exists).
""" """
if not current_user.vendor_id: if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor") raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
# Verify ownership # Verify ownership
existing_page = content_page_service.get_page_by_id(db, page_id) existing_page = content_page_service.get_page_by_id(db, page_id)
@@ -251,7 +277,9 @@ def delete_vendor_page(
raise HTTPException(status_code=404, detail="Content page not found") raise HTTPException(status_code=404, detail="Content page not found")
if existing_page.vendor_id != current_user.vendor_id: if existing_page.vendor_id != current_user.vendor_id:
raise HTTPException(status_code=403, detail="Cannot delete pages from other vendors") raise HTTPException(
status_code=403, detail="Cannot delete pages from other vendors"
)
# Delete # Delete
content_page_service.delete_page(db, page_id) content_page_service.delete_page(db, page_id)

View File

@@ -6,6 +6,7 @@ Vendor customer management endpoints.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -43,7 +44,7 @@ def get_vendor_customers(
"total": 0, "total": 0,
"skip": skip, "skip": skip,
"limit": limit, "limit": limit,
"message": "Customer management coming in Slice 4" "message": "Customer management coming in Slice 4",
} }
@@ -63,9 +64,7 @@ def get_customer_details(
- Include order history - Include order history
- Include total spent, etc. - Include total spent, etc.
""" """
return { return {"message": "Customer details coming in Slice 4"}
"message": "Customer details coming in Slice 4"
}
@router.get("/{customer_id}/orders") @router.get("/{customer_id}/orders")
@@ -83,10 +82,7 @@ def get_customer_orders(
- Filter by vendor_id - Filter by vendor_id
- Return order details - Return order details
""" """
return { return {"orders": [], "message": "Customer orders coming in Slice 5"}
"orders": [],
"message": "Customer orders coming in Slice 5"
}
@router.put("/{customer_id}") @router.put("/{customer_id}")
@@ -105,9 +101,7 @@ def update_customer(
- Verify customer belongs to vendor - Verify customer belongs to vendor
- Update customer preferences - Update customer preferences
""" """
return { return {"message": "Customer update coming in Slice 4"}
"message": "Customer update coming in Slice 4"
}
@router.put("/{customer_id}/status") @router.put("/{customer_id}/status")
@@ -125,9 +119,7 @@ def toggle_customer_status(
- Verify customer belongs to vendor - Verify customer belongs to vendor
- Log the change - Log the change
""" """
return { return {"message": "Customer status toggle coming in Slice 4"}
"message": "Customer status toggle coming in Slice 4"
}
@router.get("/{customer_id}/stats") @router.get("/{customer_id}/stats")
@@ -151,6 +143,5 @@ def get_customer_statistics(
"total_spent": 0.0, "total_spent": 0.0,
"average_order_value": 0.0, "average_order_value": 0.0,
"last_order_date": None, "last_order_date": None,
"message": "Customer statistics coming in Slice 4" "message": "Customer statistics coming in Slice 4",
} }

View File

@@ -4,13 +4,14 @@ Vendor dashboard and statistics endpoints.
""" """
import logging import logging
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.stats_service import stats_service from app.services.stats_service import stats_service
from middleware.vendor_context import require_vendor_context
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
@@ -38,24 +39,23 @@ def get_vendor_dashboard_stats(
""" """
# Get vendor from authenticated user's vendor_user record # Get vendor from authenticated user's vendor_user record
from models.database.vendor import VendorUser from models.database.vendor import VendorUser
vendor_user = db.query(VendorUser).filter(
VendorUser.user_id == current_user.id vendor_user = (
).first() db.query(VendorUser).filter(VendorUser.user_id == current_user.id).first()
)
if not vendor_user: if not vendor_user:
from fastapi import HTTPException from fastapi import HTTPException
raise HTTPException( raise HTTPException(
status_code=403, status_code=403, detail="User is not associated with any vendor"
detail="User is not associated with any vendor"
) )
vendor = vendor_user.vendor vendor = vendor_user.vendor
if not vendor or not vendor.is_active: if not vendor or not vendor.is_active:
from fastapi import HTTPException from fastapi import HTTPException
raise HTTPException(
status_code=404, raise HTTPException(status_code=404, detail="Vendor not found or inactive")
detail="Vendor not found or inactive"
)
# Get vendor-scoped statistics # Get vendor-scoped statistics
stats_data = stats_service.get_vendor_stats(db=db, vendor_id=vendor.id) stats_data = stats_service.get_vendor_stats(db=db, vendor_id=vendor.id)
@@ -82,5 +82,5 @@ def get_vendor_dashboard_stats(
"revenue": { "revenue": {
"total": stats_data.get("total_revenue", 0), "total": stats_data.get("total_revenue", 0),
"this_month": stats_data.get("revenue_this_month", 0), "this_month": stats_data.get("revenue_this_month", 0),
} },
} }

View File

@@ -8,14 +8,15 @@ This module provides:
""" """
import logging import logging
from fastapi import APIRouter, Path, Depends
from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, Path
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.exceptions import VendorNotFoundException from app.exceptions import VendorNotFoundException
from models.schema.vendor import VendorResponse, VendorDetailResponse
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.vendor import VendorDetailResponse, VendorResponse
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,10 +36,14 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor:
Raises: Raises:
VendorNotFoundException: If vendor not found or inactive VendorNotFoundException: If vendor not found or inactive
""" """
vendor = db.query(Vendor).filter( vendor = (
db.query(Vendor)
.filter(
func.upper(Vendor.vendor_code) == vendor_code.upper(), func.upper(Vendor.vendor_code) == vendor_code.upper(),
Vendor.is_active == True Vendor.is_active == True,
).first() )
.first()
)
if not vendor: if not vendor:
logger.warning(f"Vendor not found or inactive: {vendor_code}") logger.warning(f"Vendor not found or inactive: {vendor_code}")
@@ -50,7 +55,7 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor:
@router.get("/{vendor_code}", response_model=VendorDetailResponse) @router.get("/{vendor_code}", response_model=VendorDetailResponse)
def get_vendor_info( def get_vendor_info(
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Get public vendor information by vendor code. Get public vendor information by vendor code.

View File

@@ -7,19 +7,14 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.inventory_service import inventory_service from app.services.inventory_service import inventory_service
from models.schema.inventory import ( from middleware.vendor_context import require_vendor_context
InventoryCreate,
InventoryAdjust,
InventoryUpdate,
InventoryReserve,
InventoryResponse,
ProductInventorySummary,
InventoryListResponse
)
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.inventory import (InventoryAdjust, InventoryCreate,
InventoryListResponse, InventoryReserve,
InventoryResponse, InventoryUpdate,
ProductInventorySummary)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -110,10 +105,7 @@ def get_vendor_inventory(
total = len(inventories) # You might want a separate count query for large datasets total = len(inventories) # You might want a separate count query for large datasets
return InventoryListResponse( return InventoryListResponse(
inventories=inventories, inventories=inventories, total=total, skip=skip, limit=limit
total=total,
skip=skip,
limit=limit
) )
@@ -126,7 +118,9 @@ def update_inventory(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Update inventory entry.""" """Update inventory entry."""
return inventory_service.update_inventory(db, vendor.id, inventory_id, inventory_update) return inventory_service.update_inventory(
db, vendor.id, inventory_id, inventory_update
)
@router.delete("/inventory/{inventory_id}") @router.delete("/inventory/{inventory_id}")

View File

@@ -12,16 +12,15 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context # IMPORTANT from app.services.marketplace_import_job_service import \
from app.services.marketplace_import_job_service import marketplace_import_job_service marketplace_import_job_service
from app.tasks.background_tasks import process_marketplace_import from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit from middleware.decorators import rate_limit
from models.schema.marketplace_import_job import ( from middleware.vendor_context import require_vendor_context # IMPORTANT
MarketplaceImportJobResponse,
MarketplaceImportJobRequest
)
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
MarketplaceImportJobResponse)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -88,6 +87,7 @@ def get_marketplace_import_status(
# Verify job belongs to current vendor # Verify job belongs to current vendor
if job.vendor_id != vendor.id: if job.vendor_id != vendor.id:
from app.exceptions import UnauthorizedVendorAccessException from app.exceptions import UnauthorizedVendorAccessException
raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id) raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id)
return marketplace_import_job_service.convert_to_response_model(job) return marketplace_import_job_service.convert_to_response_model(job)
@@ -112,4 +112,6 @@ def get_marketplace_import_jobs(
limit=limit, limit=limit,
) )
return [marketplace_import_job_service.convert_to_response_model(job) for job in jobs] return [
marketplace_import_job_service.convert_to_response_model(job) for job in jobs
]

View File

@@ -6,7 +6,8 @@ Vendor media and file management endpoints.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, UploadFile, File
from fastapi import APIRouter, Depends, File, Query, UploadFile
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
@@ -44,7 +45,7 @@ def get_media_library(
"total": 0, "total": 0,
"skip": skip, "skip": skip,
"limit": limit, "limit": limit,
"message": "Media library coming in Slice 3" "message": "Media library coming in Slice 3",
} }
@@ -70,7 +71,7 @@ async def upload_media(
return { return {
"file_url": None, "file_url": None,
"thumbnail_url": None, "thumbnail_url": None,
"message": "Media upload coming in Slice 3" "message": "Media upload coming in Slice 3",
} }
@@ -94,7 +95,7 @@ async def upload_multiple_media(
return { return {
"uploaded_files": [], "uploaded_files": [],
"failed_files": [], "failed_files": [],
"message": "Multiple upload coming in Slice 3" "message": "Multiple upload coming in Slice 3",
} }
@@ -113,9 +114,7 @@ def get_media_details(
- Return file URL - Return file URL
- Return usage information (which products use this file) - Return usage information (which products use this file)
""" """
return { return {"message": "Media details coming in Slice 3"}
"message": "Media details coming in Slice 3"
}
@router.put("/{media_id}") @router.put("/{media_id}")
@@ -135,9 +134,7 @@ def update_media_metadata(
- Update tags/categories - Update tags/categories
- Update description - Update description
""" """
return { return {"message": "Media update coming in Slice 3"}
"message": "Media update coming in Slice 3"
}
@router.delete("/{media_id}") @router.delete("/{media_id}")
@@ -157,9 +154,7 @@ def delete_media(
- Delete database record - Delete database record
- Return success/error - Return success/error
""" """
return { return {"message": "Media deletion coming in Slice 3"}
"message": "Media deletion coming in Slice 3"
}
@router.get("/{media_id}/usage") @router.get("/{media_id}/usage")
@@ -180,7 +175,7 @@ def get_media_usage(
return { return {
"products": [], "products": [],
"other_usage": [], "other_usage": [],
"message": "Media usage tracking coming in Slice 3" "message": "Media usage tracking coming in Slice 3",
} }
@@ -200,7 +195,4 @@ def optimize_media(
- Keep original - Keep original
- Update database with new versions - Update database with new versions
""" """
return { return {"message": "Media optimization coming in Slice 3"}
"message": "Media optimization coming in Slice 3"
}

View File

@@ -6,6 +6,7 @@ Vendor notification management endpoints.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -41,7 +42,7 @@ def get_notifications(
"notifications": [], "notifications": [],
"total": 0, "total": 0,
"unread_count": 0, "unread_count": 0,
"message": "Notifications coming in Slice 5" "message": "Notifications coming in Slice 5",
} }
@@ -58,10 +59,7 @@ def get_unread_count(
- Count unread notifications for vendor - Count unread notifications for vendor
- Used for notification badge - Used for notification badge
""" """
return { return {"unread_count": 0, "message": "Unread count coming in Slice 5"}
"unread_count": 0,
"message": "Unread count coming in Slice 5"
}
@router.put("/{notification_id}/read") @router.put("/{notification_id}/read")
@@ -78,9 +76,7 @@ def mark_as_read(
- Mark single notification as read - Mark single notification as read
- Update read timestamp - Update read timestamp
""" """
return { return {"message": "Mark as read coming in Slice 5"}
"message": "Mark as read coming in Slice 5"
}
@router.put("/mark-all-read") @router.put("/mark-all-read")
@@ -96,9 +92,7 @@ def mark_all_as_read(
- Mark all vendor notifications as read - Mark all vendor notifications as read
- Update timestamps - Update timestamps
""" """
return { return {"message": "Mark all as read coming in Slice 5"}
"message": "Mark all as read coming in Slice 5"
}
@router.delete("/{notification_id}") @router.delete("/{notification_id}")
@@ -115,9 +109,7 @@ def delete_notification(
- Delete single notification - Delete single notification
- Verify notification belongs to vendor - Verify notification belongs to vendor
""" """
return { return {"message": "Notification deletion coming in Slice 5"}
"message": "Notification deletion coming in Slice 5"
}
@router.get("/settings") @router.get("/settings")
@@ -138,7 +130,7 @@ def get_notification_settings(
"email_notifications": True, "email_notifications": True,
"in_app_notifications": True, "in_app_notifications": True,
"notification_types": {}, "notification_types": {},
"message": "Notification settings coming in Slice 5" "message": "Notification settings coming in Slice 5",
} }
@@ -157,9 +149,7 @@ def update_notification_settings(
- Update in-app notification settings - Update in-app notification settings
- Enable/disable specific notification types - Enable/disable specific notification types
""" """
return { return {"message": "Notification settings update coming in Slice 5"}
"message": "Notification settings update coming in Slice 5"
}
@router.get("/templates") @router.get("/templates")
@@ -176,10 +166,7 @@ def get_notification_templates(
- Include: order confirmation, shipping notification, etc. - Include: order confirmation, shipping notification, etc.
- Return template details - Return template details
""" """
return { return {"templates": [], "message": "Notification templates coming in Slice 5"}
"templates": [],
"message": "Notification templates coming in Slice 5"
}
@router.put("/templates/{template_id}") @router.put("/templates/{template_id}")
@@ -199,9 +186,7 @@ def update_notification_template(
- Validate template variables - Validate template variables
- Preview template - Preview template
""" """
return { return {"message": "Template update coming in Slice 5"}
"message": "Template update coming in Slice 5"
}
@router.post("/test") @router.post("/test")
@@ -219,6 +204,4 @@ def send_test_notification(
- Use specified template - Use specified template
- Send to current user's email - Send to current user's email
""" """
return { return {"message": "Test notification coming in Slice 5"}
"message": "Test notification coming in Slice 5"
}

View File

@@ -6,21 +6,17 @@ Vendor order management endpoints.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, Request, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.order_service import order_service from app.services.order_service import order_service
from models.schema.order import ( from middleware.vendor_context import require_vendor_context
OrderResponse,
OrderDetailResponse,
OrderListResponse,
OrderUpdate
)
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor, VendorUser from models.database.vendor import Vendor, VendorUser
from models.schema.order import (OrderDetailResponse, OrderListResponse,
OrderResponse, OrderUpdate)
router = APIRouter(prefix="/orders") router = APIRouter(prefix="/orders")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -51,14 +47,14 @@ def get_vendor_orders(
skip=skip, skip=skip,
limit=limit, limit=limit,
status=status, status=status,
customer_id=customer_id customer_id=customer_id,
) )
return OrderListResponse( return OrderListResponse(
orders=[OrderResponse.model_validate(o) for o in orders], orders=[OrderResponse.model_validate(o) for o in orders],
total=total, total=total,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
@@ -74,11 +70,7 @@ def get_order_details(
Requires Authorization header (API endpoint). Requires Authorization header (API endpoint).
""" """
order = order_service.get_order( order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id)
db=db,
vendor_id=vendor.id,
order_id=order_id
)
return OrderDetailResponse.model_validate(order) return OrderDetailResponse.model_validate(order)
@@ -105,10 +97,7 @@ def update_order_status(
Requires Authorization header (API endpoint). Requires Authorization header (API endpoint).
""" """
order = order_service.update_order_status( order = order_service.update_order_status(
db=db, db=db, vendor_id=vendor.id, order_id=order_id, order_update=order_update
vendor_id=vendor.id,
order_id=order_id,
order_update=order_update
) )
logger.info( logger.info(

View File

@@ -5,6 +5,7 @@ Vendor payment configuration and processing endpoints.
""" """
import logging import logging
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -38,7 +39,7 @@ def get_payment_configuration(
"accepted_methods": [], "accepted_methods": [],
"currency": "EUR", "currency": "EUR",
"stripe_connected": False, "stripe_connected": False,
"message": "Payment configuration coming in Slice 5" "message": "Payment configuration coming in Slice 5",
} }
@@ -58,9 +59,7 @@ def update_payment_configuration(
- Update accepted payment methods - Update accepted payment methods
- Validate configuration before saving - Validate configuration before saving
""" """
return { return {"message": "Payment configuration update coming in Slice 5"}
"message": "Payment configuration update coming in Slice 5"
}
@router.post("/stripe/connect") @router.post("/stripe/connect")
@@ -79,9 +78,7 @@ def connect_stripe_account(
- Verify Stripe account is active - Verify Stripe account is active
- Enable payment processing - Enable payment processing
""" """
return { return {"message": "Stripe connection coming in Slice 5"}
"message": "Stripe connection coming in Slice 5"
}
@router.delete("/stripe/disconnect") @router.delete("/stripe/disconnect")
@@ -98,9 +95,7 @@ def disconnect_stripe_account(
- Disable payment processing - Disable payment processing
- Warn about pending payments - Warn about pending payments
""" """
return { return {"message": "Stripe disconnection coming in Slice 5"}
"message": "Stripe disconnection coming in Slice 5"
}
@router.get("/methods") @router.get("/methods")
@@ -116,10 +111,7 @@ def get_payment_methods(
- Return list of enabled payment methods - Return list of enabled payment methods
- Include: credit card, PayPal, bank transfer, etc. - Include: credit card, PayPal, bank transfer, etc.
""" """
return { return {"methods": [], "message": "Payment methods coming in Slice 5"}
"methods": [],
"message": "Payment methods coming in Slice 5"
}
@router.get("/transactions") @router.get("/transactions")
@@ -140,7 +132,7 @@ def get_payment_transactions(
return { return {
"transactions": [], "transactions": [],
"total": 0, "total": 0,
"message": "Payment transactions coming in Slice 5" "message": "Payment transactions coming in Slice 5",
} }
@@ -164,7 +156,7 @@ def get_payment_balance(
"pending_balance": 0.0, "pending_balance": 0.0,
"currency": "EUR", "currency": "EUR",
"next_payout_date": None, "next_payout_date": None,
"message": "Payment balance coming in Slice 5" "message": "Payment balance coming in Slice 5",
} }
@@ -185,6 +177,4 @@ def refund_payment(
- Update order status - Update order status
- Send refund notification to customer - Send refund notification to customer
""" """
return { return {"message": "Payment refund coming in Slice 5"}
"message": "Payment refund coming in Slice 5"
}

View File

@@ -11,17 +11,13 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.product_service import product_service from app.services.product_service import product_service
from models.schema.product import ( from middleware.vendor_context import require_vendor_context
ProductCreate,
ProductUpdate,
ProductResponse,
ProductDetailResponse,
ProductListResponse
)
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.product import (ProductCreate, ProductDetailResponse,
ProductListResponse, ProductResponse,
ProductUpdate)
router = APIRouter(prefix="/products") router = APIRouter(prefix="/products")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,14 +46,14 @@ def get_vendor_products(
skip=skip, skip=skip,
limit=limit, limit=limit,
is_active=is_active, is_active=is_active,
is_featured=is_featured is_featured=is_featured,
) )
return ProductListResponse( return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products], products=[ProductResponse.model_validate(p) for p in products],
total=total, total=total,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
@@ -70,9 +66,7 @@ def get_product_details(
): ):
"""Get detailed product information including inventory.""" """Get detailed product information including inventory."""
product = product_service.get_product( product = product_service.get_product(
db=db, db=db, vendor_id=vendor.id, product_id=product_id
vendor_id=vendor.id,
product_id=product_id
) )
return ProductDetailResponse.model_validate(product) return ProductDetailResponse.model_validate(product)
@@ -91,9 +85,7 @@ def add_product_to_catalog(
This publishes a MarketplaceProduct to the vendor's public catalog. This publishes a MarketplaceProduct to the vendor's public catalog.
""" """
product = product_service.create_product( product = product_service.create_product(
db=db, db=db, vendor_id=vendor.id, product_data=product_data
vendor_id=vendor.id,
product_data=product_data
) )
logger.info( logger.info(
@@ -114,10 +106,7 @@ def update_product(
): ):
"""Update product in vendor catalog.""" """Update product in vendor catalog."""
product = product_service.update_product( product = product_service.update_product(
db=db, db=db, vendor_id=vendor.id, product_id=product_id, product_update=product_data
vendor_id=vendor.id,
product_id=product_id,
product_update=product_data
) )
logger.info( logger.info(
@@ -136,11 +125,7 @@ def remove_product_from_catalog(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Remove product from vendor catalog.""" """Remove product from vendor catalog."""
product_service.delete_product( product_service.delete_product(db=db, vendor_id=vendor.id, product_id=product_id)
db=db,
vendor_id=vendor.id,
product_id=product_id
)
logger.info( logger.info(
f"Product {product_id} removed from catalog by user {current_user.username} " f"Product {product_id} removed from catalog by user {current_user.username} "
@@ -163,14 +148,11 @@ def publish_from_marketplace(
Shortcut endpoint for publishing directly from marketplace import. Shortcut endpoint for publishing directly from marketplace import.
""" """
product_data = ProductCreate( product_data = ProductCreate(
marketplace_product_id=marketplace_product_id, marketplace_product_id=marketplace_product_id, is_active=True
is_active=True
) )
product = product_service.create_product( product = product_service.create_product(
db=db, db=db, vendor_id=vendor.id, product_data=product_data
vendor_id=vendor.id,
product_data=product_data
) )
logger.info( logger.info(
@@ -198,10 +180,7 @@ def toggle_product_active(
status = "activated" if product.is_active else "deactivated" status = "activated" if product.is_active else "deactivated"
logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}") logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}")
return { return {"message": f"Product {status}", "is_active": product.is_active}
"message": f"Product {status}",
"is_active": product.is_active
}
@router.put("/{product_id}/toggle-featured") @router.put("/{product_id}/toggle-featured")
@@ -221,7 +200,4 @@ def toggle_product_featured(
status = "featured" if product.is_featured else "unfeatured" status = "featured" if product.is_featured else "unfeatured"
logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}") logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}")
return { return {"message": f"Product {status}", "is_featured": product.is_featured}
"message": f"Product {status}",
"is_featured": product.is_featured
}

View File

@@ -4,16 +4,17 @@ Vendor profile management endpoints.
""" """
import logging import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.vendor_service import vendor_service from app.services.vendor_service import vendor_service
from models.schema.vendor import VendorUpdate, VendorResponse from middleware.vendor_context import require_vendor_context
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.vendor import VendorResponse, VendorUpdate
router = APIRouter(prefix="/profile") router = APIRouter(prefix="/profile")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -4,13 +4,14 @@ Vendor settings and configuration endpoints.
""" """
import logging import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api from app.api.deps import get_current_vendor_api
from app.core.database import get_db from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.vendor_service import vendor_service from app.services.vendor_service import vendor_service
from middleware.vendor_context import require_vendor_context
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor

View File

@@ -12,35 +12,24 @@ Implements complete team management with:
import logging import logging
from typing import List from typing import List
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.deps import (get_current_vendor_api, get_user_permissions,
require_vendor_owner, require_vendor_permission)
from app.core.database import get_db from app.core.database import get_db
from app.core.permissions import VendorPermissions from app.core.permissions import VendorPermissions
from app.api.deps import (
get_current_vendor_api,
require_vendor_owner,
require_vendor_permission,
get_user_permissions
)
from app.services.vendor_team_service import vendor_team_service from app.services.vendor_team_service import vendor_team_service
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.team import ( from models.schema.team import (BulkRemoveRequest, BulkRemoveResponse,
TeamMemberInvite, InvitationAccept, InvitationAcceptResponse,
TeamMemberUpdate, InvitationResponse, RoleListResponse,
TeamMemberResponse, RoleResponse, TeamMemberInvite,
TeamMemberListResponse, TeamMemberListResponse, TeamMemberResponse,
InvitationAccept, TeamMemberUpdate, TeamStatistics,
InvitationResponse, UserPermissionsResponse)
InvitationAcceptResponse,
RoleResponse,
RoleListResponse,
UserPermissionsResponse,
TeamStatistics,
BulkRemoveRequest,
BulkRemoveResponse,
)
router = APIRouter(prefix="/team") router = APIRouter(prefix="/team")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,14 +39,15 @@ logger = logging.getLogger(__name__)
# Team Member Routes # Team Member Routes
# ============================================================================ # ============================================================================
@router.get("/members", response_model=TeamMemberListResponse) @router.get("/members", response_model=TeamMemberListResponse)
def list_team_members( def list_team_members(
request: Request, request: Request,
include_inactive: bool = False, include_inactive: bool = False,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission( current_user: User = Depends(
VendorPermissions.TEAM_VIEW.value require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
)) ),
): ):
""" """
Get all team members for current vendor. Get all team members for current vendor.
@@ -74,9 +64,7 @@ def list_team_members(
vendor = request.state.vendor vendor = request.state.vendor
members = vendor_team_service.get_team_members( members = vendor_team_service.get_team_members(
db=db, db=db, vendor=vendor, include_inactive=include_inactive
vendor=vendor,
include_inactive=include_inactive
) )
# Calculate statistics # Calculate statistics
@@ -90,10 +78,7 @@ def list_team_members(
) )
return TeamMemberListResponse( return TeamMemberListResponse(
members=members, members=members, total=total, active_count=active, pending_invitations=pending
total=total,
active_count=active,
pending_invitations=pending
) )
@@ -102,7 +87,7 @@ def invite_team_member(
invitation: TeamMemberInvite, invitation: TeamMemberInvite,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) # Owner only current_user: User = Depends(require_vendor_owner), # Owner only
): ):
""" """
Invite a new team member to the vendor. Invite a new team member to the vendor.
@@ -135,7 +120,7 @@ def invite_team_member(
vendor=vendor, vendor=vendor,
inviter=current_user, inviter=current_user,
email=invitation.email, email=invitation.email,
role_id=invitation.role_id role_id=invitation.role_id,
) )
elif invitation.role_name: elif invitation.role_name:
# Use role name with optional custom permissions # Use role name with optional custom permissions
@@ -145,7 +130,7 @@ def invite_team_member(
inviter=current_user, inviter=current_user,
email=invitation.email, email=invitation.email,
role_name=invitation.role_name, role_name=invitation.role_name,
custom_permissions=invitation.custom_permissions custom_permissions=invitation.custom_permissions,
) )
else: else:
# Default to Staff role # Default to Staff role
@@ -154,7 +139,7 @@ def invite_team_member(
vendor=vendor, vendor=vendor,
inviter=current_user, inviter=current_user,
email=invitation.email, email=invitation.email,
role_name="staff" role_name="staff",
) )
logger.info( logger.info(
@@ -166,15 +151,12 @@ def invite_team_member(
message="Invitation sent successfully", message="Invitation sent successfully",
email=result["email"], email=result["email"],
role=result["role"], role=result["role"],
invitation_sent=True invitation_sent=True,
) )
@router.post("/accept-invitation", response_model=InvitationAcceptResponse) @router.post("/accept-invitation", response_model=InvitationAcceptResponse)
def accept_invitation( def accept_invitation(acceptance: InvitationAccept, db: Session = Depends(get_db)):
acceptance: InvitationAccept,
db: Session = Depends(get_db)
):
""" """
Accept a team invitation and activate account. Accept a team invitation and activate account.
@@ -196,7 +178,7 @@ def accept_invitation(
invitation_token=acceptance.invitation_token, invitation_token=acceptance.invitation_token,
password=acceptance.password, password=acceptance.password,
first_name=acceptance.first_name, first_name=acceptance.first_name,
last_name=acceptance.last_name last_name=acceptance.last_name,
) )
logger.info( logger.info(
@@ -210,15 +192,15 @@ def accept_invitation(
"id": result["vendor"].id, "id": result["vendor"].id,
"vendor_code": result["vendor"].vendor_code, "vendor_code": result["vendor"].vendor_code,
"name": result["vendor"].name, "name": result["vendor"].name,
"subdomain": result["vendor"].subdomain "subdomain": result["vendor"].subdomain,
}, },
user={ user={
"id": result["user"].id, "id": result["user"].id,
"email": result["user"].email, "email": result["user"].email,
"username": result["user"].username, "username": result["user"].username,
"full_name": result["user"].full_name "full_name": result["user"].full_name,
}, },
role=result["role"] role=result["role"],
) )
@@ -227,9 +209,9 @@ def get_team_member(
user_id: int, user_id: int,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission( current_user: User = Depends(
VendorPermissions.TEAM_VIEW.value require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
)) ),
): ):
""" """
Get details of a specific team member. Get details of a specific team member.
@@ -239,14 +221,13 @@ def get_team_member(
vendor = request.state.vendor vendor = request.state.vendor
members = vendor_team_service.get_team_members( members = vendor_team_service.get_team_members(
db=db, db=db, vendor=vendor, include_inactive=True
vendor=vendor,
include_inactive=True
) )
member = next((m for m in members if m["id"] == user_id), None) member = next((m for m in members if m["id"] == user_id), None)
if not member: if not member:
from app.exceptions import UserNotFoundException from app.exceptions import UserNotFoundException
raise UserNotFoundException(str(user_id)) raise UserNotFoundException(str(user_id))
return TeamMemberResponse(**member) return TeamMemberResponse(**member)
@@ -258,7 +239,7 @@ def update_team_member(
update_data: TeamMemberUpdate, update_data: TeamMemberUpdate,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) # Owner only current_user: User = Depends(require_vendor_owner), # Owner only
): ):
""" """
Update a team member's role or status. Update a team member's role or status.
@@ -280,7 +261,7 @@ def update_team_member(
vendor=vendor, vendor=vendor,
user_id=user_id, user_id=user_id,
new_role_id=update_data.role_id, new_role_id=update_data.role_id,
is_active=update_data.is_active is_active=update_data.is_active,
) )
logger.info( logger.info(
@@ -300,7 +281,7 @@ def remove_team_member(
user_id: int, user_id: int,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) # Owner only current_user: User = Depends(require_vendor_owner), # Owner only
): ):
""" """
Remove a team member from the vendor. Remove a team member from the vendor.
@@ -316,21 +297,14 @@ def remove_team_member(
""" """
vendor = request.state.vendor vendor = request.state.vendor
vendor_team_service.remove_team_member( vendor_team_service.remove_team_member(db=db, vendor=vendor, user_id=user_id)
db=db,
vendor=vendor,
user_id=user_id
)
logger.info( logger.info(
f"Team member removed: {user_id} from {vendor.vendor_code} " f"Team member removed: {user_id} from {vendor.vendor_code} "
f"by {current_user.username}" f"by {current_user.username}"
) )
return { return {"message": "Team member removed successfully", "user_id": user_id}
"message": "Team member removed successfully",
"user_id": user_id
}
@router.post("/members/bulk-remove", response_model=BulkRemoveResponse) @router.post("/members/bulk-remove", response_model=BulkRemoveResponse)
@@ -338,7 +312,7 @@ def bulk_remove_team_members(
bulk_remove: BulkRemoveRequest, bulk_remove: BulkRemoveRequest,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) current_user: User = Depends(require_vendor_owner),
): ):
""" """
Remove multiple team members at once. Remove multiple team members at once.
@@ -354,17 +328,12 @@ def bulk_remove_team_members(
for user_id in bulk_remove.user_ids: for user_id in bulk_remove.user_ids:
try: try:
vendor_team_service.remove_team_member( vendor_team_service.remove_team_member(
db=db, db=db, vendor=vendor, user_id=user_id
vendor=vendor,
user_id=user_id
) )
success_count += 1 success_count += 1
except Exception as e: except Exception as e:
failed_count += 1 failed_count += 1
errors.append({ errors.append({"user_id": user_id, "error": str(e)})
"user_id": user_id,
"error": str(e)
})
logger.info( logger.info(
f"Bulk remove completed: {success_count} removed, {failed_count} failed " f"Bulk remove completed: {success_count} removed, {failed_count} failed "
@@ -372,9 +341,7 @@ def bulk_remove_team_members(
) )
return BulkRemoveResponse( return BulkRemoveResponse(
success_count=success_count, success_count=success_count, failed_count=failed_count, errors=errors
failed_count=failed_count,
errors=errors
) )
@@ -382,13 +349,14 @@ def bulk_remove_team_members(
# Role Management Routes # Role Management Routes
# ============================================================================ # ============================================================================
@router.get("/roles", response_model=RoleListResponse) @router.get("/roles", response_model=RoleListResponse)
def list_roles( def list_roles(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission( current_user: User = Depends(
VendorPermissions.TEAM_VIEW.value require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
)) ),
): ):
""" """
Get all available roles for the vendor. Get all available roles for the vendor.
@@ -403,21 +371,19 @@ def list_roles(
roles = vendor_team_service.get_vendor_roles(db=db, vendor_id=vendor.id) roles = vendor_team_service.get_vendor_roles(db=db, vendor_id=vendor.id)
return RoleListResponse( return RoleListResponse(roles=roles, total=len(roles))
roles=roles,
total=len(roles)
)
# ============================================================================ # ============================================================================
# Permission Routes # Permission Routes
# ============================================================================ # ============================================================================
@router.get("/me/permissions", response_model=UserPermissionsResponse) @router.get("/me/permissions", response_model=UserPermissionsResponse)
def get_my_permissions( def get_my_permissions(
request: Request, request: Request,
permissions: List[str] = Depends(get_user_permissions), permissions: List[str] = Depends(get_user_permissions),
current_user: User = Depends(get_current_vendor_api) current_user: User = Depends(get_current_vendor_api),
): ):
""" """
Get current user's permissions in this vendor. Get current user's permissions in this vendor.
@@ -443,7 +409,7 @@ def get_my_permissions(
permissions=permissions, permissions=permissions,
permission_count=len(permissions), permission_count=len(permissions),
is_owner=is_owner, is_owner=is_owner,
role_name=role_name role_name=role_name,
) )
@@ -451,13 +417,14 @@ def get_my_permissions(
# Statistics Routes # Statistics Routes
# ============================================================================ # ============================================================================
@router.get("/statistics", response_model=TeamStatistics) @router.get("/statistics", response_model=TeamStatistics)
def get_team_statistics( def get_team_statistics(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission( current_user: User = Depends(
VendorPermissions.TEAM_VIEW.value require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
)) ),
): ):
""" """
Get team statistics for the vendor. Get team statistics for the vendor.
@@ -474,9 +441,7 @@ def get_team_statistics(
vendor = request.state.vendor vendor = request.state.vendor
members = vendor_team_service.get_team_members( members = vendor_team_service.get_team_members(
db=db, db=db, vendor=vendor, include_inactive=True
vendor=vendor,
include_inactive=True
) )
# Calculate statistics # Calculate statistics
@@ -500,5 +465,5 @@ def get_team_statistics(
pending_invitations=pending, pending_invitations=pending,
owners=owners, owners=owners,
team_members=team_members, team_members=team_members,
roles_breakdown=roles_breakdown roles_breakdown=roles_breakdown,
) )

View File

@@ -14,6 +14,7 @@ This module focuses purely on configuration storage and validation.
""" """
from typing import List, Optional from typing import List, Optional
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@@ -137,13 +138,9 @@ settings = Settings()
# ENVIRONMENT UTILITIES - Module-level functions # ENVIRONMENT UTILITIES - Module-level functions
# ============================================================================= # =============================================================================
# Import environment detection utilities # Import environment detection utilities
from app.core.environment import ( from app.core.environment import (get_environment, is_development,
get_environment, is_production, is_staging,
is_development, should_use_secure_cookies)
is_production,
is_staging,
should_use_secure_cookies
)
def get_current_environment() -> str: def get_current_environment() -> str:
@@ -190,6 +187,7 @@ def is_staging_environment() -> bool:
# VALIDATION FUNCTIONS # VALIDATION FUNCTIONS
# ============================================================================= # =============================================================================
def validate_production_settings() -> List[str]: def validate_production_settings() -> List[str]:
""" """
Validate settings for production environment. Validate settings for production environment.
@@ -243,22 +241,19 @@ def print_environment_info():
# ============================================================================= # =============================================================================
__all__ = [ __all__ = [
# Settings singleton # Settings singleton
'settings', "settings",
# Environment detection (re-exported from app.core.environment) # Environment detection (re-exported from app.core.environment)
'get_environment', "get_environment",
'is_development', "is_development",
'is_production', "is_production",
'is_staging', "is_staging",
'should_use_secure_cookies', "should_use_secure_cookies",
# Convenience functions # Convenience functions
'get_current_environment', "get_current_environment",
'is_production_environment', "is_production_environment",
'is_development_environment', "is_development_environment",
'is_staging_environment', "is_staging_environment",
# Validation # Validation
'validate_production_settings', "validate_production_settings",
'print_environment_info', "print_environment_info",
] ]

View File

@@ -53,7 +53,10 @@ def get_environment() -> EnvironmentType:
# Check common development indicators # Check common development indicators
hostname = os.getenv("HOSTNAME", "").lower() hostname = os.getenv("HOSTNAME", "").lower()
if any(dev_indicator in hostname for dev_indicator in ["local", "dev", "laptop", "desktop"]): if any(
dev_indicator in hostname
for dev_indicator in ["local", "dev", "laptop", "desktop"]
):
return "development" return "development"
# Check for staging indicators # Check for staging indicators

View File

@@ -15,11 +15,13 @@ from fastapi import FastAPI
from sqlalchemy import text from sqlalchemy import text
from middleware.auth import AuthManager from middleware.auth import AuthManager
# Remove this import if not needed: from models.database.base import Base
from .database import SessionLocal, engine from .database import SessionLocal, engine
from .logging import setup_logging from .logging import setup_logging
# Remove this import if not needed: from models.database.base import Base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
auth_manager = AuthManager() auth_manager = AuthManager()
@@ -46,7 +48,9 @@ def check_database_ready():
try: try:
with engine.connect() as conn: with engine.connect() as conn:
# Try to query a table that should exist # Try to query a table that should exist
result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1")) result = conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1")
)
tables = result.fetchall() tables = result.fetchall()
return len(tables) > 0 return len(tables) > 0
except Exception: except Exception:
@@ -93,6 +97,7 @@ def verify_startup_requirements():
logger.info("[OK] Startup verification passed") logger.info("[OK] Startup verification passed")
return True return True
# You can call this in your main.py if desired: # You can call this in your main.py if desired:
# if not verify_startup_requirements(): # if not verify_startup_requirements():
# raise RuntimeError("Application startup requirements not met") # raise RuntimeError("Application startup requirements not met")

View File

@@ -17,6 +17,7 @@ class VendorPermissions(str, Enum):
Naming convention: RESOURCE_ACTION Naming convention: RESOURCE_ACTION
""" """
# Dashboard # Dashboard
DASHBOARD_VIEW = "dashboard.view" DASHBOARD_VIEW = "dashboard.view"
@@ -166,17 +167,23 @@ class PermissionChecker:
return required_permission in permissions return required_permission in permissions
@staticmethod @staticmethod
def has_any_permission(permissions: List[str], required_permissions: List[str]) -> bool: def has_any_permission(
permissions: List[str], required_permissions: List[str]
) -> bool:
"""Check if a permission list contains ANY of the required permissions.""" """Check if a permission list contains ANY of the required permissions."""
return any(perm in permissions for perm in required_permissions) return any(perm in permissions for perm in required_permissions)
@staticmethod @staticmethod
def has_all_permissions(permissions: List[str], required_permissions: List[str]) -> bool: def has_all_permissions(
permissions: List[str], required_permissions: List[str]
) -> bool:
"""Check if a permission list contains ALL of the required permissions.""" """Check if a permission list contains ALL of the required permissions."""
return all(perm in permissions for perm in required_permissions) return all(perm in permissions for perm in required_permissions)
@staticmethod @staticmethod
def get_missing_permissions(permissions: List[str], required_permissions: List[str]) -> List[str]: def get_missing_permissions(
permissions: List[str], required_permissions: List[str]
) -> List[str]:
"""Get list of missing permissions.""" """Get list of missing permissions."""
return [perm for perm in required_permissions if perm not in permissions] return [perm for perm in required_permissions if perm not in permissions]

View File

@@ -16,19 +16,11 @@ THEME_PRESETS = {
"accent": "#ec4899", # Pink "accent": "#ec4899", # Pink
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#1f2937", # Gray-800 "text": "#1f2937", # Gray-800
"border": "#e5e7eb" # Gray-200 "border": "#e5e7eb", # Gray-200
}, },
"fonts": { "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"heading": "Inter, sans-serif", "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"body": "Inter, sans-serif"
}, },
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}
},
"modern": { "modern": {
"colors": { "colors": {
"primary": "#6366f1", # Indigo - Modern tech look "primary": "#6366f1", # Indigo - Modern tech look
@@ -36,19 +28,11 @@ THEME_PRESETS = {
"accent": "#ec4899", # Pink "accent": "#ec4899", # Pink
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#1f2937", # Gray-800 "text": "#1f2937", # Gray-800
"border": "#e5e7eb" # Gray-200 "border": "#e5e7eb", # Gray-200
}, },
"fonts": { "fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"heading": "Inter, sans-serif", "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"body": "Inter, sans-serif"
}, },
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}
},
"classic": { "classic": {
"colors": { "colors": {
"primary": "#1e40af", # Dark blue - Traditional "primary": "#1e40af", # Dark blue - Traditional
@@ -56,19 +40,11 @@ THEME_PRESETS = {
"accent": "#dc2626", # Red "accent": "#dc2626", # Red
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#1f2937", # Gray-800 "text": "#1f2937", # Gray-800
"border": "#d1d5db" # Gray-300 "border": "#d1d5db", # Gray-300
}, },
"fonts": { "fonts": {"heading": "Georgia, serif", "body": "Arial, sans-serif"},
"heading": "Georgia, serif", "layout": {"style": "list", "header": "static", "product_card": "classic"},
"body": "Arial, sans-serif"
}, },
"layout": {
"style": "list",
"header": "static",
"product_card": "classic"
}
},
"minimal": { "minimal": {
"colors": { "colors": {
"primary": "#000000", # Black - Ultra minimal "primary": "#000000", # Black - Ultra minimal
@@ -76,19 +52,11 @@ THEME_PRESETS = {
"accent": "#666666", # Medium gray "accent": "#666666", # Medium gray
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#000000", # Black "text": "#000000", # Black
"border": "#e5e7eb" # Light gray "border": "#e5e7eb", # Light gray
}, },
"fonts": { "fonts": {"heading": "Helvetica, sans-serif", "body": "Helvetica, sans-serif"},
"heading": "Helvetica, sans-serif", "layout": {"style": "grid", "header": "transparent", "product_card": "minimal"},
"body": "Helvetica, sans-serif"
}, },
"layout": {
"style": "grid",
"header": "transparent",
"product_card": "minimal"
}
},
"vibrant": { "vibrant": {
"colors": { "colors": {
"primary": "#f59e0b", # Orange - Bold & energetic "primary": "#f59e0b", # Orange - Bold & energetic
@@ -96,19 +64,11 @@ THEME_PRESETS = {
"accent": "#8b5cf6", # Purple "accent": "#8b5cf6", # Purple
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#1f2937", # Gray-800 "text": "#1f2937", # Gray-800
"border": "#fbbf24" # Yellow "border": "#fbbf24", # Yellow
}, },
"fonts": { "fonts": {"heading": "Poppins, sans-serif", "body": "Open Sans, sans-serif"},
"heading": "Poppins, sans-serif", "layout": {"style": "masonry", "header": "fixed", "product_card": "modern"},
"body": "Open Sans, sans-serif"
}, },
"layout": {
"style": "masonry",
"header": "fixed",
"product_card": "modern"
}
},
"elegant": { "elegant": {
"colors": { "colors": {
"primary": "#6b7280", # Gray - Sophisticated "primary": "#6b7280", # Gray - Sophisticated
@@ -116,19 +76,11 @@ THEME_PRESETS = {
"accent": "#d97706", # Amber "accent": "#d97706", # Amber
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#1f2937", # Gray-800 "text": "#1f2937", # Gray-800
"border": "#e5e7eb" # Gray-200 "border": "#e5e7eb", # Gray-200
}, },
"fonts": { "fonts": {"heading": "Playfair Display, serif", "body": "Lato, sans-serif"},
"heading": "Playfair Display, serif", "layout": {"style": "grid", "header": "fixed", "product_card": "classic"},
"body": "Lato, sans-serif"
}, },
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "classic"
}
},
"nature": { "nature": {
"colors": { "colors": {
"primary": "#059669", # Green - Natural & eco "primary": "#059669", # Green - Natural & eco
@@ -136,18 +88,11 @@ THEME_PRESETS = {
"accent": "#f59e0b", # Amber "accent": "#f59e0b", # Amber
"background": "#ffffff", # White "background": "#ffffff", # White
"text": "#1f2937", # Gray-800 "text": "#1f2937", # Gray-800
"border": "#d1fae5" # Light green "border": "#d1fae5", # Light green
}, },
"fonts": { "fonts": {"heading": "Montserrat, sans-serif", "body": "Open Sans, sans-serif"},
"heading": "Montserrat, sans-serif", "layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"body": "Open Sans, sans-serif"
}, },
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}
}
} }
@@ -243,7 +188,7 @@ def get_preset_preview(preset_name: str) -> dict:
"minimal": "Ultra-clean black and white aesthetic", "minimal": "Ultra-clean black and white aesthetic",
"vibrant": "Bold and energetic with bright accent colors", "vibrant": "Bold and energetic with bright accent colors",
"elegant": "Sophisticated gray tones with refined typography", "elegant": "Sophisticated gray tones with refined typography",
"nature": "Fresh and eco-friendly green color palette" "nature": "Fresh and eco-friendly green color palette",
} }
return { return {
@@ -259,10 +204,7 @@ def get_preset_preview(preset_name: str) -> dict:
def create_custom_preset( def create_custom_preset(
colors: dict, colors: dict, fonts: dict, layout: dict, name: str = "custom"
fonts: dict,
layout: dict,
name: str = "custom"
) -> dict: ) -> dict:
""" """
Create a custom preset from provided settings. Create a custom preset from provided settings.
@@ -304,8 +246,4 @@ def create_custom_preset(
if "product_card" not in layout: if "product_card" not in layout:
layout["product_card"] = "modern" layout["product_card"] = "modern"
return { return {"colors": colors, "fonts": fonts, "layout": layout}
"colors": colors,
"fonts": fonts,
"layout": layout
}

View File

@@ -6,179 +6,109 @@ This module provides frontend-friendly exceptions with consistent error codes,
messages, and HTTP status mappings. messages, and HTTP status mappings.
""" """
# Base exceptions
from .base import (
WizamartException,
ValidationException,
AuthenticationException,
AuthorizationException,
ResourceNotFoundException,
ConflictException,
BusinessLogicException,
ExternalServiceException,
RateLimitException,
ServiceUnavailableException,
)
# Authentication exceptions
from .auth import (
InvalidCredentialsException,
TokenExpiredException,
InvalidTokenException,
InsufficientPermissionsException,
UserNotActiveException,
AdminRequiredException,
UserAlreadyExistsException
)
# Admin exceptions # Admin exceptions
from .admin import ( from .admin import (AdminOperationException, BulkOperationException,
UserNotFoundException, CannotModifyAdminException, CannotModifySelfException,
UserStatusChangeException, ConfirmationRequiredException, InvalidAdminActionException,
VendorVerificationException, UserNotFoundException, UserStatusChangeException,
AdminOperationException, VendorVerificationException)
CannotModifyAdminException, # Authentication exceptions
CannotModifySelfException, from .auth import (AdminRequiredException, InsufficientPermissionsException,
InvalidAdminActionException, InvalidCredentialsException, InvalidTokenException,
BulkOperationException, TokenExpiredException, UserAlreadyExistsException,
ConfirmationRequiredException, UserNotActiveException)
) # Base exceptions
from .base import (AuthenticationException, AuthorizationException,
BusinessLogicException, ConflictException,
ExternalServiceException, RateLimitException,
ResourceNotFoundException, ServiceUnavailableException,
ValidationException, WizamartException)
# Cart exceptions
from .cart import (CartItemNotFoundException, CartValidationException,
EmptyCartException, InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException)
# Customer exceptions
from .customer import (CustomerAlreadyExistsException,
CustomerAuthorizationException,
CustomerNotActiveException, CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException)
# Inventory exceptions
from .inventory import (InsufficientInventoryException,
InvalidInventoryOperationException,
InvalidQuantityException, InventoryNotFoundException,
InventoryValidationException,
LocationNotFoundException, NegativeInventoryException)
# Marketplace import job exceptions # Marketplace import job exceptions
from .marketplace_import_job import ( from .marketplace_import_job import (ImportJobAlreadyProcessingException,
MarketplaceImportException,
ImportJobNotFoundException,
ImportJobNotOwnedException,
InvalidImportDataException,
ImportJobCannotBeCancelledException, ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException, ImportJobCannotBeDeletedException,
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportRateLimitException,
InvalidImportDataException,
InvalidMarketplaceException,
MarketplaceConnectionException, MarketplaceConnectionException,
MarketplaceDataParsingException, MarketplaceDataParsingException,
ImportRateLimitException, MarketplaceImportException)
InvalidMarketplaceException,
ImportJobAlreadyProcessingException,
)
# Marketplace product exceptions # Marketplace product exceptions
from .marketplace_product import ( from .marketplace_product import (InvalidGTINException,
MarketplaceProductNotFoundException,
MarketplaceProductAlreadyExistsException,
InvalidMarketplaceProductDataException, InvalidMarketplaceProductDataException,
MarketplaceProductValidationException, MarketplaceProductAlreadyExistsException,
InvalidGTINException,
MarketplaceProductCSVImportException, MarketplaceProductCSVImportException,
) MarketplaceProductNotFoundException,
MarketplaceProductValidationException)
# Inventory exceptions # Order exceptions
from .inventory import ( from .order import (InvalidOrderStatusException, OrderAlreadyExistsException,
InventoryNotFoundException, OrderCannotBeCancelledException, OrderNotFoundException,
InsufficientInventoryException, OrderValidationException)
InvalidInventoryOperationException, # Product exceptions
InventoryValidationException, from .product import (CannotDeleteProductWithInventoryException,
NegativeInventoryException, CannotDeleteProductWithOrdersException,
InvalidQuantityException, InvalidProductDataException,
LocationNotFoundException ProductAlreadyExistsException, ProductNotActiveException,
) ProductNotFoundException, ProductNotInCatalogException,
ProductValidationException)
# Team exceptions
from .team import (CannotModifyOwnRoleException, CannotRemoveOwnerException,
InsufficientTeamPermissionsException,
InvalidInvitationDataException,
InvalidInvitationTokenException, InvalidRoleException,
MaxTeamMembersReachedException, RoleNotFoundException,
TeamInvitationAlreadyAcceptedException,
TeamInvitationExpiredException,
TeamInvitationNotFoundException,
TeamMemberAlreadyExistsException,
TeamMemberNotFoundException, TeamValidationException,
UnauthorizedTeamActionException)
# Vendor exceptions # Vendor exceptions
from .vendor import ( from .vendor import (InvalidVendorDataException, MaxVendorsReachedException,
VendorNotFoundException,
VendorAlreadyExistsException,
VendorNotActiveException,
VendorNotVerifiedException,
UnauthorizedVendorAccessException, UnauthorizedVendorAccessException,
InvalidVendorDataException, VendorAlreadyExistsException, VendorNotActiveException,
MaxVendorsReachedException, VendorNotFoundException, VendorNotVerifiedException,
VendorValidationException, VendorValidationException)
)
# Vendor domain exceptions # Vendor domain exceptions
from .vendor_domain import ( from .vendor_domain import (DNSVerificationException,
VendorDomainNotFoundException, DomainAlreadyVerifiedException,
VendorDomainAlreadyExistsException,
InvalidDomainFormatException,
ReservedDomainException,
DomainNotVerifiedException, DomainNotVerifiedException,
DomainVerificationFailedException, DomainVerificationFailedException,
DomainAlreadyVerifiedException, InvalidDomainFormatException,
MultiplePrimaryDomainsException,
DNSVerificationException,
MaxDomainsReachedException, MaxDomainsReachedException,
MultiplePrimaryDomainsException,
ReservedDomainException,
UnauthorizedDomainAccessException, UnauthorizedDomainAccessException,
) VendorDomainAlreadyExistsException,
VendorDomainNotFoundException)
# Vendor theme exceptions # Vendor theme exceptions
from .vendor_theme import ( from .vendor_theme import (InvalidColorFormatException,
VendorThemeNotFoundException, InvalidFontFamilyException,
InvalidThemeDataException, InvalidThemeDataException, ThemeOperationException,
ThemePresetAlreadyAppliedException,
ThemePresetNotFoundException, ThemePresetNotFoundException,
ThemeValidationException, ThemeValidationException,
ThemePresetAlreadyAppliedException, VendorThemeNotFoundException)
InvalidColorFormatException,
InvalidFontFamilyException,
ThemeOperationException,
)
# Customer exceptions
from .customer import (
CustomerNotFoundException,
CustomerAlreadyExistsException,
DuplicateCustomerEmailException,
CustomerNotActiveException,
InvalidCustomerCredentialsException,
CustomerValidationException,
CustomerAuthorizationException,
)
# Team exceptions
from .team import (
TeamMemberNotFoundException,
TeamMemberAlreadyExistsException,
TeamInvitationNotFoundException,
TeamInvitationExpiredException,
TeamInvitationAlreadyAcceptedException,
UnauthorizedTeamActionException,
CannotRemoveOwnerException,
CannotModifyOwnRoleException,
RoleNotFoundException,
InvalidRoleException,
InsufficientTeamPermissionsException,
MaxTeamMembersReachedException,
TeamValidationException,
InvalidInvitationDataException,
InvalidInvitationTokenException,
)
# Product exceptions
from .product import (
ProductNotFoundException,
ProductAlreadyExistsException,
ProductNotInCatalogException,
ProductNotActiveException,
InvalidProductDataException,
ProductValidationException,
CannotDeleteProductWithInventoryException,
CannotDeleteProductWithOrdersException,
)
# Order exceptions
from .order import (
OrderNotFoundException,
OrderAlreadyExistsException,
OrderValidationException,
InvalidOrderStatusException,
OrderCannotBeCancelledException,
)
# Cart exceptions
from .cart import (
CartItemNotFoundException,
EmptyCartException,
CartValidationException,
InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException,
)
__all__ = [ __all__ = [
# Base exceptions # Base exceptions
@@ -192,7 +122,6 @@ __all__ = [
"ExternalServiceException", "ExternalServiceException",
"RateLimitException", "RateLimitException",
"ServiceUnavailableException", "ServiceUnavailableException",
# Auth exceptions # Auth exceptions
"InvalidCredentialsException", "InvalidCredentialsException",
"TokenExpiredException", "TokenExpiredException",
@@ -201,7 +130,6 @@ __all__ = [
"UserNotActiveException", "UserNotActiveException",
"AdminRequiredException", "AdminRequiredException",
"UserAlreadyExistsException", "UserAlreadyExistsException",
# Customer exceptions # Customer exceptions
"CustomerNotFoundException", "CustomerNotFoundException",
"CustomerAlreadyExistsException", "CustomerAlreadyExistsException",
@@ -210,7 +138,6 @@ __all__ = [
"InvalidCustomerCredentialsException", "InvalidCustomerCredentialsException",
"CustomerValidationException", "CustomerValidationException",
"CustomerAuthorizationException", "CustomerAuthorizationException",
# Team exceptions # Team exceptions
"TeamMemberNotFoundException", "TeamMemberNotFoundException",
"TeamMemberAlreadyExistsException", "TeamMemberAlreadyExistsException",
@@ -227,7 +154,6 @@ __all__ = [
"TeamValidationException", "TeamValidationException",
"InvalidInvitationDataException", "InvalidInvitationDataException",
"InvalidInvitationTokenException", "InvalidInvitationTokenException",
# Inventory exceptions # Inventory exceptions
"InventoryNotFoundException", "InventoryNotFoundException",
"InsufficientInventoryException", "InsufficientInventoryException",
@@ -236,7 +162,6 @@ __all__ = [
"NegativeInventoryException", "NegativeInventoryException",
"InvalidQuantityException", "InvalidQuantityException",
"LocationNotFoundException", "LocationNotFoundException",
# Vendor exceptions # Vendor exceptions
"VendorNotFoundException", "VendorNotFoundException",
"VendorAlreadyExistsException", "VendorAlreadyExistsException",
@@ -246,7 +171,6 @@ __all__ = [
"InvalidVendorDataException", "InvalidVendorDataException",
"MaxVendorsReachedException", "MaxVendorsReachedException",
"VendorValidationException", "VendorValidationException",
# Vendor Domain # Vendor Domain
"VendorDomainNotFoundException", "VendorDomainNotFoundException",
"VendorDomainAlreadyExistsException", "VendorDomainAlreadyExistsException",
@@ -259,7 +183,6 @@ __all__ = [
"DNSVerificationException", "DNSVerificationException",
"MaxDomainsReachedException", "MaxDomainsReachedException",
"UnauthorizedDomainAccessException", "UnauthorizedDomainAccessException",
# Vendor Theme # Vendor Theme
"VendorThemeNotFoundException", "VendorThemeNotFoundException",
"InvalidThemeDataException", "InvalidThemeDataException",
@@ -269,7 +192,6 @@ __all__ = [
"InvalidColorFormatException", "InvalidColorFormatException",
"InvalidFontFamilyException", "InvalidFontFamilyException",
"ThemeOperationException", "ThemeOperationException",
# Product exceptions # Product exceptions
"ProductNotFoundException", "ProductNotFoundException",
"ProductAlreadyExistsException", "ProductAlreadyExistsException",
@@ -279,14 +201,12 @@ __all__ = [
"ProductValidationException", "ProductValidationException",
"CannotDeleteProductWithInventoryException", "CannotDeleteProductWithInventoryException",
"CannotDeleteProductWithOrdersException", "CannotDeleteProductWithOrdersException",
# Order exceptions # Order exceptions
"OrderNotFoundException", "OrderNotFoundException",
"OrderAlreadyExistsException", "OrderAlreadyExistsException",
"OrderValidationException", "OrderValidationException",
"InvalidOrderStatusException", "InvalidOrderStatusException",
"OrderCannotBeCancelledException", "OrderCannotBeCancelledException",
# Cart exceptions # Cart exceptions
"CartItemNotFoundException", "CartItemNotFoundException",
"EmptyCartException", "EmptyCartException",
@@ -294,7 +214,6 @@ __all__ = [
"InsufficientInventoryForCartException", "InsufficientInventoryForCartException",
"InvalidCartQuantityException", "InvalidCartQuantityException",
"ProductNotAvailableForCartException", "ProductNotAvailableForCartException",
# MarketplaceProduct exceptions # MarketplaceProduct exceptions
"MarketplaceProductNotFoundException", "MarketplaceProductNotFoundException",
"MarketplaceProductAlreadyExistsException", "MarketplaceProductAlreadyExistsException",
@@ -302,7 +221,6 @@ __all__ = [
"MarketplaceProductValidationException", "MarketplaceProductValidationException",
"InvalidGTINException", "InvalidGTINException",
"MarketplaceProductCSVImportException", "MarketplaceProductCSVImportException",
# Marketplace import exceptions # Marketplace import exceptions
"MarketplaceImportException", "MarketplaceImportException",
"ImportJobNotFoundException", "ImportJobNotFoundException",
@@ -315,7 +233,6 @@ __all__ = [
"ImportRateLimitException", "ImportRateLimitException",
"InvalidMarketplaceException", "InvalidMarketplaceException",
"ImportJobAlreadyProcessingException", "ImportJobAlreadyProcessingException",
# Admin exceptions # Admin exceptions
"UserNotFoundException", "UserNotFoundException",
"UserStatusChangeException", "UserStatusChangeException",

View File

@@ -4,12 +4,9 @@ Admin operations specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (AuthorizationException, BusinessLogicException,
BusinessLogicException, ResourceNotFoundException, ValidationException)
AuthorizationException,
ValidationException
)
class UserNotFoundException(ResourceNotFoundException): class UserNotFoundException(ResourceNotFoundException):
@@ -198,7 +195,7 @@ class ConfirmationRequiredException(BusinessLogicException):
self, self,
operation: str, operation: str,
message: Optional[str] = None, message: Optional[str] = None,
confirmation_param: str = "confirm" confirmation_param: str = "confirm",
): ):
if not message: if not message:
message = f"Operation '{operation}' requires confirmation parameter: {confirmation_param}=true" message = f"Operation '{operation}' requires confirmation parameter: {confirmation_param}=true"

View File

@@ -4,7 +4,9 @@ Authentication and authorization specific exceptions.
""" """
from typing import Optional from typing import Optional
from .base import AuthenticationException, AuthorizationException, ConflictException
from .base import (AuthenticationException, AuthorizationException,
ConflictException)
class InvalidCredentialsException(AuthenticationException): class InvalidCredentialsException(AuthenticationException):

View File

@@ -39,8 +39,6 @@ class WizamartException(Exception):
return result return result
class ValidationException(WizamartException): class ValidationException(WizamartException):
"""Raised when request validation fails.""" """Raised when request validation fails."""
@@ -62,8 +60,6 @@ class ValidationException(WizamartException):
) )
class AuthenticationException(WizamartException): class AuthenticationException(WizamartException):
"""Raised when authentication fails.""" """Raised when authentication fails."""
@@ -97,6 +93,7 @@ class AuthorizationException(WizamartException):
details=details, details=details,
) )
class ResourceNotFoundException(WizamartException): class ResourceNotFoundException(WizamartException):
"""Raised when a requested resource is not found.""" """Raised when a requested resource is not found."""
@@ -122,6 +119,7 @@ class ResourceNotFoundException(WizamartException):
}, },
) )
class ConflictException(WizamartException): class ConflictException(WizamartException):
"""Raised when a resource conflict occurs.""" """Raised when a resource conflict occurs."""
@@ -138,6 +136,7 @@ class ConflictException(WizamartException):
details=details, details=details,
) )
class BusinessLogicException(WizamartException): class BusinessLogicException(WizamartException):
"""Raised when business logic rules are violated.""" """Raised when business logic rules are violated."""
@@ -196,6 +195,7 @@ class RateLimitException(WizamartException):
details=rate_limit_details, details=rate_limit_details,
) )
class ServiceUnavailableException(WizamartException): class ServiceUnavailableException(WizamartException):
"""Raised when service is unavailable.""" """Raised when service is unavailable."""
@@ -206,6 +206,7 @@ class ServiceUnavailableException(WizamartException):
status_code=503, status_code=503,
) )
# Note: Domain-specific exceptions like VendorNotFoundException, UserNotFoundException, etc. # Note: Domain-specific exceptions like VendorNotFoundException, UserNotFoundException, etc.
# are defined in their respective domain modules (vendor.py, admin.py, etc.) # are defined in their respective domain modules (vendor.py, admin.py, etc.)
# to keep domain-specific logic separate from base exceptions. # to keep domain-specific logic separate from base exceptions.

View File

@@ -4,11 +4,9 @@ Shopping cart specific exceptions.
""" """
from typing import Optional from typing import Optional
from .base import (
ResourceNotFoundException, from .base import (BusinessLogicException, ResourceNotFoundException,
ValidationException, ValidationException)
BusinessLogicException
)
class CartItemNotFoundException(ResourceNotFoundException): class CartItemNotFoundException(ResourceNotFoundException):
@@ -19,22 +17,16 @@ class CartItemNotFoundException(ResourceNotFoundException):
resource_type="CartItem", resource_type="CartItem",
identifier=str(product_id), identifier=str(product_id),
message=f"Product {product_id} not found in cart", message=f"Product {product_id} not found in cart",
error_code="CART_ITEM_NOT_FOUND" error_code="CART_ITEM_NOT_FOUND",
) )
self.details.update({ self.details.update({"product_id": product_id, "session_id": session_id})
"product_id": product_id,
"session_id": session_id
})
class EmptyCartException(ValidationException): class EmptyCartException(ValidationException):
"""Raised when trying to perform operations on an empty cart.""" """Raised when trying to perform operations on an empty cart."""
def __init__(self, session_id: str): def __init__(self, session_id: str):
super().__init__( super().__init__(message="Cart is empty", details={"session_id": session_id})
message="Cart is empty",
details={"session_id": session_id}
)
self.error_code = "CART_EMPTY" self.error_code = "CART_EMPTY"
@@ -82,7 +74,9 @@ class InsufficientInventoryForCartException(BusinessLogicException):
class InvalidCartQuantityException(ValidationException): class InvalidCartQuantityException(ValidationException):
"""Raised when cart quantity is invalid.""" """Raised when cart quantity is invalid."""
def __init__(self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None): def __init__(
self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None
):
if quantity < min_quantity: if quantity < min_quantity:
message = f"Quantity must be at least {min_quantity}" message = f"Quantity must be at least {min_quantity}"
elif max_quantity and quantity > max_quantity: elif max_quantity and quantity > max_quantity:

View File

@@ -4,13 +4,10 @@ Customer management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (AuthenticationException, BusinessLogicException,
ConflictException, ConflictException, ResourceNotFoundException,
ValidationException, ValidationException)
AuthenticationException,
BusinessLogicException
)
class CustomerNotFoundException(ResourceNotFoundException): class CustomerNotFoundException(ResourceNotFoundException):
@@ -21,7 +18,7 @@ class CustomerNotFoundException(ResourceNotFoundException):
resource_type="Customer", resource_type="Customer",
identifier=customer_identifier, identifier=customer_identifier,
message=f"Customer '{customer_identifier}' not found", message=f"Customer '{customer_identifier}' not found",
error_code="CUSTOMER_NOT_FOUND" error_code="CUSTOMER_NOT_FOUND",
) )
@@ -32,7 +29,7 @@ class CustomerAlreadyExistsException(ConflictException):
super().__init__( super().__init__(
message=f"Customer with email '{email}' already exists", message=f"Customer with email '{email}' already exists",
error_code="CUSTOMER_ALREADY_EXISTS", error_code="CUSTOMER_ALREADY_EXISTS",
details={"email": email} details={"email": email},
) )
@@ -43,10 +40,7 @@ class DuplicateCustomerEmailException(ConflictException):
super().__init__( super().__init__(
message=f"Email '{email}' is already registered for this vendor", message=f"Email '{email}' is already registered for this vendor",
error_code="DUPLICATE_CUSTOMER_EMAIL", error_code="DUPLICATE_CUSTOMER_EMAIL",
details={ details={"email": email, "vendor_code": vendor_code},
"email": email,
"vendor_code": vendor_code
}
) )
@@ -57,7 +51,7 @@ class CustomerNotActiveException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Customer account '{email}' is not active", message=f"Customer account '{email}' is not active",
error_code="CUSTOMER_NOT_ACTIVE", error_code="CUSTOMER_NOT_ACTIVE",
details={"email": email} details={"email": email},
) )
@@ -67,7 +61,7 @@ class InvalidCustomerCredentialsException(AuthenticationException):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
message="Invalid email or password", message="Invalid email or password",
error_code="INVALID_CUSTOMER_CREDENTIALS" error_code="INVALID_CUSTOMER_CREDENTIALS",
) )
@@ -78,13 +72,9 @@ class CustomerValidationException(ValidationException):
self, self,
message: str = "Customer validation failed", message: str = "Customer validation failed",
field: Optional[str] = None, field: Optional[str] = None,
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None,
): ):
super().__init__( super().__init__(message=message, field=field, details=details)
message=message,
field=field,
details=details
)
self.error_code = "CUSTOMER_VALIDATION_FAILED" self.error_code = "CUSTOMER_VALIDATION_FAILED"
@@ -95,8 +85,5 @@ class CustomerAuthorizationException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Customer '{customer_email}' not authorized for: {operation}", message=f"Customer '{customer_email}' not authorized for: {operation}",
error_code="CUSTOMER_NOT_AUTHORIZED", error_code="CUSTOMER_NOT_AUTHORIZED",
details={ details={"customer_email": customer_email, "operation": operation},
"customer_email": customer_email,
"operation": operation
}
) )

View File

@@ -7,7 +7,7 @@ Handles fallback logic and context-specific customization.
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
from fastapi import Request from fastapi import Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
@@ -114,7 +114,7 @@ class ErrorPageRenderer:
"error_code": error_code, "error_code": error_code,
"context": context_type.value, "context": context_type.value,
"template": template_path, "template": template_path,
} },
) )
try: try:
@@ -129,8 +129,7 @@ class ErrorPageRenderer:
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to render error template {template_path}: {e}", f"Failed to render error template {template_path}: {e}", exc_info=True
exc_info=True
) )
# Return basic HTML as absolute fallback # Return basic HTML as absolute fallback
return ErrorPageRenderer._render_basic_html_fallback( return ErrorPageRenderer._render_basic_html_fallback(
@@ -228,7 +227,9 @@ class ErrorPageRenderer:
} }
@staticmethod @staticmethod
def _get_context_data(request: Request, context_type: RequestContext) -> Dict[str, Any]: def _get_context_data(
request: Request, context_type: RequestContext
) -> Dict[str, Any]:
"""Get context-specific data for error templates.""" """Get context-specific data for error templates."""
data = {} data = {}
@@ -261,11 +262,19 @@ class ErrorPageRenderer:
# Calculate base_url for shop links # Calculate base_url for shop links
vendor_context = getattr(request.state, "vendor_context", None) vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
base_url = "/" base_url = "/"
if access_method == "path" and vendor: if access_method == "path" and vendor:
# Use the full_prefix from vendor_context to determine which pattern was used # Use the full_prefix from vendor_context to determine which pattern was used
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/" base_url = f"{full_prefix}{vendor.subdomain}/"
data["base_url"] = base_url data["base_url"] = base_url

View File

@@ -13,13 +13,14 @@ This module provides classes and functions for:
import logging import logging
from typing import Union from typing import Union
from fastapi import Request, HTTPException from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from middleware.context import RequestContext, get_request_context
from .base import WizamartException from .base import WizamartException
from .error_renderer import ErrorPageRenderer from .error_renderer import ErrorPageRenderer
from middleware.context import RequestContext, get_request_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,8 +39,8 @@ def setup_exception_handlers(app):
extra={ extra={
"path": request.url.path, "path": request.url.path,
"accept": request.headers.get("accept", ""), "accept": request.headers.get("accept", ""),
"method": request.method "method": request.method,
} },
) )
# Redirect to appropriate login page based on context # Redirect to appropriate login page based on context
@@ -56,15 +57,12 @@ def setup_exception_handlers(app):
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method,
"exception_type": type(exc).__name__, "exception_type": type(exc).__name__,
} },
) )
# Check if this is an API request # Check if this is an API request
if _is_api_request(request): if _is_api_request(request):
return JSONResponse( return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
status_code=exc.status_code,
content=exc.to_dict()
)
# Check if this is an HTML page request # Check if this is an HTML page request
if _is_html_page_request(request): if _is_html_page_request(request):
@@ -78,10 +76,7 @@ def setup_exception_handlers(app):
) )
# Default to JSON for unknown request types # Default to JSON for unknown request types
return JSONResponse( return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
status_code=exc.status_code,
content=exc.to_dict()
)
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException): async def http_exception_handler(request: Request, exc: HTTPException):
@@ -96,7 +91,7 @@ def setup_exception_handlers(app):
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method,
"exception_type": "HTTPException", "exception_type": "HTTPException",
} },
) )
# Check if this is an API request # Check if this is an API request
@@ -107,7 +102,7 @@ def setup_exception_handlers(app):
"error_code": f"HTTP_{exc.status_code}", "error_code": f"HTTP_{exc.status_code}",
"message": exc.detail, "message": exc.detail,
"status_code": exc.status_code, "status_code": exc.status_code,
} },
) )
# Check if this is an HTML page request # Check if this is an HTML page request
@@ -128,11 +123,13 @@ def setup_exception_handlers(app):
"error_code": f"HTTP_{exc.status_code}", "error_code": f"HTTP_{exc.status_code}",
"message": exc.detail, "message": exc.detail,
"status_code": exc.status_code, "status_code": exc.status_code,
} },
) )
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
"""Handle Pydantic validation errors with consistent format.""" """Handle Pydantic validation errors with consistent format."""
# Sanitize errors to remove sensitive data from logs # Sanitize errors to remove sensitive data from logs
@@ -140,8 +137,8 @@ def setup_exception_handlers(app):
for error in exc.errors(): for error in exc.errors():
sanitized_error = error.copy() sanitized_error = error.copy()
# Remove 'input' field which may contain passwords # Remove 'input' field which may contain passwords
if 'input' in sanitized_error: if "input" in sanitized_error:
sanitized_error['input'] = '<redacted>' sanitized_error["input"] = "<redacted>"
sanitized_errors.append(sanitized_error) sanitized_errors.append(sanitized_error)
logger.error( logger.error(
@@ -151,7 +148,7 @@ def setup_exception_handlers(app):
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method,
"exception_type": "RequestValidationError", "exception_type": "RequestValidationError",
} },
) )
# Clean up validation errors to ensure JSON serializability # Clean up validation errors to ensure JSON serializability
@@ -159,15 +156,17 @@ def setup_exception_handlers(app):
for error in exc.errors(): for error in exc.errors():
clean_error = {} clean_error = {}
for key, value in error.items(): for key, value in error.items():
if key == 'input' and isinstance(value, bytes): if key == "input" and isinstance(value, bytes):
# Convert bytes to string representation for JSON serialization # Convert bytes to string representation for JSON serialization
clean_error[key] = f"<bytes: {len(value)} bytes>" clean_error[key] = f"<bytes: {len(value)} bytes>"
elif key == 'ctx' and isinstance(value, dict): elif key == "ctx" and isinstance(value, dict):
# Handle the 'ctx' field that contains ValueError objects # Handle the 'ctx' field that contains ValueError objects
clean_ctx = {} clean_ctx = {}
for ctx_key, ctx_value in value.items(): for ctx_key, ctx_value in value.items():
if isinstance(ctx_value, Exception): if isinstance(ctx_value, Exception):
clean_ctx[ctx_key] = str(ctx_value) # Convert exception to string clean_ctx[ctx_key] = str(
ctx_value
) # Convert exception to string
else: else:
clean_ctx[ctx_key] = ctx_value clean_ctx[ctx_key] = ctx_value
clean_error[key] = clean_ctx clean_error[key] = clean_ctx
@@ -186,10 +185,8 @@ def setup_exception_handlers(app):
"error_code": "VALIDATION_ERROR", "error_code": "VALIDATION_ERROR",
"message": "Request validation failed", "message": "Request validation failed",
"status_code": 422, "status_code": 422,
"details": { "details": {"validation_errors": clean_errors},
"validation_errors": clean_errors },
}
}
) )
# Check if this is an HTML page request # Check if this is an HTML page request
@@ -210,10 +207,8 @@ def setup_exception_handlers(app):
"error_code": "VALIDATION_ERROR", "error_code": "VALIDATION_ERROR",
"message": "Request validation failed", "message": "Request validation failed",
"status_code": 422, "status_code": 422,
"details": { "details": {"validation_errors": clean_errors},
"validation_errors": clean_errors },
}
}
) )
@app.exception_handler(Exception) @app.exception_handler(Exception)
@@ -227,7 +222,7 @@ def setup_exception_handlers(app):
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method,
"exception_type": type(exc).__name__, "exception_type": type(exc).__name__,
} },
) )
# Check if this is an API request # Check if this is an API request
@@ -238,7 +233,7 @@ def setup_exception_handlers(app):
"error_code": "INTERNAL_SERVER_ERROR", "error_code": "INTERNAL_SERVER_ERROR",
"message": "Internal server error", "message": "Internal server error",
"status_code": 500, "status_code": 500,
} },
) )
# Check if this is an HTML page request # Check if this is an HTML page request
@@ -259,7 +254,7 @@ def setup_exception_handlers(app):
"error_code": "INTERNAL_SERVER_ERROR", "error_code": "INTERNAL_SERVER_ERROR",
"message": "Internal server error", "message": "Internal server error",
"status_code": 500, "status_code": 500,
} },
) )
@app.exception_handler(404) @app.exception_handler(404)
@@ -275,11 +270,8 @@ def setup_exception_handlers(app):
"error_code": "ENDPOINT_NOT_FOUND", "error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}", "message": f"Endpoint not found: {request.url.path}",
"status_code": 404, "status_code": 404,
"details": { "details": {"path": request.url.path, "method": request.method},
"path": request.url.path, },
"method": request.method
}
}
) )
# Check if this is an HTML page request # Check if this is an HTML page request
@@ -300,11 +292,8 @@ def setup_exception_handlers(app):
"error_code": "ENDPOINT_NOT_FOUND", "error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}", "message": f"Endpoint not found: {request.url.path}",
"status_code": 404, "status_code": 404,
"details": { "details": {"path": request.url.path, "method": request.method},
"path": request.url.path, },
"method": request.method
}
}
) )
@@ -332,8 +321,8 @@ def _is_html_page_request(request: Request) -> bool:
extra={ extra={
"path": request.url.path, "path": request.url.path,
"method": request.method, "method": request.method,
"accept": request.headers.get("accept", "") "accept": request.headers.get("accept", ""),
} },
) )
# Don't redirect API calls # Don't redirect API calls
@@ -354,7 +343,9 @@ def _is_html_page_request(request: Request) -> bool:
# MUST explicitly accept HTML (strict check) # MUST explicitly accept HTML (strict check)
accept_header = request.headers.get("accept", "") accept_header = request.headers.get("accept", "")
if "text/html" not in accept_header: if "text/html" not in accept_header:
logger.debug(f"Not HTML page: Accept header doesn't include text/html: {accept_header}") logger.debug(
f"Not HTML page: Accept header doesn't include text/html: {accept_header}"
)
return False return False
logger.debug("IS HTML page request") logger.debug("IS HTML page request")
@@ -379,13 +370,21 @@ def _redirect_to_login(request: Request) -> RedirectResponse:
elif context_type == RequestContext.SHOP: elif context_type == RequestContext.SHOP:
# For shop context, redirect to shop login (customer login) # For shop context, redirect to shop login (customer login)
# Calculate base_url for proper routing (supports domain, subdomain, and path-based access) # Calculate base_url for proper routing (supports domain, subdomain, and path-based access)
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
vendor_context = getattr(request.state, 'vendor_context', None) vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
base_url = "/" base_url = "/"
if access_method == "path" and vendor: if access_method == "path" and vendor:
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/" base_url = f"{full_prefix}{vendor.subdomain}/"
login_url = f"{base_url}shop/account/login" login_url = f"{base_url}shop/account/login"
@@ -401,22 +400,28 @@ def _redirect_to_login(request: Request) -> RedirectResponse:
def raise_not_found(resource_type: str, identifier: str) -> None: def raise_not_found(resource_type: str, identifier: str) -> None:
"""Convenience function to raise ResourceNotFoundException.""" """Convenience function to raise ResourceNotFoundException."""
from .base import ResourceNotFoundException from .base import ResourceNotFoundException
raise ResourceNotFoundException(resource_type, identifier) raise ResourceNotFoundException(resource_type, identifier)
def raise_validation_error(message: str, field: str = None, details: dict = None) -> None: def raise_validation_error(
message: str, field: str = None, details: dict = None
) -> None:
"""Convenience function to raise ValidationException.""" """Convenience function to raise ValidationException."""
from .base import ValidationException from .base import ValidationException
raise ValidationException(message, field, details) raise ValidationException(message, field, details)
def raise_auth_error(message: str = "Authentication failed") -> None: def raise_auth_error(message: str = "Authentication failed") -> None:
"""Convenience function to raise AuthenticationException.""" """Convenience function to raise AuthenticationException."""
from .base import AuthenticationException from .base import AuthenticationException
raise AuthenticationException(message) raise AuthenticationException(message)
def raise_permission_error(message: str = "Access denied") -> None: def raise_permission_error(message: str = "Access denied") -> None:
"""Convenience function to raise AuthorizationException.""" """Convenience function to raise AuthorizationException."""
from .base import AuthorizationException from .base import AuthorizationException
raise AuthorizationException(message) raise AuthorizationException(message)

View File

@@ -4,7 +4,9 @@ Inventory management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import ResourceNotFoundException, ValidationException, BusinessLogicException
from .base import (BusinessLogicException, ResourceNotFoundException,
ValidationException)
class InventoryNotFoundException(ResourceNotFoundException): class InventoryNotFoundException(ResourceNotFoundException):
@@ -14,7 +16,9 @@ class InventoryNotFoundException(ResourceNotFoundException):
if identifier_type.lower() == "gtin": if identifier_type.lower() == "gtin":
message = f"No inventory found for GTIN '{identifier}'" message = f"No inventory found for GTIN '{identifier}'"
else: else:
message = f"Inventory record with {identifier_type} '{identifier}' not found" message = (
f"Inventory record with {identifier_type} '{identifier}' not found"
)
super().__init__( super().__init__(
resource_type="Inventory", resource_type="Inventory",

View File

@@ -4,13 +4,10 @@ Marketplace import specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (AuthorizationException, BusinessLogicException,
ValidationException, ExternalServiceException, ResourceNotFoundException,
BusinessLogicException, ValidationException)
AuthorizationException,
ExternalServiceException
)
class MarketplaceImportException(BusinessLogicException): class MarketplaceImportException(BusinessLogicException):
@@ -118,7 +115,9 @@ class ImportJobCannotBeDeletedException(BusinessLogicException):
class MarketplaceConnectionException(ExternalServiceException): class MarketplaceConnectionException(ExternalServiceException):
"""Raised when marketplace connection fails.""" """Raised when marketplace connection fails."""
def __init__(self, marketplace: str, message: str = "Failed to connect to marketplace"): def __init__(
self, marketplace: str, message: str = "Failed to connect to marketplace"
):
super().__init__( super().__init__(
service=marketplace, service=marketplace,
message=f"{message}: {marketplace}", message=f"{message}: {marketplace}",

View File

@@ -4,7 +4,9 @@ MarketplaceProduct management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import ResourceNotFoundException, ConflictException, ValidationException, BusinessLogicException
from .base import (BusinessLogicException, ConflictException,
ResourceNotFoundException, ValidationException)
class MarketplaceProductNotFoundException(ResourceNotFoundException): class MarketplaceProductNotFoundException(ResourceNotFoundException):

View File

@@ -4,11 +4,9 @@ Order management specific exceptions.
""" """
from typing import Optional from typing import Optional
from .base import (
ResourceNotFoundException, from .base import (BusinessLogicException, ResourceNotFoundException,
ValidationException, ValidationException)
BusinessLogicException
)
class OrderNotFoundException(ResourceNotFoundException): class OrderNotFoundException(ResourceNotFoundException):
@@ -19,7 +17,7 @@ class OrderNotFoundException(ResourceNotFoundException):
resource_type="Order", resource_type="Order",
identifier=order_identifier, identifier=order_identifier,
message=f"Order '{order_identifier}' not found", message=f"Order '{order_identifier}' not found",
error_code="ORDER_NOT_FOUND" error_code="ORDER_NOT_FOUND",
) )
@@ -30,7 +28,7 @@ class OrderAlreadyExistsException(ValidationException):
super().__init__( super().__init__(
message=f"Order with number '{order_number}' already exists", message=f"Order with number '{order_number}' already exists",
error_code="ORDER_ALREADY_EXISTS", error_code="ORDER_ALREADY_EXISTS",
details={"order_number": order_number} details={"order_number": order_number},
) )
@@ -39,9 +37,7 @@ class OrderValidationException(ValidationException):
def __init__(self, message: str, details: Optional[dict] = None): def __init__(self, message: str, details: Optional[dict] = None):
super().__init__( super().__init__(
message=message, message=message, error_code="ORDER_VALIDATION_FAILED", details=details
error_code="ORDER_VALIDATION_FAILED",
details=details
) )
@@ -52,10 +48,7 @@ class InvalidOrderStatusException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Cannot change order status from '{current_status}' to '{new_status}'", message=f"Cannot change order status from '{current_status}' to '{new_status}'",
error_code="INVALID_ORDER_STATUS_CHANGE", error_code="INVALID_ORDER_STATUS_CHANGE",
details={ details={"current_status": current_status, "new_status": new_status},
"current_status": current_status,
"new_status": new_status
}
) )
@@ -66,8 +59,5 @@ class OrderCannotBeCancelledException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Order '{order_number}' cannot be cancelled: {reason}", message=f"Order '{order_number}' cannot be cancelled: {reason}",
error_code="ORDER_CANNOT_BE_CANCELLED", error_code="ORDER_CANNOT_BE_CANCELLED",
details={ details={"order_number": order_number, "reason": reason},
"order_number": order_number,
"reason": reason
}
) )

View File

@@ -4,12 +4,9 @@ Product (vendor catalog) specific exceptions.
""" """
from typing import Optional from typing import Optional
from .base import (
ResourceNotFoundException, from .base import (BusinessLogicException, ConflictException,
ConflictException, ResourceNotFoundException, ValidationException)
ValidationException,
BusinessLogicException
)
class ProductNotFoundException(ResourceNotFoundException): class ProductNotFoundException(ResourceNotFoundException):

View File

@@ -4,13 +4,10 @@ Team management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (AuthorizationException, BusinessLogicException,
ConflictException, ConflictException, ResourceNotFoundException,
ValidationException, ValidationException)
AuthorizationException,
BusinessLogicException
)
class TeamMemberNotFoundException(ResourceNotFoundException): class TeamMemberNotFoundException(ResourceNotFoundException):
@@ -20,7 +17,9 @@ class TeamMemberNotFoundException(ResourceNotFoundException):
details = {"user_id": user_id} details = {"user_id": user_id}
if vendor_id: if vendor_id:
details["vendor_id"] = vendor_id details["vendor_id"] = vendor_id
message = f"Team member with user ID '{user_id}' not found in vendor {vendor_id}" message = (
f"Team member with user ID '{user_id}' not found in vendor {vendor_id}"
)
else: else:
message = f"Team member with user ID '{user_id}' not found" message = f"Team member with user ID '{user_id}' not found"
@@ -84,7 +83,12 @@ class TeamInvitationAlreadyAcceptedException(ConflictException):
class UnauthorizedTeamActionException(AuthorizationException): class UnauthorizedTeamActionException(AuthorizationException):
"""Raised when user tries to perform team action without permission.""" """Raised when user tries to perform team action without permission."""
def __init__(self, action: str, user_id: Optional[int] = None, required_permission: Optional[str] = None): def __init__(
self,
action: str,
user_id: Optional[int] = None,
required_permission: Optional[str] = None,
):
details = {"action": action} details = {"action": action}
if user_id: if user_id:
details["user_id"] = user_id details["user_id"] = user_id
@@ -240,6 +244,7 @@ class InvalidInvitationDataException(ValidationException):
# NEW: Add InvalidInvitationTokenException # NEW: Add InvalidInvitationTokenException
# ============================================================================ # ============================================================================
class InvalidInvitationTokenException(ValidationException): class InvalidInvitationTokenException(ValidationException):
"""Raised when invitation token is invalid, expired, or already used. """Raised when invitation token is invalid, expired, or already used.
@@ -250,7 +255,7 @@ class InvalidInvitationTokenException(ValidationException):
def __init__( def __init__(
self, self,
message: str = "Invalid or expired invitation token", message: str = "Invalid or expired invitation token",
invitation_token: Optional[str] = None invitation_token: Optional[str] = None,
): ):
details = {} details = {}
if invitation_token: if invitation_token:

View File

@@ -4,13 +4,10 @@ Vendor management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (AuthorizationException, BusinessLogicException,
ConflictException, ConflictException, ResourceNotFoundException,
ValidationException, ValidationException)
AuthorizationException,
BusinessLogicException
)
class VendorNotFoundException(ResourceNotFoundException): class VendorNotFoundException(ResourceNotFoundException):

View File

@@ -4,13 +4,10 @@ Vendor domain management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (BusinessLogicException, ConflictException,
ConflictException, ExternalServiceException, ResourceNotFoundException,
ValidationException, ValidationException)
BusinessLogicException,
ExternalServiceException
)
class VendorDomainNotFoundException(ResourceNotFoundException): class VendorDomainNotFoundException(ResourceNotFoundException):
@@ -64,10 +61,7 @@ class ReservedDomainException(ValidationException):
super().__init__( super().__init__(
message=f"Domain cannot use reserved subdomain: {reserved_part}", message=f"Domain cannot use reserved subdomain: {reserved_part}",
field="domain", field="domain",
details={ details={"domain": domain, "reserved_part": reserved_part},
"domain": domain,
"reserved_part": reserved_part
},
) )
self.error_code = "RESERVED_DOMAIN" self.error_code = "RESERVED_DOMAIN"
@@ -79,10 +73,7 @@ class DomainNotVerifiedException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Domain '{domain}' must be verified before activation", message=f"Domain '{domain}' must be verified before activation",
error_code="DOMAIN_NOT_VERIFIED", error_code="DOMAIN_NOT_VERIFIED",
details={ details={"domain_id": domain_id, "domain": domain},
"domain_id": domain_id,
"domain": domain
},
) )
@@ -93,10 +84,7 @@ class DomainVerificationFailedException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Domain verification failed for '{domain}': {reason}", message=f"Domain verification failed for '{domain}': {reason}",
error_code="DOMAIN_VERIFICATION_FAILED", error_code="DOMAIN_VERIFICATION_FAILED",
details={ details={"domain": domain, "reason": reason},
"domain": domain,
"reason": reason
},
) )
@@ -107,10 +95,7 @@ class DomainAlreadyVerifiedException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Domain '{domain}' is already verified", message=f"Domain '{domain}' is already verified",
error_code="DOMAIN_ALREADY_VERIFIED", error_code="DOMAIN_ALREADY_VERIFIED",
details={ details={"domain_id": domain_id, "domain": domain},
"domain_id": domain_id,
"domain": domain
},
) )
@@ -133,10 +118,7 @@ class DNSVerificationException(ExternalServiceException):
service_name="DNS", service_name="DNS",
message=f"DNS verification failed for '{domain}': {reason}", message=f"DNS verification failed for '{domain}': {reason}",
error_code="DNS_VERIFICATION_ERROR", error_code="DNS_VERIFICATION_ERROR",
details={ details={"domain": domain, "reason": reason},
"domain": domain,
"reason": reason
},
) )
@@ -147,10 +129,7 @@ class MaxDomainsReachedException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Maximum number of domains reached ({max_domains})", message=f"Maximum number of domains reached ({max_domains})",
error_code="MAX_DOMAINS_REACHED", error_code="MAX_DOMAINS_REACHED",
details={ details={"vendor_id": vendor_id, "max_domains": max_domains},
"vendor_id": vendor_id,
"max_domains": max_domains
},
) )
@@ -161,8 +140,5 @@ class UnauthorizedDomainAccessException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Unauthorized access to domain {domain_id}", message=f"Unauthorized access to domain {domain_id}",
error_code="UNAUTHORIZED_DOMAIN_ACCESS", error_code="UNAUTHORIZED_DOMAIN_ACCESS",
details={ details={"domain_id": domain_id, "vendor_id": vendor_id},
"domain_id": domain_id,
"vendor_id": vendor_id
},
) )

View File

@@ -4,12 +4,9 @@ Vendor theme management specific exceptions.
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException, from .base import (BusinessLogicException, ConflictException,
ConflictException, ResourceNotFoundException, ValidationException)
ValidationException,
BusinessLogicException
)
class VendorThemeNotFoundException(ResourceNotFoundException): class VendorThemeNotFoundException(ResourceNotFoundException):
@@ -86,10 +83,7 @@ class ThemePresetAlreadyAppliedException(BusinessLogicException):
super().__init__( super().__init__(
message=f"Preset '{preset_name}' is already applied to vendor '{vendor_code}'", message=f"Preset '{preset_name}' is already applied to vendor '{vendor_code}'",
error_code="THEME_PRESET_ALREADY_APPLIED", error_code="THEME_PRESET_ALREADY_APPLIED",
details={ details={"preset_name": preset_name, "vendor_code": vendor_code},
"preset_name": preset_name,
"vendor_code": vendor_code
},
) )
@@ -120,18 +114,13 @@ class InvalidFontFamilyException(ValidationException):
class ThemeOperationException(BusinessLogicException): class ThemeOperationException(BusinessLogicException):
"""Raised when theme operation fails.""" """Raised when theme operation fails."""
def __init__( def __init__(self, operation: str, vendor_code: str, reason: str):
self,
operation: str,
vendor_code: str,
reason: str
):
super().__init__( super().__init__(
message=f"Theme operation '{operation}' failed for vendor '{vendor_code}': {reason}", message=f"Theme operation '{operation}' failed for vendor '{vendor_code}': {reason}",
error_code="THEME_OPERATION_FAILED", error_code="THEME_OPERATION_FAILED",
details={ details={
"operation": operation, "operation": operation,
"vendor_code": vendor_code, "vendor_code": vendor_code,
"reason": reason "reason": reason,
}, },
) )

View File

@@ -3,18 +3,23 @@ Architecture Scan Models
Database models for tracking code quality scans and violations Database models for tracking code quality scans and violations
""" """
from sqlalchemy import Column, Integer, String, Float, DateTime, Text, Boolean, ForeignKey, JSON from sqlalchemy import (JSON, Boolean, Column, DateTime, Float, ForeignKey,
Integer, String, Text)
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
from app.core.database import Base from app.core.database import Base
class ArchitectureScan(Base): class ArchitectureScan(Base):
"""Represents a single run of the architecture validator""" """Represents a single run of the architecture validator"""
__tablename__ = "architecture_scans" __tablename__ = "architecture_scans"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True) timestamp = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
)
total_files = Column(Integer, default=0) total_files = Column(Integer, default=0)
total_violations = Column(Integer, default=0) total_violations = Column(Integer, default=0)
errors = Column(Integer, default=0) errors = Column(Integer, default=0)
@@ -24,7 +29,9 @@ class ArchitectureScan(Base):
git_commit_hash = Column(String(40)) git_commit_hash = Column(String(40))
# Relationship to violations # Relationship to violations
violations = relationship("ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan") violations = relationship(
"ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan"
)
def __repr__(self): def __repr__(self):
return f"<ArchitectureScan(id={self.id}, violations={self.total_violations}, errors={self.errors})>" return f"<ArchitectureScan(id={self.id}, violations={self.total_violations}, errors={self.errors})>"
@@ -32,31 +39,48 @@ class ArchitectureScan(Base):
class ArchitectureViolation(Base): class ArchitectureViolation(Base):
"""Represents a single architectural violation found during a scan""" """Represents a single architectural violation found during a scan"""
__tablename__ = "architecture_violations" __tablename__ = "architecture_violations"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
scan_id = Column(Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True) scan_id = Column(
Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True
)
rule_id = Column(String(20), nullable=False, index=True) # e.g., 'API-001' rule_id = Column(String(20), nullable=False, index=True) # e.g., 'API-001'
rule_name = Column(String(200), nullable=False) rule_name = Column(String(200), nullable=False)
severity = Column(String(10), nullable=False, index=True) # 'error', 'warning', 'info' severity = Column(
String(10), nullable=False, index=True
) # 'error', 'warning', 'info'
file_path = Column(String(500), nullable=False, index=True) file_path = Column(String(500), nullable=False, index=True)
line_number = Column(Integer, nullable=False) line_number = Column(Integer, nullable=False)
message = Column(Text, nullable=False) message = Column(Text, nullable=False)
context = Column(Text) # Code snippet context = Column(Text) # Code snippet
suggestion = Column(Text) suggestion = Column(Text)
status = Column(String(20), default='open', index=True) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt' status = Column(
String(20), default="open", index=True
) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt'
assigned_to = Column(Integer, ForeignKey("users.id")) assigned_to = Column(Integer, ForeignKey("users.id"))
resolved_at = Column(DateTime(timezone=True)) resolved_at = Column(DateTime(timezone=True))
resolved_by = Column(Integer, ForeignKey("users.id")) resolved_by = Column(Integer, ForeignKey("users.id"))
resolution_note = Column(Text) resolution_note = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationships # Relationships
scan = relationship("ArchitectureScan", back_populates="violations") scan = relationship("ArchitectureScan", back_populates="violations")
assigned_user = relationship("User", foreign_keys=[assigned_to], backref="assigned_violations") assigned_user = relationship(
resolver = relationship("User", foreign_keys=[resolved_by], backref="resolved_violations") "User", foreign_keys=[assigned_to], backref="assigned_violations"
assignments = relationship("ViolationAssignment", back_populates="violation", cascade="all, delete-orphan") )
comments = relationship("ViolationComment", back_populates="violation", cascade="all, delete-orphan") resolver = relationship(
"User", foreign_keys=[resolved_by], backref="resolved_violations"
)
assignments = relationship(
"ViolationAssignment", back_populates="violation", cascade="all, delete-orphan"
)
comments = relationship(
"ViolationComment", back_populates="violation", cascade="all, delete-orphan"
)
def __repr__(self): def __repr__(self):
return f"<ArchitectureViolation(id={self.id}, rule={self.rule_id}, file={self.file_path}:{self.line_number})>" return f"<ArchitectureViolation(id={self.id}, rule={self.rule_id}, file={self.file_path}:{self.line_number})>"
@@ -64,18 +88,30 @@ class ArchitectureViolation(Base):
class ArchitectureRule(Base): class ArchitectureRule(Base):
"""Architecture rules configuration (from YAML with database overrides)""" """Architecture rules configuration (from YAML with database overrides)"""
__tablename__ = "architecture_rules" __tablename__ = "architecture_rules"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
rule_id = Column(String(20), unique=True, nullable=False, index=True) # e.g., 'API-001' rule_id = Column(
category = Column(String(50), nullable=False) # 'api_endpoint', 'service_layer', etc. String(20), unique=True, nullable=False, index=True
) # e.g., 'API-001'
category = Column(
String(50), nullable=False
) # 'api_endpoint', 'service_layer', etc.
name = Column(String(200), nullable=False) name = Column(String(200), nullable=False)
description = Column(Text) description = Column(Text)
severity = Column(String(10), nullable=False) # Can override default from YAML severity = Column(String(10), nullable=False) # Can override default from YAML
enabled = Column(Boolean, default=True, nullable=False) enabled = Column(Boolean, default=True, nullable=False)
custom_config = Column(JSON) # For rule-specific settings custom_config = Column(JSON) # For rule-specific settings
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
def __repr__(self): def __repr__(self):
return f"<ArchitectureRule(id={self.rule_id}, name={self.name}, enabled={self.enabled})>" return f"<ArchitectureRule(id={self.rule_id}, name={self.name}, enabled={self.enabled})>"
@@ -83,20 +119,29 @@ class ArchitectureRule(Base):
class ViolationAssignment(Base): class ViolationAssignment(Base):
"""Tracks assignment of violations to developers""" """Tracks assignment of violations to developers"""
__tablename__ = "violation_assignments" __tablename__ = "violation_assignments"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True) violation_id = Column(
Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True
)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
assigned_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) assigned_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
assigned_by = Column(Integer, ForeignKey("users.id")) assigned_by = Column(Integer, ForeignKey("users.id"))
due_date = Column(DateTime(timezone=True)) due_date = Column(DateTime(timezone=True))
priority = Column(String(10), default='medium') # 'low', 'medium', 'high', 'critical' priority = Column(
String(10), default="medium"
) # 'low', 'medium', 'high', 'critical'
# Relationships # Relationships
violation = relationship("ArchitectureViolation", back_populates="assignments") violation = relationship("ArchitectureViolation", back_populates="assignments")
user = relationship("User", foreign_keys=[user_id], backref="violation_assignments") user = relationship("User", foreign_keys=[user_id], backref="violation_assignments")
assigner = relationship("User", foreign_keys=[assigned_by], backref="assigned_by_me") assigner = relationship(
"User", foreign_keys=[assigned_by], backref="assigned_by_me"
)
def __repr__(self): def __repr__(self):
return f"<ViolationAssignment(id={self.id}, violation_id={self.violation_id}, user_id={self.user_id})>" return f"<ViolationAssignment(id={self.id}, violation_id={self.violation_id}, user_id={self.user_id})>"
@@ -104,13 +149,18 @@ class ViolationAssignment(Base):
class ViolationComment(Base): class ViolationComment(Base):
"""Comments on violations for collaboration""" """Comments on violations for collaboration"""
__tablename__ = "violation_comments" __tablename__ = "violation_comments"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True) violation_id = Column(
Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True
)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
comment = Column(Text, nullable=False) comment = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationships # Relationships
violation = relationship("ArchitectureViolation", back_populates="comments") violation = relationship("ArchitectureViolation", back_populates="comments")

View File

@@ -30,17 +30,15 @@ Routes:
- GET /code-quality/violations/{violation_id} → Violation details (auth required) - GET /code-quality/violations/{violation_id} → Violation details (auth required)
""" """
from fastapi import APIRouter, Request, Depends, Path from typing import Optional
from fastapi import APIRouter, Depends, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import Optional
from app.api.deps import ( from app.api.deps import (get_current_admin_from_cookie_or_header,
get_current_admin_from_cookie_or_header, get_current_admin_optional, get_db)
get_current_admin_optional,
get_db
)
from models.database.user import User from models.database.user import User
router = APIRouter() router = APIRouter()
@@ -51,9 +49,10 @@ templates = Jinja2Templates(directory="app/templates")
# PUBLIC ROUTES (No Authentication Required) # PUBLIC ROUTES (No Authentication Required)
# ============================================================================ # ============================================================================
@router.get("/", response_class=RedirectResponse, include_in_schema=False) @router.get("/", response_class=RedirectResponse, include_in_schema=False)
async def admin_root( async def admin_root(
current_user: Optional[User] = Depends(get_current_admin_optional) current_user: Optional[User] = Depends(get_current_admin_optional),
): ):
""" """
Redirect /admin/ based on authentication status. Redirect /admin/ based on authentication status.
@@ -70,8 +69,7 @@ async def admin_root(
@router.get("/login", response_class=HTMLResponse, include_in_schema=False) @router.get("/login", response_class=HTMLResponse, include_in_schema=False)
async def admin_login_page( async def admin_login_page(
request: Request, request: Request, current_user: Optional[User] = Depends(get_current_admin_optional)
current_user: Optional[User] = Depends(get_current_admin_optional)
): ):
""" """
Render admin login page. Render admin login page.
@@ -83,21 +81,19 @@ async def admin_login_page(
# User is already logged in as admin, redirect to dashboard # User is already logged in as admin, redirect to dashboard
return RedirectResponse(url="/admin/dashboard", status_code=302) return RedirectResponse(url="/admin/dashboard", status_code=302)
return templates.TemplateResponse( return templates.TemplateResponse("admin/login.html", {"request": request})
"admin/login.html",
{"request": request}
)
# ============================================================================ # ============================================================================
# AUTHENTICATED ROUTES (Admin Only) # AUTHENTICATED ROUTES (Admin Only)
# ============================================================================ # ============================================================================
@router.get("/dashboard", response_class=HTMLResponse, include_in_schema=False) @router.get("/dashboard", response_class=HTMLResponse, include_in_schema=False)
async def admin_dashboard_page( async def admin_dashboard_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render admin dashboard page. Render admin dashboard page.
@@ -108,7 +104,7 @@ async def admin_dashboard_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -116,11 +112,12 @@ async def admin_dashboard_page(
# VENDOR MANAGEMENT ROUTES # VENDOR MANAGEMENT ROUTES
# ============================================================================ # ============================================================================
@router.get("/vendors", response_class=HTMLResponse, include_in_schema=False) @router.get("/vendors", response_class=HTMLResponse, include_in_schema=False)
async def admin_vendors_list_page( async def admin_vendors_list_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendors management page. Render vendors management page.
@@ -131,7 +128,7 @@ async def admin_vendors_list_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -139,7 +136,7 @@ async def admin_vendors_list_page(
async def admin_vendor_create_page( async def admin_vendor_create_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendor creation form. Render vendor creation form.
@@ -149,16 +146,18 @@ async def admin_vendor_create_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@router.get("/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False
)
async def admin_vendor_detail_page( async def admin_vendor_detail_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendor detail page. Render vendor detail page.
@@ -170,16 +169,18 @@ async def admin_vendor_detail_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@router.get("/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False
)
async def admin_vendor_edit_page( async def admin_vendor_edit_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendor edit form. Render vendor edit form.
@@ -190,7 +191,7 @@ async def admin_vendor_edit_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -198,12 +199,17 @@ async def admin_vendor_edit_page(
# VENDOR DOMAINS ROUTES # VENDOR DOMAINS ROUTES
# ============================================================================ # ============================================================================
@router.get("/vendors/{vendor_code}/domains", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/vendors/{vendor_code}/domains",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_vendor_domains_page( async def admin_vendor_domains_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendor domains management page. Render vendor domains management page.
@@ -215,7 +221,7 @@ async def admin_vendor_domains_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -223,12 +229,15 @@ async def admin_vendor_domains_page(
# VENDOR THEMES ROUTES # VENDOR THEMES ROUTES
# ============================================================================ # ============================================================================
@router.get("/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False
)
async def admin_vendor_theme_page( async def admin_vendor_theme_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendor theme customization page. Render vendor theme customization page.
@@ -240,7 +249,7 @@ async def admin_vendor_theme_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -248,11 +257,12 @@ async def admin_vendor_theme_page(
# USER MANAGEMENT ROUTES # USER MANAGEMENT ROUTES
# ============================================================================ # ============================================================================
@router.get("/users", response_class=HTMLResponse, include_in_schema=False) @router.get("/users", response_class=HTMLResponse, include_in_schema=False)
async def admin_users_page( async def admin_users_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render users management page. Render users management page.
@@ -263,7 +273,7 @@ async def admin_users_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -271,11 +281,12 @@ async def admin_users_page(
# IMPORT MANAGEMENT ROUTES # IMPORT MANAGEMENT ROUTES
# ============================================================================ # ============================================================================
@router.get("/imports", response_class=HTMLResponse, include_in_schema=False) @router.get("/imports", response_class=HTMLResponse, include_in_schema=False)
async def admin_imports_page( async def admin_imports_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render imports management page. Render imports management page.
@@ -286,7 +297,7 @@ async def admin_imports_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -294,11 +305,12 @@ async def admin_imports_page(
# SETTINGS ROUTES # SETTINGS ROUTES
# ============================================================================ # ============================================================================
@router.get("/settings", response_class=HTMLResponse, include_in_schema=False) @router.get("/settings", response_class=HTMLResponse, include_in_schema=False)
async def admin_settings_page( async def admin_settings_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render admin settings page. Render admin settings page.
@@ -309,7 +321,7 @@ async def admin_settings_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -317,11 +329,12 @@ async def admin_settings_page(
# CONTENT MANAGEMENT SYSTEM (CMS) ROUTES # CONTENT MANAGEMENT SYSTEM (CMS) ROUTES
# ============================================================================ # ============================================================================
@router.get("/platform-homepage", response_class=HTMLResponse, include_in_schema=False) @router.get("/platform-homepage", response_class=HTMLResponse, include_in_schema=False)
async def admin_platform_homepage_manager( async def admin_platform_homepage_manager(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render platform homepage manager. Render platform homepage manager.
@@ -332,7 +345,7 @@ async def admin_platform_homepage_manager(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -340,7 +353,7 @@ async def admin_platform_homepage_manager(
async def admin_content_pages_list( async def admin_content_pages_list(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render content pages list. Render content pages list.
@@ -351,15 +364,17 @@ async def admin_content_pages_list(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@router.get("/content-pages/create", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/content-pages/create", response_class=HTMLResponse, include_in_schema=False
)
async def admin_content_page_create( async def admin_content_page_create(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render create content page form. Render create content page form.
@@ -371,16 +386,20 @@ async def admin_content_page_create(
"request": request, "request": request,
"user": current_user, "user": current_user,
"page_id": None, # Indicates this is a create operation "page_id": None, # Indicates this is a create operation
} },
) )
@router.get("/content-pages/{page_id}/edit", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/content-pages/{page_id}/edit",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_content_page_edit( async def admin_content_page_edit(
request: Request, request: Request,
page_id: int = Path(..., description="Content page ID"), page_id: int = Path(..., description="Content page ID"),
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render edit content page form. Render edit content page form.
@@ -392,7 +411,7 @@ async def admin_content_page_edit(
"request": request, "request": request,
"user": current_user, "user": current_user,
"page_id": page_id, "page_id": page_id,
} },
) )
@@ -400,11 +419,12 @@ async def admin_content_page_edit(
# DEVELOPER TOOLS - COMPONENTS & TESTING # DEVELOPER TOOLS - COMPONENTS & TESTING
# ============================================================================ # ============================================================================
@router.get("/components", response_class=HTMLResponse, include_in_schema=False) @router.get("/components", response_class=HTMLResponse, include_in_schema=False)
async def admin_components_page( async def admin_components_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render UI components library page. Render UI components library page.
@@ -415,7 +435,7 @@ async def admin_components_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -423,7 +443,7 @@ async def admin_components_page(
async def admin_icons_page( async def admin_icons_page(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render icons browser page. Render icons browser page.
@@ -434,7 +454,7 @@ async def admin_icons_page(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -442,7 +462,7 @@ async def admin_icons_page(
async def admin_testing_hub( async def admin_testing_hub(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render testing hub page. Render testing hub page.
@@ -453,7 +473,7 @@ async def admin_testing_hub(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -461,7 +481,7 @@ async def admin_testing_hub(
async def admin_test_auth_flow( async def admin_test_auth_flow(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render authentication flow testing page. Render authentication flow testing page.
@@ -472,15 +492,19 @@ async def admin_test_auth_flow(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@router.get("/test/vendors-users-migration", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/test/vendors-users-migration",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_test_vendors_users_migration( async def admin_test_vendors_users_migration(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render vendors and users migration testing page. Render vendors and users migration testing page.
@@ -491,7 +515,7 @@ async def admin_test_vendors_users_migration(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@@ -499,11 +523,12 @@ async def admin_test_vendors_users_migration(
# CODE QUALITY & ARCHITECTURE ROUTES # CODE QUALITY & ARCHITECTURE ROUTES
# ============================================================================ # ============================================================================
@router.get("/code-quality", response_class=HTMLResponse, include_in_schema=False) @router.get("/code-quality", response_class=HTMLResponse, include_in_schema=False)
async def admin_code_quality_dashboard( async def admin_code_quality_dashboard(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render code quality dashboard. Render code quality dashboard.
@@ -514,15 +539,17 @@ async def admin_code_quality_dashboard(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@router.get("/code-quality/violations", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/code-quality/violations", response_class=HTMLResponse, include_in_schema=False
)
async def admin_code_quality_violations( async def admin_code_quality_violations(
request: Request, request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render violations list page. Render violations list page.
@@ -533,16 +560,20 @@ async def admin_code_quality_violations(
{ {
"request": request, "request": request,
"user": current_user, "user": current_user,
} },
) )
@router.get("/code-quality/violations/{violation_id}", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/code-quality/violations/{violation_id}",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_code_quality_violation_detail( async def admin_code_quality_violation_detail(
request: Request, request: Request,
violation_id: int = Path(..., description="Violation ID"), violation_id: int = Path(..., description="Violation ID"),
current_user: User = Depends(get_current_admin_from_cookie_or_header), current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render violation detail page. Render violation detail page.
@@ -554,5 +585,5 @@ async def admin_code_quality_violation_detail(
"request": request, "request": request,
"user": current_user, "user": current_user,
"violation_id": violation_id, "violation_id": violation_id,
} },
) )

View File

@@ -31,7 +31,8 @@ Routes (all mounted at /shop/* or /vendors/{code}/shop/* prefix):
""" """
import logging import logging
from fastapi import APIRouter, Request, Depends, Path
from fastapi import APIRouter, Depends, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -50,6 +51,7 @@ logger = logging.getLogger(__name__)
# HELPER: Build Shop Template Context # HELPER: Build Shop Template Context
# ============================================================================ # ============================================================================
def get_shop_context(request: Request, db: Session = None, **extra_context) -> dict: def get_shop_context(request: Request, db: Session = None, **extra_context) -> dict:
""" """
Build template context for shop pages. Build template context for shop pages.
@@ -76,13 +78,17 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
get_shop_context(request, db=db, user=current_user, product_id=123) get_shop_context(request, db=db, user=current_user, product_id=123)
""" """
# Extract from middleware state # Extract from middleware state
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
theme = getattr(request.state, 'theme', None) theme = getattr(request.state, "theme", None)
clean_path = getattr(request.state, 'clean_path', request.url.path) clean_path = getattr(request.state, "clean_path", request.url.path)
vendor_context = getattr(request.state, 'vendor_context', None) vendor_context = getattr(request.state, "vendor_context", None)
# Get detection method from vendor_context # Get detection method from vendor_context
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
if vendor is None: if vendor is None:
logger.warning( logger.warning(
@@ -91,7 +97,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
"path": request.url.path, "path": request.url.path,
"host": request.headers.get("host", ""), "host": request.headers.get("host", ""),
"has_vendor": False, "has_vendor": False,
} },
) )
# Calculate base URL for links # Calculate base URL for links
@@ -100,7 +106,11 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
base_url = "/" base_url = "/"
if access_method == "path" and vendor: if access_method == "path" and vendor:
# Use the full_prefix from vendor_context to determine which pattern was used # Use the full_prefix from vendor_context to determine which pattern was used
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/" base_url = f"{full_prefix}{vendor.subdomain}/"
# Load footer navigation pages from CMS if db session provided # Load footer navigation pages from CMS if db session provided
@@ -111,22 +121,16 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
vendor_id = vendor.id vendor_id = vendor.id
# Get pages configured to show in footer # Get pages configured to show in footer
footer_pages = content_page_service.list_pages_for_vendor( footer_pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=vendor_id, footer_only=True, include_unpublished=False
vendor_id=vendor_id,
footer_only=True,
include_unpublished=False
) )
# Get pages configured to show in header # Get pages configured to show in header
header_pages = content_page_service.list_pages_for_vendor( header_pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=vendor_id, header_only=True, include_unpublished=False
vendor_id=vendor_id,
header_only=True,
include_unpublished=False
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[SHOP_CONTEXT] Failed to load navigation pages", f"[SHOP_CONTEXT] Failed to load navigation pages",
extra={"error": str(e), "vendor_id": vendor.id if vendor else None} extra={"error": str(e), "vendor_id": vendor.id if vendor else None},
) )
context = { context = {
@@ -156,7 +160,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
"footer_pages_count": len(footer_pages), "footer_pages_count": len(footer_pages),
"header_pages_count": len(header_pages), "header_pages_count": len(header_pages),
"extra_keys": list(extra_context.keys()) if extra_context else [], "extra_keys": list(extra_context.keys()) if extra_context else [],
} },
) )
return context return context
@@ -166,6 +170,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
# PUBLIC SHOP ROUTES (No Authentication Required) # PUBLIC SHOP ROUTES (No Authentication Required)
# ============================================================================ # ============================================================================
@router.get("/", response_class=HTMLResponse, include_in_schema=False) @router.get("/", response_class=HTMLResponse, include_in_schema=False)
@router.get("/products", response_class=HTMLResponse, include_in_schema=False) @router.get("/products", response_class=HTMLResponse, include_in_schema=False)
async def shop_products_page(request: Request, db: Session = Depends(get_db)): async def shop_products_page(request: Request, db: Session = Depends(get_db)):
@@ -177,21 +182,21 @@ async def shop_products_page(request: Request, db: Session = Depends(get_db)):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/products.html", "shop/products.html", get_shop_context(request, db=db)
get_shop_context(request, db=db)
) )
@router.get("/products/{product_id}", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/products/{product_id}", response_class=HTMLResponse, include_in_schema=False
)
async def shop_product_detail_page( async def shop_product_detail_page(
request: Request, request: Request, product_id: int = Path(..., description="Product ID")
product_id: int = Path(..., description="Product ID")
): ):
""" """
Render product detail page. Render product detail page.
@@ -201,21 +206,21 @@ async def shop_product_detail_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/product.html", "shop/product.html", get_shop_context(request, product_id=product_id)
get_shop_context(request, product_id=product_id)
) )
@router.get("/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False
)
async def shop_category_page( async def shop_category_page(
request: Request, request: Request, category_slug: str = Path(..., description="Category slug")
category_slug: str = Path(..., description="Category slug")
): ):
""" """
Render category products page. Render category products page.
@@ -225,14 +230,13 @@ async def shop_category_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/category.html", "shop/category.html", get_shop_context(request, category_slug=category_slug)
get_shop_context(request, category_slug=category_slug)
) )
@@ -246,15 +250,12 @@ async def shop_cart_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse("shop/cart.html", get_shop_context(request))
"shop/cart.html",
get_shop_context(request)
)
@router.get("/checkout", response_class=HTMLResponse, include_in_schema=False) @router.get("/checkout", response_class=HTMLResponse, include_in_schema=False)
@@ -267,15 +268,12 @@ async def shop_checkout_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse("shop/checkout.html", get_shop_context(request))
"shop/checkout.html",
get_shop_context(request)
)
@router.get("/search", response_class=HTMLResponse, include_in_schema=False) @router.get("/search", response_class=HTMLResponse, include_in_schema=False)
@@ -288,21 +286,19 @@ async def shop_search_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse("shop/search.html", get_shop_context(request))
"shop/search.html",
get_shop_context(request)
)
# ============================================================================ # ============================================================================
# CUSTOMER ACCOUNT - PUBLIC ROUTES (No Authentication) # CUSTOMER ACCOUNT - PUBLIC ROUTES (No Authentication)
# ============================================================================ # ============================================================================
@router.get("/account/register", response_class=HTMLResponse, include_in_schema=False) @router.get("/account/register", response_class=HTMLResponse, include_in_schema=False)
async def shop_register_page(request: Request): async def shop_register_page(request: Request):
""" """
@@ -313,14 +309,13 @@ async def shop_register_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/register.html", "shop/account/register.html", get_shop_context(request)
get_shop_context(request)
) )
@@ -334,18 +329,19 @@ async def shop_login_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/login.html", "shop/account/login.html", get_shop_context(request)
get_shop_context(request)
) )
@router.get("/account/forgot-password", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/account/forgot-password", response_class=HTMLResponse, include_in_schema=False
)
async def shop_forgot_password_page(request: Request): async def shop_forgot_password_page(request: Request):
""" """
Render forgot password page. Render forgot password page.
@@ -355,14 +351,13 @@ async def shop_forgot_password_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/forgot-password.html", "shop/account/forgot-password.html", get_shop_context(request)
get_shop_context(request)
) )
@@ -370,6 +365,7 @@ async def shop_forgot_password_page(request: Request):
# CUSTOMER ACCOUNT - AUTHENTICATED ROUTES # CUSTOMER ACCOUNT - AUTHENTICATED ROUTES
# ============================================================================ # ============================================================================
@router.get("/account", response_class=RedirectResponse, include_in_schema=False) @router.get("/account", response_class=RedirectResponse, include_in_schema=False)
@router.get("/account/", response_class=RedirectResponse, include_in_schema=False) @router.get("/account/", response_class=RedirectResponse, include_in_schema=False)
async def shop_account_root(request: Request): async def shop_account_root(request: Request):
@@ -380,19 +376,27 @@ async def shop_account_root(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
# Get base_url from context for proper redirect # Get base_url from context for proper redirect
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
vendor_context = getattr(request.state, 'vendor_context', None) vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
base_url = "/" base_url = "/"
if access_method == "path" and vendor: if access_method == "path" and vendor:
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/" base_url = f"{full_prefix}{vendor.subdomain}/"
return RedirectResponse(url=f"{base_url}shop/account/dashboard", status_code=302) return RedirectResponse(url=f"{base_url}shop/account/dashboard", status_code=302)
@@ -402,7 +406,7 @@ async def shop_account_root(request: Request):
async def shop_account_dashboard_page( async def shop_account_dashboard_page(
request: Request, request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer account dashboard. Render customer account dashboard.
@@ -413,14 +417,13 @@ async def shop_account_dashboard_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/dashboard.html", "shop/account/dashboard.html", get_shop_context(request, user=current_customer)
get_shop_context(request, user=current_customer)
) )
@@ -428,7 +431,7 @@ async def shop_account_dashboard_page(
async def shop_orders_page( async def shop_orders_page(
request: Request, request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer orders history page. Render customer orders history page.
@@ -439,23 +442,24 @@ async def shop_orders_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/orders.html", "shop/account/orders.html", get_shop_context(request, user=current_customer)
get_shop_context(request, user=current_customer)
) )
@router.get("/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False
)
async def shop_order_detail_page( async def shop_order_detail_page(
request: Request, request: Request,
order_id: int = Path(..., description="Order ID"), order_id: int = Path(..., description="Order ID"),
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer order detail page. Render customer order detail page.
@@ -466,14 +470,14 @@ async def shop_order_detail_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/order-detail.html", "shop/account/order-detail.html",
get_shop_context(request, user=current_customer, order_id=order_id) get_shop_context(request, user=current_customer, order_id=order_id),
) )
@@ -481,7 +485,7 @@ async def shop_order_detail_page(
async def shop_profile_page( async def shop_profile_page(
request: Request, request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer profile page. Render customer profile page.
@@ -492,14 +496,13 @@ async def shop_profile_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/profile.html", "shop/account/profile.html", get_shop_context(request, user=current_customer)
get_shop_context(request, user=current_customer)
) )
@@ -507,7 +510,7 @@ async def shop_profile_page(
async def shop_addresses_page( async def shop_addresses_page(
request: Request, request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer addresses management page. Render customer addresses management page.
@@ -518,14 +521,13 @@ async def shop_addresses_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/addresses.html", "shop/account/addresses.html", get_shop_context(request, user=current_customer)
get_shop_context(request, user=current_customer)
) )
@@ -533,7 +535,7 @@ async def shop_addresses_page(
async def shop_wishlist_page( async def shop_wishlist_page(
request: Request, request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer wishlist page. Render customer wishlist page.
@@ -544,14 +546,13 @@ async def shop_wishlist_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/wishlist.html", "shop/account/wishlist.html", get_shop_context(request, user=current_customer)
get_shop_context(request, user=current_customer)
) )
@@ -559,7 +560,7 @@ async def shop_wishlist_page(
async def shop_settings_page( async def shop_settings_page(
request: Request, request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header), current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Render customer account settings page. Render customer account settings page.
@@ -570,14 +571,13 @@ async def shop_settings_page(
f"[SHOP_HANDLER] shop_products_page REACHED", f"[SHOP_HANDLER] shop_products_page REACHED",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/account/settings.html", "shop/account/settings.html", get_shop_context(request, user=current_customer)
get_shop_context(request, user=current_customer)
) )
@@ -585,11 +585,12 @@ async def shop_settings_page(
# DYNAMIC CONTENT PAGES (CMS) # DYNAMIC CONTENT PAGES (CMS)
# ============================================================================ # ============================================================================
@router.get("/{slug}", response_class=HTMLResponse, include_in_schema=False) @router.get("/{slug}", response_class=HTMLResponse, include_in_schema=False)
async def generic_content_page( async def generic_content_page(
request: Request, request: Request,
slug: str = Path(..., description="Content page slug"), slug: str = Path(..., description="Content page slug"),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Generic content page handler (CMS). Generic content page handler (CMS).
@@ -612,20 +613,17 @@ async def generic_content_page(
extra={ extra={
"path": request.url.path, "path": request.url.path,
"slug": slug, "slug": slug,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None vendor_id = vendor.id if vendor else None
# Load content page from database (vendor override → platform default) # Load content page from database (vendor override → platform default)
page = content_page_service.get_page_for_vendor( page = content_page_service.get_page_for_vendor(
db, db, slug=slug, vendor_id=vendor_id, include_unpublished=False
slug=slug,
vendor_id=vendor_id,
include_unpublished=False
) )
if not page: if not page:
@@ -635,7 +633,7 @@ async def generic_content_page(
"slug": slug, "slug": slug,
"vendor_id": vendor_id, "vendor_id": vendor_id,
"vendor_name": vendor.name if vendor else None, "vendor_name": vendor.name if vendor else None,
} },
) )
raise HTTPException(status_code=404, detail=f"Page not found: {slug}") raise HTTPException(status_code=404, detail=f"Page not found: {slug}")
@@ -647,12 +645,11 @@ async def generic_content_page(
"page_title": page.title, "page_title": page.title,
"is_vendor_override": page.vendor_id is not None, "is_vendor_override": page.vendor_id is not None,
"vendor_id": vendor_id, "vendor_id": vendor_id,
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
"shop/content-page.html", "shop/content-page.html", get_shop_context(request, page=page)
get_shop_context(request, page=page)
) )
@@ -660,6 +657,7 @@ async def generic_content_page(
# DEBUG ENDPOINTS - For troubleshooting context issues # DEBUG ENDPOINTS - For troubleshooting context issues
# ============================================================================ # ============================================================================
@router.get("/debug/context", response_class=HTMLResponse, include_in_schema=False) @router.get("/debug/context", response_class=HTMLResponse, include_in_schema=False)
async def debug_context(request: Request): async def debug_context(request: Request):
""" """
@@ -670,8 +668,8 @@ async def debug_context(request: Request):
URL: /shop/debug/context URL: /shop/debug/context
""" """
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
theme = getattr(request.state, 'theme', None) theme = getattr(request.state, "theme", None)
debug_info = { debug_info = {
"path": request.url.path, "path": request.url.path,
@@ -687,12 +685,13 @@ async def debug_context(request: Request):
"found": theme is not None, "found": theme is not None,
"name": theme.get("theme_name") if theme else None, "name": theme.get("theme_name") if theme else None,
}, },
"clean_path": getattr(request.state, 'clean_path', 'NOT SET'), "clean_path": getattr(request.state, "clean_path", "NOT SET"),
"context_type": str(getattr(request.state, 'context_type', 'NOT SET')), "context_type": str(getattr(request.state, "context_type", "NOT SET")),
} }
# Return as JSON-like HTML for easy reading # Return as JSON-like HTML for easy reading
import json import json
html_content = f""" html_content = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>

View File

@@ -21,18 +21,16 @@ Routes:
- GET /vendor/{vendor_code}/settings → Vendor settings - GET /vendor/{vendor_code}/settings → Vendor settings
""" """
from fastapi import APIRouter, Request, Depends, Path, HTTPException import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import Optional
import logging
from app.api.deps import ( from app.api.deps import (get_current_vendor_from_cookie_or_header,
get_current_vendor_from_cookie_or_header, get_current_vendor_optional, get_db)
get_current_vendor_optional,
get_db
)
from app.services.content_page_service import content_page_service from app.services.content_page_service import content_page_service
from models.database.user import User from models.database.user import User
@@ -46,6 +44,7 @@ templates = Jinja2Templates(directory="app/templates")
# PUBLIC ROUTES (No Authentication Required) # PUBLIC ROUTES (No Authentication Required)
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}", response_class=RedirectResponse, include_in_schema=False) @router.get("/{vendor_code}", response_class=RedirectResponse, include_in_schema=False)
async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor code")): async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor code")):
""" """
@@ -58,7 +57,7 @@ async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor
@router.get("/{vendor_code}/", response_class=RedirectResponse, include_in_schema=False) @router.get("/{vendor_code}/", response_class=RedirectResponse, include_in_schema=False)
async def vendor_root( async def vendor_root(
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: Optional[User] = Depends(get_current_vendor_optional) current_user: Optional[User] = Depends(get_current_vendor_optional),
): ):
""" """
Redirect /vendor/{code}/ based on authentication status. Redirect /vendor/{code}/ based on authentication status.
@@ -73,11 +72,13 @@ async def vendor_root(
return RedirectResponse(url=f"/vendor/{vendor_code}/login", status_code=302) return RedirectResponse(url=f"/vendor/{vendor_code}/login", status_code=302)
@router.get("/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_login_page( async def vendor_login_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: Optional[User] = Depends(get_current_vendor_optional) current_user: Optional[User] = Depends(get_current_vendor_optional),
): ):
""" """
Render vendor login page. Render vendor login page.
@@ -99,7 +100,7 @@ async def vendor_login_page(
{ {
"request": request, "request": request,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -107,11 +108,14 @@ async def vendor_login_page(
# AUTHENTICATED ROUTES (Vendor Users Only) # AUTHENTICATED ROUTES (Vendor Users Only)
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_dashboard_page( async def vendor_dashboard_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render vendor dashboard. Render vendor dashboard.
@@ -128,7 +132,7 @@ async def vendor_dashboard_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -136,11 +140,14 @@ async def vendor_dashboard_page(
# PRODUCT MANAGEMENT # PRODUCT MANAGEMENT
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_products_page( async def vendor_products_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render products management page. Render products management page.
@@ -152,7 +159,7 @@ async def vendor_products_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -160,11 +167,14 @@ async def vendor_products_page(
# ORDER MANAGEMENT # ORDER MANAGEMENT
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_orders_page( async def vendor_orders_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render orders management page. Render orders management page.
@@ -176,7 +186,7 @@ async def vendor_orders_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -184,11 +194,14 @@ async def vendor_orders_page(
# CUSTOMER MANAGEMENT # CUSTOMER MANAGEMENT
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_customers_page( async def vendor_customers_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render customers management page. Render customers management page.
@@ -200,7 +213,7 @@ async def vendor_customers_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -208,11 +221,14 @@ async def vendor_customers_page(
# INVENTORY MANAGEMENT # INVENTORY MANAGEMENT
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_inventory_page( async def vendor_inventory_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render inventory management page. Render inventory management page.
@@ -224,7 +240,7 @@ async def vendor_inventory_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -232,11 +248,14 @@ async def vendor_inventory_page(
# MARKETPLACE IMPORTS # MARKETPLACE IMPORTS
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_marketplace_page( async def vendor_marketplace_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render marketplace import page. Render marketplace import page.
@@ -248,7 +267,7 @@ async def vendor_marketplace_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -256,11 +275,12 @@ async def vendor_marketplace_page(
# TEAM MANAGEMENT # TEAM MANAGEMENT
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/team", response_class=HTMLResponse, include_in_schema=False) @router.get("/{vendor_code}/team", response_class=HTMLResponse, include_in_schema=False)
async def vendor_team_page( async def vendor_team_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render team management page. Render team management page.
@@ -272,7 +292,7 @@ async def vendor_team_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -280,11 +300,14 @@ async def vendor_team_page(
# PROFILE & SETTINGS # PROFILE & SETTINGS
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_profile_page( async def vendor_profile_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render vendor profile page. Render vendor profile page.
@@ -296,15 +319,17 @@ async def vendor_profile_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@router.get("/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False) @router.get(
"/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_settings_page( async def vendor_settings_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header) current_user: User = Depends(get_current_vendor_from_cookie_or_header),
): ):
""" """
Render vendor settings page. Render vendor settings page.
@@ -316,7 +341,7 @@ async def vendor_settings_page(
"request": request, "request": request,
"user": current_user, "user": current_user,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )
@@ -324,12 +349,15 @@ async def vendor_settings_page(
# DYNAMIC CONTENT PAGES (CMS) # DYNAMIC CONTENT PAGES (CMS)
# ============================================================================ # ============================================================================
@router.get("/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_content_page( async def vendor_content_page(
request: Request, request: Request,
vendor_code: str = Path(..., description="Vendor code"), vendor_code: str = Path(..., description="Vendor code"),
slug: str = Path(..., description="Content page slug"), slug: str = Path(..., description="Content page slug"),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
Generic content page handler for vendor shop (CMS). Generic content page handler for vendor shop (CMS).
@@ -351,20 +379,17 @@ async def vendor_content_page(
"path": request.url.path, "path": request.url.path,
"vendor_code": vendor_code, "vendor_code": vendor_code,
"slug": slug, "slug": slug,
"vendor": getattr(request.state, 'vendor', 'NOT SET'), "vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, 'context_type', 'NOT SET'), "context": getattr(request.state, "context_type", "NOT SET"),
} },
) )
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None vendor_id = vendor.id if vendor else None
# Load content page from database (vendor override → platform default) # Load content page from database (vendor override → platform default)
page = content_page_service.get_page_for_vendor( page = content_page_service.get_page_for_vendor(
db, db, slug=slug, vendor_id=vendor_id, include_unpublished=False
slug=slug,
vendor_id=vendor_id,
include_unpublished=False
) )
if not page: if not page:
@@ -374,7 +399,7 @@ async def vendor_content_page(
"slug": slug, "slug": slug,
"vendor_code": vendor_code, "vendor_code": vendor_code,
"vendor_id": vendor_id, "vendor_id": vendor_id,
} },
) )
raise HTTPException(status_code=404, detail="Page not found") raise HTTPException(status_code=404, detail="Page not found")
@@ -385,7 +410,7 @@ async def vendor_content_page(
"page_id": page.id, "page_id": page.id,
"is_vendor_override": page.vendor_id is not None, "is_vendor_override": page.vendor_id is not None,
"vendor_id": vendor_id, "vendor_id": vendor_id,
} },
) )
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -394,5 +419,5 @@ async def vendor_content_page(
"request": request, "request": request,
"page": page, "page": page,
"vendor_code": vendor_code, "vendor_code": vendor_code,
} },
) )

View File

@@ -10,15 +10,15 @@ This module provides functions for:
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional, Dict, Any from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from app.exceptions import AdminOperationException
from models.database.admin import AdminAuditLog from models.database.admin import AdminAuditLog
from models.database.user import User from models.database.user import User
from models.schema.admin import AdminAuditLogFilters, AdminAuditLogResponse from models.schema.admin import AdminAuditLogFilters, AdminAuditLogResponse
from app.exceptions import AdminOperationException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class AdminAuditService:
details: Optional[Dict[str, Any]] = None, details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None, ip_address: Optional[str] = None,
user_agent: Optional[str] = None, user_agent: Optional[str] = None,
request_id: Optional[str] = None request_id: Optional[str] = None,
) -> AdminAuditLog: ) -> AdminAuditLog:
""" """
Log an admin action to the audit trail. Log an admin action to the audit trail.
@@ -63,7 +63,7 @@ class AdminAuditService:
details=details or {}, details=details or {},
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
request_id=request_id request_id=request_id,
) )
db.add(audit_log) db.add(audit_log)
@@ -84,9 +84,7 @@ class AdminAuditService:
return None return None
def get_audit_logs( def get_audit_logs(
self, self, db: Session, filters: AdminAuditLogFilters
db: Session,
filters: AdminAuditLogFilters
) -> List[AdminAuditLogResponse]: ) -> List[AdminAuditLogResponse]:
""" """
Get filtered admin audit logs with pagination. Get filtered admin audit logs with pagination.
@@ -98,7 +96,9 @@ class AdminAuditService:
List of audit log responses List of audit log responses
""" """
try: try:
query = db.query(AdminAuditLog).join(User, AdminAuditLog.admin_user_id == User.id) query = db.query(AdminAuditLog).join(
User, AdminAuditLog.admin_user_id == User.id
)
# Apply filters # Apply filters
conditions = [] conditions = []
@@ -123,8 +123,7 @@ class AdminAuditService:
# Execute query with pagination # Execute query with pagination
logs = ( logs = (
query query.order_by(AdminAuditLog.created_at.desc())
.order_by(AdminAuditLog.created_at.desc())
.offset(filters.skip) .offset(filters.skip)
.limit(filters.limit) .limit(filters.limit)
.all() .all()
@@ -143,7 +142,7 @@ class AdminAuditService:
ip_address=log.ip_address, ip_address=log.ip_address,
user_agent=log.user_agent, user_agent=log.user_agent,
request_id=log.request_id, request_id=log.request_id,
created_at=log.created_at created_at=log.created_at,
) )
for log in logs for log in logs
] ]
@@ -151,15 +150,10 @@ class AdminAuditService:
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve audit logs: {str(e)}") logger.error(f"Failed to retrieve audit logs: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_audit_logs", operation="get_audit_logs", reason="Database query failed"
reason="Database query failed"
) )
def get_audit_logs_count( def get_audit_logs_count(self, db: Session, filters: AdminAuditLogFilters) -> int:
self,
db: Session,
filters: AdminAuditLogFilters
) -> int:
"""Get total count of audit logs matching filters.""" """Get total count of audit logs matching filters."""
try: try:
query = db.query(AdminAuditLog) query = db.query(AdminAuditLog)
@@ -192,24 +186,14 @@ class AdminAuditService:
return 0 return 0
def get_recent_actions_by_admin( def get_recent_actions_by_admin(
self, self, db: Session, admin_user_id: int, limit: int = 10
db: Session,
admin_user_id: int,
limit: int = 10
) -> List[AdminAuditLogResponse]: ) -> List[AdminAuditLogResponse]:
"""Get recent actions by a specific admin.""" """Get recent actions by a specific admin."""
filters = AdminAuditLogFilters( filters = AdminAuditLogFilters(admin_user_id=admin_user_id, limit=limit)
admin_user_id=admin_user_id,
limit=limit
)
return self.get_audit_logs(db, filters) return self.get_audit_logs(db, filters)
def get_actions_by_target( def get_actions_by_target(
self, self, db: Session, target_type: str, target_id: str, limit: int = 50
db: Session,
target_type: str,
target_id: str,
limit: int = 50
) -> List[AdminAuditLogResponse]: ) -> List[AdminAuditLogResponse]:
"""Get all actions performed on a specific target.""" """Get all actions performed on a specific target."""
try: try:
@@ -218,7 +202,7 @@ class AdminAuditService:
.filter( .filter(
and_( and_(
AdminAuditLog.target_type == target_type, AdminAuditLog.target_type == target_type,
AdminAuditLog.target_id == str(target_id) AdminAuditLog.target_id == str(target_id),
) )
) )
.order_by(AdminAuditLog.created_at.desc()) .order_by(AdminAuditLog.created_at.desc())
@@ -236,7 +220,7 @@ class AdminAuditService:
target_id=log.target_id, target_id=log.target_id,
details=log.details, details=log.details,
ip_address=log.ip_address, ip_address=log.ip_address,
created_at=log.created_at created_at=log.created_at,
) )
for log in logs for log in logs
] ]

View File

@@ -16,24 +16,19 @@ import string
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (AdminOperationException, CannotModifySelfException,
UserNotFoundException, UserNotFoundException, UserStatusChangeException,
UserStatusChangeException, ValidationException, VendorAlreadyExistsException,
CannotModifySelfException,
VendorNotFoundException, VendorNotFoundException,
VendorAlreadyExistsException, VendorVerificationException)
VendorVerificationException, from models.database.marketplace_import_job import MarketplaceImportJob
AdminOperationException, from models.database.user import User
ValidationException, from models.database.vendor import Role, Vendor, VendorUser
)
from models.schema.marketplace_import_job import MarketplaceImportJobResponse from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.schema.vendor import VendorCreate from models.schema.vendor import VendorCreate
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor, Role, VendorUser
from models.database.user import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,8 +47,7 @@ class AdminService:
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve users: {str(e)}") logger.error(f"Failed to retrieve users: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_all_users", operation="get_all_users", reason="Database query failed"
reason="Database query failed"
) )
def toggle_user_status( def toggle_user_status(
@@ -72,7 +66,7 @@ class AdminService:
user_id=user_id, user_id=user_id,
current_status="admin", current_status="admin",
attempted_action="toggle status", attempted_action="toggle status",
reason="Cannot modify another admin user" reason="Cannot modify another admin user",
) )
try: try:
@@ -95,7 +89,7 @@ class AdminService:
user_id=user_id, user_id=user_id,
current_status="active" if original_status else "inactive", current_status="active" if original_status else "inactive",
attempted_action="toggle status", attempted_action="toggle status",
reason="Database update failed" reason="Database update failed",
) )
# ============================================================================ # ============================================================================
@@ -118,17 +112,23 @@ class AdminService:
""" """
try: try:
# Check if vendor code already exists # Check if vendor code already exists
existing_vendor = db.query(Vendor).filter( existing_vendor = (
db.query(Vendor)
.filter(
func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper() func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
).first() )
.first()
)
if existing_vendor: if existing_vendor:
raise VendorAlreadyExistsException(vendor_data.vendor_code) raise VendorAlreadyExistsException(vendor_data.vendor_code)
# Check if subdomain already exists # Check if subdomain already exists
existing_subdomain = db.query(Vendor).filter( existing_subdomain = (
func.lower(Vendor.subdomain) == vendor_data.subdomain.lower() db.query(Vendor)
).first() .filter(func.lower(Vendor.subdomain) == vendor_data.subdomain.lower())
.first()
)
if existing_subdomain: if existing_subdomain:
raise ValidationException( raise ValidationException(
@@ -140,15 +140,14 @@ class AdminService:
# Create owner user with owner_email # Create owner user with owner_email
from middleware.auth import AuthManager from middleware.auth import AuthManager
auth_manager = AuthManager() auth_manager = AuthManager()
owner_username = f"{vendor_data.subdomain}_owner" owner_username = f"{vendor_data.subdomain}_owner"
owner_email = vendor_data.owner_email # ✅ For User authentication owner_email = vendor_data.owner_email # ✅ For User authentication
# Check if user with this email already exists # Check if user with this email already exists
existing_user = db.query(User).filter( existing_user = db.query(User).filter(User.email == owner_email).first()
User.email == owner_email
).first()
if existing_user: if existing_user:
# Use existing user as owner # Use existing user as owner
@@ -215,7 +214,7 @@ class AdminService:
logger.error(f"Failed to create vendor: {str(e)}") logger.error(f"Failed to create vendor: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="create_vendor_with_owner", operation="create_vendor_with_owner",
reason=f"Failed to create vendor: {str(e)}" reason=f"Failed to create vendor: {str(e)}",
) )
def get_all_vendors( def get_all_vendors(
@@ -225,7 +224,7 @@ class AdminService:
limit: int = 100, limit: int = 100,
search: Optional[str] = None, search: Optional[str] = None,
is_active: Optional[bool] = None, is_active: Optional[bool] = None,
is_verified: Optional[bool] = None is_verified: Optional[bool] = None,
) -> Tuple[List[Vendor], int]: ) -> Tuple[List[Vendor], int]:
"""Get paginated list of all vendors with filtering.""" """Get paginated list of all vendors with filtering."""
try: try:
@@ -238,7 +237,7 @@ class AdminService:
or_( or_(
Vendor.name.ilike(search_term), Vendor.name.ilike(search_term),
Vendor.vendor_code.ilike(search_term), Vendor.vendor_code.ilike(search_term),
Vendor.subdomain.ilike(search_term) Vendor.subdomain.ilike(search_term),
) )
) )
@@ -255,8 +254,7 @@ class AdminService:
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve vendors: {str(e)}") logger.error(f"Failed to retrieve vendors: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_all_vendors", operation="get_all_vendors", reason="Database query failed"
reason="Database query failed"
) )
def get_vendor_by_id(self, db: Session, vendor_id: int) -> Vendor: def get_vendor_by_id(self, db: Session, vendor_id: int) -> Vendor:
@@ -290,7 +288,7 @@ class AdminService:
raise VendorVerificationException( raise VendorVerificationException(
vendor_id=vendor_id, vendor_id=vendor_id,
reason="Database update failed", reason="Database update failed",
current_verification_status=original_status current_verification_status=original_status,
) )
def toggle_vendor_status(self, db: Session, vendor_id: int) -> Tuple[Vendor, str]: def toggle_vendor_status(self, db: Session, vendor_id: int) -> Tuple[Vendor, str]:
@@ -317,7 +315,7 @@ class AdminService:
operation="toggle_vendor_status", operation="toggle_vendor_status",
reason="Database update failed", reason="Database update failed",
target_type="vendor", target_type="vendor",
target_id=str(vendor_id) target_id=str(vendor_id),
) )
def delete_vendor(self, db: Session, vendor_id: int) -> str: def delete_vendor(self, db: Session, vendor_id: int) -> str:
@@ -345,15 +343,11 @@ class AdminService:
db.rollback() db.rollback()
logger.error(f"Failed to delete vendor {vendor_id}: {str(e)}") logger.error(f"Failed to delete vendor {vendor_id}: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="delete_vendor", operation="delete_vendor", reason="Database deletion failed"
reason="Database deletion failed"
) )
def update_vendor( def update_vendor(
self, self, db: Session, vendor_id: int, vendor_update # VendorUpdate schema
db: Session,
vendor_id: int,
vendor_update # VendorUpdate schema
) -> Vendor: ) -> Vendor:
""" """
Update vendor information (Admin only). Update vendor information (Admin only).
@@ -387,11 +381,18 @@ class AdminService:
update_data = vendor_update.model_dump(exclude_unset=True) update_data = vendor_update.model_dump(exclude_unset=True)
# Check subdomain uniqueness if changing # Check subdomain uniqueness if changing
if 'subdomain' in update_data and update_data['subdomain'] != vendor.subdomain: if (
existing = db.query(Vendor).filter( "subdomain" in update_data
Vendor.subdomain == update_data['subdomain'], and update_data["subdomain"] != vendor.subdomain
Vendor.id != vendor_id ):
).first() existing = (
db.query(Vendor)
.filter(
Vendor.subdomain == update_data["subdomain"],
Vendor.id != vendor_id,
)
.first()
)
if existing: if existing:
raise ValidationException( raise ValidationException(
f"Subdomain '{update_data['subdomain']}' is already taken" f"Subdomain '{update_data['subdomain']}' is already taken"
@@ -419,8 +420,7 @@ class AdminService:
db.rollback() db.rollback()
logger.error(f"Failed to update vendor {vendor_id}: {str(e)}") logger.error(f"Failed to update vendor {vendor_id}: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="update_vendor", operation="update_vendor", reason=f"Database update failed: {str(e)}"
reason=f"Database update failed: {str(e)}"
) )
# Add this NEW method for transferring ownership: # Add this NEW method for transferring ownership:
@@ -429,7 +429,7 @@ class AdminService:
self, self,
db: Session, db: Session,
vendor_id: int, vendor_id: int,
transfer_data # VendorTransferOwnership schema transfer_data, # VendorTransferOwnership schema
) -> Tuple[Vendor, User, User]: ) -> Tuple[Vendor, User, User]:
""" """
Transfer vendor ownership to another user. Transfer vendor ownership to another user.
@@ -466,9 +466,9 @@ class AdminService:
old_owner = vendor.owner old_owner = vendor.owner
# Get new owner # Get new owner
new_owner = db.query(User).filter( new_owner = (
User.id == transfer_data.new_owner_user_id db.query(User).filter(User.id == transfer_data.new_owner_user_id).first()
).first() )
if not new_owner: if not new_owner:
raise UserNotFoundException(str(transfer_data.new_owner_user_id)) raise UserNotFoundException(str(transfer_data.new_owner_user_id))
@@ -487,26 +487,32 @@ class AdminService:
try: try:
# Get Owner role for this vendor # Get Owner role for this vendor
owner_role = db.query(Role).filter( owner_role = (
Role.vendor_id == vendor_id, db.query(Role)
Role.name == "Owner" .filter(Role.vendor_id == vendor_id, Role.name == "Owner")
).first() .first()
)
if not owner_role: if not owner_role:
raise ValidationException("Owner role not found for vendor") raise ValidationException("Owner role not found for vendor")
# Get Manager role (to demote old owner) # Get Manager role (to demote old owner)
manager_role = db.query(Role).filter( manager_role = (
Role.vendor_id == vendor_id, db.query(Role)
Role.name == "Manager" .filter(Role.vendor_id == vendor_id, Role.name == "Manager")
).first() .first()
)
# Remove old owner from Owner role # Remove old owner from Owner role
old_owner_link = db.query(VendorUser).filter( old_owner_link = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id, VendorUser.vendor_id == vendor_id,
VendorUser.user_id == old_owner.id, VendorUser.user_id == old_owner.id,
VendorUser.role_id == owner_role.id VendorUser.role_id == owner_role.id,
).first() )
.first()
)
if old_owner_link: if old_owner_link:
if manager_role: if manager_role:
@@ -525,10 +531,14 @@ class AdminService:
) )
# Check if new owner already has a vendor_user link # Check if new owner already has a vendor_user link
new_owner_link = db.query(VendorUser).filter( new_owner_link = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id, VendorUser.vendor_id == vendor_id,
VendorUser.user_id == new_owner.id VendorUser.user_id == new_owner.id,
).first() )
.first()
)
if new_owner_link: if new_owner_link:
# Update existing link to Owner role # Update existing link to Owner role
@@ -540,7 +550,7 @@ class AdminService:
vendor_id=vendor_id, vendor_id=vendor_id,
user_id=new_owner.id, user_id=new_owner.id,
role_id=owner_role.id, role_id=owner_role.id,
is_active=True is_active=True,
) )
db.add(new_owner_link) db.add(new_owner_link)
@@ -568,10 +578,12 @@ class AdminService:
raise raise
except Exception as e: except Exception as e:
db.rollback() db.rollback()
logger.error(f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}") logger.error(
f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException( raise AdminOperationException(
operation="transfer_vendor_ownership", operation="transfer_vendor_ownership",
reason=f"Ownership transfer failed: {str(e)}" reason=f"Ownership transfer failed: {str(e)}",
) )
# ============================================================================ # ============================================================================
@@ -596,7 +608,9 @@ class AdminService:
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
) )
if vendor_name: if vendor_name:
query = query.filter(MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%")) query = query.filter(
MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%")
)
if status: if status:
query = query.filter(MarketplaceImportJob.status == status) query = query.filter(MarketplaceImportJob.status == status)
@@ -612,8 +626,7 @@ class AdminService:
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve marketplace import jobs: {str(e)}") logger.error(f"Failed to retrieve marketplace import jobs: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_marketplace_import_jobs", operation="get_marketplace_import_jobs", reason="Database query failed"
reason="Database query failed"
) )
# ============================================================================ # ============================================================================
@@ -624,10 +637,7 @@ class AdminService:
"""Get recently created vendors.""" """Get recently created vendors."""
try: try:
vendors = ( vendors = (
db.query(Vendor) db.query(Vendor).order_by(Vendor.created_at.desc()).limit(limit).all()
.order_by(Vendor.created_at.desc())
.limit(limit)
.all()
) )
return [ return [
@@ -638,7 +648,7 @@ class AdminService:
"subdomain": v.subdomain, "subdomain": v.subdomain,
"is_active": v.is_active, "is_active": v.is_active,
"is_verified": v.is_verified, "is_verified": v.is_verified,
"created_at": v.created_at "created_at": v.created_at,
} }
for v in vendors for v in vendors
] ]
@@ -663,7 +673,7 @@ class AdminService:
"vendor_name": j.vendor_name, "vendor_name": j.vendor_name,
"status": j.status, "status": j.status,
"total_processed": j.total_processed or 0, "total_processed": j.total_processed or 0,
"created_at": j.created_at "created_at": j.created_at,
} }
for j in jobs for j in jobs
] ]
@@ -692,47 +702,53 @@ class AdminService:
def _generate_temp_password(self, length: int = 12) -> str: def _generate_temp_password(self, length: int = 12) -> str:
"""Generate secure temporary password.""" """Generate secure temporary password."""
alphabet = string.ascii_letters + string.digits + "!@#$%^&*" alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(length)) return "".join(secrets.choice(alphabet) for _ in range(length))
def _create_default_roles(self, db: Session, vendor_id: int): def _create_default_roles(self, db: Session, vendor_id: int):
"""Create default roles for a new vendor.""" """Create default roles for a new vendor."""
default_roles = [ default_roles = [
{ {"name": "Owner", "permissions": ["*"]}, # Full access
"name": "Owner",
"permissions": ["*"] # Full access
},
{ {
"name": "Manager", "name": "Manager",
"permissions": [ "permissions": [
"products.*", "orders.*", "customers.view", "products.*",
"inventory.*", "team.view" "orders.*",
] "customers.view",
"inventory.*",
"team.view",
],
}, },
{ {
"name": "Editor", "name": "Editor",
"permissions": [ "permissions": [
"products.view", "products.edit", "products.view",
"orders.view", "inventory.view" "products.edit",
] "orders.view",
"inventory.view",
],
}, },
{ {
"name": "Viewer", "name": "Viewer",
"permissions": [ "permissions": [
"products.view", "orders.view", "products.view",
"customers.view", "inventory.view" "orders.view",
] "customers.view",
} "inventory.view",
],
},
] ]
for role_data in default_roles: for role_data in default_roles:
role = Role( role = Role(
vendor_id=vendor_id, vendor_id=vendor_id,
name=role_data["name"], name=role_data["name"],
permissions=role_data["permissions"] permissions=role_data["permissions"],
) )
db.add(role) db.add(role)
def _convert_job_to_response(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse: def _convert_job_to_response(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to response schema.""" """Convert database model to response schema."""
return MarketplaceImportJobResponse( return MarketplaceImportJobResponse(
job_id=job.id, job_id=job.id,

View File

@@ -8,25 +8,19 @@ This module provides functions for:
- Encrypting sensitive settings - Encrypting sensitive settings
""" """
import logging
import json import json
from typing import Optional, List, Any, Dict import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (AdminOperationException, ResourceNotFoundException,
ValidationException)
from models.database.admin import AdminSetting from models.database.admin import AdminSetting
from models.schema.admin import ( from models.schema.admin import (AdminSettingCreate, AdminSettingResponse,
AdminSettingCreate, AdminSettingUpdate)
AdminSettingResponse,
AdminSettingUpdate
)
from app.exceptions import (
AdminOperationException,
ValidationException,
ResourceNotFoundException
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,26 +28,19 @@ logger = logging.getLogger(__name__)
class AdminSettingsService: class AdminSettingsService:
"""Service for managing platform-wide settings.""" """Service for managing platform-wide settings."""
def get_setting_by_key( def get_setting_by_key(self, db: Session, key: str) -> Optional[AdminSetting]:
self,
db: Session,
key: str
) -> Optional[AdminSetting]:
"""Get setting by key.""" """Get setting by key."""
try: try:
return db.query(AdminSetting).filter( return (
func.lower(AdminSetting.key) == key.lower() db.query(AdminSetting)
).first() .filter(func.lower(AdminSetting.key) == key.lower())
.first()
)
except Exception as e: except Exception as e:
logger.error(f"Failed to get setting {key}: {str(e)}") logger.error(f"Failed to get setting {key}: {str(e)}")
return None return None
def get_setting_value( def get_setting_value(self, db: Session, key: str, default: Any = None) -> Any:
self,
db: Session,
key: str,
default: Any = None
) -> Any:
""" """
Get setting value with type conversion. Get setting value with type conversion.
@@ -76,7 +63,7 @@ class AdminSettingsService:
elif setting.value_type == "float": elif setting.value_type == "float":
return float(setting.value) return float(setting.value)
elif setting.value_type == "boolean": elif setting.value_type == "boolean":
return setting.value.lower() in ('true', '1', 'yes') return setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json": elif setting.value_type == "json":
return json.loads(setting.value) return json.loads(setting.value)
else: else:
@@ -89,7 +76,7 @@ class AdminSettingsService:
self, self,
db: Session, db: Session,
category: Optional[str] = None, category: Optional[str] = None,
is_public: Optional[bool] = None is_public: Optional[bool] = None,
) -> List[AdminSettingResponse]: ) -> List[AdminSettingResponse]:
"""Get all settings with optional filtering.""" """Get all settings with optional filtering."""
try: try:
@@ -104,22 +91,16 @@ class AdminSettingsService:
settings = query.order_by(AdminSetting.category, AdminSetting.key).all() settings = query.order_by(AdminSetting.category, AdminSetting.key).all()
return [ return [
AdminSettingResponse.model_validate(setting) AdminSettingResponse.model_validate(setting) for setting in settings
for setting in settings
] ]
except Exception as e: except Exception as e:
logger.error(f"Failed to get settings: {str(e)}") logger.error(f"Failed to get settings: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_all_settings", operation="get_all_settings", reason="Database query failed"
reason="Database query failed"
) )
def get_settings_by_category( def get_settings_by_category(self, db: Session, category: str) -> Dict[str, Any]:
self,
db: Session,
category: str
) -> Dict[str, Any]:
""" """
Get all settings in a category as a dictionary. Get all settings in a category as a dictionary.
@@ -136,7 +117,7 @@ class AdminSettingsService:
elif setting.value_type == "float": elif setting.value_type == "float":
result[setting.key] = float(setting.value) result[setting.key] = float(setting.value)
elif setting.value_type == "boolean": elif setting.value_type == "boolean":
result[setting.key] = setting.value.lower() in ('true', '1', 'yes') result[setting.key] = setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json": elif setting.value_type == "json":
result[setting.key] = json.loads(setting.value) result[setting.key] = json.loads(setting.value)
else: else:
@@ -145,10 +126,7 @@ class AdminSettingsService:
return result return result
def create_setting( def create_setting(
self, self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
db: Session,
setting_data: AdminSettingCreate,
admin_user_id: int
) -> AdminSettingResponse: ) -> AdminSettingResponse:
"""Create new setting.""" """Create new setting."""
try: try:
@@ -176,7 +154,7 @@ class AdminSettingsService:
description=setting_data.description, description=setting_data.description,
is_encrypted=setting_data.is_encrypted, is_encrypted=setting_data.is_encrypted,
is_public=setting_data.is_public, is_public=setting_data.is_public,
last_modified_by_user_id=admin_user_id last_modified_by_user_id=admin_user_id,
) )
db.add(setting) db.add(setting)
@@ -194,25 +172,17 @@ class AdminSettingsService:
db.rollback() db.rollback()
logger.error(f"Failed to create setting: {str(e)}") logger.error(f"Failed to create setting: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="create_setting", operation="create_setting", reason="Database operation failed"
reason="Database operation failed"
) )
def update_setting( def update_setting(
self, self, db: Session, key: str, update_data: AdminSettingUpdate, admin_user_id: int
db: Session,
key: str,
update_data: AdminSettingUpdate,
admin_user_id: int
) -> AdminSettingResponse: ) -> AdminSettingResponse:
"""Update existing setting.""" """Update existing setting."""
setting = self.get_setting_by_key(db, key) setting = self.get_setting_by_key(db, key)
if not setting: if not setting:
raise ResourceNotFoundException( raise ResourceNotFoundException(resource_type="setting", identifier=key)
resource_type="setting",
identifier=key
)
try: try:
# Validate new value # Validate new value
@@ -244,42 +214,29 @@ class AdminSettingsService:
db.rollback() db.rollback()
logger.error(f"Failed to update setting {key}: {str(e)}") logger.error(f"Failed to update setting {key}: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="update_setting", operation="update_setting", reason="Database operation failed"
reason="Database operation failed"
) )
def upsert_setting( def upsert_setting(
self, self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
db: Session,
setting_data: AdminSettingCreate,
admin_user_id: int
) -> AdminSettingResponse: ) -> AdminSettingResponse:
"""Create or update setting (upsert).""" """Create or update setting (upsert)."""
existing = self.get_setting_by_key(db, setting_data.key) existing = self.get_setting_by_key(db, setting_data.key)
if existing: if existing:
update_data = AdminSettingUpdate( update_data = AdminSettingUpdate(
value=setting_data.value, value=setting_data.value, description=setting_data.description
description=setting_data.description
) )
return self.update_setting(db, setting_data.key, update_data, admin_user_id) return self.update_setting(db, setting_data.key, update_data, admin_user_id)
else: else:
return self.create_setting(db, setting_data, admin_user_id) return self.create_setting(db, setting_data, admin_user_id)
def delete_setting( def delete_setting(self, db: Session, key: str, admin_user_id: int) -> str:
self,
db: Session,
key: str,
admin_user_id: int
) -> str:
"""Delete setting.""" """Delete setting."""
setting = self.get_setting_by_key(db, key) setting = self.get_setting_by_key(db, key)
if not setting: if not setting:
raise ResourceNotFoundException( raise ResourceNotFoundException(resource_type="setting", identifier=key)
resource_type="setting",
identifier=key
)
try: try:
db.delete(setting) db.delete(setting)
@@ -293,8 +250,7 @@ class AdminSettingsService:
db.rollback() db.rollback()
logger.error(f"Failed to delete setting {key}: {str(e)}") logger.error(f"Failed to delete setting {key}: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="delete_setting", operation="delete_setting", reason="Database operation failed"
reason="Database operation failed"
) )
# ============================================================================ # ============================================================================
@@ -309,7 +265,7 @@ class AdminSettingsService:
elif value_type == "float": elif value_type == "float":
float(value) float(value)
elif value_type == "boolean": elif value_type == "boolean":
if value.lower() not in ('true', 'false', '1', '0', 'yes', 'no'): if value.lower() not in ("true", "false", "1", "0", "yes", "no"):
raise ValueError("Invalid boolean value") raise ValueError("Invalid boolean value")
elif value_type == "json": elif value_type == "json":
json.loads(value) json.loads(value)

View File

@@ -13,15 +13,12 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (InvalidCredentialsException,
UserAlreadyExistsException, UserAlreadyExistsException, UserNotActiveException,
InvalidCredentialsException, ValidationException)
UserNotActiveException,
ValidationException,
)
from middleware.auth import AuthManager from middleware.auth import AuthManager
from models.schema.auth import UserLogin, UserRegister
from models.database.user import User from models.database.user import User
from models.schema.auth import UserLogin, UserRegister
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -51,11 +48,15 @@ class AuthService:
try: try:
# Check if email already exists # Check if email already exists
if self._email_exists(db, user_data.email): if self._email_exists(db, user_data.email):
raise UserAlreadyExistsException("Email already registered", field="email") raise UserAlreadyExistsException(
"Email already registered", field="email"
)
# Check if username already exists # Check if username already exists
if self._username_exists(db, user_data.username): if self._username_exists(db, user_data.username):
raise UserAlreadyExistsException("Username already taken", field="username") raise UserAlreadyExistsException(
"Username already taken", field="username"
)
# Hash password and create user # Hash password and create user
hashed_password = self.auth_manager.hash_password(user_data.password) hashed_password = self.auth_manager.hash_password(user_data.password)
@@ -182,7 +183,9 @@ class AuthService:
Dictionary with access_token, token_type, and expires_in Dictionary with access_token, token_type, and expires_in
""" """
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from jose import jwt from jose import jwt
from app.core.config import settings from app.core.config import settings
try: try:
@@ -217,6 +220,5 @@ class AuthService:
return db.query(User).filter(User.username == username).first() is not None return db.query(User).filter(User.username == username).first() is not None
# Create service instance following the same pattern as other services # Create service instance following the same pattern as other services
auth_service = AuthService() auth_service = AuthService()

View File

@@ -9,23 +9,20 @@ This module provides:
""" """
import logging import logging
from typing import Dict, List, Optional
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy.orm import Session
from models.database.product import Product from app.exceptions import (CartItemNotFoundException, CartValidationException,
from models.database.vendor import Vendor
from models.database.cart import CartItem
from app.exceptions import (
ProductNotFoundException,
CartItemNotFoundException,
CartValidationException,
InsufficientInventoryForCartException, InsufficientInventoryForCartException,
InvalidCartQuantityException, InvalidCartQuantityException,
ProductNotAvailableForCartException, ProductNotAvailableForCartException,
) ProductNotFoundException)
from models.database.cart import CartItem
from models.database.product import Product
from models.database.vendor import Vendor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,12 +30,7 @@ logger = logging.getLogger(__name__)
class CartService: class CartService:
"""Service for managing shopping carts.""" """Service for managing shopping carts."""
def get_cart( def get_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
self,
db: Session,
vendor_id: int,
session_id: str
) -> Dict:
""" """
Get cart contents for a session. Get cart contents for a session.
@@ -55,20 +47,21 @@ class CartService:
extra={ extra={
"vendor_id": vendor_id, "vendor_id": vendor_id,
"session_id": session_id, "session_id": session_id,
} },
) )
# Fetch cart items from database # Fetch cart items from database
cart_items = db.query(CartItem).filter( cart_items = (
and_( db.query(CartItem)
CartItem.vendor_id == vendor_id, .filter(
CartItem.session_id == session_id and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
.all()
) )
).all()
logger.info( logger.info(
f"[CART_SERVICE] Found {len(cart_items)} items in database", f"[CART_SERVICE] Found {len(cart_items)} items in database",
extra={"item_count": len(cart_items)} extra={"item_count": len(cart_items)},
) )
# Build response # Build response
@@ -79,14 +72,20 @@ class CartService:
product = cart_item.product product = cart_item.product
line_total = cart_item.line_total line_total = cart_item.line_total
items.append({ items.append(
{
"product_id": product.id, "product_id": product.id,
"product_name": product.marketplace_product.title, "product_name": product.marketplace_product.title,
"quantity": cart_item.quantity, "quantity": cart_item.quantity,
"price": cart_item.price_at_add, "price": cart_item.price_at_add,
"line_total": line_total, "line_total": line_total,
"image_url": product.marketplace_product.image_link if product.marketplace_product else None, "image_url": (
}) product.marketplace_product.image_link
if product.marketplace_product
else None
),
}
)
subtotal += line_total subtotal += line_total
@@ -95,12 +94,12 @@ class CartService:
"session_id": session_id, "session_id": session_id,
"items": items, "items": items,
"subtotal": subtotal, "subtotal": subtotal,
"total": subtotal # Could add tax/shipping later "total": subtotal, # Could add tax/shipping later
} }
logger.info( logger.info(
f"[CART_SERVICE] get_cart returning: {len(cart_data['items'])} items, total: {cart_data['total']}", f"[CART_SERVICE] get_cart returning: {len(cart_data['items'])} items, total: {cart_data['total']}",
extra={"cart": cart_data} extra={"cart": cart_data},
) )
return cart_data return cart_data
@@ -111,7 +110,7 @@ class CartService:
vendor_id: int, vendor_id: int,
session_id: str, session_id: str,
product_id: int, product_id: int,
quantity: int = 1 quantity: int = 1,
) -> Dict: ) -> Dict:
""" """
Add product to cart. Add product to cart.
@@ -136,23 +135,27 @@ class CartService:
"vendor_id": vendor_id, "vendor_id": vendor_id,
"session_id": session_id, "session_id": session_id,
"product_id": product_id, "product_id": product_id,
"quantity": quantity "quantity": quantity,
} },
) )
# Verify product exists and belongs to vendor # Verify product exists and belongs to vendor
product = db.query(Product).filter( product = (
db.query(Product)
.filter(
and_( and_(
Product.id == product_id, Product.id == product_id,
Product.vendor_id == vendor_id, Product.vendor_id == vendor_id,
Product.is_active == True Product.is_active == True,
)
)
.first()
) )
).first()
if not product: if not product:
logger.error( logger.error(
f"[CART_SERVICE] Product not found", f"[CART_SERVICE] Product not found",
extra={"product_id": product_id, "vendor_id": vendor_id} extra={"product_id": product_id, "vendor_id": vendor_id},
) )
raise ProductNotFoundException(product_id=product_id, vendor_id=vendor_id) raise ProductNotFoundException(product_id=product_id, vendor_id=vendor_id)
@@ -161,21 +164,25 @@ class CartService:
extra={ extra={
"product_id": product_id, "product_id": product_id,
"product_name": product.marketplace_product.title, "product_name": product.marketplace_product.title,
"available_inventory": product.available_inventory "available_inventory": product.available_inventory,
} },
) )
# Get current price (use sale_price if available, otherwise regular price) # Get current price (use sale_price if available, otherwise regular price)
current_price = product.sale_price if product.sale_price else product.price current_price = product.sale_price if product.sale_price else product.price
# Check if item already exists in cart # Check if item already exists in cart
existing_item = db.query(CartItem).filter( existing_item = (
db.query(CartItem)
.filter(
and_( and_(
CartItem.vendor_id == vendor_id, CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id, CartItem.session_id == session_id,
CartItem.product_id == product_id CartItem.product_id == product_id,
)
)
.first()
) )
).first()
if existing_item: if existing_item:
# Update quantity # Update quantity
@@ -190,14 +197,14 @@ class CartService:
"current_in_cart": existing_item.quantity, "current_in_cart": existing_item.quantity,
"adding": quantity, "adding": quantity,
"requested_total": new_quantity, "requested_total": new_quantity,
"available": product.available_inventory "available": product.available_inventory,
} },
) )
raise InsufficientInventoryForCartException( raise InsufficientInventoryForCartException(
product_id=product_id, product_id=product_id,
product_name=product.marketplace_product.title, product_name=product.marketplace_product.title,
requested=new_quantity, requested=new_quantity,
available=product.available_inventory available=product.available_inventory,
) )
existing_item.quantity = new_quantity existing_item.quantity = new_quantity
@@ -206,16 +213,13 @@ class CartService:
logger.info( logger.info(
f"[CART_SERVICE] Updated existing cart item", f"[CART_SERVICE] Updated existing cart item",
extra={ extra={"cart_item_id": existing_item.id, "new_quantity": new_quantity},
"cart_item_id": existing_item.id,
"new_quantity": new_quantity
}
) )
return { return {
"message": "Product quantity updated in cart", "message": "Product quantity updated in cart",
"product_id": product_id, "product_id": product_id,
"quantity": new_quantity "quantity": new_quantity,
} }
else: else:
# Check inventory for new item # Check inventory for new item
@@ -225,14 +229,14 @@ class CartService:
extra={ extra={
"product_id": product_id, "product_id": product_id,
"requested": quantity, "requested": quantity,
"available": product.available_inventory "available": product.available_inventory,
} },
) )
raise InsufficientInventoryForCartException( raise InsufficientInventoryForCartException(
product_id=product_id, product_id=product_id,
product_name=product.marketplace_product.title, product_name=product.marketplace_product.title,
requested=quantity, requested=quantity,
available=product.available_inventory available=product.available_inventory,
) )
# Create new cart item # Create new cart item
@@ -241,7 +245,7 @@ class CartService:
session_id=session_id, session_id=session_id,
product_id=product_id, product_id=product_id,
quantity=quantity, quantity=quantity,
price_at_add=current_price price_at_add=current_price,
) )
db.add(cart_item) db.add(cart_item)
db.commit() db.commit()
@@ -252,14 +256,14 @@ class CartService:
extra={ extra={
"cart_item_id": cart_item.id, "cart_item_id": cart_item.id,
"quantity": quantity, "quantity": quantity,
"price": current_price "price": current_price,
} },
) )
return { return {
"message": "Product added to cart", "message": "Product added to cart",
"product_id": product_id, "product_id": product_id,
"quantity": quantity "quantity": quantity,
} }
def update_cart_item( def update_cart_item(
@@ -268,7 +272,7 @@ class CartService:
vendor_id: int, vendor_id: int,
session_id: str, session_id: str,
product_id: int, product_id: int,
quantity: int quantity: int,
) -> Dict: ) -> Dict:
""" """
Update quantity of item in cart. Update quantity of item in cart.
@@ -292,25 +296,35 @@ class CartService:
raise InvalidCartQuantityException(quantity=quantity, min_quantity=1) raise InvalidCartQuantityException(quantity=quantity, min_quantity=1)
# Find cart item # Find cart item
cart_item = db.query(CartItem).filter( cart_item = (
db.query(CartItem)
.filter(
and_( and_(
CartItem.vendor_id == vendor_id, CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id, CartItem.session_id == session_id,
CartItem.product_id == product_id CartItem.product_id == product_id,
)
)
.first()
) )
).first()
if not cart_item: if not cart_item:
raise CartItemNotFoundException(product_id=product_id, session_id=session_id) raise CartItemNotFoundException(
product_id=product_id, session_id=session_id
)
# Verify product still exists and is active # Verify product still exists and is active
product = db.query(Product).filter( product = (
db.query(Product)
.filter(
and_( and_(
Product.id == product_id, Product.id == product_id,
Product.vendor_id == vendor_id, Product.vendor_id == vendor_id,
Product.is_active == True Product.is_active == True,
)
)
.first()
) )
).first()
if not product: if not product:
raise ProductNotFoundException(str(product_id)) raise ProductNotFoundException(str(product_id))
@@ -321,7 +335,7 @@ class CartService:
product_id=product_id, product_id=product_id,
product_name=product.marketplace_product.title, product_name=product.marketplace_product.title,
requested=quantity, requested=quantity,
available=product.available_inventory available=product.available_inventory,
) )
# Update quantity # Update quantity
@@ -334,22 +348,18 @@ class CartService:
extra={ extra={
"cart_item_id": cart_item.id, "cart_item_id": cart_item.id,
"product_id": product_id, "product_id": product_id,
"new_quantity": quantity "new_quantity": quantity,
} },
) )
return { return {
"message": "Cart updated", "message": "Cart updated",
"product_id": product_id, "product_id": product_id,
"quantity": quantity "quantity": quantity,
} }
def remove_from_cart( def remove_from_cart(
self, self, db: Session, vendor_id: int, session_id: str, product_id: int
db: Session,
vendor_id: int,
session_id: str,
product_id: int
) -> Dict: ) -> Dict:
""" """
Remove item from cart. Remove item from cart.
@@ -367,16 +377,22 @@ class CartService:
ProductNotFoundException: If product not in cart ProductNotFoundException: If product not in cart
""" """
# Find and delete cart item # Find and delete cart item
cart_item = db.query(CartItem).filter( cart_item = (
db.query(CartItem)
.filter(
and_( and_(
CartItem.vendor_id == vendor_id, CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id, CartItem.session_id == session_id,
CartItem.product_id == product_id CartItem.product_id == product_id,
)
)
.first()
) )
).first()
if not cart_item: if not cart_item:
raise CartItemNotFoundException(product_id=product_id, session_id=session_id) raise CartItemNotFoundException(
product_id=product_id, session_id=session_id
)
db.delete(cart_item) db.delete(cart_item)
db.commit() db.commit()
@@ -386,21 +402,13 @@ class CartService:
extra={ extra={
"cart_item_id": cart_item.id, "cart_item_id": cart_item.id,
"product_id": product_id, "product_id": product_id,
"session_id": session_id "session_id": session_id,
} },
) )
return { return {"message": "Item removed from cart", "product_id": product_id}
"message": "Item removed from cart",
"product_id": product_id
}
def clear_cart( def clear_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
self,
db: Session,
vendor_id: int,
session_id: str
) -> Dict:
""" """
Clear all items from cart. Clear all items from cart.
@@ -413,12 +421,13 @@ class CartService:
Success message with count of items removed Success message with count of items removed
""" """
# Delete all cart items for this session # Delete all cart items for this session
deleted_count = db.query(CartItem).filter( deleted_count = (
and_( db.query(CartItem)
CartItem.vendor_id == vendor_id, .filter(
CartItem.session_id == session_id and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
.delete()
) )
).delete()
db.commit() db.commit()
@@ -427,14 +436,11 @@ class CartService:
extra={ extra={
"session_id": session_id, "session_id": session_id,
"vendor_id": vendor_id, "vendor_id": vendor_id,
"items_removed": deleted_count "items_removed": deleted_count,
} },
) )
return { return {"message": "Cart cleared", "items_removed": deleted_count}
"message": "Cart cleared",
"items_removed": deleted_count
}
# Create service instance # Create service instance

View File

@@ -3,22 +3,20 @@ Code Quality Service
Business logic for managing architecture scans and violations Business logic for managing architecture scans and violations
""" """
import subprocess
import json import json
import logging import logging
import subprocess
from datetime import datetime from datetime import datetime
from typing import List, Tuple, Optional, Dict
from pathlib import Path from pathlib import Path
from sqlalchemy.orm import Session from typing import Dict, List, Optional, Tuple
from sqlalchemy import func, desc
from app.models.architecture_scan import ( from sqlalchemy import desc, func
ArchitectureScan, from sqlalchemy.orm import Session
from app.models.architecture_scan import (ArchitectureRule, ArchitectureScan,
ArchitectureViolation, ArchitectureViolation,
ArchitectureRule,
ViolationAssignment, ViolationAssignment,
ViolationComment ViolationComment)
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,7 +24,7 @@ logger = logging.getLogger(__name__)
class CodeQualityService: class CodeQualityService:
"""Service for managing code quality scans and violations""" """Service for managing code quality scans and violations"""
def run_scan(self, db: Session, triggered_by: str = 'manual') -> ArchitectureScan: def run_scan(self, db: Session, triggered_by: str = "manual") -> ArchitectureScan:
""" """
Run architecture validator and store results in database Run architecture validator and store results in database
@@ -49,10 +47,10 @@ class CodeQualityService:
start_time = datetime.now() start_time = datetime.now()
try: try:
result = subprocess.run( result = subprocess.run(
['python', 'scripts/validate_architecture.py', '--json'], ["python", "scripts/validate_architecture.py", "--json"],
capture_output=True, capture_output=True,
text=True, text=True,
timeout=300 # 5 minute timeout timeout=300, # 5 minute timeout
) )
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.error("Architecture scan timed out after 5 minutes") logger.error("Architecture scan timed out after 5 minutes")
@@ -63,17 +61,17 @@ class CodeQualityService:
# Parse JSON output (get only the JSON part, skip progress messages) # Parse JSON output (get only the JSON part, skip progress messages)
try: try:
# Find the JSON part in stdout # Find the JSON part in stdout
lines = result.stdout.strip().split('\n') lines = result.stdout.strip().split("\n")
json_start = -1 json_start = -1
for i, line in enumerate(lines): for i, line in enumerate(lines):
if line.strip().startswith('{'): if line.strip().startswith("{"):
json_start = i json_start = i
break break
if json_start == -1: if json_start == -1:
raise ValueError("No JSON output found") raise ValueError("No JSON output found")
json_output = '\n'.join(lines[json_start:]) json_output = "\n".join(lines[json_start:])
data = json.loads(json_output) data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse validator output: {e}") logger.error(f"Failed to parse validator output: {e}")
@@ -84,33 +82,33 @@ class CodeQualityService:
# Create scan record # Create scan record
scan = ArchitectureScan( scan = ArchitectureScan(
timestamp=datetime.now(), timestamp=datetime.now(),
total_files=data.get('files_checked', 0), total_files=data.get("files_checked", 0),
total_violations=data.get('total_violations', 0), total_violations=data.get("total_violations", 0),
errors=data.get('errors', 0), errors=data.get("errors", 0),
warnings=data.get('warnings', 0), warnings=data.get("warnings", 0),
duration_seconds=duration, duration_seconds=duration,
triggered_by=triggered_by, triggered_by=triggered_by,
git_commit_hash=git_commit git_commit_hash=git_commit,
) )
db.add(scan) db.add(scan)
db.flush() # Get scan.id db.flush() # Get scan.id
# Create violation records # Create violation records
violations_data = data.get('violations', []) violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} violation records") logger.info(f"Creating {len(violations_data)} violation records")
for v in violations_data: for v in violations_data:
violation = ArchitectureViolation( violation = ArchitectureViolation(
scan_id=scan.id, scan_id=scan.id,
rule_id=v['rule_id'], rule_id=v["rule_id"],
rule_name=v['rule_name'], rule_name=v["rule_name"],
severity=v['severity'], severity=v["severity"],
file_path=v['file_path'], file_path=v["file_path"],
line_number=v['line_number'], line_number=v["line_number"],
message=v['message'], message=v["message"],
context=v.get('context', ''), context=v.get("context", ""),
suggestion=v.get('suggestion', ''), suggestion=v.get("suggestion", ""),
status='open' status="open",
) )
db.add(violation) db.add(violation)
@@ -122,7 +120,11 @@ class CodeQualityService:
def get_latest_scan(self, db: Session) -> Optional[ArchitectureScan]: def get_latest_scan(self, db: Session) -> Optional[ArchitectureScan]:
"""Get the most recent scan""" """Get the most recent scan"""
return db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)).first() return (
db.query(ArchitectureScan)
.order_by(desc(ArchitectureScan.timestamp))
.first()
)
def get_scan_by_id(self, db: Session, scan_id: int) -> Optional[ArchitectureScan]: def get_scan_by_id(self, db: Session, scan_id: int) -> Optional[ArchitectureScan]:
"""Get scan by ID""" """Get scan by ID"""
@@ -139,10 +141,12 @@ class CodeQualityService:
Returns: Returns:
List of ArchitectureScan objects, newest first List of ArchitectureScan objects, newest first
""" """
return db.query(ArchitectureScan)\ return (
.order_by(desc(ArchitectureScan.timestamp))\ db.query(ArchitectureScan)
.limit(limit)\ .order_by(desc(ArchitectureScan.timestamp))
.limit(limit)
.all() .all()
)
def get_violations( def get_violations(
self, self,
@@ -153,7 +157,7 @@ class CodeQualityService:
rule_id: str = None, rule_id: str = None,
file_path: str = None, file_path: str = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0,
) -> Tuple[List[ArchitectureViolation], int]: ) -> Tuple[List[ArchitectureViolation], int]:
""" """
Get violations with filtering and pagination Get violations with filtering and pagination
@@ -194,24 +198,32 @@ class CodeQualityService:
query = query.filter(ArchitectureViolation.rule_id == rule_id) query = query.filter(ArchitectureViolation.rule_id == rule_id)
if file_path: if file_path:
query = query.filter(ArchitectureViolation.file_path.like(f'%{file_path}%')) query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
# Get total count # Get total count
total = query.count() total = query.count()
# Get page of results # Get page of results
violations = query.order_by( violations = (
ArchitectureViolation.severity.desc(), query.order_by(
ArchitectureViolation.file_path ArchitectureViolation.severity.desc(), ArchitectureViolation.file_path
).limit(limit).offset(offset).all() )
.limit(limit)
.offset(offset)
.all()
)
return violations, total return violations, total
def get_violation_by_id(self, db: Session, violation_id: int) -> Optional[ArchitectureViolation]: def get_violation_by_id(
self, db: Session, violation_id: int
) -> Optional[ArchitectureViolation]:
"""Get single violation with details""" """Get single violation with details"""
return db.query(ArchitectureViolation).filter( return (
ArchitectureViolation.id == violation_id db.query(ArchitectureViolation)
).first() .filter(ArchitectureViolation.id == violation_id)
.first()
)
def assign_violation( def assign_violation(
self, self,
@@ -220,7 +232,7 @@ class CodeQualityService:
user_id: int, user_id: int,
assigned_by: int, assigned_by: int,
due_date: datetime = None, due_date: datetime = None,
priority: str = 'medium' priority: str = "medium",
) -> ViolationAssignment: ) -> ViolationAssignment:
""" """
Assign violation to a developer Assign violation to a developer
@@ -239,7 +251,7 @@ class CodeQualityService:
# Update violation status # Update violation status
violation = self.get_violation_by_id(db, violation_id) violation = self.get_violation_by_id(db, violation_id)
if violation: if violation:
violation.status = 'assigned' violation.status = "assigned"
violation.assigned_to = user_id violation.assigned_to = user_id
# Create assignment record # Create assignment record
@@ -248,7 +260,7 @@ class CodeQualityService:
user_id=user_id, user_id=user_id,
assigned_by=assigned_by, assigned_by=assigned_by,
due_date=due_date, due_date=due_date,
priority=priority priority=priority,
) )
db.add(assignment) db.add(assignment)
db.commit() db.commit()
@@ -257,11 +269,7 @@ class CodeQualityService:
return assignment return assignment
def resolve_violation( def resolve_violation(
self, self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
db: Session,
violation_id: int,
resolved_by: int,
resolution_note: str
) -> ArchitectureViolation: ) -> ArchitectureViolation:
""" """
Mark violation as resolved Mark violation as resolved
@@ -279,7 +287,7 @@ class CodeQualityService:
if not violation: if not violation:
raise ValueError(f"Violation {violation_id} not found") raise ValueError(f"Violation {violation_id} not found")
violation.status = 'resolved' violation.status = "resolved"
violation.resolved_at = datetime.now() violation.resolved_at = datetime.now()
violation.resolved_by = resolved_by violation.resolved_by = resolved_by
violation.resolution_note = resolution_note violation.resolution_note = resolution_note
@@ -289,11 +297,7 @@ class CodeQualityService:
return violation return violation
def ignore_violation( def ignore_violation(
self, self, db: Session, violation_id: int, ignored_by: int, reason: str
db: Session,
violation_id: int,
ignored_by: int,
reason: str
) -> ArchitectureViolation: ) -> ArchitectureViolation:
""" """
Mark violation as ignored/won't fix Mark violation as ignored/won't fix
@@ -311,7 +315,7 @@ class CodeQualityService:
if not violation: if not violation:
raise ValueError(f"Violation {violation_id} not found") raise ValueError(f"Violation {violation_id} not found")
violation.status = 'ignored' violation.status = "ignored"
violation.resolved_at = datetime.now() violation.resolved_at = datetime.now()
violation.resolved_by = ignored_by violation.resolved_by = ignored_by
violation.resolution_note = f"Ignored: {reason}" violation.resolution_note = f"Ignored: {reason}"
@@ -321,11 +325,7 @@ class CodeQualityService:
return violation return violation
def add_comment( def add_comment(
self, self, db: Session, violation_id: int, user_id: int, comment: str
db: Session,
violation_id: int,
user_id: int,
comment: str
) -> ViolationComment: ) -> ViolationComment:
""" """
Add comment to violation Add comment to violation
@@ -340,9 +340,7 @@ class CodeQualityService:
ViolationComment object ViolationComment object
""" """
comment_obj = ViolationComment( comment_obj = ViolationComment(
violation_id=violation_id, violation_id=violation_id, user_id=user_id, comment=comment
user_id=user_id,
comment=comment
) )
db.add(comment_obj) db.add(comment_obj)
db.commit() db.commit()
@@ -360,79 +358,95 @@ class CodeQualityService:
latest_scan = self.get_latest_scan(db) latest_scan = self.get_latest_scan(db)
if not latest_scan: if not latest_scan:
return { return {
'total_violations': 0, "total_violations": 0,
'errors': 0, "errors": 0,
'warnings': 0, "warnings": 0,
'open': 0, "open": 0,
'assigned': 0, "assigned": 0,
'resolved': 0, "resolved": 0,
'ignored': 0, "ignored": 0,
'technical_debt_score': 100, "technical_debt_score": 100,
'trend': [], "trend": [],
'by_severity': {}, "by_severity": {},
'by_rule': {}, "by_rule": {},
'by_module': {}, "by_module": {},
'top_files': [] "top_files": [],
} }
# Get violation counts by status # Get violation counts by status
status_counts = db.query( status_counts = (
ArchitectureViolation.status, db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
func.count(ArchitectureViolation.id) .filter(ArchitectureViolation.scan_id == latest_scan.id)
).filter( .group_by(ArchitectureViolation.status)
ArchitectureViolation.scan_id == latest_scan.id .all()
).group_by(ArchitectureViolation.status).all() )
status_dict = {status: count for status, count in status_counts} status_dict = {status: count for status, count in status_counts}
# Get violations by severity # Get violations by severity
severity_counts = db.query( severity_counts = (
ArchitectureViolation.severity, db.query(
func.count(ArchitectureViolation.id) ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
).filter( )
ArchitectureViolation.scan_id == latest_scan.id .filter(ArchitectureViolation.scan_id == latest_scan.id)
).group_by(ArchitectureViolation.severity).all() .group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts} by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule # Get violations by rule
rule_counts = db.query( rule_counts = (
ArchitectureViolation.rule_id, db.query(
func.count(ArchitectureViolation.id) ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
).filter( )
ArchitectureViolation.scan_id == latest_scan.id .filter(ArchitectureViolation.scan_id == latest_scan.id)
).group_by(ArchitectureViolation.rule_id).all() .group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {rule: count for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]} by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[
:10
]
}
# Get top violating files # Get top violating files
file_counts = db.query( file_counts = (
db.query(
ArchitectureViolation.file_path, ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label('count') func.count(ArchitectureViolation.id).label("count"),
).filter( )
ArchitectureViolation.scan_id == latest_scan.id .filter(ArchitectureViolation.scan_id == latest_scan.id)
).group_by(ArchitectureViolation.file_path)\ .group_by(ArchitectureViolation.file_path)
.order_by(desc('count'))\ .order_by(desc("count"))
.limit(10).all() .limit(10)
.all()
)
top_files = [{'file': file, 'count': count} for file, count in file_counts] top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module (extract module from file path) # Get violations by module (extract module from file path)
by_module = {} by_module = {}
violations = db.query(ArchitectureViolation.file_path).filter( violations = (
ArchitectureViolation.scan_id == latest_scan.id db.query(ArchitectureViolation.file_path)
).all() .filter(ArchitectureViolation.scan_id == latest_scan.id)
.all()
)
for v in violations: for v in violations:
path_parts = v.file_path.split('/') path_parts = v.file_path.split("/")
if len(path_parts) >= 2: if len(path_parts) >= 2:
module = '/'.join(path_parts[:2]) # e.g., 'app/api' module = "/".join(path_parts[:2]) # e.g., 'app/api'
else: else:
module = path_parts[0] module = path_parts[0]
by_module[module] = by_module.get(module, 0) + 1 by_module[module] = by_module.get(module, 0) + 1
# Sort by count and take top 10 # Sort by count and take top 10
by_module = dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]) by_module = dict(
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
)
# Calculate technical debt score # Calculate technical debt score
tech_debt_score = self.calculate_technical_debt_score(db, latest_scan.id) tech_debt_score = self.calculate_technical_debt_score(db, latest_scan.id)
@@ -441,29 +455,29 @@ class CodeQualityService:
trend_scans = self.get_scan_history(db, limit=7) trend_scans = self.get_scan_history(db, limit=7)
trend = [ trend = [
{ {
'timestamp': scan.timestamp.isoformat(), "timestamp": scan.timestamp.isoformat(),
'violations': scan.total_violations, "violations": scan.total_violations,
'errors': scan.errors, "errors": scan.errors,
'warnings': scan.warnings "warnings": scan.warnings,
} }
for scan in reversed(trend_scans) # Oldest first for chart for scan in reversed(trend_scans) # Oldest first for chart
] ]
return { return {
'total_violations': latest_scan.total_violations, "total_violations": latest_scan.total_violations,
'errors': latest_scan.errors, "errors": latest_scan.errors,
'warnings': latest_scan.warnings, "warnings": latest_scan.warnings,
'open': status_dict.get('open', 0), "open": status_dict.get("open", 0),
'assigned': status_dict.get('assigned', 0), "assigned": status_dict.get("assigned", 0),
'resolved': status_dict.get('resolved', 0), "resolved": status_dict.get("resolved", 0),
'ignored': status_dict.get('ignored', 0), "ignored": status_dict.get("ignored", 0),
'technical_debt_score': tech_debt_score, "technical_debt_score": tech_debt_score,
'trend': trend, "trend": trend,
'by_severity': by_severity, "by_severity": by_severity,
'by_rule': by_rule, "by_rule": by_rule,
'by_module': by_module, "by_module": by_module,
'top_files': top_files, "top_files": top_files,
'last_scan': latest_scan.timestamp.isoformat() if latest_scan else None "last_scan": latest_scan.timestamp.isoformat() if latest_scan else None,
} }
def calculate_technical_debt_score(self, db: Session, scan_id: int = None) -> int: def calculate_technical_debt_score(self, db: Session, scan_id: int = None) -> int:
@@ -497,10 +511,7 @@ class CodeQualityService:
"""Get current git commit hash""" """Get current git commit hash"""
try: try:
result = subprocess.run( result = subprocess.run(
['git', 'rev-parse', 'HEAD'], ["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
capture_output=True,
text=True,
timeout=5
) )
if result.returncode == 0: if result.returncode == 0:
return result.stdout.strip()[:40] return result.stdout.strip()[:40]

View File

@@ -19,8 +19,9 @@ This allows:
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from models.database.content_page import ContentPage from models.database.content_page import ContentPage
@@ -35,7 +36,7 @@ class ContentPageService:
db: Session, db: Session,
slug: str, slug: str,
vendor_id: Optional[int] = None, vendor_id: Optional[int] = None,
include_unpublished: bool = False include_unpublished: bool = False,
) -> Optional[ContentPage]: ) -> Optional[ContentPage]:
""" """
Get content page for a vendor with fallback to platform default. Get content page for a vendor with fallback to platform default.
@@ -62,28 +63,20 @@ class ContentPageService:
if vendor_id: if vendor_id:
vendor_page = ( vendor_page = (
db.query(ContentPage) db.query(ContentPage)
.filter( .filter(and_(ContentPage.vendor_id == vendor_id, *filters))
and_(
ContentPage.vendor_id == vendor_id,
*filters
)
)
.first() .first()
) )
if vendor_page: if vendor_page:
logger.debug(f"Found vendor-specific page: {slug} for vendor_id={vendor_id}") logger.debug(
f"Found vendor-specific page: {slug} for vendor_id={vendor_id}"
)
return vendor_page return vendor_page
# Fallback to platform default # Fallback to platform default
platform_page = ( platform_page = (
db.query(ContentPage) db.query(ContentPage)
.filter( .filter(and_(ContentPage.vendor_id == None, *filters))
and_(
ContentPage.vendor_id == None,
*filters
)
)
.first() .first()
) )
@@ -100,7 +93,7 @@ class ContentPageService:
vendor_id: Optional[int] = None, vendor_id: Optional[int] = None,
include_unpublished: bool = False, include_unpublished: bool = False,
footer_only: bool = False, footer_only: bool = False,
header_only: bool = False header_only: bool = False,
) -> List[ContentPage]: ) -> List[ContentPage]:
""" """
List all available pages for a vendor (includes vendor overrides + platform defaults). List all available pages for a vendor (includes vendor overrides + platform defaults).
@@ -133,12 +126,7 @@ class ContentPageService:
if vendor_id: if vendor_id:
vendor_pages = ( vendor_pages = (
db.query(ContentPage) db.query(ContentPage)
.filter( .filter(and_(ContentPage.vendor_id == vendor_id, *filters))
and_(
ContentPage.vendor_id == vendor_id,
*filters
)
)
.order_by(ContentPage.display_order, ContentPage.title) .order_by(ContentPage.display_order, ContentPage.title)
.all() .all()
) )
@@ -146,12 +134,7 @@ class ContentPageService:
# Get platform defaults # Get platform defaults
platform_pages = ( platform_pages = (
db.query(ContentPage) db.query(ContentPage)
.filter( .filter(and_(ContentPage.vendor_id == None, *filters))
and_(
ContentPage.vendor_id == None,
*filters
)
)
.order_by(ContentPage.display_order, ContentPage.title) .order_by(ContentPage.display_order, ContentPage.title)
.all() .all()
) )
@@ -159,8 +142,7 @@ class ContentPageService:
# Merge: vendor overrides take precedence # Merge: vendor overrides take precedence
vendor_slugs = {page.slug for page in vendor_pages} vendor_slugs = {page.slug for page in vendor_pages}
all_pages = vendor_pages + [ all_pages = vendor_pages + [
page for page in platform_pages page for page in platform_pages if page.slug not in vendor_slugs
if page.slug not in vendor_slugs
] ]
# Sort by display_order # Sort by display_order
@@ -183,7 +165,7 @@ class ContentPageService:
show_in_footer: bool = True, show_in_footer: bool = True,
show_in_header: bool = False, show_in_header: bool = False,
display_order: int = 0, display_order: int = 0,
created_by: Optional[int] = None created_by: Optional[int] = None,
) -> ContentPage: ) -> ContentPage:
""" """
Create a new content page. Create a new content page.
@@ -229,7 +211,9 @@ class ContentPageService:
db.commit() db.commit()
db.refresh(page) db.refresh(page)
logger.info(f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})") logger.info(
f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})"
)
return page return page
@staticmethod @staticmethod
@@ -246,7 +230,7 @@ class ContentPageService:
show_in_footer: Optional[bool] = None, show_in_footer: Optional[bool] = None,
show_in_header: Optional[bool] = None, show_in_header: Optional[bool] = None,
display_order: Optional[int] = None, display_order: Optional[int] = None,
updated_by: Optional[int] = None updated_by: Optional[int] = None,
) -> Optional[ContentPage]: ) -> Optional[ContentPage]:
""" """
Update an existing content page. Update an existing content page.
@@ -338,9 +322,7 @@ class ContentPageService:
@staticmethod @staticmethod
def list_all_vendor_pages( def list_all_vendor_pages(
db: Session, db: Session, vendor_id: int, include_unpublished: bool = False
vendor_id: int,
include_unpublished: bool = False
) -> List[ContentPage]: ) -> List[ContentPage]:
""" """
List only vendor-specific pages (no platform defaults). List only vendor-specific pages (no platform defaults).
@@ -367,8 +349,7 @@ class ContentPageService:
@staticmethod @staticmethod
def list_all_platform_pages( def list_all_platform_pages(
db: Session, db: Session, include_unpublished: bool = False
include_unpublished: bool = False
) -> List[ContentPage]: ) -> List[ContentPage]:
""" """
List only platform default pages. List only platform default pages.

View File

@@ -8,24 +8,24 @@ with complete vendor isolation.
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions.customer import (CustomerAlreadyExistsException,
CustomerNotActiveException,
CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException)
from app.exceptions.vendor import (VendorNotActiveException,
VendorNotFoundException)
from app.services.auth_service import AuthService
from models.database.customer import Customer, CustomerAddress from models.database.customer import Customer, CustomerAddress
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.customer import CustomerRegister, CustomerUpdate
from models.schema.auth import UserLogin from models.schema.auth import UserLogin
from app.exceptions.customer import ( from models.schema.customer import CustomerRegister, CustomerUpdate
CustomerNotFoundException,
CustomerAlreadyExistsException,
CustomerNotActiveException,
InvalidCustomerCredentialsException,
CustomerValidationException,
DuplicateCustomerEmailException
)
from app.exceptions.vendor import VendorNotFoundException, VendorNotActiveException
from app.services.auth_service import AuthService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,10 +37,7 @@ class CustomerService:
self.auth_service = AuthService() self.auth_service = AuthService()
def register_customer( def register_customer(
self, self, db: Session, vendor_id: int, customer_data: CustomerRegister
db: Session,
vendor_id: int,
customer_data: CustomerRegister
) -> Customer: ) -> Customer:
""" """
Register a new customer for a specific vendor. Register a new customer for a specific vendor.
@@ -68,18 +65,26 @@ class CustomerService:
raise VendorNotActiveException(vendor.vendor_code) raise VendorNotActiveException(vendor.vendor_code)
# Check if email already exists for this vendor # Check if email already exists for this vendor
existing_customer = db.query(Customer).filter( existing_customer = (
db.query(Customer)
.filter(
and_( and_(
Customer.vendor_id == vendor_id, Customer.vendor_id == vendor_id,
Customer.email == customer_data.email.lower() Customer.email == customer_data.email.lower(),
)
)
.first()
) )
).first()
if existing_customer: if existing_customer:
raise DuplicateCustomerEmailException(customer_data.email, vendor.vendor_code) raise DuplicateCustomerEmailException(
customer_data.email, vendor.vendor_code
)
# Generate unique customer number for this vendor # Generate unique customer number for this vendor
customer_number = self._generate_customer_number(db, vendor_id, vendor.vendor_code) customer_number = self._generate_customer_number(
db, vendor_id, vendor.vendor_code
)
# Hash password # Hash password
hashed_password = self.auth_service.hash_password(customer_data.password) hashed_password = self.auth_service.hash_password(customer_data.password)
@@ -93,8 +98,12 @@ class CustomerService:
last_name=customer_data.last_name, last_name=customer_data.last_name,
phone=customer_data.phone, phone=customer_data.phone,
customer_number=customer_number, customer_number=customer_number,
marketing_consent=customer_data.marketing_consent if hasattr(customer_data, 'marketing_consent') else False, marketing_consent=(
is_active=True customer_data.marketing_consent
if hasattr(customer_data, "marketing_consent")
else False
),
is_active=True,
) )
try: try:
@@ -114,15 +123,11 @@ class CustomerService:
db.rollback() db.rollback()
logger.error(f"Error registering customer: {str(e)}") logger.error(f"Error registering customer: {str(e)}")
raise CustomerValidationException( raise CustomerValidationException(
message="Failed to register customer", message="Failed to register customer", details={"error": str(e)}
details={"error": str(e)}
) )
def login_customer( def login_customer(
self, self, db: Session, vendor_id: int, credentials: UserLogin
db: Session,
vendor_id: int,
credentials: UserLogin
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Authenticate customer and generate JWT token. Authenticate customer and generate JWT token.
@@ -146,20 +151,23 @@ class CustomerService:
raise VendorNotFoundException(str(vendor_id), identifier_type="id") raise VendorNotFoundException(str(vendor_id), identifier_type="id")
# Find customer by email (vendor-scoped) # Find customer by email (vendor-scoped)
customer = db.query(Customer).filter( customer = (
db.query(Customer)
.filter(
and_( and_(
Customer.vendor_id == vendor_id, Customer.vendor_id == vendor_id,
Customer.email == credentials.email_or_username.lower() Customer.email == credentials.email_or_username.lower(),
)
)
.first()
) )
).first()
if not customer: if not customer:
raise InvalidCustomerCredentialsException() raise InvalidCustomerCredentialsException()
# Verify password using auth_manager directly # Verify password using auth_manager directly
if not self.auth_service.auth_manager.verify_password( if not self.auth_service.auth_manager.verify_password(
credentials.password, credentials.password, customer.hashed_password
customer.hashed_password
): ):
raise InvalidCustomerCredentialsException() raise InvalidCustomerCredentialsException()
@@ -170,6 +178,7 @@ class CustomerService:
# Generate JWT token with customer context # Generate JWT token with customer context
# Use auth_manager directly since Customer is not a User model # Use auth_manager directly since Customer is not a User model
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from jose import jwt from jose import jwt
auth_manager = self.auth_service.auth_manager auth_manager = self.auth_service.auth_manager
@@ -185,7 +194,9 @@ class CustomerService:
"iat": datetime.now(timezone.utc), "iat": datetime.now(timezone.utc),
} }
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm) token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
token_data = { token_data = {
"access_token": token, "access_token": token,
@@ -198,17 +209,9 @@ class CustomerService:
f"for vendor {vendor.vendor_code}" f"for vendor {vendor.vendor_code}"
) )
return { return {"customer": customer, "token_data": token_data}
"customer": customer,
"token_data": token_data
}
def get_customer( def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
self,
db: Session,
vendor_id: int,
customer_id: int
) -> Customer:
""" """
Get customer by ID with vendor isolation. Get customer by ID with vendor isolation.
@@ -223,12 +226,11 @@ class CustomerService:
Raises: Raises:
CustomerNotFoundException: If customer not found CustomerNotFoundException: If customer not found
""" """
customer = db.query(Customer).filter( customer = (
and_( db.query(Customer)
Customer.id == customer_id, .filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
Customer.vendor_id == vendor_id .first()
) )
).first()
if not customer: if not customer:
raise CustomerNotFoundException(str(customer_id)) raise CustomerNotFoundException(str(customer_id))
@@ -236,10 +238,7 @@ class CustomerService:
return customer return customer
def get_customer_by_email( def get_customer_by_email(
self, self, db: Session, vendor_id: int, email: str
db: Session,
vendor_id: int,
email: str
) -> Optional[Customer]: ) -> Optional[Customer]:
""" """
Get customer by email (vendor-scoped). Get customer by email (vendor-scoped).
@@ -252,19 +251,20 @@ class CustomerService:
Returns: Returns:
Optional[Customer]: Customer object or None Optional[Customer]: Customer object or None
""" """
return db.query(Customer).filter( return (
and_( db.query(Customer)
Customer.vendor_id == vendor_id, .filter(
Customer.email == email.lower() and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
)
.first()
) )
).first()
def update_customer( def update_customer(
self, self,
db: Session, db: Session,
vendor_id: int, vendor_id: int,
customer_id: int, customer_id: int,
customer_data: CustomerUpdate customer_data: CustomerUpdate,
) -> Customer: ) -> Customer:
""" """
Update customer profile. Update customer profile.
@@ -290,13 +290,17 @@ class CustomerService:
for field, value in update_data.items(): for field, value in update_data.items():
if field == "email" and value: if field == "email" and value:
# Check if new email already exists for this vendor # Check if new email already exists for this vendor
existing = db.query(Customer).filter( existing = (
db.query(Customer)
.filter(
and_( and_(
Customer.vendor_id == vendor_id, Customer.vendor_id == vendor_id,
Customer.email == value.lower(), Customer.email == value.lower(),
Customer.id != customer_id Customer.id != customer_id,
)
)
.first()
) )
).first()
if existing: if existing:
raise DuplicateCustomerEmailException(value, "vendor") raise DuplicateCustomerEmailException(value, "vendor")
@@ -317,15 +321,11 @@ class CustomerService:
db.rollback() db.rollback()
logger.error(f"Error updating customer: {str(e)}") logger.error(f"Error updating customer: {str(e)}")
raise CustomerValidationException( raise CustomerValidationException(
message="Failed to update customer", message="Failed to update customer", details={"error": str(e)}
details={"error": str(e)}
) )
def deactivate_customer( def deactivate_customer(
self, self, db: Session, vendor_id: int, customer_id: int
db: Session,
vendor_id: int,
customer_id: int
) -> Customer: ) -> Customer:
""" """
Deactivate customer account. Deactivate customer account.
@@ -352,10 +352,7 @@ class CustomerService:
return customer return customer
def update_customer_stats( def update_customer_stats(
self, self, db: Session, customer_id: int, order_total: float
db: Session,
customer_id: int,
order_total: float
) -> None: ) -> None:
""" """
Update customer statistics after order. Update customer statistics after order.
@@ -377,10 +374,7 @@ class CustomerService:
logger.debug(f"Updated stats for customer {customer.email}") logger.debug(f"Updated stats for customer {customer.email}")
def _generate_customer_number( def _generate_customer_number(
self, self, db: Session, vendor_id: int, vendor_code: str
db: Session,
vendor_id: int,
vendor_code: str
) -> str: ) -> str:
""" """
Generate unique customer number for vendor. Generate unique customer number for vendor.
@@ -397,21 +391,23 @@ class CustomerService:
str: Unique customer number str: Unique customer number
""" """
# Get count of customers for this vendor # Get count of customers for this vendor
count = db.query(Customer).filter( count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
Customer.vendor_id == vendor_id
).count()
# Generate number with padding # Generate number with padding
sequence = str(count + 1).zfill(5) sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}" customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
# Ensure uniqueness (in case of deletions) # Ensure uniqueness (in case of deletions)
while db.query(Customer).filter( while (
db.query(Customer)
.filter(
and_( and_(
Customer.vendor_id == vendor_id, Customer.vendor_id == vendor_id,
Customer.customer_number == customer_number Customer.customer_number == customer_number,
) )
).first(): )
.first()
):
count += 1 count += 1
sequence = str(count + 1).zfill(5) sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}" customer_number = f"{vendor_code.upper()}-CUST-{sequence}"

View File

@@ -5,27 +5,20 @@ from typing import List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (InsufficientInventoryException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidInventoryOperationException, InvalidInventoryOperationException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException, InventoryValidationException,
NegativeInventoryException, NegativeInventoryException,
InvalidQuantityException, ProductNotFoundException, ValidationException)
ValidationException,
ProductNotFoundException,
)
from models.schema.inventory import (
InventoryCreate,
InventoryAdjust,
InventoryUpdate,
InventoryReserve,
InventoryLocationResponse,
ProductInventorySummary
)
from models.database.inventory import Inventory from models.database.inventory import Inventory
from models.database.product import Product from models.database.product import Product
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.schema.inventory import (InventoryAdjust, InventoryCreate,
InventoryLocationResponse,
InventoryReserve, InventoryUpdate,
ProductInventorySummary)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -93,7 +86,11 @@ class InventoryService:
) )
return new_inventory return new_inventory
except (ProductNotFoundException, InvalidQuantityException, InventoryValidationException): except (
ProductNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -124,7 +121,9 @@ class InventoryService:
location = self._validate_location(inventory_data.location) location = self._validate_location(inventory_data.location)
# Check if inventory exists # Check if inventory exists
existing = self._get_inventory_entry(db, inventory_data.product_id, location) existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if not existing: if not existing:
# Create new if adding, error if removing # Create new if adding, error if removing
@@ -173,8 +172,12 @@ class InventoryService:
) )
return existing return existing
except (ProductNotFoundException, InventoryNotFoundException, except (
InsufficientInventoryException, InventoryValidationException): ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InventoryValidationException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -231,8 +234,12 @@ class InventoryService:
) )
return inventory return inventory
except (ProductNotFoundException, InventoryNotFoundException, except (
InsufficientInventoryException, InvalidQuantityException): ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -287,7 +294,11 @@ class InventoryService:
) )
return inventory return inventory
except (ProductNotFoundException, InventoryNotFoundException, InvalidQuantityException): except (
ProductNotFoundException,
InventoryNotFoundException,
InvalidQuantityException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -349,8 +360,12 @@ class InventoryService:
) )
return inventory return inventory
except (ProductNotFoundException, InventoryNotFoundException, except (
InsufficientInventoryException, InvalidQuantityException): ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -376,9 +391,7 @@ class InventoryService:
product = self._get_vendor_product(db, vendor_id, product_id) product = self._get_vendor_product(db, vendor_id, product_id)
inventory_entries = ( inventory_entries = (
db.query(Inventory) db.query(Inventory).filter(Inventory.product_id == product_id).all()
.filter(Inventory.product_id == product_id)
.all()
) )
if not inventory_entries: if not inventory_entries:
@@ -425,8 +438,13 @@ class InventoryService:
raise ValidationException("Failed to retrieve product inventory") raise ValidationException("Failed to retrieve product inventory")
def get_vendor_inventory( def get_vendor_inventory(
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 100, self,
location: Optional[str] = None, low_stock_threshold: Optional[int] = None db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
location: Optional[str] = None,
low_stock_threshold: Optional[int] = None,
) -> List[Inventory]: ) -> List[Inventory]:
""" """
Get all inventory for a vendor with filtering. Get all inventory for a vendor with filtering.
@@ -458,8 +476,11 @@ class InventoryService:
raise ValidationException("Failed to retrieve vendor inventory") raise ValidationException("Failed to retrieve vendor inventory")
def update_inventory( def update_inventory(
self, db: Session, vendor_id: int, inventory_id: int, self,
inventory_update: InventoryUpdate db: Session,
vendor_id: int,
inventory_id: int,
inventory_update: InventoryUpdate,
) -> Inventory: ) -> Inventory:
"""Update inventory entry.""" """Update inventory entry."""
try: try:
@@ -475,7 +496,9 @@ class InventoryService:
inventory.quantity = inventory_update.quantity inventory.quantity = inventory_update.quantity
if inventory_update.reserved_quantity is not None: if inventory_update.reserved_quantity is not None:
self._validate_quantity(inventory_update.reserved_quantity, allow_zero=True) self._validate_quantity(
inventory_update.reserved_quantity, allow_zero=True
)
inventory.reserved_quantity = inventory_update.reserved_quantity inventory.reserved_quantity = inventory_update.reserved_quantity
if inventory_update.location: if inventory_update.location:
@@ -488,7 +511,11 @@ class InventoryService:
logger.info(f"Updated inventory {inventory_id}") logger.info(f"Updated inventory {inventory_id}")
return inventory return inventory
except (InventoryNotFoundException, InvalidQuantityException, InventoryValidationException): except (
InventoryNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -496,9 +523,7 @@ class InventoryService:
logger.error(f"Error updating inventory: {str(e)}") logger.error(f"Error updating inventory: {str(e)}")
raise ValidationException("Failed to update inventory") raise ValidationException("Failed to update inventory")
def delete_inventory( def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
self, db: Session, vendor_id: int, inventory_id: int
) -> bool:
"""Delete inventory entry.""" """Delete inventory entry."""
try: try:
inventory = self._get_inventory_by_id(db, inventory_id) inventory = self._get_inventory_by_id(db, inventory_id)
@@ -521,15 +546,20 @@ class InventoryService:
raise ValidationException("Failed to delete inventory") raise ValidationException("Failed to delete inventory")
# Private helper methods # Private helper methods
def _get_vendor_product(self, db: Session, vendor_id: int, product_id: int) -> Product: def _get_vendor_product(
self, db: Session, vendor_id: int, product_id: int
) -> Product:
"""Get product and verify it belongs to vendor.""" """Get product and verify it belongs to vendor."""
product = db.query(Product).filter( product = (
Product.id == product_id, db.query(Product)
Product.vendor_id == vendor_id .filter(Product.id == product_id, Product.vendor_id == vendor_id)
).first() .first()
)
if not product: if not product:
raise ProductNotFoundException(f"Product {product_id} not found in your catalog") raise ProductNotFoundException(
f"Product {product_id} not found in your catalog"
)
return product return product
@@ -539,10 +569,7 @@ class InventoryService:
"""Get inventory entry by product and location.""" """Get inventory entry by product and location."""
return ( return (
db.query(Inventory) db.query(Inventory)
.filter( .filter(Inventory.product_id == product_id, Inventory.location == location)
Inventory.product_id == product_id,
Inventory.location == location
)
.first() .first()
) )

View File

@@ -5,20 +5,15 @@ from typing import List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (ImportJobCannotBeCancelledException,
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException, ImportJobCannotBeDeletedException,
ValidationException, ImportJobNotFoundException,
) ImportJobNotOwnedException, ValidationException)
from models.schema.marketplace_import_job import (
MarketplaceImportJobResponse,
MarketplaceImportJobRequest
)
from models.database.marketplace_import_job import MarketplaceImportJob from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
MarketplaceImportJobResponse)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,7 +26,7 @@ class MarketplaceImportJobService:
db: Session, db: Session,
request: MarketplaceImportJobRequest, request: MarketplaceImportJobRequest,
vendor: Vendor, # CHANGED: Vendor object from middleware vendor: Vendor, # CHANGED: Vendor object from middleware
user: User user: User,
) -> MarketplaceImportJob: ) -> MarketplaceImportJob:
""" """
Create a new marketplace import job. Create a new marketplace import job.
@@ -147,7 +142,9 @@ class MarketplaceImportJobService:
marketplace=job.marketplace, marketplace=job.marketplace,
vendor_id=job.vendor_id, vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None, # FIXED vendor_code=job.vendor.vendor_code if job.vendor else None, # FIXED
vendor_name=job.vendor.name if job.vendor else None, # FIXED: from relationship vendor_name=(
job.vendor.name if job.vendor else None
), # FIXED: from relationship
source_url=job.source_url, source_url=job.source_url,
imported=job.imported_count or 0, imported=job.imported_count or 0,
updated=job.updated_count or 0, updated=job.updated_count or 0,

View File

@@ -17,19 +17,20 @@ from typing import Generator, List, Optional, Tuple
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (InvalidMarketplaceProductDataException,
MarketplaceProductNotFoundException,
MarketplaceProductAlreadyExistsException, MarketplaceProductAlreadyExistsException,
InvalidMarketplaceProductDataException, MarketplaceProductNotFoundException,
MarketplaceProductValidationException, MarketplaceProductValidationException,
ValidationException, ValidationException)
) from app.services.marketplace_import_job_service import \
from app.services.marketplace_import_job_service import marketplace_import_job_service marketplace_import_job_service
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from app.utils.data_processing import GTINProcessor, PriceProcessor from app.utils.data_processing import GTINProcessor, PriceProcessor
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
from models.schema.inventory import (InventoryLocationResponse,
InventorySummaryResponse)
from models.schema.marketplace_product import (MarketplaceProductCreate,
MarketplaceProductUpdate)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,14 +43,18 @@ class MarketplaceProductService:
self.gtin_processor = GTINProcessor() self.gtin_processor = GTINProcessor()
self.price_processor = PriceProcessor() self.price_processor = PriceProcessor()
def create_product(self, db: Session, product_data: MarketplaceProductCreate) -> MarketplaceProduct: def create_product(
self, db: Session, product_data: MarketplaceProductCreate
) -> MarketplaceProduct:
"""Create a new product with validation.""" """Create a new product with validation."""
try: try:
# Process and validate GTIN if provided # Process and validate GTIN if provided
if product_data.gtin: if product_data.gtin:
normalized_gtin = self.gtin_processor.normalize(product_data.gtin) normalized_gtin = self.gtin_processor.normalize(product_data.gtin)
if not normalized_gtin: if not normalized_gtin:
raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin") raise InvalidMarketplaceProductDataException(
"Invalid GTIN format", field="gtin"
)
product_data.gtin = normalized_gtin product_data.gtin = normalized_gtin
# Process price if provided # Process price if provided
@@ -70,11 +75,18 @@ class MarketplaceProductService:
product_data.marketplace = "Letzshop" product_data.marketplace = "Letzshop"
# Validate required fields # Validate required fields
if not product_data.marketplace_product_id or not product_data.marketplace_product_id.strip(): if (
raise MarketplaceProductValidationException("MarketplaceProduct ID is required", field="marketplace_product_id") not product_data.marketplace_product_id
or not product_data.marketplace_product_id.strip()
):
raise MarketplaceProductValidationException(
"MarketplaceProduct ID is required", field="marketplace_product_id"
)
if not product_data.title or not product_data.title.strip(): if not product_data.title or not product_data.title.strip():
raise MarketplaceProductValidationException("MarketplaceProduct title is required", field="title") raise MarketplaceProductValidationException(
"MarketplaceProduct title is required", field="title"
)
db_product = MarketplaceProduct(**product_data.model_dump()) db_product = MarketplaceProduct(**product_data.model_dump())
db.add(db_product) db.add(db_product)
@@ -84,30 +96,47 @@ class MarketplaceProductService:
logger.info(f"Created product {db_product.marketplace_product_id}") logger.info(f"Created product {db_product.marketplace_product_id}")
return db_product return db_product
except (InvalidMarketplaceProductDataException, MarketplaceProductValidationException): except (
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
):
db.rollback() db.rollback()
raise # Re-raise custom exceptions raise # Re-raise custom exceptions
except IntegrityError as e: except IntegrityError as e:
db.rollback() db.rollback()
logger.error(f"Database integrity error: {str(e)}") logger.error(f"Database integrity error: {str(e)}")
if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower(): if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower():
raise MarketplaceProductAlreadyExistsException(product_data.marketplace_product_id) raise MarketplaceProductAlreadyExistsException(
product_data.marketplace_product_id
)
else: else:
raise MarketplaceProductValidationException("Data integrity constraint violation") raise MarketplaceProductValidationException(
"Data integrity constraint violation"
)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
logger.error(f"Error creating product: {str(e)}") logger.error(f"Error creating product: {str(e)}")
raise ValidationException("Failed to create product") raise ValidationException("Failed to create product")
def get_product_by_id(self, db: Session, marketplace_product_id: str) -> Optional[MarketplaceProduct]: def get_product_by_id(
self, db: Session, marketplace_product_id: str
) -> Optional[MarketplaceProduct]:
"""Get a product by its ID.""" """Get a product by its ID."""
try: try:
return db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first() return (
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
)
.first()
)
except Exception as e: except Exception as e:
logger.error(f"Error getting product {marketplace_product_id}: {str(e)}") logger.error(f"Error getting product {marketplace_product_id}: {str(e)}")
return None return None
def get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct: def get_product_by_id_or_raise(
self, db: Session, marketplace_product_id: str
) -> MarketplaceProduct:
""" """
Get a product by its ID or raise exception. Get a product by its ID or raise exception.
@@ -162,13 +191,19 @@ class MarketplaceProductService:
if brand: if brand:
query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%")) query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%"))
if category: if category:
query = query.filter(MarketplaceProduct.google_product_category.ilike(f"%{category}%")) query = query.filter(
MarketplaceProduct.google_product_category.ilike(f"%{category}%")
)
if availability: if availability:
query = query.filter(MarketplaceProduct.availability == availability) query = query.filter(MarketplaceProduct.availability == availability)
if marketplace: if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")) query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name: if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")) query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
if search: if search:
# Search in title, description, marketplace, and name # Search in title, description, marketplace, and name
search_term = f"%{search}%" search_term = f"%{search}%"
@@ -188,7 +223,12 @@ class MarketplaceProductService:
logger.error(f"Error getting products with filters: {str(e)}") logger.error(f"Error getting products with filters: {str(e)}")
raise ValidationException("Failed to retrieve products") raise ValidationException("Failed to retrieve products")
def update_product(self, db: Session, marketplace_product_id: str, product_update: MarketplaceProductUpdate) -> MarketplaceProduct: def update_product(
self,
db: Session,
marketplace_product_id: str,
product_update: MarketplaceProductUpdate,
) -> MarketplaceProduct:
"""Update product with validation.""" """Update product with validation."""
try: try:
product = self.get_product_by_id_or_raise(db, marketplace_product_id) product = self.get_product_by_id_or_raise(db, marketplace_product_id)
@@ -200,7 +240,9 @@ class MarketplaceProductService:
if "gtin" in update_data and update_data["gtin"]: if "gtin" in update_data and update_data["gtin"]:
normalized_gtin = self.gtin_processor.normalize(update_data["gtin"]) normalized_gtin = self.gtin_processor.normalize(update_data["gtin"])
if not normalized_gtin: if not normalized_gtin:
raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin") raise InvalidMarketplaceProductDataException(
"Invalid GTIN format", field="gtin"
)
update_data["gtin"] = normalized_gtin update_data["gtin"] = normalized_gtin
# Process price if being updated # Process price if being updated
@@ -217,8 +259,12 @@ class MarketplaceProductService:
raise InvalidMarketplaceProductDataException(str(e), field="price") raise InvalidMarketplaceProductDataException(str(e), field="price")
# Validate required fields if being updated # Validate required fields if being updated
if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()): if "title" in update_data and (
raise MarketplaceProductValidationException("MarketplaceProduct title cannot be empty", field="title") not update_data["title"] or not update_data["title"].strip()
):
raise MarketplaceProductValidationException(
"MarketplaceProduct title cannot be empty", field="title"
)
for key, value in update_data.items(): for key, value in update_data.items():
setattr(product, key, value) setattr(product, key, value)
@@ -230,7 +276,11 @@ class MarketplaceProductService:
logger.info(f"Updated product {marketplace_product_id}") logger.info(f"Updated product {marketplace_product_id}")
return product return product
except (MarketplaceProductNotFoundException, InvalidMarketplaceProductDataException, MarketplaceProductValidationException): except (
MarketplaceProductNotFoundException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
):
db.rollback() db.rollback()
raise # Re-raise custom exceptions raise # Re-raise custom exceptions
except Exception as e: except Exception as e:
@@ -272,7 +322,9 @@ class MarketplaceProductService:
logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}") logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}")
raise ValidationException("Failed to delete product") raise ValidationException("Failed to delete product")
def get_inventory_info(self, db: Session, gtin: str) -> Optional[InventorySummaryResponse]: def get_inventory_info(
self, db: Session, gtin: str
) -> Optional[InventorySummaryResponse]:
""" """
Get inventory information for a product by GTIN. Get inventory information for a product by GTIN.
@@ -290,7 +342,9 @@ class MarketplaceProductService:
total_quantity = sum(entry.quantity for entry in inventory_entries) total_quantity = sum(entry.quantity for entry in inventory_entries)
locations = [ locations = [
InventoryLocationResponse(location=entry.location, quantity=entry.quantity) InventoryLocationResponse(
location=entry.location, quantity=entry.quantity
)
for entry in inventory_entries for entry in inventory_entries
] ]
@@ -305,6 +359,7 @@ class MarketplaceProductService:
import csv import csv
from io import StringIO from io import StringIO
from typing import Generator, Optional from typing import Generator, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
def generate_csv_export( def generate_csv_export(
@@ -331,9 +386,18 @@ class MarketplaceProductService:
# Write header row # Write header row
headers = [ headers = [
"marketplace_product_id", "title", "description", "link", "image_link", "marketplace_product_id",
"availability", "price", "currency", "brand", "gtin", "title",
"marketplace", "name" "description",
"link",
"image_link",
"availability",
"price",
"currency",
"brand",
"gtin",
"marketplace",
"name",
] ]
writer.writerow(headers) writer.writerow(headers)
yield output.getvalue() yield output.getvalue()
@@ -350,9 +414,13 @@ class MarketplaceProductService:
# Apply marketplace filters # Apply marketplace filters
if marketplace: if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")) query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name: if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")) query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
products = query.offset(offset).limit(batch_size).all() products = query.offset(offset).limit(batch_size).all()
if not products: if not products:
@@ -392,7 +460,11 @@ class MarketplaceProductService:
"""Check if product exists by ID.""" """Check if product exists by ID."""
try: try:
return ( return (
db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first() db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
)
.first()
is not None is not None
) )
except Exception as e: except Exception as e:
@@ -402,18 +474,27 @@ class MarketplaceProductService:
# Private helper methods # Private helper methods
def _validate_product_data(self, product_data: dict) -> None: def _validate_product_data(self, product_data: dict) -> None:
"""Validate product data structure.""" """Validate product data structure."""
required_fields = ['marketplace_product_id', 'title'] required_fields = ["marketplace_product_id", "title"]
for field in required_fields: for field in required_fields:
if field not in product_data or not product_data[field]: if field not in product_data or not product_data[field]:
raise MarketplaceProductValidationException(f"{field} is required", field=field) raise MarketplaceProductValidationException(
f"{field} is required", field=field
)
def _normalize_product_data(self, product_data: dict) -> dict: def _normalize_product_data(self, product_data: dict) -> dict:
"""Normalize and clean product data.""" """Normalize and clean product data."""
normalized = product_data.copy() normalized = product_data.copy()
# Trim whitespace from string fields # Trim whitespace from string fields
string_fields = ['marketplace_product_id', 'title', 'description', 'brand', 'marketplace', 'name'] string_fields = [
"marketplace_product_id",
"title",
"description",
"brand",
"marketplace",
"name",
]
for field in string_fields: for field in string_fields:
if field in normalized and normalized[field]: if field in normalized and normalized[field]:
normalized[field] = normalized[field].strip() normalized[field] = normalized[field].strip()

View File

@@ -9,24 +9,21 @@ This module provides:
""" """
import logging import logging
from datetime import datetime, timezone
from typing import List, Optional, Tuple
import random import random
import string import string
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from models.database.order import Order, OrderItem from app.exceptions import (CustomerNotFoundException,
from models.database.customer import Customer, CustomerAddress
from models.database.product import Product
from models.schema.order import OrderCreate, OrderUpdate, OrderAddressCreate
from app.exceptions import (
OrderNotFoundException,
ValidationException,
InsufficientInventoryException, InsufficientInventoryException,
CustomerNotFoundException OrderNotFoundException, ValidationException)
) from models.database.customer import Customer, CustomerAddress
from models.database.order import Order, OrderItem
from models.database.product import Product
from models.schema.order import OrderAddressCreate, OrderCreate, OrderUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,12 +39,16 @@ class OrderService:
Example: ORD-1-20250110-A1B2C3 Example: ORD-1-20250110-A1B2C3
""" """
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d") timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) random_suffix = "".join(
random.choices(string.ascii_uppercase + string.digits, k=6)
)
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
# Ensure uniqueness # Ensure uniqueness
while db.query(Order).filter(Order.order_number == order_number).first(): while db.query(Order).filter(Order.order_number == order_number).first():
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) random_suffix = "".join(
random.choices(string.ascii_uppercase + string.digits, k=6)
)
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
return order_number return order_number
@@ -58,7 +59,7 @@ class OrderService:
vendor_id: int, vendor_id: int,
customer_id: int, customer_id: int,
address_data: OrderAddressCreate, address_data: OrderAddressCreate,
address_type: str address_type: str,
) -> CustomerAddress: ) -> CustomerAddress:
"""Create a customer address for order.""" """Create a customer address for order."""
address = CustomerAddress( address = CustomerAddress(
@@ -73,17 +74,14 @@ class OrderService:
city=address_data.city, city=address_data.city,
postal_code=address_data.postal_code, postal_code=address_data.postal_code,
country=address_data.country, country=address_data.country,
is_default=False is_default=False,
) )
db.add(address) db.add(address)
db.flush() # Get ID without committing db.flush() # Get ID without committing
return address return address
def create_order( def create_order(
self, self, db: Session, vendor_id: int, order_data: OrderCreate
db: Session,
vendor_id: int,
order_data: OrderCreate
) -> Order: ) -> Order:
""" """
Create a new order. Create a new order.
@@ -104,12 +102,15 @@ class OrderService:
# Validate customer exists if provided # Validate customer exists if provided
customer_id = order_data.customer_id customer_id = order_data.customer_id
if customer_id: if customer_id:
customer = db.query(Customer).filter( customer = (
db.query(Customer)
.filter(
and_( and_(
Customer.id == customer_id, Customer.id == customer_id, Customer.vendor_id == vendor_id
Customer.vendor_id == vendor_id )
)
.first()
) )
).first()
if not customer: if not customer:
raise CustomerNotFoundException(str(customer_id)) raise CustomerNotFoundException(str(customer_id))
@@ -124,7 +125,7 @@ class OrderService:
vendor_id=vendor_id, vendor_id=vendor_id,
customer_id=customer_id, customer_id=customer_id,
address_data=order_data.shipping_address, address_data=order_data.shipping_address,
address_type="shipping" address_type="shipping",
) )
# Create billing address (use shipping if not provided) # Create billing address (use shipping if not provided)
@@ -134,7 +135,7 @@ class OrderService:
vendor_id=vendor_id, vendor_id=vendor_id,
customer_id=customer_id, customer_id=customer_id,
address_data=order_data.billing_address, address_data=order_data.billing_address,
address_type="billing" address_type="billing",
) )
else: else:
billing_address = shipping_address billing_address = shipping_address
@@ -145,23 +146,29 @@ class OrderService:
for item_data in order_data.items: for item_data in order_data.items:
# Get product # Get product
product = db.query(Product).filter( product = (
db.query(Product)
.filter(
and_( and_(
Product.id == item_data.product_id, Product.id == item_data.product_id,
Product.vendor_id == vendor_id, Product.vendor_id == vendor_id,
Product.is_active == True Product.is_active == True,
)
)
.first()
) )
).first()
if not product: if not product:
raise ValidationException(f"Product {item_data.product_id} not found") raise ValidationException(
f"Product {item_data.product_id} not found"
)
# Check inventory # Check inventory
if product.available_inventory < item_data.quantity: if product.available_inventory < item_data.quantity:
raise InsufficientInventoryException( raise InsufficientInventoryException(
product_id=product.id, product_id=product.id,
requested=item_data.quantity, requested=item_data.quantity,
available=product.available_inventory available=product.available_inventory,
) )
# Calculate item total # Calculate item total
@@ -172,14 +179,16 @@ class OrderService:
item_total = unit_price * item_data.quantity item_total = unit_price * item_data.quantity
subtotal += item_total subtotal += item_total
order_items_data.append({ order_items_data.append(
{
"product_id": product.id, "product_id": product.id,
"product_name": product.marketplace_product.title, "product_name": product.marketplace_product.title,
"product_sku": product.product_id, "product_sku": product.product_id,
"quantity": item_data.quantity, "quantity": item_data.quantity,
"unit_price": unit_price, "unit_price": unit_price,
"total_price": item_total "total_price": item_total,
}) }
)
# Calculate tax and shipping (simple implementation) # Calculate tax and shipping (simple implementation)
tax_amount = 0.0 # TODO: Implement tax calculation tax_amount = 0.0 # TODO: Implement tax calculation
@@ -205,7 +214,7 @@ class OrderService:
shipping_address_id=shipping_address.id, shipping_address_id=shipping_address.id,
billing_address_id=billing_address.id, billing_address_id=billing_address.id,
shipping_method=order_data.shipping_method, shipping_method=order_data.shipping_method,
customer_notes=order_data.customer_notes customer_notes=order_data.customer_notes,
) )
db.add(order) db.add(order)
@@ -213,10 +222,7 @@ class OrderService:
# Create order items # Create order items
for item_data in order_items_data: for item_data in order_items_data:
order_item = OrderItem( order_item = OrderItem(order_id=order.id, **item_data)
order_id=order.id,
**item_data
)
db.add(order_item) db.add(order_item)
db.commit() db.commit()
@@ -229,7 +235,11 @@ class OrderService:
return order return order
except (ValidationException, InsufficientInventoryException, CustomerNotFoundException): except (
ValidationException,
InsufficientInventoryException,
CustomerNotFoundException,
):
db.rollback() db.rollback()
raise raise
except Exception as e: except Exception as e:
@@ -237,19 +247,13 @@ class OrderService:
logger.error(f"Error creating order: {str(e)}") logger.error(f"Error creating order: {str(e)}")
raise ValidationException(f"Failed to create order: {str(e)}") raise ValidationException(f"Failed to create order: {str(e)}")
def get_order( def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order:
self,
db: Session,
vendor_id: int,
order_id: int
) -> Order:
"""Get order by ID.""" """Get order by ID."""
order = db.query(Order).filter( order = (
and_( db.query(Order)
Order.id == order_id, .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
Order.vendor_id == vendor_id .first()
) )
).first()
if not order: if not order:
raise OrderNotFoundException(str(order_id)) raise OrderNotFoundException(str(order_id))
@@ -263,7 +267,7 @@ class OrderService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
status: Optional[str] = None, status: Optional[str] = None,
customer_id: Optional[int] = None customer_id: Optional[int] = None,
) -> Tuple[List[Order], int]: ) -> Tuple[List[Order], int]:
""" """
Get orders for vendor with filtering. Get orders for vendor with filtering.
@@ -301,23 +305,15 @@ class OrderService:
vendor_id: int, vendor_id: int,
customer_id: int, customer_id: int,
skip: int = 0, skip: int = 0,
limit: int = 100 limit: int = 100,
) -> Tuple[List[Order], int]: ) -> Tuple[List[Order], int]:
"""Get orders for a specific customer.""" """Get orders for a specific customer."""
return self.get_vendor_orders( return self.get_vendor_orders(
db=db, db=db, vendor_id=vendor_id, skip=skip, limit=limit, customer_id=customer_id
vendor_id=vendor_id,
skip=skip,
limit=limit,
customer_id=customer_id
) )
def update_order_status( def update_order_status(
self, self, db: Session, vendor_id: int, order_id: int, order_update: OrderUpdate
db: Session,
vendor_id: int,
order_id: int,
order_update: OrderUpdate
) -> Order: ) -> Order:
""" """
Update order status and tracking information. Update order status and tracking information.

View File

@@ -14,14 +14,11 @@ from typing import List, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (ProductAlreadyExistsException,
ProductNotFoundException, ProductNotFoundException, ValidationException)
ProductAlreadyExistsException,
ValidationException,
)
from models.schema.product import ProductCreate, ProductUpdate
from models.database.product import Product
from models.database.marketplace_product import MarketplaceProduct from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.schema.product import ProductCreate, ProductUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,10 +42,11 @@ class ProductService:
ProductNotFoundException: If product not found ProductNotFoundException: If product not found
""" """
try: try:
product = db.query(Product).filter( product = (
Product.id == product_id, db.query(Product)
Product.vendor_id == vendor_id .filter(Product.id == product_id, Product.vendor_id == vendor_id)
).first() .first()
)
if not product: if not product:
raise ProductNotFoundException(f"Product {product_id} not found") raise ProductNotFoundException(f"Product {product_id} not found")
@@ -81,10 +79,14 @@ class ProductService:
""" """
try: try:
# Verify marketplace product exists and belongs to vendor # Verify marketplace product exists and belongs to vendor
marketplace_product = db.query(MarketplaceProduct).filter( marketplace_product = (
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.id == product_data.marketplace_product_id, MarketplaceProduct.id == product_data.marketplace_product_id,
MarketplaceProduct.vendor_id == vendor_id MarketplaceProduct.vendor_id == vendor_id,
).first() )
.first()
)
if not marketplace_product: if not marketplace_product:
raise ValidationException( raise ValidationException(
@@ -92,10 +94,15 @@ class ProductService:
) )
# Check if already in catalog # Check if already in catalog
existing = db.query(Product).filter( existing = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id, Product.vendor_id == vendor_id,
Product.marketplace_product_id == product_data.marketplace_product_id Product.marketplace_product_id
).first() == product_data.marketplace_product_id,
)
.first()
)
if existing: if existing:
raise ProductAlreadyExistsException( raise ProductAlreadyExistsException(
@@ -122,9 +129,7 @@ class ProductService:
db.commit() db.commit()
db.refresh(product) db.refresh(product)
logger.info( logger.info(f"Added product {product.id} to vendor {vendor_id} catalog")
f"Added product {product.id} to vendor {vendor_id} catalog"
)
return product return product
except (ProductAlreadyExistsException, ValidationException): except (ProductAlreadyExistsException, ValidationException):
@@ -136,7 +141,11 @@ class ProductService:
raise ValidationException("Failed to create product") raise ValidationException("Failed to create product")
def update_product( def update_product(
self, db: Session, vendor_id: int, product_id: int, product_update: ProductUpdate self,
db: Session,
vendor_id: int,
product_id: int,
product_update: ProductUpdate,
) -> Product: ) -> Product:
""" """
Update product in vendor catalog. Update product in vendor catalog.

View File

@@ -10,25 +10,21 @@ This module provides:
""" """
import logging import logging
from typing import Any, Dict, List
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import AdminOperationException, VendorNotFoundException
VendorNotFoundException,
AdminOperationException,
)
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.database.inventory import Inventory
from models.database.vendor import Vendor
from models.database.order import Order
from models.database.user import User
from models.database.customer import Customer from models.database.customer import Customer
from models.database.inventory import Inventory
from models.database.marketplace_import_job import MarketplaceImportJob from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Vendor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -62,63 +58,77 @@ class StatsService:
try: try:
# Catalog statistics # Catalog statistics
total_catalog_products = db.query(Product).filter( total_catalog_products = (
Product.vendor_id == vendor_id, db.query(Product)
Product.is_active == True .filter(Product.vendor_id == vendor_id, Product.is_active == True)
).count() .count()
)
featured_products = db.query(Product).filter( featured_products = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id, Product.vendor_id == vendor_id,
Product.is_featured == True, Product.is_featured == True,
Product.is_active == True Product.is_active == True,
).count() )
.count()
)
# Staging statistics # Staging statistics
# TODO: This is fragile - MarketplaceProduct uses vendor_name (string) not vendor_id # TODO: This is fragile - MarketplaceProduct uses vendor_name (string) not vendor_id
# Should add vendor_id foreign key to MarketplaceProduct for robust querying # Should add vendor_id foreign key to MarketplaceProduct for robust querying
# For now, matching by vendor name which could fail if names don't match exactly # For now, matching by vendor name which could fail if names don't match exactly
staging_products = db.query(MarketplaceProduct).filter( staging_products = (
MarketplaceProduct.vendor_name == vendor.name db.query(MarketplaceProduct)
).count() .filter(MarketplaceProduct.vendor_name == vendor.name)
.count()
)
# Inventory statistics # Inventory statistics
total_inventory = db.query( total_inventory = (
func.sum(Inventory.quantity) db.query(func.sum(Inventory.quantity))
).filter( .filter(Inventory.vendor_id == vendor_id)
Inventory.vendor_id == vendor_id .scalar()
).scalar() or 0 or 0
)
reserved_inventory = db.query( reserved_inventory = (
func.sum(Inventory.reserved_quantity) db.query(func.sum(Inventory.reserved_quantity))
).filter( .filter(Inventory.vendor_id == vendor_id)
Inventory.vendor_id == vendor_id .scalar()
).scalar() or 0 or 0
)
inventory_locations = db.query( inventory_locations = (
func.count(func.distinct(Inventory.location)) db.query(func.count(func.distinct(Inventory.location)))
).filter( .filter(Inventory.vendor_id == vendor_id)
Inventory.vendor_id == vendor_id .scalar()
).scalar() or 0 or 0
)
# Import statistics # Import statistics
total_imports = db.query(MarketplaceImportJob).filter( total_imports = (
MarketplaceImportJob.vendor_id == vendor_id db.query(MarketplaceImportJob)
).count() .filter(MarketplaceImportJob.vendor_id == vendor_id)
.count()
)
successful_imports = db.query(MarketplaceImportJob).filter( successful_imports = (
db.query(MarketplaceImportJob)
.filter(
MarketplaceImportJob.vendor_id == vendor_id, MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.status == "completed" MarketplaceImportJob.status == "completed",
).count() )
.count()
)
# Orders # Orders
total_orders = db.query(Order).filter( total_orders = db.query(Order).filter(Order.vendor_id == vendor_id).count()
Order.vendor_id == vendor_id
).count()
# Customers # Customers
total_customers = db.query(Customer).filter( total_customers = (
Customer.vendor_id == vendor_id db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
).count() )
return { return {
"catalog": { "catalog": {
@@ -138,7 +148,11 @@ class StatsService:
"imports": { "imports": {
"total_imports": total_imports, "total_imports": total_imports,
"successful_imports": successful_imports, "successful_imports": successful_imports,
"success_rate": (successful_imports / total_imports * 100) if total_imports > 0 else 0, "success_rate": (
(successful_imports / total_imports * 100)
if total_imports > 0
else 0
),
}, },
"orders": { "orders": {
"total_orders": total_orders, "total_orders": total_orders,
@@ -151,12 +165,14 @@ class StatsService:
except VendorNotFoundException: except VendorNotFoundException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}") logger.error(
f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException( raise AdminOperationException(
operation="get_vendor_stats", operation="get_vendor_stats",
reason=f"Database query failed: {str(e)}", reason=f"Database query failed: {str(e)}",
target_type="vendor", target_type="vendor",
target_id=str(vendor_id) target_id=str(vendor_id),
) )
def get_vendor_analytics( def get_vendor_analytics(
@@ -188,21 +204,28 @@ class StatsService:
start_date = datetime.utcnow() - timedelta(days=days) start_date = datetime.utcnow() - timedelta(days=days)
# Import activity # Import activity
recent_imports = db.query(MarketplaceImportJob).filter( recent_imports = (
db.query(MarketplaceImportJob)
.filter(
MarketplaceImportJob.vendor_id == vendor_id, MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.created_at >= start_date MarketplaceImportJob.created_at >= start_date,
).count() )
.count()
)
# Products added to catalog # Products added to catalog
products_added = db.query(Product).filter( products_added = (
Product.vendor_id == vendor_id, db.query(Product)
Product.created_at >= start_date .filter(
).count() Product.vendor_id == vendor_id, Product.created_at >= start_date
)
.count()
)
# Inventory changes # Inventory changes
inventory_entries = db.query(Inventory).filter( inventory_entries = (
Inventory.vendor_id == vendor_id db.query(Inventory).filter(Inventory.vendor_id == vendor_id).count()
).count() )
return { return {
"period": period, "period": period,
@@ -221,12 +244,14 @@ class StatsService:
except VendorNotFoundException: except VendorNotFoundException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}") logger.error(
f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException( raise AdminOperationException(
operation="get_vendor_analytics", operation="get_vendor_analytics",
reason=f"Database query failed: {str(e)}", reason=f"Database query failed: {str(e)}",
target_type="vendor", target_type="vendor",
target_id=str(vendor_id) target_id=str(vendor_id),
) )
def get_vendor_statistics(self, db: Session) -> dict: def get_vendor_statistics(self, db: Session) -> dict:
@@ -234,7 +259,9 @@ class StatsService:
try: try:
total_vendors = db.query(Vendor).count() total_vendors = db.query(Vendor).count()
active_vendors = db.query(Vendor).filter(Vendor.is_active == True).count() active_vendors = db.query(Vendor).filter(Vendor.is_active == True).count()
verified_vendors = db.query(Vendor).filter(Vendor.is_verified == True).count() verified_vendors = (
db.query(Vendor).filter(Vendor.is_verified == True).count()
)
inactive_vendors = total_vendors - active_vendors inactive_vendors = total_vendors - active_vendors
return { return {
@@ -242,13 +269,14 @@ class StatsService:
"active_vendors": active_vendors, "active_vendors": active_vendors,
"inactive_vendors": inactive_vendors, "inactive_vendors": inactive_vendors,
"verified_vendors": verified_vendors, "verified_vendors": verified_vendors,
"verification_rate": (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0 "verification_rate": (
(verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
),
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to get vendor statistics: {str(e)}") logger.error(f"Failed to get vendor statistics: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_vendor_statistics", operation="get_vendor_statistics", reason="Database query failed"
reason="Database query failed"
) )
# ======================================================================== # ========================================================================
@@ -302,7 +330,7 @@ class StatsService:
logger.error(f"Failed to retrieve comprehensive statistics: {str(e)}") logger.error(f"Failed to retrieve comprehensive statistics: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_comprehensive_stats", operation="get_comprehensive_stats",
reason=f"Database query failed: {str(e)}" reason=f"Database query failed: {str(e)}",
) )
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]: def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
@@ -323,8 +351,12 @@ class StatsService:
db.query( db.query(
MarketplaceProduct.marketplace, MarketplaceProduct.marketplace,
func.count(MarketplaceProduct.id).label("total_products"), func.count(MarketplaceProduct.id).label("total_products"),
func.count(func.distinct(MarketplaceProduct.vendor_name)).label("unique_vendors"), func.count(func.distinct(MarketplaceProduct.vendor_name)).label(
func.count(func.distinct(MarketplaceProduct.brand)).label("unique_brands"), "unique_vendors"
),
func.count(func.distinct(MarketplaceProduct.brand)).label(
"unique_brands"
),
) )
.filter(MarketplaceProduct.marketplace.isnot(None)) .filter(MarketplaceProduct.marketplace.isnot(None))
.group_by(MarketplaceProduct.marketplace) .group_by(MarketplaceProduct.marketplace)
@@ -342,10 +374,12 @@ class StatsService:
] ]
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve marketplace breakdown statistics: {str(e)}") logger.error(
f"Failed to retrieve marketplace breakdown statistics: {str(e)}"
)
raise AdminOperationException( raise AdminOperationException(
operation="get_marketplace_breakdown_stats", operation="get_marketplace_breakdown_stats",
reason=f"Database query failed: {str(e)}" reason=f"Database query failed: {str(e)}",
) )
def get_user_statistics(self, db: Session) -> Dict[str, Any]: def get_user_statistics(self, db: Session) -> Dict[str, Any]:
@@ -372,13 +406,14 @@ class StatsService:
"active_users": active_users, "active_users": active_users,
"inactive_users": inactive_users, "inactive_users": inactive_users,
"admin_users": admin_users, "admin_users": admin_users,
"activation_rate": (active_users / total_users * 100) if total_users > 0 else 0 "activation_rate": (
(active_users / total_users * 100) if total_users > 0 else 0
),
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to get user statistics: {str(e)}") logger.error(f"Failed to get user statistics: {str(e)}")
raise AdminOperationException( raise AdminOperationException(
operation="get_user_statistics", operation="get_user_statistics", reason="Database query failed"
reason="Database query failed"
) )
def get_import_statistics(self, db: Session) -> Dict[str, Any]: def get_import_statistics(self, db: Session) -> Dict[str, Any]:
@@ -396,18 +431,22 @@ class StatsService:
""" """
try: try:
total = db.query(MarketplaceImportJob).count() total = db.query(MarketplaceImportJob).count()
completed = db.query(MarketplaceImportJob).filter( completed = (
MarketplaceImportJob.status == "completed" db.query(MarketplaceImportJob)
).count() .filter(MarketplaceImportJob.status == "completed")
failed = db.query(MarketplaceImportJob).filter( .count()
MarketplaceImportJob.status == "failed" )
).count() failed = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "failed")
.count()
)
return { return {
"total_imports": total, "total_imports": total,
"completed_imports": completed, "completed_imports": completed,
"failed_imports": failed, "failed_imports": failed,
"success_rate": (completed / total * 100) if total > 0 else 0 "success_rate": (completed / total * 100) if total > 0 else 0,
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to get import statistics: {str(e)}") logger.error(f"Failed to get import statistics: {str(e)}")
@@ -415,7 +454,7 @@ class StatsService:
"total_imports": 0, "total_imports": 0,
"completed_imports": 0, "completed_imports": 0,
"failed_imports": 0, "failed_imports": 0,
"success_rate": 0 "success_rate": 0,
} }
def get_order_statistics(self, db: Session) -> Dict[str, Any]: def get_order_statistics(self, db: Session) -> Dict[str, Any]:
@@ -431,11 +470,7 @@ class StatsService:
Note: Note:
TODO: Implement when Order model is fully available TODO: Implement when Order model is fully available
""" """
return { return {"total_orders": 0, "pending_orders": 0, "completed_orders": 0}
"total_orders": 0,
"pending_orders": 0,
"completed_orders": 0
}
def get_product_statistics(self, db: Session) -> Dict[str, Any]: def get_product_statistics(self, db: Session) -> Dict[str, Any]:
""" """
@@ -450,11 +485,7 @@ class StatsService:
Note: Note:
TODO: Implement when Product model is fully available TODO: Implement when Product model is fully available
""" """
return { return {"total_products": 0, "active_products": 0, "out_of_stock": 0}
"total_products": 0,
"active_products": 0,
"out_of_stock": 0
}
# ======================================================================== # ========================================================================
# PRIVATE HELPER METHODS # PRIVATE HELPER METHODS
@@ -491,8 +522,7 @@ class StatsService:
return ( return (
db.query(MarketplaceProduct.brand) db.query(MarketplaceProduct.brand)
.filter( .filter(
MarketplaceProduct.brand.isnot(None), MarketplaceProduct.brand.isnot(None), MarketplaceProduct.brand != ""
MarketplaceProduct.brand != ""
) )
.distinct() .distinct()
.count() .count()

View File

@@ -9,17 +9,15 @@ This module provides:
""" """
import logging import logging
from typing import List, Dict, Any
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Dict, List
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (UnauthorizedVendorAccessException,
ValidationException, ValidationException)
UnauthorizedVendorAccessException,
)
from models.database.vendor import VendorUser, Role
from models.database.user import User from models.database.user import User
from models.database.vendor import Role, VendorUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,14 +40,16 @@ class TeamService:
List of team members List of team members
""" """
try: try:
vendor_users = db.query(VendorUser).filter( vendor_users = (
VendorUser.vendor_id == vendor_id, db.query(VendorUser)
VendorUser.is_active == True .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
).all() .all()
)
members = [] members = []
for vu in vendor_users: for vu in vendor_users:
members.append({ members.append(
{
"id": vu.user_id, "id": vu.user_id,
"email": vu.user.email, "email": vu.user.email,
"first_name": vu.user.first_name, "first_name": vu.user.first_name,
@@ -58,7 +58,8 @@ class TeamService:
"role_id": vu.role_id, "role_id": vu.role_id,
"is_active": vu.is_active, "is_active": vu.is_active,
"joined_at": vu.created_at, "joined_at": vu.created_at,
}) }
)
return members return members
@@ -100,7 +101,7 @@ class TeamService:
vendor_id: int, vendor_id: int,
user_id: int, user_id: int,
update_data: dict, update_data: dict,
current_user: User current_user: User,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Update team member role or status. Update team member role or status.
@@ -116,10 +117,13 @@ class TeamService:
Updated member info Updated member info
""" """
try: try:
vendor_user = db.query(VendorUser).filter( vendor_user = (
VendorUser.vendor_id == vendor_id, db.query(VendorUser)
VendorUser.user_id == user_id .filter(
).first() VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
)
.first()
)
if not vendor_user: if not vendor_user:
raise ValidationException("Team member not found") raise ValidationException("Team member not found")
@@ -161,10 +165,13 @@ class TeamService:
True if removed True if removed
""" """
try: try:
vendor_user = db.query(VendorUser).filter( vendor_user = (
VendorUser.vendor_id == vendor_id, db.query(VendorUser)
VendorUser.user_id == user_id .filter(
).first() VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
)
.first()
)
if not vendor_user: if not vendor_user:
raise ValidationException("Team member not found") raise ValidationException("Team member not found")

View File

@@ -12,30 +12,28 @@ This module provides classes and functions for:
import logging import logging
import secrets import secrets
from typing import List, Tuple, Optional
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (DNSVerificationException,
VendorNotFoundException, DomainAlreadyVerifiedException,
VendorDomainNotFoundException,
VendorDomainAlreadyExistsException,
InvalidDomainFormatException,
ReservedDomainException,
DomainNotVerifiedException, DomainNotVerifiedException,
DomainVerificationFailedException, DomainVerificationFailedException,
DomainAlreadyVerifiedException, InvalidDomainFormatException,
MultiplePrimaryDomainsException,
DNSVerificationException,
MaxDomainsReachedException, MaxDomainsReachedException,
MultiplePrimaryDomainsException,
ReservedDomainException,
UnauthorizedDomainAccessException, UnauthorizedDomainAccessException,
ValidationException, ValidationException,
) VendorDomainAlreadyExistsException,
from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate VendorDomainNotFoundException,
VendorNotFoundException)
from models.database.vendor import Vendor from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain from models.database.vendor_domain import VendorDomain
from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,13 +43,19 @@ class VendorDomainService:
def __init__(self): def __init__(self):
self.max_domains_per_vendor = 10 # Configure as needed self.max_domains_per_vendor = 10 # Configure as needed
self.reserved_subdomains = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail'] self.reserved_subdomains = [
"www",
"admin",
"api",
"mail",
"smtp",
"ftp",
"cpanel",
"webmail",
]
def add_domain( def add_domain(
self, self, db: Session, vendor_id: int, domain_data: VendorDomainCreate
db: Session,
vendor_id: int,
domain_data: VendorDomainCreate
) -> VendorDomain: ) -> VendorDomain:
""" """
Add a custom domain to vendor. Add a custom domain to vendor.
@@ -85,12 +89,14 @@ class VendorDomainService:
# Check if domain already exists # Check if domain already exists
if self._domain_exists(db, normalized_domain): if self._domain_exists(db, normalized_domain):
existing_domain = db.query(VendorDomain).filter( existing_domain = (
VendorDomain.domain == normalized_domain db.query(VendorDomain)
).first() .filter(VendorDomain.domain == normalized_domain)
.first()
)
raise VendorDomainAlreadyExistsException( raise VendorDomainAlreadyExistsException(
normalized_domain, normalized_domain,
existing_domain.vendor_id if existing_domain else None existing_domain.vendor_id if existing_domain else None,
) )
# If setting as primary, unset other primary domains # If setting as primary, unset other primary domains
@@ -105,7 +111,7 @@ class VendorDomainService:
verification_token=secrets.token_urlsafe(32), verification_token=secrets.token_urlsafe(32),
is_verified=False, # Requires DNS verification is_verified=False, # Requires DNS verification
is_active=False, # Cannot be active until verified is_active=False, # Cannot be active until verified
ssl_status="pending" ssl_status="pending",
) )
db.add(new_domain) db.add(new_domain)
@@ -120,7 +126,7 @@ class VendorDomainService:
VendorDomainAlreadyExistsException, VendorDomainAlreadyExistsException,
MaxDomainsReachedException, MaxDomainsReachedException,
InvalidDomainFormatException, InvalidDomainFormatException,
ReservedDomainException ReservedDomainException,
): ):
db.rollback() db.rollback()
raise raise
@@ -129,11 +135,7 @@ class VendorDomainService:
logger.error(f"Error adding domain: {str(e)}") logger.error(f"Error adding domain: {str(e)}")
raise ValidationException("Failed to add domain") raise ValidationException("Failed to add domain")
def get_vendor_domains( def get_vendor_domains(self, db: Session, vendor_id: int) -> List[VendorDomain]:
self,
db: Session,
vendor_id: int
) -> List[VendorDomain]:
""" """
Get all domains for a vendor. Get all domains for a vendor.
@@ -151,12 +153,14 @@ class VendorDomainService:
# Verify vendor exists # Verify vendor exists
self._get_vendor_by_id_or_raise(db, vendor_id) self._get_vendor_by_id_or_raise(db, vendor_id)
domains = db.query(VendorDomain).filter( domains = (
VendorDomain.vendor_id == vendor_id db.query(VendorDomain)
).order_by( .filter(VendorDomain.vendor_id == vendor_id)
VendorDomain.is_primary.desc(), .order_by(
VendorDomain.created_at.desc() VendorDomain.is_primary.desc(), VendorDomain.created_at.desc()
).all() )
.all()
)
return domains return domains
@@ -166,11 +170,7 @@ class VendorDomainService:
logger.error(f"Error getting vendor domains: {str(e)}") logger.error(f"Error getting vendor domains: {str(e)}")
raise ValidationException("Failed to retrieve domains") raise ValidationException("Failed to retrieve domains")
def get_domain_by_id( def get_domain_by_id(self, db: Session, domain_id: int) -> VendorDomain:
self,
db: Session,
domain_id: int
) -> VendorDomain:
""" """
Get domain by ID. Get domain by ID.
@@ -190,10 +190,7 @@ class VendorDomainService:
return domain return domain
def update_domain( def update_domain(
self, self, db: Session, domain_id: int, domain_update: VendorDomainUpdate
db: Session,
domain_id: int,
domain_update: VendorDomainUpdate
) -> VendorDomain: ) -> VendorDomain:
""" """
Update domain settings. Update domain settings.
@@ -215,7 +212,9 @@ class VendorDomainService:
# If setting as primary, unset other primary domains # If setting as primary, unset other primary domains
if domain_update.is_primary: if domain_update.is_primary:
self._unset_primary_domains(db, domain.vendor_id, exclude_domain_id=domain_id) self._unset_primary_domains(
db, domain.vendor_id, exclude_domain_id=domain_id
)
domain.is_primary = True domain.is_primary = True
# If activating, check verification # If activating, check verification
@@ -240,11 +239,7 @@ class VendorDomainService:
logger.error(f"Error updating domain: {str(e)}") logger.error(f"Error updating domain: {str(e)}")
raise ValidationException("Failed to update domain") raise ValidationException("Failed to update domain")
def delete_domain( def delete_domain(self, db: Session, domain_id: int) -> str:
self,
db: Session,
domain_id: int
) -> str:
""" """
Delete a custom domain. Delete a custom domain.
@@ -277,11 +272,7 @@ class VendorDomainService:
logger.error(f"Error deleting domain: {str(e)}") logger.error(f"Error deleting domain: {str(e)}")
raise ValidationException("Failed to delete domain") raise ValidationException("Failed to delete domain")
def verify_domain( def verify_domain(self, db: Session, domain_id: int) -> Tuple[VendorDomain, str]:
self,
db: Session,
domain_id: int
) -> Tuple[VendorDomain, str]:
""" """
Verify domain ownership via DNS TXT record. Verify domain ownership via DNS TXT record.
@@ -313,8 +304,7 @@ class VendorDomainService:
# Query DNS TXT records # Query DNS TXT records
try: try:
txt_records = dns.resolver.resolve( txt_records = dns.resolver.resolve(
f"_wizamart-verify.{domain.domain}", f"_wizamart-verify.{domain.domain}", "TXT"
'TXT'
) )
# Check if verification token is present # Check if verification token is present
@@ -332,42 +322,33 @@ class VendorDomainService:
# Token not found # Token not found
raise DomainVerificationFailedException( raise DomainVerificationFailedException(
domain.domain, domain.domain, "Verification token not found in DNS records"
"Verification token not found in DNS records"
) )
except dns.resolver.NXDOMAIN: except dns.resolver.NXDOMAIN:
raise DomainVerificationFailedException( raise DomainVerificationFailedException(
domain.domain, domain.domain,
f"DNS record _wizamart-verify.{domain.domain} not found" f"DNS record _wizamart-verify.{domain.domain} not found",
) )
except dns.resolver.NoAnswer: except dns.resolver.NoAnswer:
raise DomainVerificationFailedException( raise DomainVerificationFailedException(
domain.domain, domain.domain, "No TXT records found for verification"
"No TXT records found for verification"
) )
except Exception as dns_error: except Exception as dns_error:
raise DNSVerificationException( raise DNSVerificationException(domain.domain, str(dns_error))
domain.domain,
str(dns_error)
)
except ( except (
VendorDomainNotFoundException, VendorDomainNotFoundException,
DomainAlreadyVerifiedException, DomainAlreadyVerifiedException,
DomainVerificationFailedException, DomainVerificationFailedException,
DNSVerificationException DNSVerificationException,
): ):
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error verifying domain: {str(e)}") logger.error(f"Error verifying domain: {str(e)}")
raise ValidationException("Failed to verify domain") raise ValidationException("Failed to verify domain")
def get_verification_instructions( def get_verification_instructions(self, db: Session, domain_id: int) -> dict:
self,
db: Session,
domain_id: int
) -> dict:
""" """
Get DNS verification instructions for domain. Get DNS verification instructions for domain.
@@ -390,20 +371,20 @@ class VendorDomainService:
"step1": "Go to your domain's DNS settings (at your domain registrar)", "step1": "Go to your domain's DNS settings (at your domain registrar)",
"step2": "Add a new TXT record with the following values:", "step2": "Add a new TXT record with the following values:",
"step3": "Wait for DNS propagation (5-15 minutes)", "step3": "Wait for DNS propagation (5-15 minutes)",
"step4": "Click 'Verify Domain' button in admin panel" "step4": "Click 'Verify Domain' button in admin panel",
}, },
"txt_record": { "txt_record": {
"type": "TXT", "type": "TXT",
"name": "_wizamart-verify", "name": "_wizamart-verify",
"value": domain.verification_token, "value": domain.verification_token,
"ttl": 3600 "ttl": 3600,
}, },
"common_registrars": { "common_registrars": {
"Cloudflare": "https://dash.cloudflare.com", "Cloudflare": "https://dash.cloudflare.com",
"GoDaddy": "https://dcc.godaddy.com/manage/dns", "GoDaddy": "https://dcc.godaddy.com/manage/dns",
"Namecheap": "https://www.namecheap.com/myaccount/domain-list/", "Namecheap": "https://www.namecheap.com/myaccount/domain-list/",
"Google Domains": "https://domains.google.com" "Google Domains": "https://domains.google.com",
} },
} }
# Private helper methods # Private helper methods
@@ -416,36 +397,33 @@ class VendorDomainService:
def _check_domain_limit(self, db: Session, vendor_id: int) -> None: def _check_domain_limit(self, db: Session, vendor_id: int) -> None:
"""Check if vendor has reached maximum domain limit.""" """Check if vendor has reached maximum domain limit."""
domain_count = db.query(VendorDomain).filter( domain_count = (
VendorDomain.vendor_id == vendor_id db.query(VendorDomain).filter(VendorDomain.vendor_id == vendor_id).count()
).count() )
if domain_count >= self.max_domains_per_vendor: if domain_count >= self.max_domains_per_vendor:
raise MaxDomainsReachedException(vendor_id, self.max_domains_per_vendor) raise MaxDomainsReachedException(vendor_id, self.max_domains_per_vendor)
def _domain_exists(self, db: Session, domain: str) -> bool: def _domain_exists(self, db: Session, domain: str) -> bool:
"""Check if domain already exists in system.""" """Check if domain already exists in system."""
return db.query(VendorDomain).filter( return (
VendorDomain.domain == domain db.query(VendorDomain).filter(VendorDomain.domain == domain).first()
).first() is not None is not None
)
def _validate_domain_format(self, domain: str) -> None: def _validate_domain_format(self, domain: str) -> None:
"""Validate domain format and check for reserved subdomains.""" """Validate domain format and check for reserved subdomains."""
# Check for reserved subdomains # Check for reserved subdomains
first_part = domain.split('.')[0] first_part = domain.split(".")[0]
if first_part in self.reserved_subdomains: if first_part in self.reserved_subdomains:
raise ReservedDomainException(domain, first_part) raise ReservedDomainException(domain, first_part)
def _unset_primary_domains( def _unset_primary_domains(
self, self, db: Session, vendor_id: int, exclude_domain_id: Optional[int] = None
db: Session,
vendor_id: int,
exclude_domain_id: Optional[int] = None
) -> None: ) -> None:
"""Unset all primary domains for vendor.""" """Unset all primary domains for vendor."""
query = db.query(VendorDomain).filter( query = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id, VendorDomain.vendor_id == vendor_id, VendorDomain.is_primary == True
VendorDomain.is_primary == True
) )
if exclude_domain_id: if exclude_domain_id:

View File

@@ -15,22 +15,19 @@ from typing import List, Optional, Tuple
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (InvalidVendorDataException,
VendorNotFoundException,
VendorAlreadyExistsException,
UnauthorizedVendorAccessException,
InvalidVendorDataException,
MarketplaceProductNotFoundException, MarketplaceProductNotFoundException,
ProductAlreadyExistsException,
MaxVendorsReachedException, MaxVendorsReachedException,
ValidationException, ProductAlreadyExistsException,
) UnauthorizedVendorAccessException,
from models.schema.vendor import VendorCreate ValidationException, VendorAlreadyExistsException,
from models.schema.product import ProductCreate VendorNotFoundException)
from models.database.marketplace_product import MarketplaceProduct from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor
from models.database.product import Product from models.database.product import Product
from models.database.user import User from models.database.user import User
from models.database.vendor import Vendor
from models.schema.product import ProductCreate
from models.schema.vendor import VendorCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -91,7 +88,11 @@ class VendorService:
) )
return new_vendor return new_vendor
except (VendorAlreadyExistsException, MaxVendorsReachedException, InvalidVendorDataException): except (
VendorAlreadyExistsException,
MaxVendorsReachedException,
InvalidVendorDataException,
):
db.rollback() db.rollback()
raise # Re-raise custom exceptions raise # Re-raise custom exceptions
except Exception as e: except Exception as e:
@@ -129,7 +130,10 @@ class VendorService:
if current_user.role != "admin": if current_user.role != "admin":
query = query.filter( query = query.filter(
(Vendor.is_active == True) (Vendor.is_active == True)
& ((Vendor.is_verified == True) | (Vendor.owner_user_id == current_user.id)) & (
(Vendor.is_verified == True)
| (Vendor.owner_user_id == current_user.id)
)
) )
else: else:
# Admin can apply filters # Admin can apply filters
@@ -147,7 +151,9 @@ class VendorService:
logger.error(f"Error getting vendors: {str(e)}") logger.error(f"Error getting vendors: {str(e)}")
raise ValidationException("Failed to retrieve vendors") raise ValidationException("Failed to retrieve vendors")
def get_vendor_by_code(self, db: Session, vendor_code: str, current_user: User) -> Vendor: def get_vendor_by_code(
self, db: Session, vendor_code: str, current_user: User
) -> Vendor:
""" """
Get vendor by vendor code with access control. Get vendor by vendor code with access control.
@@ -205,11 +211,15 @@ class VendorService:
""" """
try: try:
# Check if product exists # Check if product exists
marketplace_product = self._get_product_by_id_or_raise(db, product.marketplace_product_id) marketplace_product = self._get_product_by_id_or_raise(
db, product.marketplace_product_id
)
# Check if product already in vendor # Check if product already in vendor
if self._product_in_catalog(db, vendor.id, marketplace_product.id): if self._product_in_catalog(db, vendor.id, marketplace_product.id):
raise ProductAlreadyExistsException(vendor.vendor_code, product.marketplace_product_id) raise ProductAlreadyExistsException(
vendor.vendor_code, product.marketplace_product_id
)
# Create vendor -product association # Create vendor -product association
new_product = Product( new_product = Product(
@@ -225,7 +235,9 @@ class VendorService:
# Load the product relationship # Load the product relationship
db.refresh(new_product) db.refresh(new_product)
logger.info(f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}") logger.info(
f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}"
)
return new_product return new_product
except (MarketplaceProductNotFoundException, ProductAlreadyExistsException): except (MarketplaceProductNotFoundException, ProductAlreadyExistsException):
@@ -267,7 +279,9 @@ class VendorService:
try: try:
# Check access permissions # Check access permissions
if not self._can_access_vendor(vendor, current_user): if not self._can_access_vendor(vendor, current_user):
raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id) raise UnauthorizedVendorAccessException(
vendor.vendor_code, current_user.id
)
# Query vendor products # Query vendor products
query = db.query(Product).filter(Product.vendor_id == vendor.id) query = db.query(Product).filter(Product.vendor_id == vendor.id)
@@ -292,17 +306,20 @@ class VendorService:
def _validate_vendor_data(self, vendor_data: VendorCreate) -> None: def _validate_vendor_data(self, vendor_data: VendorCreate) -> None:
"""Validate vendor creation data.""" """Validate vendor creation data."""
if not vendor_data.vendor_code or not vendor_data.vendor_code.strip(): if not vendor_data.vendor_code or not vendor_data.vendor_code.strip():
raise InvalidVendorDataException("Vendor code is required", field="vendor_code") raise InvalidVendorDataException(
"Vendor code is required", field="vendor_code"
)
if not vendor_data.vendor_name or not vendor_data.vendor_name.strip(): if not vendor_data.vendor_name or not vendor_data.vendor_name.strip():
raise InvalidVendorDataException("Vendor name is required", field="name") raise InvalidVendorDataException("Vendor name is required", field="name")
# Validate vendor code format (alphanumeric, underscores, hyphens) # Validate vendor code format (alphanumeric, underscores, hyphens)
import re import re
if not re.match(r'^[A-Za-z0-9_-]+$', vendor_data.vendor_code):
if not re.match(r"^[A-Za-z0-9_-]+$", vendor_data.vendor_code):
raise InvalidVendorDataException( raise InvalidVendorDataException(
"Vendor code can only contain letters, numbers, underscores, and hyphens", "Vendor code can only contain letters, numbers, underscores, and hyphens",
field="vendor_code" field="vendor_code",
) )
def _check_vendor_limit(self, db: Session, user: User) -> None: def _check_vendor_limit(self, db: Session, user: User) -> None:
@@ -310,7 +327,9 @@ class VendorService:
if user.role == "admin": if user.role == "admin":
return # Admins have no limit return # Admins have no limit
user_vendor_count = db.query(Vendor).filter(Vendor.owner_user_id == user.id).count() user_vendor_count = (
db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
)
max_vendors = 5 # Configure this as needed max_vendors = 5 # Configure this as needed
if user_vendor_count >= max_vendors: if user_vendor_count >= max_vendors:
@@ -321,25 +340,35 @@ class VendorService:
return ( return (
db.query(Vendor) db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == vendor_code.upper()) .filter(func.upper(Vendor.vendor_code) == vendor_code.upper())
.first() is not None .first()
is not None
) )
def _get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct: def _get_product_by_id_or_raise(
self, db: Session, marketplace_product_id: str
) -> MarketplaceProduct:
"""Get product by ID or raise exception.""" """Get product by ID or raise exception."""
product = db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first() product = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
.first()
)
if not product: if not product:
raise MarketplaceProductNotFoundException(marketplace_product_id) raise MarketplaceProductNotFoundException(marketplace_product_id)
return product return product
def _product_in_catalog(self, db: Session, vendor_id: int, marketplace_product_id: int) -> bool: def _product_in_catalog(
self, db: Session, vendor_id: int, marketplace_product_id: int
) -> bool:
"""Check if product is already in vendor.""" """Check if product is already in vendor."""
return ( return (
db.query(Product) db.query(Product)
.filter( .filter(
Product.vendor_id == vendor_id, Product.vendor_id == vendor_id,
Product.marketplace_product_id == marketplace_product_id Product.marketplace_product_id == marketplace_product_id,
) )
.first() is not None .first()
is not None
) )
def _can_access_vendor(self, vendor: Vendor, user: User) -> bool: def _can_access_vendor(self, vendor: Vendor, user: User) -> bool:
@@ -355,5 +384,6 @@ class VendorService:
"""Check if user is vendor owner.""" """Check if user is vendor owner."""
return vendor.owner_user_id == user.id return vendor.owner_user_id == user.id
# Create service instance following the same pattern as other services # Create service instance following the same pattern as other services
vendor_service = VendorService() vendor_service = VendorService()

View File

@@ -11,23 +11,20 @@ Handles:
import logging import logging
import secrets import secrets
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.permissions import get_preset_permissions from app.core.permissions import get_preset_permissions
from app.exceptions import ( from app.exceptions import (CannotRemoveOwnerException,
TeamMemberAlreadyExistsException,
InvalidInvitationTokenException, InvalidInvitationTokenException,
TeamInvitationAlreadyAcceptedException,
MaxTeamMembersReachedException, MaxTeamMembersReachedException,
UserNotFoundException, TeamInvitationAlreadyAcceptedException,
VendorNotFoundException, TeamMemberAlreadyExistsException,
CannotRemoveOwnerException, UserNotFoundException, VendorNotFoundException)
)
from models.database.user import User
from models.database.vendor import Vendor, VendorUser, VendorUserType, Role
from middleware.auth import AuthManager from middleware.auth import AuthManager
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser, VendorUserType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -69,10 +66,14 @@ class VendorTeamService:
""" """
try: try:
# Check team size limit # Check team size limit
current_team_size = db.query(VendorUser).filter( current_team_size = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id, VendorUser.vendor_id == vendor.id,
VendorUser.is_active == True, VendorUser.is_active == True,
).count() )
.count()
)
if current_team_size >= self.max_team_members: if current_team_size >= self.max_team_members:
raise MaxTeamMembersReachedException( raise MaxTeamMembersReachedException(
@@ -85,22 +86,34 @@ class VendorTeamService:
if user: if user:
# Check if already a member # Check if already a member
existing_membership = db.query(VendorUser).filter( existing_membership = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id, VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user.id, VendorUser.user_id == user.id,
).first() )
.first()
)
if existing_membership: if existing_membership:
if existing_membership.is_active: if existing_membership.is_active:
raise TeamMemberAlreadyExistsException(email, vendor.vendor_code) raise TeamMemberAlreadyExistsException(
email, vendor.vendor_code
)
# Reactivate old membership # Reactivate old membership
existing_membership.is_active = False # Will be activated on acceptance existing_membership.is_active = (
existing_membership.invitation_token = self._generate_invitation_token() False # Will be activated on acceptance
)
existing_membership.invitation_token = (
self._generate_invitation_token()
)
existing_membership.invitation_sent_at = datetime.utcnow() existing_membership.invitation_sent_at = datetime.utcnow()
existing_membership.invitation_accepted_at = None existing_membership.invitation_accepted_at = None
db.commit() db.commit()
logger.info(f"Re-invited user {email} to vendor {vendor.vendor_code}") logger.info(
f"Re-invited user {email} to vendor {vendor.vendor_code}"
)
return { return {
"invitation_token": existing_membership.invitation_token, "invitation_token": existing_membership.invitation_token,
"email": email, "email": email,
@@ -108,7 +121,7 @@ class VendorTeamService:
} }
else: else:
# Create new user account (inactive until invitation accepted) # Create new user account (inactive until invitation accepted)
username = email.split('@')[0] username = email.split("@")[0]
# Ensure unique username # Ensure unique username
base_username = username base_username = username
counter = 1 counter = 1
@@ -201,9 +214,13 @@ class VendorTeamService:
""" """
try: try:
# Find invitation # Find invitation
vendor_user = db.query(VendorUser).filter( vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.invitation_token == invitation_token, VendorUser.invitation_token == invitation_token,
).first() )
.first()
)
if not vendor_user: if not vendor_user:
raise InvalidInvitationTokenException() raise InvalidInvitationTokenException()
@@ -247,7 +264,10 @@ class VendorTeamService:
"role": vendor_user.role.name if vendor_user.role else "member", "role": vendor_user.role.name if vendor_user.role else "member",
} }
except (InvalidInvitationTokenException, TeamInvitationAlreadyAcceptedException): except (
InvalidInvitationTokenException,
TeamInvitationAlreadyAcceptedException,
):
raise raise
except Exception as e: except Exception as e:
db.rollback() db.rollback()
@@ -274,10 +294,14 @@ class VendorTeamService:
True if removed True if removed
""" """
try: try:
vendor_user = db.query(VendorUser).filter( vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id, VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id, VendorUser.user_id == user_id,
).first() )
.first()
)
if not vendor_user: if not vendor_user:
raise UserNotFoundException(str(user_id)) raise UserNotFoundException(str(user_id))
@@ -322,10 +346,14 @@ class VendorTeamService:
Updated VendorUser Updated VendorUser
""" """
try: try:
vendor_user = db.query(VendorUser).filter( vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id, VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id, VendorUser.user_id == user_id,
).first() )
.first()
)
if not vendor_user: if not vendor_user:
raise UserNotFoundException(str(user_id)) raise UserNotFoundException(str(user_id))
@@ -387,7 +415,8 @@ class VendorTeamService:
members = [] members = []
for vu in vendor_users: for vu in vendor_users:
members.append({ members.append(
{
"id": vu.user.id, "id": vu.user.id,
"email": vu.user.email, "email": vu.user.email,
"username": vu.user.username, "username": vu.user.username,
@@ -400,7 +429,8 @@ class VendorTeamService:
"invitation_pending": vu.is_invitation_pending, "invitation_pending": vu.is_invitation_pending,
"invited_at": vu.invitation_sent_at, "invited_at": vu.invitation_sent_at,
"accepted_at": vu.invitation_accepted_at, "accepted_at": vu.invitation_accepted_at,
}) }
)
return members return members
@@ -419,10 +449,14 @@ class VendorTeamService:
) -> Role: ) -> Role:
"""Get existing role or create new one with preset/custom permissions.""" """Get existing role or create new one with preset/custom permissions."""
# Try to find existing role with same name # Try to find existing role with same name
role = db.query(Role).filter( role = (
db.query(Role)
.filter(
Role.vendor_id == vendor.id, Role.vendor_id == vendor.id,
Role.name == role_name, Role.name == role_name,
).first() )
.first()
)
if role and custom_permissions is None: if role and custom_permissions is None:
# Use existing role # Use existing role

View File

@@ -8,32 +8,24 @@ Handles theme CRUD operations, preset application, and validation.
import logging import logging
import re import re
from typing import Optional, Dict, List from typing import Dict, List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from models.database.vendor import Vendor from app.core.theme_presets import (THEME_PRESETS, apply_preset,
from models.database.vendor_theme import VendorTheme get_available_presets, get_preset_preview)
from models.schema.vendor_theme import (
VendorThemeUpdate,
ThemePresetPreview
)
from app.exceptions.vendor import VendorNotFoundException from app.exceptions.vendor import VendorNotFoundException
from app.exceptions.vendor_theme import ( from app.exceptions.vendor_theme import (InvalidColorFormatException,
VendorThemeNotFoundException, InvalidFontFamilyException,
InvalidThemeDataException, InvalidThemeDataException,
ThemeOperationException,
ThemePresetAlreadyAppliedException,
ThemePresetNotFoundException, ThemePresetNotFoundException,
ThemeValidationException, ThemeValidationException,
InvalidColorFormatException, VendorThemeNotFoundException)
InvalidFontFamilyException, from models.database.vendor import Vendor
ThemePresetAlreadyAppliedException, from models.database.vendor_theme import VendorTheme
ThemeOperationException from models.schema.vendor_theme import ThemePresetPreview, VendorThemeUpdate
)
from app.core.theme_presets import (
apply_preset,
get_available_presets,
get_preset_preview,
THEME_PRESETS
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,9 +63,9 @@ class VendorThemeService:
Raises: Raises:
VendorNotFoundException: If vendor not found VendorNotFoundException: If vendor not found
""" """
vendor = db.query(Vendor).filter( vendor = (
Vendor.vendor_code == vendor_code.upper() db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
).first() )
if not vendor: if not vendor:
self.logger.warning(f"Vendor not found: {vendor_code}") self.logger.warning(f"Vendor not found: {vendor_code}")
@@ -105,12 +97,12 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code) vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme # Get theme
theme = db.query(VendorTheme).filter( theme = db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
VendorTheme.vendor_id == vendor.id
).first()
if not theme: if not theme:
self.logger.info(f"No custom theme for vendor {vendor_code}, returning default") self.logger.info(
f"No custom theme for vendor {vendor_code}, returning default"
)
return self._get_default_theme() return self._get_default_theme()
return theme.to_dict() return theme.to_dict()
@@ -130,23 +122,16 @@ class VendorThemeService:
"accent": "#ec4899", "accent": "#ec4899",
"background": "#ffffff", "background": "#ffffff",
"text": "#1f2937", "text": "#1f2937",
"border": "#e5e7eb" "border": "#e5e7eb",
},
"fonts": {
"heading": "Inter, sans-serif",
"body": "Inter, sans-serif"
}, },
"fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"branding": { "branding": {
"logo": None, "logo": None,
"logo_dark": None, "logo_dark": None,
"favicon": None, "favicon": None,
"banner": None "banner": None,
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}, },
"layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"social_links": {}, "social_links": {},
"custom_css": None, "custom_css": None,
"css_variables": { "css_variables": {
@@ -158,7 +143,7 @@ class VendorThemeService:
"--color-border": "#e5e7eb", "--color-border": "#e5e7eb",
"--font-heading": "Inter, sans-serif", "--font-heading": "Inter, sans-serif",
"--font-body": "Inter, sans-serif", "--font-body": "Inter, sans-serif",
} },
} }
# ============================================================================ # ============================================================================
@@ -166,10 +151,7 @@ class VendorThemeService:
# ============================================================================ # ============================================================================
def update_theme( def update_theme(
self, self, db: Session, vendor_code: str, theme_data: VendorThemeUpdate
db: Session,
vendor_code: str,
theme_data: VendorThemeUpdate
) -> VendorTheme: ) -> VendorTheme:
""" """
Update or create theme for vendor. Update or create theme for vendor.
@@ -194,9 +176,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code) vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme # Get or create theme
theme = db.query(VendorTheme).filter( theme = (
VendorTheme.vendor_id == vendor.id db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
).first() )
if not theme: if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}") self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -224,15 +206,11 @@ class VendorThemeService:
db.rollback() db.rollback()
self.logger.error(f"Failed to update theme for vendor {vendor_code}: {e}") self.logger.error(f"Failed to update theme for vendor {vendor_code}: {e}")
raise ThemeOperationException( raise ThemeOperationException(
operation="update", operation="update", vendor_code=vendor_code, reason=str(e)
vendor_code=vendor_code,
reason=str(e)
) )
def _apply_theme_updates( def _apply_theme_updates(
self, self, theme: VendorTheme, theme_data: VendorThemeUpdate
theme: VendorTheme,
theme_data: VendorThemeUpdate
) -> None: ) -> None:
""" """
Apply theme updates to theme object. Apply theme updates to theme object.
@@ -251,30 +229,30 @@ class VendorThemeService:
# Update fonts # Update fonts
if theme_data.fonts: if theme_data.fonts:
if theme_data.fonts.get('heading'): if theme_data.fonts.get("heading"):
theme.font_family_heading = theme_data.fonts['heading'] theme.font_family_heading = theme_data.fonts["heading"]
if theme_data.fonts.get('body'): if theme_data.fonts.get("body"):
theme.font_family_body = theme_data.fonts['body'] theme.font_family_body = theme_data.fonts["body"]
# Update branding # Update branding
if theme_data.branding: if theme_data.branding:
if theme_data.branding.get('logo') is not None: if theme_data.branding.get("logo") is not None:
theme.logo_url = theme_data.branding['logo'] theme.logo_url = theme_data.branding["logo"]
if theme_data.branding.get('logo_dark') is not None: if theme_data.branding.get("logo_dark") is not None:
theme.logo_dark_url = theme_data.branding['logo_dark'] theme.logo_dark_url = theme_data.branding["logo_dark"]
if theme_data.branding.get('favicon') is not None: if theme_data.branding.get("favicon") is not None:
theme.favicon_url = theme_data.branding['favicon'] theme.favicon_url = theme_data.branding["favicon"]
if theme_data.branding.get('banner') is not None: if theme_data.branding.get("banner") is not None:
theme.banner_url = theme_data.branding['banner'] theme.banner_url = theme_data.branding["banner"]
# Update layout # Update layout
if theme_data.layout: if theme_data.layout:
if theme_data.layout.get('style'): if theme_data.layout.get("style"):
theme.layout_style = theme_data.layout['style'] theme.layout_style = theme_data.layout["style"]
if theme_data.layout.get('header'): if theme_data.layout.get("header"):
theme.header_style = theme_data.layout['header'] theme.header_style = theme_data.layout["header"]
if theme_data.layout.get('product_card'): if theme_data.layout.get("product_card"):
theme.product_card_style = theme_data.layout['product_card'] theme.product_card_style = theme_data.layout["product_card"]
# Update custom CSS # Update custom CSS
if theme_data.custom_css is not None: if theme_data.custom_css is not None:
@@ -289,10 +267,7 @@ class VendorThemeService:
# ============================================================================ # ============================================================================
def apply_theme_preset( def apply_theme_preset(
self, self, db: Session, vendor_code: str, preset_name: str
db: Session,
vendor_code: str,
preset_name: str
) -> VendorTheme: ) -> VendorTheme:
""" """
Apply a theme preset to vendor. Apply a theme preset to vendor.
@@ -322,9 +297,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code) vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme # Get or create theme
theme = db.query(VendorTheme).filter( theme = (
VendorTheme.vendor_id == vendor.id db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
).first() )
if not theme: if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}") self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -338,7 +313,9 @@ class VendorThemeService:
db.commit() db.commit()
db.refresh(theme) db.refresh(theme)
self.logger.info(f"Preset '{preset_name}' applied successfully to vendor {vendor_code}") self.logger.info(
f"Preset '{preset_name}' applied successfully to vendor {vendor_code}"
)
return theme return theme
except (VendorNotFoundException, ThemePresetNotFoundException): except (VendorNotFoundException, ThemePresetNotFoundException):
@@ -349,9 +326,7 @@ class VendorThemeService:
db.rollback() db.rollback()
self.logger.error(f"Failed to apply preset to vendor {vendor_code}: {e}") self.logger.error(f"Failed to apply preset to vendor {vendor_code}: {e}")
raise ThemeOperationException( raise ThemeOperationException(
operation="apply_preset", operation="apply_preset", vendor_code=vendor_code, reason=str(e)
vendor_code=vendor_code,
reason=str(e)
) )
def get_available_presets(self) -> List[ThemePresetPreview]: def get_available_presets(self) -> List[ThemePresetPreview]:
@@ -399,9 +374,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code) vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme # Get theme
theme = db.query(VendorTheme).filter( theme = (
VendorTheme.vendor_id == vendor.id db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
).first() )
if not theme: if not theme:
raise VendorThemeNotFoundException(vendor_code) raise VendorThemeNotFoundException(vendor_code)
@@ -423,9 +398,7 @@ class VendorThemeService:
db.rollback() db.rollback()
self.logger.error(f"Failed to delete theme for vendor {vendor_code}: {e}") self.logger.error(f"Failed to delete theme for vendor {vendor_code}: {e}")
raise ThemeOperationException( raise ThemeOperationException(
operation="delete", operation="delete", vendor_code=vendor_code, reason=str(e)
vendor_code=vendor_code,
reason=str(e)
) )
# ============================================================================ # ============================================================================
@@ -459,9 +432,9 @@ class VendorThemeService:
# Validate layout values # Validate layout values
if theme_data.layout: if theme_data.layout:
valid_layouts = { valid_layouts = {
'style': ['grid', 'list', 'masonry'], "style": ["grid", "list", "masonry"],
'header': ['fixed', 'static', 'transparent'], "header": ["fixed", "static", "transparent"],
'product_card': ['modern', 'classic', 'minimal'] "product_card": ["modern", "classic", "minimal"],
} }
for layout_key, layout_value in theme_data.layout.items(): for layout_key, layout_value in theme_data.layout.items():
@@ -472,7 +445,7 @@ class VendorThemeService:
field=layout_key, field=layout_key,
validation_errors={ validation_errors={
layout_key: f"Must be one of: {', '.join(valid_layouts[layout_key])}" layout_key: f"Must be one of: {', '.join(valid_layouts[layout_key])}"
} },
) )
def _is_valid_color(self, color: str) -> bool: def _is_valid_color(self, color: str) -> bool:
@@ -489,7 +462,7 @@ class VendorThemeService:
return False return False
# Check for hex color format (#RGB or #RRGGBB) # Check for hex color format (#RGB or #RRGGBB)
hex_pattern = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$' hex_pattern = r"^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
return bool(re.match(hex_pattern, color)) return bool(re.match(hex_pattern, color))
def _is_valid_font(self, font: str) -> bool: def _is_valid_font(self, font: str) -> bool:

View File

@@ -3,9 +3,9 @@ import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from app.core.database import SessionLocal from app.core.database import SessionLocal
from app.utils.csv_processor import CSVProcessor
from models.database.marketplace_import_job import MarketplaceImportJob from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor from models.database.vendor import Vendor
from app.utils.csv_processor import CSVProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ async def process_marketplace_import(
url: str, url: str,
marketplace: str, marketplace: str,
vendor_id: int, # FIXED: Changed from vendor_name to vendor_id vendor_id: int, # FIXED: Changed from vendor_name to vendor_id
batch_size: int = 1000 batch_size: int = 1000,
): ):
"""Background task to process marketplace CSV import.""" """Background task to process marketplace CSV import."""
db = SessionLocal() db = SessionLocal()
@@ -59,7 +59,7 @@ async def process_marketplace_import(
marketplace, marketplace,
vendor_id, # FIXED: Pass vendor_id instead of vendor_name vendor_id, # FIXED: Pass vendor_id instead of vendor_name
batch_size, batch_size,
db db,
) )
# Update job with results # Update job with results

View File

@@ -267,7 +267,9 @@ class CSVProcessor:
# Validate required fields # Validate required fields
if not product_data.get("marketplace_product_id"): if not product_data.get("marketplace_product_id"):
logger.warning(f"Row {index}: Missing marketplace_product_id, skipping") logger.warning(
f"Row {index}: Missing marketplace_product_id, skipping"
)
errors += 1 errors += 1
continue continue
@@ -279,7 +281,10 @@ class CSVProcessor:
# Check if product exists # Check if product exists
existing_product = ( existing_product = (
db.query(MarketplaceProduct) db.query(MarketplaceProduct)
.filter(MarketplaceProduct.marketplace_product_id == literal(product_data["marketplace_product_id"])) .filter(
MarketplaceProduct.marketplace_product_id
== literal(product_data["marketplace_product_id"])
)
.first() .first()
) )

View File

@@ -109,7 +109,9 @@ class PriceProcessor:
r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)), r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)),
} }
def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]: def parse_price_currency(
self, price_str: any
) -> Tuple[Optional[str], Optional[str]]:
""" """
Parse a price string to extract the numeric value and currency. Parse a price string to extract the numeric value and currency.

View File

@@ -7,6 +7,7 @@ This module provides utility functions and classes to interact with a database u
""" """
import logging import logging
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool

142
main.py
View File

@@ -9,13 +9,13 @@ Multi-tenant e-commerce marketplace platform with:
- Middleware stack for context injection - Middleware stack for context injection
""" """
import sys
import io import io
import sys
# Fix Windows console encoding issues (must be at the very top) # Fix Windows console encoding issues (must be at the very top)
if sys.platform == 'win32': if sys.platform == "win32":
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -23,27 +23,25 @@ from pathlib import Path
from fastapi import Depends, FastAPI, HTTPException, Request, Response from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse, FileResponse from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.main import api_router from app.api.main import api_router
# Import page routers
from app.routes import admin_pages, vendor_pages, shop_pages
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.core.lifespan import lifespan from app.core.lifespan import lifespan
from app.exceptions.handler import setup_exception_handlers
from app.exceptions import ServiceUnavailableException from app.exceptions import ServiceUnavailableException
from app.exceptions.handler import setup_exception_handlers
# Import page routers
from app.routes import admin_pages, shop_pages, vendor_pages
from middleware.context import ContextMiddleware
from middleware.logging import LoggingMiddleware
from middleware.theme_context import ThemeContextMiddleware
# Import REFACTORED class-based middleware # Import REFACTORED class-based middleware
from middleware.vendor_context import VendorContextMiddleware from middleware.vendor_context import VendorContextMiddleware
from middleware.context import ContextMiddleware
from middleware.theme_context import ThemeContextMiddleware
from middleware.logging import LoggingMiddleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -146,6 +144,7 @@ app.include_router(api_router, prefix="/api")
# FAVICON ROUTES (Must be registered BEFORE page routers) # FAVICON ROUTES (Must be registered BEFORE page routers)
# ============================================================================ # ============================================================================
def serve_favicon() -> Response: def serve_favicon() -> Response:
""" """
Serve favicon with caching headers. Serve favicon with caching headers.
@@ -164,7 +163,7 @@ def serve_favicon() -> Response:
media_type="image/x-icon", media_type="image/x-icon",
headers={ headers={
"Cache-Control": "public, max-age=86400", # Cache for 1 day "Cache-Control": "public, max-age=86400", # Cache for 1 day
} },
) )
return Response(status_code=204) return Response(status_code=204)
@@ -194,10 +193,7 @@ logger.info("=" * 80)
# Admin pages # Admin pages
logger.info("Registering admin page routes: /admin/*") logger.info("Registering admin page routes: /admin/*")
app.include_router( app.include_router(
admin_pages.router, admin_pages.router, prefix="/admin", tags=["admin-pages"], include_in_schema=False
prefix="/admin",
tags=["admin-pages"],
include_in_schema=False
) )
# Vendor management pages (dashboard, products, orders, etc.) # Vendor management pages (dashboard, products, orders, etc.)
@@ -206,7 +202,7 @@ app.include_router(
vendor_pages.router, vendor_pages.router,
prefix="/vendor", prefix="/vendor",
tags=["vendor-pages"], tags=["vendor-pages"],
include_in_schema=False include_in_schema=False,
) )
# Customer shop pages - Register at TWO prefixes: # Customer shop pages - Register at TWO prefixes:
@@ -217,46 +213,42 @@ logger.info(" - /shop/* (subdomain/custom domain mode)")
logger.info(" - /vendors/{code}/shop/* (path-based development mode)") logger.info(" - /vendors/{code}/shop/* (path-based development mode)")
app.include_router( app.include_router(
shop_pages.router, shop_pages.router, prefix="/shop", tags=["shop-pages"], include_in_schema=False
prefix="/shop",
tags=["shop-pages"],
include_in_schema=False
) )
app.include_router( app.include_router(
shop_pages.router, shop_pages.router,
prefix="/vendors/{vendor_code}/shop", prefix="/vendors/{vendor_code}/shop",
tags=["shop-pages"], tags=["shop-pages"],
include_in_schema=False include_in_schema=False,
) )
# Add handler for /vendors/{vendor_code}/ root path # Add handler for /vendors/{vendor_code}/ root path
@app.get("/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False) @app.get(
async def vendor_root_path(vendor_code: str, request: Request, db: Session = Depends(get_db)): "/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_root_path(
vendor_code: str, request: Request, db: Session = Depends(get_db)
):
"""Handle vendor root path (e.g., /vendors/wizamart/)""" """Handle vendor root path (e.g., /vendors/wizamart/)"""
# Vendor should already be in request.state from middleware # Vendor should already be in request.state from middleware
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if not vendor: if not vendor:
raise HTTPException(status_code=404, detail=f"Vendor '{vendor_code}' not found") raise HTTPException(status_code=404, detail=f"Vendor '{vendor_code}' not found")
from app.services.content_page_service import content_page_service
from app.routes.shop_pages import get_shop_context from app.routes.shop_pages import get_shop_context
from app.services.content_page_service import content_page_service
# Try to find landing page # Try to find landing page
landing_page = content_page_service.get_page_for_vendor( landing_page = content_page_service.get_page_for_vendor(
db, db, slug="landing", vendor_id=vendor.id, include_unpublished=False
slug="landing",
vendor_id=vendor.id,
include_unpublished=False
) )
if not landing_page: if not landing_page:
landing_page = content_page_service.get_page_for_vendor( landing_page = content_page_service.get_page_for_vendor(
db, db, slug="home", vendor_id=vendor.id, include_unpublished=False
slug="home",
vendor_id=vendor.id,
include_unpublished=False
) )
if landing_page: if landing_page:
@@ -265,8 +257,7 @@ async def vendor_root_path(vendor_code: str, request: Request, db: Session = Dep
template_path = f"vendor/landing-{template_name}.html" template_path = f"vendor/landing-{template_name}.html"
return templates.TemplateResponse( return templates.TemplateResponse(
template_path, template_path, get_shop_context(request, db=db, page=landing_page)
get_shop_context(request, db=db, page=landing_page)
) )
else: else:
# No landing page - redirect to shop # No landing page - redirect to shop
@@ -298,22 +289,16 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
db, db,
slug="platform_homepage", slug="platform_homepage",
vendor_id=None, # Platform-level page vendor_id=None, # Platform-level page
include_unpublished=False include_unpublished=False,
) )
# Load header and footer navigation # Load header and footer navigation
header_pages = content_page_service.list_pages_for_vendor( header_pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=None, header_only=True, include_unpublished=False
vendor_id=None,
header_only=True,
include_unpublished=False
) )
footer_pages = content_page_service.list_pages_for_vendor( footer_pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=None, footer_only=True, include_unpublished=False
vendor_id=None,
footer_only=True,
include_unpublished=False
) )
if homepage: if homepage:
@@ -330,7 +315,7 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
"page": homepage, "page": homepage,
"header_pages": header_pages, "header_pages": header_pages,
"footer_pages": footer_pages, "footer_pages": footer_pages,
} },
) )
else: else:
# Fallback to default static template # Fallback to default static template
@@ -342,15 +327,13 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
"request": request, "request": request,
"header_pages": header_pages, "header_pages": header_pages,
"footer_pages": footer_pages, "footer_pages": footer_pages,
} },
) )
@app.get("/{slug}", response_class=HTMLResponse, include_in_schema=False) @app.get("/{slug}", response_class=HTMLResponse, include_in_schema=False)
async def platform_content_page( async def platform_content_page(
request: Request, request: Request, slug: str, db: Session = Depends(get_db)
slug: str,
db: Session = Depends(get_db)
): ):
""" """
Platform content pages: /about, /faq, /terms, /contact, etc. Platform content pages: /about, /faq, /terms, /contact, etc.
@@ -366,10 +349,7 @@ async def platform_content_page(
# Load page from CMS # Load page from CMS
page = content_page_service.get_page_for_vendor( page = content_page_service.get_page_for_vendor(
db, db, slug=slug, vendor_id=None, include_unpublished=False # Platform pages only
slug=slug,
vendor_id=None, # Platform pages only
include_unpublished=False
) )
if not page: if not page:
@@ -378,17 +358,11 @@ async def platform_content_page(
# Load header and footer navigation # Load header and footer navigation
header_pages = content_page_service.list_pages_for_vendor( header_pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=None, header_only=True, include_unpublished=False
vendor_id=None,
header_only=True,
include_unpublished=False
) )
footer_pages = content_page_service.list_pages_for_vendor( footer_pages = content_page_service.list_pages_for_vendor(
db, db, vendor_id=None, footer_only=True, include_unpublished=False
vendor_id=None,
footer_only=True,
include_unpublished=False
) )
logger.info(f"[PLATFORM] Rendering content page: {page.title} (/{slug})") logger.info(f"[PLATFORM] Rendering content page: {page.title} (/{slug})")
@@ -400,7 +374,7 @@ async def platform_content_page(
"page": page, "page": page,
"header_pages": header_pages, "header_pages": header_pages,
"footer_pages": footer_pages, "footer_pages": footer_pages,
} },
) )
@@ -411,8 +385,8 @@ logger.info("=" * 80)
logger.info("REGISTERED ROUTES SUMMARY") logger.info("REGISTERED ROUTES SUMMARY")
logger.info("=" * 80) logger.info("=" * 80)
for route in app.routes: for route in app.routes:
if hasattr(route, 'methods') and hasattr(route, 'path'): if hasattr(route, "methods") and hasattr(route, "path"):
methods = ', '.join(route.methods) if route.methods else 'N/A' methods = ", ".join(route.methods) if route.methods else "N/A"
logger.info(f" {methods:<10} {route.path:<60}") logger.info(f" {methods:<10} {route.path:<60}")
logger.info("=" * 80) logger.info("=" * 80)
@@ -420,6 +394,7 @@ logger.info("=" * 80)
# API ROUTES (JSON Responses) # API ROUTES (JSON Responses)
# ============================================================================ # ============================================================================
# Public Routes (no authentication required) # Public Routes (no authentication required)
@app.get("/", response_class=HTMLResponse, include_in_schema=False) @app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def root(request: Request, db: Session = Depends(get_db)): async def root(request: Request, db: Session = Depends(get_db)):
@@ -428,7 +403,7 @@ async def root(request: Request, db: Session = Depends(get_db)):
- If vendor detected (domain/subdomain): Show vendor landing page or redirect to shop - If vendor detected (domain/subdomain): Show vendor landing page or redirect to shop
- If no vendor (platform root): Redirect to documentation - If no vendor (platform root): Redirect to documentation
""" """
vendor = getattr(request.state, 'vendor', None) vendor = getattr(request.state, "vendor", None)
if vendor: if vendor:
# Vendor context detected - serve landing page # Vendor context detected - serve landing page
@@ -436,19 +411,13 @@ async def root(request: Request, db: Session = Depends(get_db)):
# Try to find landing page (slug='landing' or 'home') # Try to find landing page (slug='landing' or 'home')
landing_page = content_page_service.get_page_for_vendor( landing_page = content_page_service.get_page_for_vendor(
db, db, slug="landing", vendor_id=vendor.id, include_unpublished=False
slug="landing",
vendor_id=vendor.id,
include_unpublished=False
) )
if not landing_page: if not landing_page:
# Try 'home' slug as fallback # Try 'home' slug as fallback
landing_page = content_page_service.get_page_for_vendor( landing_page = content_page_service.get_page_for_vendor(
db, db, slug="home", vendor_id=vendor.id, include_unpublished=False
slug="home",
vendor_id=vendor.id,
include_unpublished=False
) )
if landing_page: if landing_page:
@@ -459,17 +428,26 @@ async def root(request: Request, db: Session = Depends(get_db)):
template_path = f"vendor/landing-{template_name}.html" template_path = f"vendor/landing-{template_name}.html"
return templates.TemplateResponse( return templates.TemplateResponse(
template_path, template_path, get_shop_context(request, db=db, page=landing_page)
get_shop_context(request, db=db, page=landing_page)
) )
else: else:
# No landing page - redirect to shop # No landing page - redirect to shop
vendor_context = getattr(request.state, 'vendor_context', None) vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown' access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
if access_method == "path": if access_method == "path":
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/' full_prefix = (
return RedirectResponse(url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302) vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
return RedirectResponse(
url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302
)
else: else:
# Domain/subdomain # Domain/subdomain
return RedirectResponse(url="/shop/", status_code=302) return RedirectResponse(url="/shop/", status_code=302)

View File

@@ -26,14 +26,10 @@ from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.exceptions import ( from app.exceptions import (AdminRequiredException,
AdminRequiredException, InsufficientPermissionsException,
InvalidTokenException, InvalidCredentialsException, InvalidTokenException,
TokenExpiredException, TokenExpiredException, UserNotActiveException)
UserNotActiveException,
InvalidCredentialsException,
InsufficientPermissionsException
)
from models.database.user import User from models.database.user import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -99,7 +95,9 @@ class AuthManager:
""" """
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: def authenticate_user(
self, db: Session, username: str, password: str
) -> Optional[User]:
"""Authenticate user credentials against the database. """Authenticate user credentials against the database.
Supports authentication using either username or email address. Supports authentication using either username or email address.
@@ -201,7 +199,9 @@ class AuthManager:
raise InvalidTokenException("Token missing expiration") raise InvalidTokenException("Token missing expiration")
# Check if token has expired (additional check beyond jwt.decode) # Check if token has expired (additional check beyond jwt.decode)
if datetime.now(timezone.utc) > datetime.fromtimestamp(exp, tz=timezone.utc): if datetime.now(timezone.utc) > datetime.fromtimestamp(
exp, tz=timezone.utc
):
raise TokenExpiredException() raise TokenExpiredException()
# Validate user identifier claim exists # Validate user identifier claim exists
@@ -214,7 +214,9 @@ class AuthManager:
"user_id": int(user_id), "user_id": int(user_id),
"username": payload.get("username"), "username": payload.get("username"),
"email": payload.get("email"), "email": payload.get("email"),
"role": payload.get("role", "user"), # Default to "user" role if not specified "role": payload.get(
"role", "user"
), # Default to "user" role if not specified
} }
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
@@ -232,7 +234,9 @@ class AuthManager:
logger.error(f"Token verification error: {e}") logger.error(f"Token verification error: {e}")
raise InvalidTokenException("Authentication failed") raise InvalidTokenException("Authentication failed")
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User: def get_current_user(
self, db: Session, credentials: HTTPAuthorizationCredentials
) -> User:
"""Extract and validate the current authenticated user from request credentials. """Extract and validate the current authenticated user from request credentials.
Verifies the JWT token from the Authorization header, looks up the user Verifies the JWT token from the Authorization header, looks up the user
@@ -286,8 +290,10 @@ class AuthManager:
# This will only execute if user has "admin" role # This will only execute if user has "admin" role
pass pass
""" """
def decorator(func): def decorator(func):
"""Decorator that wraps the function with role checking.""" """Decorator that wraps the function with role checking."""
def wrapper(current_user: User, *args, **kwargs): def wrapper(current_user: User, *args, **kwargs):
# Check if current user has the required role # Check if current user has the required role
if current_user.role != required_role: if current_user.role != required_role:
@@ -339,8 +345,7 @@ class AuthManager:
# Check if user has vendor or admin role (admins have full access) # Check if user has vendor or admin role (admins have full access)
if current_user.role not in ["vendor", "admin"]: if current_user.role not in ["vendor", "admin"]:
raise InsufficientPermissionsException( raise InsufficientPermissionsException(
message="Vendor access required", message="Vendor access required", required_permission="vendor"
required_permission="vendor"
) )
return current_user return current_user
@@ -363,7 +368,7 @@ class AuthManager:
if current_user.role not in ["customer", "admin"]: if current_user.role not in ["customer", "admin"]:
raise InsufficientPermissionsException( raise InsufficientPermissionsException(
message="Customer account access required", message="Customer account access required",
required_permission="customer" required_permission="customer",
) )
return current_user return current_user

View File

@@ -17,14 +17,16 @@ Class-based middleware provides:
import logging import logging
from enum import Enum from enum import Enum
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RequestContext(str, Enum): class RequestContext(str, Enum):
"""Request context types for the application.""" """Request context types for the application."""
API = "api" API = "api"
ADMIN = "admin" ADMIN = "admin"
VENDOR_DASHBOARD = "vendor" VENDOR_DASHBOARD = "vendor"
@@ -59,7 +61,7 @@ class ContextManager:
# Use clean_path if available (extracted by vendor_context_middleware) # Use clean_path if available (extracted by vendor_context_middleware)
# Falls back to original path if clean_path not set # Falls back to original path if clean_path not set
# This is critical for correct context detection with path-based routing # This is critical for correct context detection with path-based routing
path = getattr(request.state, 'clean_path', request.url.path) path = getattr(request.state, "clean_path", request.url.path)
host = request.headers.get("host", "") host = request.headers.get("host", "")
@@ -71,10 +73,10 @@ class ContextManager:
f"[CONTEXT] Detecting context", f"[CONTEXT] Detecting context",
extra={ extra={
"original_path": request.url.path, "original_path": request.url.path,
"clean_path": getattr(request.state, 'clean_path', 'NOT SET'), "clean_path": getattr(request.state, "clean_path", "NOT SET"),
"path_to_check": path, "path_to_check": path,
"host": host, "host": host,
} },
) )
# 1. API context (highest priority) # 1. API context (highest priority)
@@ -84,24 +86,30 @@ class ContextManager:
# 2. Admin context # 2. Admin context
if ContextManager._is_admin_context(request, host, path): if ContextManager._is_admin_context(request, host, path):
logger.debug("[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host}) logger.debug(
"[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host}
)
return RequestContext.ADMIN return RequestContext.ADMIN
# 3. Vendor Dashboard context (vendor management area) # 3. Vendor Dashboard context (vendor management area)
# Check both clean_path and original path for vendor dashboard # Check both clean_path and original path for vendor dashboard
original_path = request.url.path original_path = request.url.path
if ContextManager._is_vendor_dashboard_context(path) or \ if ContextManager._is_vendor_dashboard_context(
ContextManager._is_vendor_dashboard_context(original_path): path
logger.debug("[CONTEXT] Detected as VENDOR_DASHBOARD", extra={"path": path, "original_path": original_path}) ) or ContextManager._is_vendor_dashboard_context(original_path):
logger.debug(
"[CONTEXT] Detected as VENDOR_DASHBOARD",
extra={"path": path, "original_path": original_path},
)
return RequestContext.VENDOR_DASHBOARD return RequestContext.VENDOR_DASHBOARD
# 4. Shop context (vendor storefront) # 4. Shop context (vendor storefront)
# Check if vendor context exists (set by vendor_context_middleware) # Check if vendor context exists (set by vendor_context_middleware)
if hasattr(request.state, 'vendor') and request.state.vendor: if hasattr(request.state, "vendor") and request.state.vendor:
# If we have a vendor and it's not admin or vendor dashboard, it's shop # If we have a vendor and it's not admin or vendor dashboard, it's shop
logger.debug( logger.debug(
"[CONTEXT] Detected as SHOP (has vendor context)", "[CONTEXT] Detected as SHOP (has vendor context)",
extra={"vendor": request.state.vendor.name} extra={"vendor": request.state.vendor.name},
) )
return RequestContext.SHOP return RequestContext.SHOP
@@ -173,11 +181,12 @@ class ContextMiddleware(BaseHTTPMiddleware):
f"[CONTEXT_MIDDLEWARE] Context detected: {context_type.value}", f"[CONTEXT_MIDDLEWARE] Context detected: {context_type.value}",
extra={ extra={
"path": request.url.path, "path": request.url.path,
"clean_path": getattr(request.state, 'clean_path', 'NOT SET'), "clean_path": getattr(request.state, "clean_path", "NOT SET"),
"host": request.headers.get("host", ""), "host": request.headers.get("host", ""),
"context": context_type.value, "context": context_type.value,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None, "has_vendor": hasattr(request.state, "vendor")
} and request.state.vendor is not None,
},
) )
# Continue processing # Continue processing

Some files were not shown because too many files have changed in this diff Show More