style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -176,19 +176,19 @@ test-inventory:
format:
@echo "Running black..."
$(PYTHON) -m black . --exclude venv
$(PYTHON) -m black . --exclude '/(\.)?venv/'
@echo "Running isort..."
$(PYTHON) -m isort . --skip venv
$(PYTHON) -m isort . --skip venv --skip .venv
lint:
@echo "Running linting..."
$(PYTHON) -m ruff check . --exclude venv
$(PYTHON) -m mypy . --ignore-missing-imports --exclude venv
$(PYTHON) -m ruff check . --exclude venv --exclude .venv
$(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*'
lint-flake8:
@echo "Running linting..."
$(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,__pycache__,.git
$(PYTHON) -m mypy . --ignore-missing-imports --exclude venv
$(PYTHON) -m flake8 . --max-line-length=120 --extend-ignore=E203,W503,I201,I100 --exclude=venv,.venv,__pycache__,.git
$(PYTHON) -m mypy . --ignore-missing-imports --exclude '.*(\.)?venv.*'
check: format lint

View File

@@ -15,6 +15,7 @@ import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
# Add your project directory to the Python path
@@ -39,13 +40,9 @@ print("=" * 70)
# ADMIN MODELS
# ----------------------------------------------------------------------------
try:
from models.database.admin import (
AdminAuditLog,
AdminNotification,
AdminSetting,
PlatformAlert,
AdminSession
)
from models.database.admin import (AdminAuditLog, AdminNotification,
AdminSession, AdminSetting,
PlatformAlert)
print(" ✓ Admin models imported (5 models)")
print(" - AdminAuditLog")
@@ -70,7 +67,7 @@ except ImportError as e:
# VENDOR MODELS
# ----------------------------------------------------------------------------
try:
from models.database.vendor import Vendor, VendorUser, Role
from models.database.vendor import Role, Vendor, VendorUser
print(" ✓ Vendor models imported (3 models)")
print(" - Vendor")
@@ -248,10 +245,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()

File diff suppressed because it is too large Load Diff

View File

@@ -5,59 +5,72 @@ Revises: fef1d20ce8b4
Create Date: 2025-11-22 15:16:13.213613
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '72aa309d4007'
down_revision: Union[str, None] = 'fef1d20ce8b4'
revision: str = "72aa309d4007"
down_revision: Union[str, None] = "fef1d20ce8b4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('content_pages',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('vendor_id', sa.Integer(), nullable=True),
sa.Column('slug', sa.String(length=100), nullable=False),
sa.Column('title', sa.String(length=200), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('content_format', sa.String(length=20), nullable=True),
sa.Column('meta_description', sa.String(length=300), nullable=True),
sa.Column('meta_keywords', sa.String(length=300), nullable=True),
sa.Column('is_published', sa.Boolean(), nullable=False),
sa.Column('published_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('display_order', sa.Integer(), nullable=True),
sa.Column('show_in_footer', sa.Boolean(), nullable=True),
sa.Column('show_in_header', sa.Boolean(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['updated_by'], ['users.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('vendor_id', 'slug', name='uq_vendor_slug')
op.create_table(
"content_pages",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("vendor_id", sa.Integer(), nullable=True),
sa.Column("slug", sa.String(length=100), nullable=False),
sa.Column("title", sa.String(length=200), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("content_format", sa.String(length=20), nullable=True),
sa.Column("meta_description", sa.String(length=300), nullable=True),
sa.Column("meta_keywords", sa.String(length=300), nullable=True),
sa.Column("is_published", sa.Boolean(), nullable=False),
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("display_order", sa.Integer(), nullable=True),
sa.Column("show_in_footer", sa.Boolean(), nullable=True),
sa.Column("show_in_header", sa.Boolean(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_by", sa.Integer(), nullable=True),
sa.Column("updated_by", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["created_by"], ["users.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(["updated_by"], ["users.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(["vendor_id"], ["vendors.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("vendor_id", "slug", name="uq_vendor_slug"),
)
op.create_index(
"idx_slug_published", "content_pages", ["slug", "is_published"], unique=False
)
op.create_index(
"idx_vendor_published",
"content_pages",
["vendor_id", "is_published"],
unique=False,
)
op.create_index(op.f("ix_content_pages_id"), "content_pages", ["id"], unique=False)
op.create_index(
op.f("ix_content_pages_slug"), "content_pages", ["slug"], unique=False
)
op.create_index(
op.f("ix_content_pages_vendor_id"), "content_pages", ["vendor_id"], unique=False
)
op.create_index('idx_slug_published', 'content_pages', ['slug', 'is_published'], unique=False)
op.create_index('idx_vendor_published', 'content_pages', ['vendor_id', 'is_published'], unique=False)
op.create_index(op.f('ix_content_pages_id'), 'content_pages', ['id'], unique=False)
op.create_index(op.f('ix_content_pages_slug'), 'content_pages', ['slug'], unique=False)
op.create_index(op.f('ix_content_pages_vendor_id'), 'content_pages', ['vendor_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_content_pages_vendor_id'), table_name='content_pages')
op.drop_index(op.f('ix_content_pages_slug'), table_name='content_pages')
op.drop_index(op.f('ix_content_pages_id'), table_name='content_pages')
op.drop_index('idx_vendor_published', table_name='content_pages')
op.drop_index('idx_slug_published', table_name='content_pages')
op.drop_table('content_pages')
op.drop_index(op.f("ix_content_pages_vendor_id"), table_name="content_pages")
op.drop_index(op.f("ix_content_pages_slug"), table_name="content_pages")
op.drop_index(op.f("ix_content_pages_id"), table_name="content_pages")
op.drop_index("idx_vendor_published", table_name="content_pages")
op.drop_index("idx_slug_published", table_name="content_pages")
op.drop_table("content_pages")
# ### end Alembic commands ###

View File

@@ -5,15 +5,17 @@ Revises: a2064e1dfcd4
Create Date: 2025-11-28 09:21:16.545203
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '7a7ce92593d5'
down_revision: Union[str, None] = 'a2064e1dfcd4'
revision: str = "7a7ce92593d5"
down_revision: Union[str, None] = "a2064e1dfcd4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -21,127 +23,269 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create architecture_scans table
op.create_table(
'architecture_scans',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
sa.Column('total_files', sa.Integer(), nullable=True),
sa.Column('total_violations', sa.Integer(), nullable=True),
sa.Column('errors', sa.Integer(), nullable=True),
sa.Column('warnings', sa.Integer(), nullable=True),
sa.Column('duration_seconds', sa.Float(), nullable=True),
sa.Column('triggered_by', sa.String(length=100), nullable=True),
sa.Column('git_commit_hash', sa.String(length=40), nullable=True),
sa.PrimaryKeyConstraint('id')
"architecture_scans",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"timestamp",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.Column("total_files", sa.Integer(), nullable=True),
sa.Column("total_violations", sa.Integer(), nullable=True),
sa.Column("errors", sa.Integer(), nullable=True),
sa.Column("warnings", sa.Integer(), nullable=True),
sa.Column("duration_seconds", sa.Float(), nullable=True),
sa.Column("triggered_by", sa.String(length=100), nullable=True),
sa.Column("git_commit_hash", sa.String(length=40), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_architecture_scans_id"), "architecture_scans", ["id"], unique=False
)
op.create_index(
op.f("ix_architecture_scans_timestamp"),
"architecture_scans",
["timestamp"],
unique=False,
)
op.create_index(op.f('ix_architecture_scans_id'), 'architecture_scans', ['id'], unique=False)
op.create_index(op.f('ix_architecture_scans_timestamp'), 'architecture_scans', ['timestamp'], unique=False)
# Create architecture_rules table
op.create_table(
'architecture_rules',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('rule_id', sa.String(length=20), nullable=False),
sa.Column('category', sa.String(length=50), nullable=False),
sa.Column('name', sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('severity', sa.String(length=10), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False, server_default='1'),
sa.Column('custom_config', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('rule_id')
"architecture_rules",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("rule_id", sa.String(length=20), nullable=False),
sa.Column("category", sa.String(length=50), nullable=False),
sa.Column("name", sa.String(length=200), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("severity", sa.String(length=10), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="1"),
sa.Column("custom_config", sa.JSON(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("rule_id"),
)
op.create_index(
op.f("ix_architecture_rules_id"), "architecture_rules", ["id"], unique=False
)
op.create_index(
op.f("ix_architecture_rules_rule_id"),
"architecture_rules",
["rule_id"],
unique=True,
)
op.create_index(op.f('ix_architecture_rules_id'), 'architecture_rules', ['id'], unique=False)
op.create_index(op.f('ix_architecture_rules_rule_id'), 'architecture_rules', ['rule_id'], unique=True)
# Create architecture_violations table
op.create_table(
'architecture_violations',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('scan_id', sa.Integer(), nullable=False),
sa.Column('rule_id', sa.String(length=20), nullable=False),
sa.Column('rule_name', sa.String(length=200), nullable=False),
sa.Column('severity', sa.String(length=10), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('line_number', sa.Integer(), nullable=False),
sa.Column('message', sa.Text(), nullable=False),
sa.Column('context', sa.Text(), nullable=True),
sa.Column('suggestion', sa.Text(), nullable=True),
sa.Column('status', sa.String(length=20), server_default='open', nullable=True),
sa.Column('assigned_to', sa.Integer(), nullable=True),
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('resolved_by', sa.Integer(), nullable=True),
sa.Column('resolution_note', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
sa.ForeignKeyConstraint(['assigned_to'], ['users.id'], ),
sa.ForeignKeyConstraint(['resolved_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['scan_id'], ['architecture_scans.id'], ),
sa.PrimaryKeyConstraint('id')
"architecture_violations",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("scan_id", sa.Integer(), nullable=False),
sa.Column("rule_id", sa.String(length=20), nullable=False),
sa.Column("rule_name", sa.String(length=200), nullable=False),
sa.Column("severity", sa.String(length=10), nullable=False),
sa.Column("file_path", sa.String(length=500), nullable=False),
sa.Column("line_number", sa.Integer(), nullable=False),
sa.Column("message", sa.Text(), nullable=False),
sa.Column("context", sa.Text(), nullable=True),
sa.Column("suggestion", sa.Text(), nullable=True),
sa.Column("status", sa.String(length=20), server_default="open", nullable=True),
sa.Column("assigned_to", sa.Integer(), nullable=True),
sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("resolved_by", sa.Integer(), nullable=True),
sa.Column("resolution_note", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.ForeignKeyConstraint(
["assigned_to"],
["users.id"],
),
sa.ForeignKeyConstraint(
["resolved_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["scan_id"],
["architecture_scans.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_architecture_violations_file_path"),
"architecture_violations",
["file_path"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_id"),
"architecture_violations",
["id"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_rule_id"),
"architecture_violations",
["rule_id"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_scan_id"),
"architecture_violations",
["scan_id"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_severity"),
"architecture_violations",
["severity"],
unique=False,
)
op.create_index(
op.f("ix_architecture_violations_status"),
"architecture_violations",
["status"],
unique=False,
)
op.create_index(op.f('ix_architecture_violations_file_path'), 'architecture_violations', ['file_path'], unique=False)
op.create_index(op.f('ix_architecture_violations_id'), 'architecture_violations', ['id'], unique=False)
op.create_index(op.f('ix_architecture_violations_rule_id'), 'architecture_violations', ['rule_id'], unique=False)
op.create_index(op.f('ix_architecture_violations_scan_id'), 'architecture_violations', ['scan_id'], unique=False)
op.create_index(op.f('ix_architecture_violations_severity'), 'architecture_violations', ['severity'], unique=False)
op.create_index(op.f('ix_architecture_violations_status'), 'architecture_violations', ['status'], unique=False)
# Create violation_assignments table
op.create_table(
'violation_assignments',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('violation_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('assigned_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
sa.Column('assigned_by', sa.Integer(), nullable=True),
sa.Column('due_date', sa.DateTime(timezone=True), nullable=True),
sa.Column('priority', sa.String(length=10), server_default='medium', nullable=True),
sa.ForeignKeyConstraint(['assigned_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ),
sa.PrimaryKeyConstraint('id')
"violation_assignments",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("violation_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column(
"assigned_at",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.Column("assigned_by", sa.Integer(), nullable=True),
sa.Column("due_date", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"priority", sa.String(length=10), server_default="medium", nullable=True
),
sa.ForeignKeyConstraint(
["assigned_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.ForeignKeyConstraint(
["violation_id"],
["architecture_violations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_violation_assignments_id"),
"violation_assignments",
["id"],
unique=False,
)
op.create_index(
op.f("ix_violation_assignments_violation_id"),
"violation_assignments",
["violation_id"],
unique=False,
)
op.create_index(op.f('ix_violation_assignments_id'), 'violation_assignments', ['id'], unique=False)
op.create_index(op.f('ix_violation_assignments_violation_id'), 'violation_assignments', ['violation_id'], unique=False)
# Create violation_comments table
op.create_table(
'violation_comments',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('violation_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('comment', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text("(datetime('now'))"), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['violation_id'], ['architecture_violations.id'], ),
sa.PrimaryKeyConstraint('id')
"violation_comments",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("violation_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("comment", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("(datetime('now'))"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.ForeignKeyConstraint(
["violation_id"],
["architecture_violations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_violation_comments_id"), "violation_comments", ["id"], unique=False
)
op.create_index(
op.f("ix_violation_comments_violation_id"),
"violation_comments",
["violation_id"],
unique=False,
)
op.create_index(op.f('ix_violation_comments_id'), 'violation_comments', ['id'], unique=False)
op.create_index(op.f('ix_violation_comments_violation_id'), 'violation_comments', ['violation_id'], unique=False)
def downgrade() -> None:
# Drop tables in reverse order (to respect foreign key constraints)
op.drop_index(op.f('ix_violation_comments_violation_id'), table_name='violation_comments')
op.drop_index(op.f('ix_violation_comments_id'), table_name='violation_comments')
op.drop_table('violation_comments')
op.drop_index(
op.f("ix_violation_comments_violation_id"), table_name="violation_comments"
)
op.drop_index(op.f("ix_violation_comments_id"), table_name="violation_comments")
op.drop_table("violation_comments")
op.drop_index(op.f('ix_violation_assignments_violation_id'), table_name='violation_assignments')
op.drop_index(op.f('ix_violation_assignments_id'), table_name='violation_assignments')
op.drop_table('violation_assignments')
op.drop_index(
op.f("ix_violation_assignments_violation_id"),
table_name="violation_assignments",
)
op.drop_index(
op.f("ix_violation_assignments_id"), table_name="violation_assignments"
)
op.drop_table("violation_assignments")
op.drop_index(op.f('ix_architecture_violations_status'), table_name='architecture_violations')
op.drop_index(op.f('ix_architecture_violations_severity'), table_name='architecture_violations')
op.drop_index(op.f('ix_architecture_violations_scan_id'), table_name='architecture_violations')
op.drop_index(op.f('ix_architecture_violations_rule_id'), table_name='architecture_violations')
op.drop_index(op.f('ix_architecture_violations_id'), table_name='architecture_violations')
op.drop_index(op.f('ix_architecture_violations_file_path'), table_name='architecture_violations')
op.drop_table('architecture_violations')
op.drop_index(
op.f("ix_architecture_violations_status"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_severity"),
table_name="architecture_violations",
)
op.drop_index(
op.f("ix_architecture_violations_scan_id"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_rule_id"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_id"), table_name="architecture_violations"
)
op.drop_index(
op.f("ix_architecture_violations_file_path"),
table_name="architecture_violations",
)
op.drop_table("architecture_violations")
op.drop_index(op.f('ix_architecture_rules_rule_id'), table_name='architecture_rules')
op.drop_index(op.f('ix_architecture_rules_id'), table_name='architecture_rules')
op.drop_table('architecture_rules')
op.drop_index(
op.f("ix_architecture_rules_rule_id"), table_name="architecture_rules"
)
op.drop_index(op.f("ix_architecture_rules_id"), table_name="architecture_rules")
op.drop_table("architecture_rules")
op.drop_index(op.f('ix_architecture_scans_timestamp'), table_name='architecture_scans')
op.drop_index(op.f('ix_architecture_scans_id'), table_name='architecture_scans')
op.drop_table('architecture_scans')
op.drop_index(
op.f("ix_architecture_scans_timestamp"), table_name="architecture_scans"
)
op.drop_index(op.f("ix_architecture_scans_id"), table_name="architecture_scans")
op.drop_table("architecture_scans")

View File

@@ -5,15 +5,16 @@ Revises: f68d8da5315a
Create Date: 2025-11-23 19:52:40.509538
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'a2064e1dfcd4'
down_revision: Union[str, None] = 'f68d8da5315a'
revision: str = "a2064e1dfcd4"
down_revision: Union[str, None] = "f68d8da5315a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -21,34 +22,46 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create cart_items table
op.create_table(
'cart_items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('vendor_id', sa.Integer(), nullable=False),
sa.Column('product_id', sa.Integer(), nullable=False),
sa.Column('session_id', sa.String(length=255), nullable=False),
sa.Column('quantity', sa.Integer(), nullable=False),
sa.Column('price_at_add', sa.Float(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['product_id'], ['products.id'], ),
sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('vendor_id', 'session_id', 'product_id', name='uq_cart_item')
"cart_items",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("vendor_id", sa.Integer(), nullable=False),
sa.Column("product_id", sa.Integer(), nullable=False),
sa.Column("session_id", sa.String(length=255), nullable=False),
sa.Column("quantity", sa.Integer(), nullable=False),
sa.Column("price_at_add", sa.Float(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["product_id"],
["products.id"],
),
sa.ForeignKeyConstraint(
["vendor_id"],
["vendors.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"vendor_id", "session_id", "product_id", name="uq_cart_item"
),
)
# Create indexes
op.create_index('idx_cart_session', 'cart_items', ['vendor_id', 'session_id'], unique=False)
op.create_index('idx_cart_created', 'cart_items', ['created_at'], unique=False)
op.create_index(op.f('ix_cart_items_id'), 'cart_items', ['id'], unique=False)
op.create_index(op.f('ix_cart_items_session_id'), 'cart_items', ['session_id'], unique=False)
op.create_index(
"idx_cart_session", "cart_items", ["vendor_id", "session_id"], unique=False
)
op.create_index("idx_cart_created", "cart_items", ["created_at"], unique=False)
op.create_index(op.f("ix_cart_items_id"), "cart_items", ["id"], unique=False)
op.create_index(
op.f("ix_cart_items_session_id"), "cart_items", ["session_id"], unique=False
)
def downgrade() -> None:
# Drop indexes
op.drop_index(op.f('ix_cart_items_session_id'), table_name='cart_items')
op.drop_index(op.f('ix_cart_items_id'), table_name='cart_items')
op.drop_index('idx_cart_created', table_name='cart_items')
op.drop_index('idx_cart_session', table_name='cart_items')
op.drop_index(op.f("ix_cart_items_session_id"), table_name="cart_items")
op.drop_index(op.f("ix_cart_items_id"), table_name="cart_items")
op.drop_index("idx_cart_created", table_name="cart_items")
op.drop_index("idx_cart_session", table_name="cart_items")
# Drop table
op.drop_table('cart_items')
op.drop_table("cart_items")

View File

@@ -5,24 +5,30 @@ Revises: 72aa309d4007
Create Date: 2025-11-22 23:51:40.694983
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'f68d8da5315a'
down_revision: Union[str, None] = '72aa309d4007'
revision: str = "f68d8da5315a"
down_revision: Union[str, None] = "72aa309d4007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add template column to content_pages table
op.add_column('content_pages', sa.Column('template', sa.String(length=50), nullable=False, server_default='default'))
op.add_column(
"content_pages",
sa.Column(
"template", sa.String(length=50), nullable=False, server_default="default"
),
)
def downgrade() -> None:
# Remove template column from content_pages table
op.drop_column('content_pages', 'template')
op.drop_column("content_pages", "template")

View File

@@ -6,15 +6,16 @@ Create Date: 2025-11-13 16:51:25.010057
SQLite-compatible version
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'fa7d4d10e358'
down_revision: Union[str, None] = '4951b2e50581'
revision: str = "fa7d4d10e358"
down_revision: Union[str, None] = "4951b2e50581"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -28,9 +29,14 @@ def upgrade():
# ========================================================================
# User table changes
# ========================================================================
with op.batch_alter_table('users', schema=None) as batch_op:
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(
sa.Column('is_email_verified', sa.Boolean(), nullable=False, server_default='false')
sa.Column(
"is_email_verified",
sa.Boolean(),
nullable=False,
server_default="false",
)
)
# Set existing active users as verified
@@ -39,68 +45,65 @@ def upgrade():
# ========================================================================
# VendorUser table changes (requires table recreation for SQLite)
# ========================================================================
with op.batch_alter_table('vendor_users', schema=None) as batch_op:
with op.batch_alter_table("vendor_users", schema=None) as batch_op:
# Add new columns
batch_op.add_column(
sa.Column('user_type', sa.String(length=20), nullable=False, server_default='member')
sa.Column(
"user_type",
sa.String(length=20),
nullable=False,
server_default="member",
)
)
batch_op.add_column(
sa.Column('invitation_token', sa.String(length=100), nullable=True)
sa.Column("invitation_token", sa.String(length=100), nullable=True)
)
batch_op.add_column(
sa.Column('invitation_sent_at', sa.DateTime(), nullable=True)
sa.Column("invitation_sent_at", sa.DateTime(), nullable=True)
)
batch_op.add_column(
sa.Column('invitation_accepted_at', sa.DateTime(), nullable=True)
sa.Column("invitation_accepted_at", sa.DateTime(), nullable=True)
)
# Create index on invitation_token
batch_op.create_index(
'idx_vendor_users_invitation_token',
['invitation_token']
)
batch_op.create_index("idx_vendor_users_invitation_token", ["invitation_token"])
# Modify role_id to be nullable (this recreates the table in SQLite)
batch_op.alter_column(
'role_id',
existing_type=sa.Integer(),
nullable=True
)
batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=True)
# Change is_active default (this recreates the table in SQLite)
batch_op.alter_column(
'is_active',
existing_type=sa.Boolean(),
server_default='false'
"is_active", existing_type=sa.Boolean(), server_default="false"
)
# Set owners correctly (after table modifications)
# SQLite-compatible UPDATE with subquery
op.execute("""
op.execute(
"""
UPDATE vendor_users
SET user_type = 'owner'
WHERE (vendor_id, user_id) IN (
SELECT id, owner_user_id
FROM vendors
)
""")
"""
)
# Set existing owners as active
op.execute("""
op.execute(
"""
UPDATE vendor_users
SET is_active = TRUE
WHERE user_type = 'owner'
""")
"""
)
# ========================================================================
# Role table changes
# ========================================================================
with op.batch_alter_table('roles', schema=None) as batch_op:
with op.batch_alter_table("roles", schema=None) as batch_op:
# Create index on vendor_id and name
batch_op.create_index(
'idx_roles_vendor_name',
['vendor_id', 'name']
)
batch_op.create_index("idx_roles_vendor_name", ["vendor_id", "name"])
# Note: JSONB conversion only for PostgreSQL
# SQLite stores JSON as TEXT by default, no conversion needed
@@ -115,37 +118,31 @@ def downgrade():
# ========================================================================
# Role table changes
# ========================================================================
with op.batch_alter_table('roles', schema=None) as batch_op:
batch_op.drop_index('idx_roles_vendor_name')
with op.batch_alter_table("roles", schema=None) as batch_op:
batch_op.drop_index("idx_roles_vendor_name")
# ========================================================================
# VendorUser table changes
# ========================================================================
with op.batch_alter_table('vendor_users', schema=None) as batch_op:
with op.batch_alter_table("vendor_users", schema=None) as batch_op:
# Revert is_active default
batch_op.alter_column(
'is_active',
existing_type=sa.Boolean(),
server_default='true'
"is_active", existing_type=sa.Boolean(), server_default="true"
)
# Revert role_id to NOT NULL
# Note: This might fail if there are NULL values
batch_op.alter_column(
'role_id',
existing_type=sa.Integer(),
nullable=False
)
batch_op.alter_column("role_id", existing_type=sa.Integer(), nullable=False)
# Drop indexes and columns
batch_op.drop_index('idx_vendor_users_invitation_token')
batch_op.drop_column('invitation_accepted_at')
batch_op.drop_column('invitation_sent_at')
batch_op.drop_column('invitation_token')
batch_op.drop_column('user_type')
batch_op.drop_index("idx_vendor_users_invitation_token")
batch_op.drop_column("invitation_accepted_at")
batch_op.drop_column("invitation_sent_at")
batch_op.drop_column("invitation_token")
batch_op.drop_column("user_type")
# ========================================================================
# User table changes
# ========================================================================
with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.drop_column('is_email_verified')
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column("is_email_verified")

View File

@@ -5,30 +5,43 @@ Revises: fa7d4d10e358
Create Date: 2025-11-22 13:41:18.069674
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'fef1d20ce8b4'
down_revision: Union[str, None] = 'fa7d4d10e358'
revision: str = "fef1d20ce8b4"
down_revision: Union[str, None] = "fa7d4d10e358"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('idx_roles_vendor_name', table_name='roles')
op.drop_index('idx_vendor_users_invitation_token', table_name='vendor_users')
op.create_index(op.f('ix_vendor_users_invitation_token'), 'vendor_users', ['invitation_token'], unique=False)
op.drop_index("idx_roles_vendor_name", table_name="roles")
op.drop_index("idx_vendor_users_invitation_token", table_name="vendor_users")
op.create_index(
op.f("ix_vendor_users_invitation_token"),
"vendor_users",
["invitation_token"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_vendor_users_invitation_token'), table_name='vendor_users')
op.create_index('idx_vendor_users_invitation_token', 'vendor_users', ['invitation_token'], unique=False)
op.create_index('idx_roles_vendor_name', 'roles', ['vendor_id', 'name'], unique=False)
op.drop_index(op.f("ix_vendor_users_invitation_token"), table_name="vendor_users")
op.create_index(
"idx_vendor_users_invitation_token",
"vendor_users",
["invitation_token"],
unique=False,
)
op.create_index(
"idx_roles_vendor_name", "roles", ["vendor_id", "name"], unique=False
)
# ### end Alembic commands ###

View File

@@ -34,22 +34,20 @@ The cookie path restrictions prevent cross-context cookie leakage:
import logging
from typing import Optional
from fastapi import Depends, Request, Cookie
from fastapi import Cookie, Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.exceptions import (AdminRequiredException,
InsufficientPermissionsException,
InvalidTokenException,
UnauthorizedVendorAccessException,
VendorNotFoundException)
from middleware.auth import AuthManager
from middleware.rate_limiter import RateLimiter
from models.database.vendor import Vendor
from models.database.user import User
from app.exceptions import (
AdminRequiredException,
InvalidTokenException,
InsufficientPermissionsException,
VendorNotFoundException,
UnauthorizedVendorAccessException
)
from models.database.vendor import Vendor
# Initialize dependencies
security = HTTPBearer(auto_error=False) # auto_error=False prevents automatic 403
@@ -62,11 +60,12 @@ logger = logging.getLogger(__name__)
# HELPER FUNCTIONS
# ============================================================================
def _get_token_from_request(
credentials: Optional[HTTPAuthorizationCredentials],
cookie_value: Optional[str],
cookie_name: str,
request_path: str
request_path: str,
) -> tuple[Optional[str], Optional[str]]:
"""
Extract token from Authorization header or cookie.
@@ -108,10 +107,7 @@ def _validate_user_token(token: str, db: Session) -> User:
Raises:
InvalidTokenException: If token is invalid
"""
mock_credentials = HTTPAuthorizationCredentials(
scheme="Bearer",
credentials=token
)
mock_credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
return auth_manager.get_current_user(db, mock_credentials)
@@ -119,6 +115,7 @@ def _validate_user_token(token: str, db: Session) -> User:
# ADMIN AUTHENTICATION
# ============================================================================
def get_current_admin_from_cookie_or_header(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -148,10 +145,7 @@ def get_current_admin_from_cookie_or_header(
AdminRequiredException: If user is not admin
"""
token, source = _get_token_from_request(
credentials,
admin_token,
"admin_token",
str(request.url.path)
credentials, admin_token, "admin_token", str(request.url.path)
)
if not token:
@@ -208,6 +202,7 @@ def get_current_admin_api(
# VENDOR AUTHENTICATION
# ============================================================================
def get_current_vendor_from_cookie_or_header(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -237,10 +232,7 @@ def get_current_vendor_from_cookie_or_header(
InsufficientPermissionsException: If user is not vendor or is admin
"""
token, source = _get_token_from_request(
credentials,
vendor_token,
"vendor_token",
str(request.url.path)
credentials, vendor_token, "vendor_token", str(request.url.path)
)
if not token:
@@ -310,6 +302,7 @@ def get_current_vendor_api(
# CUSTOMER AUTHENTICATION (SHOP)
# ============================================================================
def get_current_customer_from_cookie_or_header(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -338,15 +331,14 @@ def get_current_customer_from_cookie_or_header(
Raises:
InvalidTokenException: If no token or invalid token
"""
from models.database.customer import Customer
from jose import jwt, JWTError
from datetime import datetime, timezone
from jose import JWTError, jwt
from models.database.customer import Customer
token, source = _get_token_from_request(
credentials,
customer_token,
"customer_token",
str(request.url.path)
credentials, customer_token, "customer_token", str(request.url.path)
)
if not token:
@@ -356,9 +348,7 @@ def get_current_customer_from_cookie_or_header(
# Decode and validate customer JWT token
try:
payload = jwt.decode(
token,
auth_manager.secret_key,
algorithms=[auth_manager.algorithm]
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
)
# Verify this is a customer token
@@ -375,7 +365,9 @@ def get_current_customer_from_cookie_or_header(
# Verify token hasn't expired
exp = payload.get("exp")
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(
timezone.utc
):
logger.warning(f"Expired customer token for customer_id={customer_id}")
raise InvalidTokenException("Token has expired")
@@ -445,9 +437,10 @@ def get_current_customer_api(
# GENERIC AUTHENTICATION (for mixed-use endpoints)
# ============================================================================
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
) -> User:
"""
Get current authenticated user from Authorization header only.
@@ -475,6 +468,7 @@ def get_current_user(
# VENDOR OWNERSHIP VERIFICATION
# ============================================================================
def get_user_vendor(
vendor_code: str,
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
@@ -500,9 +494,7 @@ def get_user_vendor(
VendorNotFoundException: If vendor doesn't exist
UnauthorizedVendorAccessException: If user doesn't have access
"""
vendor = db.query(Vendor).filter(
Vendor.vendor_code == vendor_code.upper()
).first()
vendor = db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
if not vendor:
raise VendorNotFoundException(vendor_code)
@@ -517,10 +509,12 @@ def get_user_vendor(
# User doesn't have access to this vendor
raise UnauthorizedVendorAccessException(vendor_code, current_user.id)
# ============================================================================
# PERMISSIONS CHECKING
# ============================================================================
def require_vendor_permission(permission: str):
"""
Dependency factory to require a specific vendor permission.
@@ -610,8 +604,7 @@ def require_any_vendor_permission(*permissions: str):
# Check if user has ANY of the required permissions
has_permission = any(
current_user.has_vendor_permission(vendor.id, perm)
for perm in permissions
current_user.has_vendor_permission(vendor.id, perm) for perm in permissions
)
if not has_permission:
@@ -651,7 +644,8 @@ def require_all_vendor_permissions(*permissions: str):
# Check if user has ALL required permissions
missing_permissions = [
perm for perm in permissions
perm
for perm in permissions
if not current_user.has_vendor_permission(vendor.id, perm)
]
@@ -682,6 +676,7 @@ def get_user_permissions(
# If owner, return all permissions
if current_user.is_owner_of(vendor.id):
from app.core.permissions import VendorPermissions
return [p.value for p in VendorPermissions]
# Get permissions from vendor membership
@@ -696,6 +691,7 @@ def get_user_permissions(
# OPTIONAL AUTHENTICATION (For Login Page Redirects)
# ============================================================================
def get_current_admin_optional(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
@@ -723,10 +719,7 @@ def get_current_admin_optional(
None: If no token, invalid token, or user is not admin
"""
token, source = _get_token_from_request(
credentials,
admin_token,
"admin_token",
str(request.url.path)
credentials, admin_token, "admin_token", str(request.url.path)
)
if not token:
@@ -773,10 +766,7 @@ def get_current_vendor_optional(
None: If no token, invalid token, or user is not vendor
"""
token, source = _get_token_from_request(
credentials,
vendor_token,
"vendor_token",
str(request.url.path)
credentials, vendor_token, "vendor_token", str(request.url.path)
)
if not token:
@@ -823,10 +813,7 @@ def get_current_customer_optional(
None: If no token, invalid token, or user is not customer
"""
token, source = _get_token_from_request(
credentials,
customer_token,
"customer_token",
str(request.url.path)
credentials, customer_token, "customer_token", str(request.url.path)
)
if not token:
@@ -844,5 +831,3 @@ def get_current_customer_optional(
pass
return None

View File

@@ -9,7 +9,8 @@ This module provides:
"""
from fastapi import APIRouter
from app.api.v1 import admin, vendor, shop
from app.api.v1 import admin, shop, vendor
api_router = APIRouter()
@@ -18,31 +19,18 @@ api_router = APIRouter()
# Prefix: /api/v1/admin
# ============================================================================
api_router.include_router(
admin.router,
prefix="/v1/admin",
tags=["admin"]
)
api_router.include_router(admin.router, prefix="/v1/admin", tags=["admin"])
# ============================================================================
# VENDOR ROUTES (Vendor-scoped operations)
# Prefix: /api/v1/vendor
# ============================================================================
api_router.include_router(
vendor.router,
prefix="/v1/vendor",
tags=["vendor"]
)
api_router.include_router(vendor.router, prefix="/v1/vendor", tags=["vendor"])
# ============================================================================
# SHOP ROUTES (Public shop frontend API)
# Prefix: /api/v1/shop
# ============================================================================
api_router.include_router(
shop.router,
prefix="/v1/shop",
tags=["shop"]
)
api_router.include_router(shop.router, prefix="/v1/shop", tags=["shop"])

View File

@@ -3,6 +3,6 @@
API Version 1 - All endpoints
"""
from . import admin, vendor, shop
from . import admin, shop, vendor
__all__ = ["admin", "vendor", "shop"]

View File

@@ -24,21 +24,9 @@ IMPORTANT:
from fastapi import APIRouter
# Import all admin routers
from . import (
auth,
vendors,
vendor_domains,
vendor_themes,
users,
dashboard,
marketplace,
monitoring,
audit,
settings,
notifications,
content_pages,
code_quality
)
from . import (audit, auth, code_quality, content_pages, dashboard,
marketplace, monitoring, notifications, settings, users,
vendor_domains, vendor_themes, vendors)
# Create admin router
router = APIRouter()
@@ -66,7 +54,9 @@ router.include_router(vendor_domains.router, tags=["admin-vendor-domains"])
router.include_router(vendor_themes.router, tags=["admin-vendor-themes"])
# Include content pages management endpoints
router.include_router(content_pages.router, prefix="/content-pages", tags=["admin-content-pages"])
router.include_router(
content_pages.router, prefix="/content-pages", tags=["admin-content-pages"]
)
# ============================================================================
@@ -115,7 +105,9 @@ router.include_router(notifications.router, tags=["admin-notifications"])
# ============================================================================
# Include code quality and architecture validation endpoints
router.include_router(code_quality.router, prefix="/code-quality", tags=["admin-code-quality"])
router.include_router(
code_quality.router, prefix="/code-quality", tags=["admin-code-quality"]
)
# Export the router
__all__ = ["router"]

View File

@@ -9,8 +9,8 @@ Provides endpoints for:
"""
import logging
from typing import Optional
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -18,12 +18,10 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_audit_service import admin_audit_service
from models.schema.admin import (
AdminAuditLogResponse,
AdminAuditLogFilters,
AdminAuditLogListResponse
)
from models.database.user import User
from models.schema.admin import (AdminAuditLogFilters,
AdminAuditLogListResponse,
AdminAuditLogResponse)
router = APIRouter(prefix="/audit")
logger = logging.getLogger(__name__)
@@ -54,7 +52,7 @@ def get_audit_logs(
date_from=date_from,
date_to=date_to,
skip=skip,
limit=limit
limit=limit,
)
logs = admin_audit_service.get_audit_logs(db, filters)
@@ -62,12 +60,7 @@ def get_audit_logs(
logger.info(f"Admin {current_admin.username} retrieved {len(logs)} audit logs")
return AdminAuditLogListResponse(
logs=logs,
total=total,
skip=skip,
limit=limit
)
return AdminAuditLogListResponse(logs=logs, total=total, skip=skip, limit=limit)
@router.get("/logs/recent", response_model=list[AdminAuditLogResponse])
@@ -89,9 +82,7 @@ def get_my_actions(
):
"""Get audit logs for current admin's actions."""
return admin_audit_service.get_recent_actions_by_admin(
db=db,
admin_user_id=current_admin.id,
limit=limit
db=db, admin_user_id=current_admin.id, limit=limit
)
@@ -109,8 +100,5 @@ def get_actions_by_target(
Useful for tracking the history of a specific vendor, user, or entity.
"""
return admin_audit_service.get_actions_by_target(
db=db,
target_type=target_type,
target_id=target_id,
limit=limit
db=db, target_type=target_type, target_id=target_id, limit=limit
)

View File

@@ -10,16 +10,17 @@ This prevents admin cookies from being sent to vendor routes.
"""
import logging
from fastapi import APIRouter, Depends, Response
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.core.environment import should_use_secure_cookies
from app.services.auth_service import auth_service
from app.exceptions import InvalidCredentialsException
from models.schema.auth import LoginResponse, UserLogin, UserResponse
from app.services.auth_service import auth_service
from models.database.user import User
from app.api.deps import get_current_admin_api
from models.schema.auth import LoginResponse, UserLogin, UserResponse
router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__)
@@ -27,9 +28,7 @@ logger = logging.getLogger(__name__)
@router.post("/login", response_model=LoginResponse)
def admin_login(
user_credentials: UserLogin,
response: Response,
db: Session = Depends(get_db)
user_credentials: UserLogin, response: Response, db: Session = Depends(get_db)
):
"""
Admin login endpoint.
@@ -49,7 +48,9 @@ def admin_login(
# Verify user is admin
if login_result["user"].role != "admin":
logger.warning(f"Non-admin user attempted admin login: {user_credentials.email_or_username}")
logger.warning(
f"Non-admin user attempted admin login: {user_credentials.email_or_username}"
)
raise InvalidCredentialsException("Admin access required")
logger.info(f"Admin login successful: {login_result['user'].username}")

View File

@@ -3,25 +3,27 @@ Code Quality API Endpoints
RESTful API for architecture validation and violation management
"""
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.code_quality_service import code_quality_service
from app.api.deps import get_current_admin_api
from models.database.user import User
router = APIRouter()
# Pydantic Models for API
class ScanResponse(BaseModel):
"""Response model for a scan"""
id: int
timestamp: str
total_files: int
@@ -38,6 +40,7 @@ class ScanResponse(BaseModel):
class ViolationResponse(BaseModel):
"""Response model for a violation"""
id: int
scan_id: int
rule_id: str
@@ -61,6 +64,7 @@ class ViolationResponse(BaseModel):
class ViolationListResponse(BaseModel):
"""Response model for paginated violations list"""
violations: list[ViolationResponse]
total: int
page: int
@@ -70,34 +74,42 @@ class ViolationListResponse(BaseModel):
class ViolationDetailResponse(ViolationResponse):
"""Response model for single violation with relationships"""
assignments: list = []
comments: list = []
class AssignViolationRequest(BaseModel):
"""Request model for assigning a violation"""
user_id: int = Field(..., description="User ID to assign to")
due_date: Optional[datetime] = Field(None, description="Due date for resolution")
priority: str = Field("medium", description="Priority level (low, medium, high, critical)")
priority: str = Field(
"medium", description="Priority level (low, medium, high, critical)"
)
class ResolveViolationRequest(BaseModel):
"""Request model for resolving a violation"""
resolution_note: str = Field(..., description="Note about the resolution")
class IgnoreViolationRequest(BaseModel):
"""Request model for ignoring a violation"""
reason: str = Field(..., description="Reason for ignoring")
class AddCommentRequest(BaseModel):
"""Request model for adding a comment"""
comment: str = Field(..., min_length=1, description="Comment text")
class DashboardStatsResponse(BaseModel):
"""Response model for dashboard statistics"""
total_violations: int
errors: int
warnings: int
@@ -116,10 +128,10 @@ class DashboardStatsResponse(BaseModel):
# API Endpoints
@router.post("/scan", response_model=ScanResponse)
async def trigger_scan(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api)
):
"""
Trigger a new architecture scan
@@ -127,7 +139,9 @@ async def trigger_scan(
Requires authentication. Runs the validator script and stores results.
"""
try:
scan = code_quality_service.run_scan(db, triggered_by=f"manual:{current_user.username}")
scan = code_quality_service.run_scan(
db, triggered_by=f"manual:{current_user.username}"
)
return ScanResponse(
id=scan.id,
@@ -138,7 +152,7 @@ async def trigger_scan(
warnings=scan.warnings,
duration_seconds=scan.duration_seconds,
triggered_by=scan.triggered_by,
git_commit_hash=scan.git_commit_hash
git_commit_hash=scan.git_commit_hash,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}")
@@ -148,7 +162,7 @@ async def trigger_scan(
async def list_scans(
limit: int = Query(30, ge=1, le=100, description="Number of scans to return"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Get scan history
@@ -167,7 +181,7 @@ async def list_scans(
warnings=scan.warnings,
duration_seconds=scan.duration_seconds,
triggered_by=scan.triggered_by,
git_commit_hash=scan.git_commit_hash
git_commit_hash=scan.git_commit_hash,
)
for scan in scans
]
@@ -175,15 +189,23 @@ async def list_scans(
@router.get("/violations", response_model=ViolationListResponse)
async def list_violations(
scan_id: Optional[int] = Query(None, description="Filter by scan ID (defaults to latest)"),
severity: Optional[str] = Query(None, description="Filter by severity (error, warning)"),
status: Optional[str] = Query(None, description="Filter by status (open, assigned, resolved, ignored)"),
scan_id: Optional[int] = Query(
None, description="Filter by scan ID (defaults to latest)"
),
severity: Optional[str] = Query(
None, description="Filter by severity (error, warning)"
),
status: Optional[str] = Query(
None, description="Filter by status (open, assigned, resolved, ignored)"
),
rule_id: Optional[str] = Query(None, description="Filter by rule ID"),
file_path: Optional[str] = Query(None, description="Filter by file path (partial match)"),
file_path: Optional[str] = Query(
None, description="Filter by file path (partial match)"
),
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(50, ge=1, le=200, description="Items per page"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Get violations with filtering and pagination
@@ -200,7 +222,7 @@ async def list_violations(
rule_id=rule_id,
file_path=file_path,
limit=page_size,
offset=offset
offset=offset,
)
total_pages = (total + page_size - 1) // page_size
@@ -223,14 +245,14 @@ async def list_violations(
resolved_at=v.resolved_at.isoformat() if v.resolved_at else None,
resolved_by=v.resolved_by,
resolution_note=v.resolution_note,
created_at=v.created_at.isoformat()
created_at=v.created_at.isoformat(),
)
for v in violations
],
total=total,
page=page,
page_size=page_size,
total_pages=total_pages
total_pages=total_pages,
)
@@ -238,7 +260,7 @@ async def list_violations(
async def get_violation(
violation_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Get single violation with details
@@ -253,12 +275,12 @@ async def get_violation(
# Format assignments
assignments = [
{
'id': a.id,
'user_id': a.user_id,
'assigned_at': a.assigned_at.isoformat(),
'assigned_by': a.assigned_by,
'due_date': a.due_date.isoformat() if a.due_date else None,
'priority': a.priority
"id": a.id,
"user_id": a.user_id,
"assigned_at": a.assigned_at.isoformat(),
"assigned_by": a.assigned_by,
"due_date": a.due_date.isoformat() if a.due_date else None,
"priority": a.priority,
}
for a in violation.assignments
]
@@ -266,10 +288,10 @@ async def get_violation(
# Format comments
comments = [
{
'id': c.id,
'user_id': c.user_id,
'comment': c.comment,
'created_at': c.created_at.isoformat()
"id": c.id,
"user_id": c.user_id,
"comment": c.comment,
"created_at": c.created_at.isoformat(),
}
for c in violation.comments
]
@@ -287,12 +309,14 @@ async def get_violation(
suggestion=violation.suggestion,
status=violation.status,
assigned_to=violation.assigned_to,
resolved_at=violation.resolved_at.isoformat() if violation.resolved_at else None,
resolved_at=(
violation.resolved_at.isoformat() if violation.resolved_at else None
),
resolved_by=violation.resolved_by,
resolution_note=violation.resolution_note,
created_at=violation.created_at.isoformat(),
assignments=assignments,
comments=comments
comments=comments,
)
@@ -301,7 +325,7 @@ async def assign_violation(
violation_id: int,
request: AssignViolationRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Assign violation to a developer
@@ -315,17 +339,19 @@ async def assign_violation(
user_id=request.user_id,
assigned_by=current_user.id,
due_date=request.due_date,
priority=request.priority
priority=request.priority,
)
return {
'id': assignment.id,
'violation_id': assignment.violation_id,
'user_id': assignment.user_id,
'assigned_at': assignment.assigned_at.isoformat(),
'assigned_by': assignment.assigned_by,
'due_date': assignment.due_date.isoformat() if assignment.due_date else None,
'priority': assignment.priority
"id": assignment.id,
"violation_id": assignment.violation_id,
"user_id": assignment.user_id,
"assigned_at": assignment.assigned_at.isoformat(),
"assigned_by": assignment.assigned_by,
"due_date": (
assignment.due_date.isoformat() if assignment.due_date else None
),
"priority": assignment.priority,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -336,7 +362,7 @@ async def resolve_violation(
violation_id: int,
request: ResolveViolationRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Mark violation as resolved
@@ -348,15 +374,17 @@ async def resolve_violation(
db,
violation_id=violation_id,
resolved_by=current_user.id,
resolution_note=request.resolution_note
resolution_note=request.resolution_note,
)
return {
'id': violation.id,
'status': violation.status,
'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None,
'resolved_by': violation.resolved_by,
'resolution_note': violation.resolution_note
"id": violation.id,
"status": violation.status,
"resolved_at": (
violation.resolved_at.isoformat() if violation.resolved_at else None
),
"resolved_by": violation.resolved_by,
"resolution_note": violation.resolution_note,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -369,7 +397,7 @@ async def ignore_violation(
violation_id: int,
request: IgnoreViolationRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Mark violation as ignored (won't fix)
@@ -381,15 +409,17 @@ async def ignore_violation(
db,
violation_id=violation_id,
ignored_by=current_user.id,
reason=request.reason
reason=request.reason,
)
return {
'id': violation.id,
'status': violation.status,
'resolved_at': violation.resolved_at.isoformat() if violation.resolved_at else None,
'resolved_by': violation.resolved_by,
'resolution_note': violation.resolution_note
"id": violation.id,
"status": violation.status,
"resolved_at": (
violation.resolved_at.isoformat() if violation.resolved_at else None
),
"resolved_by": violation.resolved_by,
"resolution_note": violation.resolution_note,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -402,7 +432,7 @@ async def add_comment(
violation_id: int,
request: AddCommentRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
current_user: User = Depends(get_current_admin_api),
):
"""
Add comment to violation
@@ -414,15 +444,15 @@ async def add_comment(
db,
violation_id=violation_id,
user_id=current_user.id,
comment=request.comment
comment=request.comment,
)
return {
'id': comment.id,
'violation_id': comment.violation_id,
'user_id': comment.user_id,
'comment': comment.comment,
'created_at': comment.created_at.isoformat()
"id": comment.id,
"violation_id": comment.violation_id,
"user_id": comment.user_id,
"comment": comment.comment,
"created_at": comment.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -430,8 +460,7 @@ async def add_comment(
@router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_api)
db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_api)
):
"""
Get dashboard statistics

View File

@@ -10,6 +10,7 @@ Platform administrators can:
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -26,24 +27,43 @@ logger = logging.getLogger(__name__)
# REQUEST/RESPONSE SCHEMAS
# ============================================================================
class ContentPageCreate(BaseModel):
"""Schema for creating a content page."""
slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)")
slug: str = Field(
...,
max_length=100,
description="URL-safe identifier (about, faq, contact, etc.)",
)
title: str = Field(..., max_length=200, description="Page title")
content: str = Field(..., description="HTML or Markdown content")
content_format: str = Field(default="html", description="Content format: html or markdown")
template: str = Field(default="default", max_length=50, description="Template name (default, minimal, modern)")
meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description")
meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords")
content_format: str = Field(
default="html", description="Content format: html or markdown"
)
template: str = Field(
default="default",
max_length=50,
description="Template name (default, minimal, modern)",
)
meta_description: Optional[str] = Field(
None, max_length=300, description="SEO meta description"
)
meta_keywords: Optional[str] = Field(
None, max_length=300, description="SEO keywords"
)
is_published: bool = Field(default=False, description="Publish immediately")
show_in_footer: bool = Field(default=True, description="Show in footer navigation")
show_in_header: bool = Field(default=False, description="Show in header navigation")
display_order: int = Field(default=0, description="Display order (lower = first)")
vendor_id: Optional[int] = Field(None, description="Vendor ID (None for platform default)")
vendor_id: Optional[int] = Field(
None, description="Vendor ID (None for platform default)"
)
class ContentPageUpdate(BaseModel):
"""Schema for updating a content page."""
title: Optional[str] = Field(None, max_length=200)
content: Optional[str] = None
content_format: Optional[str] = None
@@ -58,6 +78,7 @@ class ContentPageUpdate(BaseModel):
class ContentPageResponse(BaseModel):
"""Schema for content page response."""
id: int
vendor_id: Optional[int]
vendor_name: Optional[str]
@@ -84,11 +105,12 @@ class ContentPageResponse(BaseModel):
# PLATFORM DEFAULT PAGES (vendor_id=NULL)
# ============================================================================
@router.get("/platform", response_model=List[ContentPageResponse])
def list_platform_pages(
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
List all platform default content pages.
@@ -96,8 +118,7 @@ def list_platform_pages(
These are used as fallbacks when vendors haven't created custom pages.
"""
pages = content_page_service.list_all_platform_pages(
db,
include_unpublished=include_unpublished
db, include_unpublished=include_unpublished
)
return [page.to_dict() for page in pages]
@@ -107,7 +128,7 @@ def list_platform_pages(
def create_platform_page(
page_data: ContentPageCreate,
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Create a new platform default content page.
@@ -129,7 +150,7 @@ def create_platform_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
created_by=current_user.id
created_by=current_user.id,
)
return page.to_dict()
@@ -139,12 +160,13 @@ def create_platform_page(
# ALL CONTENT PAGES (Platform + Vendors)
# ============================================================================
@router.get("/", response_model=List[ContentPageResponse])
def list_all_pages(
vendor_id: Optional[int] = Query(None, description="Filter by vendor ID"),
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
List all content pages (platform defaults and vendor overrides).
@@ -153,15 +175,14 @@ def list_all_pages(
"""
if vendor_id:
pages = content_page_service.list_all_vendor_pages(
db,
vendor_id=vendor_id,
include_unpublished=include_unpublished
db, vendor_id=vendor_id, include_unpublished=include_unpublished
)
else:
# Get all pages (both platform and vendor)
from models.database.content_page import ContentPage
from sqlalchemy import and_
from models.database.content_page import ContentPage
filters = []
if not include_unpublished:
filters.append(ContentPage.is_published == True)
@@ -169,7 +190,9 @@ def list_all_pages(
pages = (
db.query(ContentPage)
.filter(and_(*filters) if filters else True)
.order_by(ContentPage.vendor_id, ContentPage.display_order, ContentPage.title)
.order_by(
ContentPage.vendor_id, ContentPage.display_order, ContentPage.title
)
.all()
)
@@ -180,7 +203,7 @@ def list_all_pages(
def get_page(
page_id: int,
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get a specific content page by ID."""
page = content_page_service.get_page_by_id(db, page_id)
@@ -196,7 +219,7 @@ def update_page(
page_id: int,
page_data: ContentPageUpdate,
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Update a content page (platform or vendor)."""
page = content_page_service.update_page(
@@ -212,7 +235,7 @@ def update_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
updated_by=current_user.id
updated_by=current_user.id,
)
if not page:
@@ -225,7 +248,7 @@ def update_page(
def delete_page(
page_id: int,
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Delete a content page."""
success = content_page_service.delete_page(db, page_id)

View File

@@ -5,6 +5,7 @@ Admin dashboard and statistics endpoints.
import logging
from typing import List
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

View File

@@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_service import admin_service
from app.services.stats_service import stats_service
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.database.user import User
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
router = APIRouter(prefix="/marketplace-import-jobs")
logger = logging.getLogger(__name__)

View File

@@ -16,16 +16,13 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from models.schema.admin import (
AdminNotificationCreate,
AdminNotificationResponse,
AdminNotificationListResponse,
PlatformAlertCreate,
PlatformAlertResponse,
PlatformAlertListResponse,
PlatformAlertResolve
)
from models.database.user import User
from models.schema.admin import (AdminNotificationCreate,
AdminNotificationListResponse,
AdminNotificationResponse,
PlatformAlertCreate,
PlatformAlertListResponse,
PlatformAlertResolve, PlatformAlertResponse)
router = APIRouter(prefix="/notifications")
logger = logging.getLogger(__name__)
@@ -35,6 +32,7 @@ logger = logging.getLogger(__name__)
# ADMIN NOTIFICATIONS
# ============================================================================
@router.get("", response_model=AdminNotificationListResponse)
def get_notifications(
priority: Optional[str] = Query(None, description="Filter by priority"),
@@ -47,11 +45,7 @@ def get_notifications(
"""Get admin notifications with filtering."""
# TODO: Implement notification service
return AdminNotificationListResponse(
notifications=[],
total=0,
unread_count=0,
skip=skip,
limit=limit
notifications=[], total=0, unread_count=0, skip=skip, limit=limit
)
@@ -90,10 +84,13 @@ def mark_all_as_read(
# PLATFORM ALERTS
# ============================================================================
@router.get("/alerts", response_model=PlatformAlertListResponse)
def get_platform_alerts(
severity: Optional[str] = Query(None, description="Filter by severity"),
is_resolved: Optional[bool] = Query(None, description="Filter by resolution status"),
is_resolved: Optional[bool] = Query(
None, description="Filter by resolution status"
),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100),
db: Session = Depends(get_db),
@@ -102,12 +99,7 @@ def get_platform_alerts(
"""Get platform alerts with filtering."""
# TODO: Implement alert service
return PlatformAlertListResponse(
alerts=[],
total=0,
active_count=0,
critical_count=0,
skip=skip,
limit=limit
alerts=[], total=0, active_count=0, critical_count=0, skip=skip, limit=limit
)
@@ -147,5 +139,5 @@ def get_alert_statistics(
"total_alerts": 0,
"active_alerts": 0,
"critical_alerts": 0,
"resolved_today": 0
"resolved_today": 0,
}

View File

@@ -16,15 +16,11 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_settings_service import admin_settings_service
from app.services.admin_audit_service import admin_audit_service
from models.schema.admin import (
AdminSettingCreate,
AdminSettingResponse,
AdminSettingUpdate,
AdminSettingListResponse
)
from app.services.admin_settings_service import admin_settings_service
from models.database.user import User
from models.schema.admin import (AdminSettingCreate, AdminSettingListResponse,
AdminSettingResponse, AdminSettingUpdate)
router = APIRouter(prefix="/settings")
logger = logging.getLogger(__name__)
@@ -46,9 +42,7 @@ def get_all_settings(
settings = admin_settings_service.get_all_settings(db, category, is_public)
return AdminSettingListResponse(
settings=settings,
total=len(settings),
category=category
settings=settings, total=len(settings), category=category
)
@@ -66,7 +60,7 @@ def get_setting_categories(
"marketplace",
"notifications",
"integrations",
"payments"
"payments",
]
}
@@ -82,6 +76,7 @@ def get_setting(
if not setting:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=f"Setting '{key}' not found")
return AdminSettingResponse.model_validate(setting)
@@ -99,9 +94,7 @@ def create_setting(
Setting keys should be lowercase with underscores (e.g., max_vendors_allowed).
"""
result = admin_settings_service.create_setting(
db=db,
setting_data=setting_data,
admin_user_id=current_admin.id
db=db, setting_data=setting_data, admin_user_id=current_admin.id
)
# Log action
@@ -111,7 +104,10 @@ def create_setting(
action="create_setting",
target_type="setting",
target_id=setting_data.key,
details={"category": setting_data.category, "value_type": setting_data.value_type}
details={
"category": setting_data.category,
"value_type": setting_data.value_type,
},
)
return result
@@ -128,10 +124,7 @@ def update_setting(
old_value = admin_settings_service.get_setting_value(db, key)
result = admin_settings_service.update_setting(
db=db,
key=key,
update_data=update_data,
admin_user_id=current_admin.id
db=db, key=key, update_data=update_data, admin_user_id=current_admin.id
)
# Log action
@@ -141,7 +134,7 @@ def update_setting(
action="update_setting",
target_type="setting",
target_id=key,
details={"old_value": str(old_value), "new_value": update_data.value}
details={"old_value": str(old_value), "new_value": update_data.value},
)
return result
@@ -159,9 +152,7 @@ def upsert_setting(
If setting exists, updates its value. If not, creates new setting.
"""
result = admin_settings_service.upsert_setting(
db=db,
setting_data=setting_data,
admin_user_id=current_admin.id
db=db, setting_data=setting_data, admin_user_id=current_admin.id
)
# Log action
@@ -171,7 +162,7 @@ def upsert_setting(
action="upsert_setting",
target_type="setting",
target_id=setting_data.key,
details={"category": setting_data.category}
details={"category": setting_data.category},
)
return result
@@ -195,13 +186,11 @@ def delete_setting(
if not confirm:
raise HTTPException(
status_code=400,
detail="Deletion requires confirmation parameter: confirm=true"
detail="Deletion requires confirmation parameter: confirm=true",
)
message = admin_settings_service.delete_setting(
db=db,
key=key,
admin_user_id=current_admin.id
db=db, key=key, admin_user_id=current_admin.id
)
# Log action
@@ -211,7 +200,7 @@ def delete_setting(
action="delete_setting",
target_type="setting",
target_id=key,
details={}
details={},
)
return {"message": message}

View File

@@ -13,8 +13,8 @@ from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.admin_service import admin_service
from app.services.stats_service import stats_service
from models.schema.auth import UserResponse
from models.database.user import User
from models.schema.auth import UserResponse
router = APIRouter(prefix="/users")
logger = logging.getLogger(__name__)

View File

@@ -12,24 +12,22 @@ Follows the architecture pattern:
import logging
from typing import List
from fastapi import APIRouter, Depends, Path, Body, Query
from fastapi import APIRouter, Body, Depends, Path, Query
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.vendor_domain_service import vendor_domain_service
from app.exceptions import VendorNotFoundException
from models.schema.vendor_domain import (
VendorDomainCreate,
VendorDomainUpdate,
VendorDomainResponse,
VendorDomainListResponse,
DomainVerificationInstructions,
DomainVerificationResponse,
DomainDeletionResponse,
)
from app.services.vendor_domain_service import vendor_domain_service
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.vendor_domain import (DomainDeletionResponse,
DomainVerificationInstructions,
DomainVerificationResponse,
VendorDomainCreate,
VendorDomainListResponse,
VendorDomainResponse,
VendorDomainUpdate)
router = APIRouter(prefix="/vendors")
logger = logging.getLogger(__name__)
@@ -88,9 +86,7 @@ def add_vendor_domain(
- 422: Invalid domain format or reserved subdomain
"""
domain = vendor_domain_service.add_domain(
db=db,
vendor_id=vendor_id,
domain_data=domain_data
db=db, vendor_id=vendor_id, domain_data=domain_data
)
return VendorDomainResponse(
@@ -148,7 +144,7 @@ def list_vendor_domains(
)
for d in domains
],
total=len(domains)
total=len(domains),
)
@@ -174,7 +170,9 @@ def get_domain_details(
is_active=domain.is_active,
is_verified=domain.is_verified,
ssl_status=domain.ssl_status,
verification_token=domain.verification_token if not domain.is_verified else None,
verification_token=(
domain.verification_token if not domain.is_verified else None
),
verified_at=domain.verified_at,
ssl_verified_at=domain.ssl_verified_at,
created_at=domain.created_at,
@@ -206,9 +204,7 @@ def update_vendor_domain(
- 400: Cannot activate unverified domain
"""
domain = vendor_domain_service.update_domain(
db=db,
domain_id=domain_id,
domain_update=domain_update
db=db, domain_id=domain_id, domain_update=domain_update
)
return VendorDomainResponse(
@@ -250,9 +246,7 @@ def delete_vendor_domain(
message = vendor_domain_service.delete_domain(db, domain_id)
return DomainDeletionResponse(
message=message,
domain=domain_name,
vendor_id=vendor_id
message=message, domain=domain_name, vendor_id=vendor_id
)
@@ -290,11 +284,14 @@ def verify_domain_ownership(
message=message,
domain=domain.domain,
verified_at=domain.verified_at,
is_verified=domain.is_verified
is_verified=domain.is_verified,
)
@router.get("/domains/{domain_id}/verification-instructions", response_model=DomainVerificationInstructions)
@router.get(
"/domains/{domain_id}/verification-instructions",
response_model=DomainVerificationInstructions,
)
def get_domain_verification_instructions(
domain_id: int = Path(..., description="Domain ID", gt=0),
db: Session = Depends(get_db),
@@ -324,5 +321,5 @@ def get_domain_verification_instructions(
verification_token=instructions["verification_token"],
instructions=instructions["instructions"],
txt_record=instructions["txt_record"],
common_registrars=instructions["common_registrars"]
common_registrars=instructions["common_registrars"],
)

View File

@@ -20,11 +20,8 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api, get_db
from app.services.vendor_theme_service import vendor_theme_service
from models.database.user import User
from models.schema.vendor_theme import (
VendorThemeResponse,
VendorThemeUpdate,
ThemePresetListResponse
)
from models.schema.vendor_theme import (ThemePresetListResponse,
VendorThemeResponse, VendorThemeUpdate)
router = APIRouter(prefix="/vendor-themes")
logger = logging.getLogger(__name__)
@@ -34,10 +31,9 @@ logger = logging.getLogger(__name__)
# PRESET ENDPOINTS
# ============================================================================
@router.get("/presets", response_model=ThemePresetListResponse)
async def get_theme_presets(
current_admin: User = Depends(get_current_admin_api)
):
async def get_theme_presets(current_admin: User = Depends(get_current_admin_api)):
"""
Get all available theme presets with preview information.
@@ -59,11 +55,12 @@ async def get_theme_presets(
# THEME RETRIEVAL
# ============================================================================
@router.get("/{vendor_code}", response_model=VendorThemeResponse)
async def get_vendor_theme(
vendor_code: str = Path(..., description="Vendor code"),
db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api)
current_admin: User = Depends(get_current_admin_api),
):
"""
Get theme configuration for a vendor.
@@ -93,12 +90,13 @@ async def get_vendor_theme(
# THEME UPDATE
# ============================================================================
@router.put("/{vendor_code}", response_model=VendorThemeResponse)
async def update_vendor_theme(
vendor_code: str = Path(..., description="Vendor code"),
theme_data: VendorThemeUpdate = None,
db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api)
current_admin: User = Depends(get_current_admin_api),
):
"""
Update or create theme for a vendor.
@@ -140,12 +138,13 @@ async def update_vendor_theme(
# PRESET APPLICATION
# ============================================================================
@router.post("/{vendor_code}/preset/{preset_name}")
async def apply_theme_preset(
vendor_code: str = Path(..., description="Vendor code"),
preset_name: str = Path(..., description="Preset name"),
db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api)
current_admin: User = Depends(get_current_admin_api),
):
"""
Apply a theme preset to a vendor.
@@ -184,7 +183,7 @@ async def apply_theme_preset(
return {
"message": f"Applied {preset_name} preset successfully",
"theme": theme.to_dict()
"theme": theme.to_dict(),
}
@@ -192,11 +191,12 @@ async def apply_theme_preset(
# THEME DELETION
# ============================================================================
@router.delete("/{vendor_code}")
async def delete_vendor_theme(
vendor_code: str = Path(..., description="Vendor code"),
db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_api)
current_admin: User = Depends(get_current_admin_api),
):
"""
Delete custom theme for a vendor.

View File

@@ -6,28 +6,24 @@ Vendor management endpoints for admin.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Query, Path, Body
from fastapi import APIRouter, Body, Depends, Path, Query
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.exceptions import (ConfirmationRequiredException,
VendorNotFoundException)
from app.services.admin_service import admin_service
from app.services.stats_service import stats_service
from app.exceptions import VendorNotFoundException, ConfirmationRequiredException
from models.schema.stats import VendorStatsResponse
from models.schema.vendor import (
VendorListResponse,
VendorResponse,
VendorDetailResponse,
VendorCreate,
VendorCreateResponse,
VendorUpdate,
VendorTransferOwnership,
VendorTransferOwnershipResponse,
)
from models.database.user import User
from models.database.vendor import Vendor
from sqlalchemy import func
from models.schema.stats import VendorStatsResponse
from models.schema.vendor import (VendorCreate, VendorCreateResponse,
VendorDetailResponse, VendorListResponse,
VendorResponse, VendorTransferOwnership,
VendorTransferOwnershipResponse,
VendorUpdate)
router = APIRouter(prefix="/vendors")
logger = logging.getLogger(__name__)
@@ -60,9 +56,11 @@ def _get_vendor_by_identifier(db: Session, identifier: str) -> Vendor:
pass
# Try as vendor_code (case-insensitive)
vendor = db.query(Vendor).filter(
func.upper(Vendor.vendor_code) == identifier.upper()
).first()
vendor = (
db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == identifier.upper())
.first()
)
if not vendor:
raise VendorNotFoundException(identifier, identifier_type="code")
@@ -93,8 +91,7 @@ def create_vendor_with_owner(
Returns vendor details with owner credentials.
"""
vendor, owner_user, temp_password = admin_service.create_vendor_with_owner(
db=db,
vendor_data=vendor_data
db=db, vendor_data=vendor_data
)
return VendorCreateResponse(
@@ -121,7 +118,7 @@ def create_vendor_with_owner(
owner_email=owner_user.email,
owner_username=owner_user.username,
temporary_password=temp_password,
login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login"
login_url=f"http://localhost:8000/vendor/{vendor.subdomain}/login",
)
@@ -142,7 +139,7 @@ def get_all_vendors_admin(
limit=limit,
search=search,
is_active=is_active,
is_verified=is_verified
is_verified=is_verified,
)
return VendorListResponse(vendors=vendors, total=total, skip=skip, limit=limit)
@@ -257,7 +254,10 @@ def update_vendor(
)
@router.post("/{vendor_identifier}/transfer-ownership", response_model=VendorTransferOwnershipResponse)
@router.post(
"/{vendor_identifier}/transfer-ownership",
response_model=VendorTransferOwnershipResponse,
)
def transfer_vendor_ownership(
vendor_identifier: str = Path(..., description="Vendor ID or vendor_code"),
transfer_data: VendorTransferOwnership = Body(...),
@@ -436,7 +436,7 @@ def delete_vendor(
if not confirm:
raise ConfirmationRequiredException(
operation="delete_vendor",
message="Deletion requires confirmation parameter: confirm=true"
message="Deletion requires confirmation parameter: confirm=true",
)
vendor = _get_vendor_by_identifier(db, vendor_identifier)

View File

@@ -21,7 +21,7 @@ Authentication:
from fastapi import APIRouter
# Import shop routers
from . import products, cart, orders, auth, content_pages
from . import auth, cart, content_pages, orders, products
# Create shop router
router = APIRouter()
@@ -43,6 +43,8 @@ router.include_router(cart.router, tags=["shop-cart"])
router.include_router(orders.router, tags=["shop-orders"])
# Content pages (public)
router.include_router(content_pages.router, prefix="/content-pages", tags=["shop-content-pages"])
router.include_router(
content_pages.router, prefix="/content-pages", tags=["shop-content-pages"]
)
__all__ = ["router"]

View File

@@ -15,15 +15,16 @@ This prevents:
"""
import logging
from fastapi import APIRouter, Depends, Response, Request, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.environment import should_use_secure_cookies
from app.services.customer_service import customer_service
from models.schema.auth import UserLogin
from models.schema.customer import CustomerRegister, CustomerResponse
from app.core.environment import should_use_secure_cookies
from pydantic import BaseModel
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
# Response model for customer login
class CustomerLoginResponse(BaseModel):
"""Customer login response with token and customer data."""
access_token: str
token_type: str
expires_in: int
@@ -40,9 +42,7 @@ class CustomerLoginResponse(BaseModel):
@router.post("/auth/register", response_model=CustomerResponse)
def register_customer(
request: Request,
customer_data: CustomerRegister,
db: Session = Depends(get_db)
request: Request, customer_data: CustomerRegister, db: Session = Depends(get_db)
):
"""
Register a new customer for current vendor.
@@ -59,12 +59,12 @@ def register_customer(
- phone: Customer phone number (optional)
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -73,14 +73,12 @@ def register_customer(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"email": customer_data.email,
}
},
)
# Create customer account
customer = customer_service.register_customer(
db=db,
vendor_id=vendor.id,
customer_data=customer_data
db=db, vendor_id=vendor.id, customer_data=customer_data
)
logger.info(
@@ -89,7 +87,7 @@ def register_customer(
"customer_id": customer.id,
"vendor_id": vendor.id,
"email": customer.email,
}
},
)
return CustomerResponse.model_validate(customer)
@@ -100,7 +98,7 @@ def customer_login(
request: Request,
user_credentials: UserLogin,
response: Response,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Customer login for current vendor.
@@ -121,12 +119,12 @@ def customer_login(
- password: Customer password
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -135,33 +133,39 @@ def customer_login(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"email_or_username": user_credentials.email_or_username,
}
},
)
# Authenticate customer
login_result = customer_service.login_customer(
db=db,
vendor_id=vendor.id,
credentials=user_credentials
db=db, vendor_id=vendor.id, credentials=user_credentials
)
logger.info(
f"Customer login successful: {login_result['customer'].email} for vendor {vendor.subdomain}",
extra={
"customer_id": login_result['customer'].id,
"customer_id": login_result["customer"].id,
"vendor_id": vendor.id,
"email": login_result['customer'].email,
}
"email": login_result["customer"].email,
},
)
# Calculate cookie path based on vendor access method
vendor_context = getattr(request.state, 'vendor_context', None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
vendor_context = getattr(request.state, "vendor_context", None)
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
cookie_path = "/shop" # Default for domain/subdomain access
if access_method == "path":
# For path-based access like /vendors/wizamart/shop
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
cookie_path = f"{full_prefix}{vendor.subdomain}/shop"
# Set HTTP-only cookie for browser navigation
@@ -180,10 +184,10 @@ def customer_login(
f"Set customer_token cookie with {login_result['token_data']['expires_in']}s expiry "
f"(path={cookie_path}, httponly=True, secure={should_use_secure_cookies()})",
extra={
"expires_in": login_result['token_data']['expires_in'],
"expires_in": login_result["token_data"]["expires_in"],
"secure": should_use_secure_cookies(),
"cookie_path": cookie_path,
}
},
)
# Return full login response
@@ -196,10 +200,7 @@ def customer_login(
@router.post("/auth/logout")
def customer_logout(
request: Request,
response: Response
):
def customer_logout(request: Request, response: Response):
"""
Customer logout for current vendor.
@@ -208,24 +209,32 @@ def customer_logout(
Client should also remove token from localStorage.
"""
# Get vendor from middleware (for logging)
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
logger.info(
f"Customer logout for vendor {vendor.subdomain if vendor else 'unknown'}",
extra={
"vendor_id": vendor.id if vendor else None,
"vendor_code": vendor.subdomain if vendor else None,
}
},
)
# Calculate cookie path based on vendor access method (must match login)
vendor_context = getattr(request.state, 'vendor_context', None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
vendor_context = getattr(request.state, "vendor_context", None)
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
cookie_path = "/shop" # Default for domain/subdomain access
if access_method == "path" and vendor:
# For path-based access like /vendors/wizamart/shop
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
cookie_path = f"{full_prefix}{vendor.subdomain}/shop"
# Clear the cookie (must match path used when setting)
@@ -240,11 +249,7 @@ def customer_logout(
@router.post("/auth/forgot-password")
def forgot_password(
request: Request,
email: str,
db: Session = Depends(get_db)
):
def forgot_password(request: Request, email: str, db: Session = Depends(get_db)):
"""
Request password reset for customer.
@@ -255,12 +260,12 @@ def forgot_password(
- email: Customer email address
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -269,7 +274,7 @@ def forgot_password(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"email": email,
}
},
)
# TODO: Implement password reset functionality
@@ -278,9 +283,7 @@ def forgot_password(
# - Send reset email to customer
# - Return success message (don't reveal if email exists)
logger.info(
f"Password reset requested for {email} (vendor: {vendor.subdomain})"
)
logger.info(f"Password reset requested for {email} (vendor: {vendor.subdomain})")
return {
"message": "If an account exists with this email, a password reset link has been sent."
@@ -289,10 +292,7 @@ def forgot_password(
@router.post("/auth/reset-password")
def reset_password(
request: Request,
reset_token: str,
new_password: str,
db: Session = Depends(get_db)
request: Request, reset_token: str, new_password: str, db: Session = Depends(get_db)
):
"""
Reset customer password using reset token.
@@ -304,12 +304,12 @@ def reset_password(
- new_password: New password
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -317,7 +317,7 @@ def reset_password(
extra={
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
}
},
)
# TODO: Implement password reset
@@ -327,9 +327,7 @@ def reset_password(
# - Invalidate reset token
# - Return success
logger.info(
f"Password reset completed (vendor: {vendor.subdomain})"
)
logger.info(f"Password reset completed (vendor: {vendor.subdomain})")
return {
"message": "Password reset successfully. You can now log in with your new password."

View File

@@ -8,18 +8,15 @@ No authentication required - uses session ID for cart tracking.
"""
import logging
from fastapi import APIRouter, Depends, Path, Body, Request, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.services.cart_service import cart_service
from models.schema.cart import (
AddToCartRequest,
UpdateCartItemRequest,
CartResponse,
CartOperationResponse,
ClearCartResponse,
)
from models.schema.cart import (AddToCartRequest, CartOperationResponse,
CartResponse, ClearCartResponse,
UpdateCartItemRequest)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -29,6 +26,7 @@ logger = logging.getLogger(__name__)
# CART ENDPOINTS
# ============================================================================
@router.get("/cart/{session_id}", response_model=CartResponse)
def get_cart(
request: Request,
@@ -45,12 +43,12 @@ def get_cart(
- session_id: Unique session identifier for the cart
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.info(
@@ -59,23 +57,19 @@ def get_cart(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"session_id": session_id,
}
},
)
cart = cart_service.get_cart(
db=db,
vendor_id=vendor.id,
session_id=session_id
)
cart = cart_service.get_cart(db=db, vendor_id=vendor.id, session_id=session_id)
logger.info(
f"[SHOP_API] get_cart result: {len(cart.get('items', []))} items in cart",
extra={
"session_id": session_id,
"vendor_id": vendor.id,
"item_count": len(cart.get('items', [])),
"total": cart.get('total', 0),
}
"item_count": len(cart.get("items", [])),
"total": cart.get("total", 0),
},
)
return CartResponse.from_service_dict(cart)
@@ -102,12 +96,12 @@ def add_to_cart(
- quantity: Quantity to add (default: 1)
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.info(
@@ -118,7 +112,7 @@ def add_to_cart(
"session_id": session_id,
"product_id": cart_data.product_id,
"quantity": cart_data.quantity,
}
},
)
result = cart_service.add_to_cart(
@@ -126,7 +120,7 @@ def add_to_cart(
vendor_id=vendor.id,
session_id=session_id,
product_id=cart_data.product_id,
quantity=cart_data.quantity
quantity=cart_data.quantity,
)
logger.info(
@@ -134,13 +128,15 @@ def add_to_cart(
extra={
"session_id": session_id,
"result": result,
}
},
)
return CartOperationResponse(**result)
@router.put("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse)
@router.put(
"/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse
)
def update_cart_item(
request: Request,
session_id: str = Path(..., description="Shopping session ID"),
@@ -162,12 +158,12 @@ def update_cart_item(
- quantity: New quantity (must be >= 1)
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -178,7 +174,7 @@ def update_cart_item(
"session_id": session_id,
"product_id": product_id,
"quantity": cart_data.quantity,
}
},
)
result = cart_service.update_cart_item(
@@ -186,13 +182,15 @@ def update_cart_item(
vendor_id=vendor.id,
session_id=session_id,
product_id=product_id,
quantity=cart_data.quantity
quantity=cart_data.quantity,
)
return CartOperationResponse(**result)
@router.delete("/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse)
@router.delete(
"/cart/{session_id}/items/{product_id}", response_model=CartOperationResponse
)
def remove_from_cart(
request: Request,
session_id: str = Path(..., description="Shopping session ID"),
@@ -210,12 +208,12 @@ def remove_from_cart(
- product_id: ID of product to remove
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -225,14 +223,11 @@ def remove_from_cart(
"vendor_code": vendor.subdomain,
"session_id": session_id,
"product_id": product_id,
}
},
)
result = cart_service.remove_from_cart(
db=db,
vendor_id=vendor.id,
session_id=session_id,
product_id=product_id
db=db, vendor_id=vendor.id, session_id=session_id, product_id=product_id
)
return CartOperationResponse(**result)
@@ -254,12 +249,12 @@ def clear_cart(
- session_id: Unique session identifier for the cart
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -268,13 +263,9 @@ def clear_cart(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"session_id": session_id,
}
},
)
result = cart_service.clear_cart(
db=db,
vendor_id=vendor.id,
session_id=session_id
)
result = cart_service.clear_cart(db=db, vendor_id=vendor.id, session_id=session_id)
return ClearCartResponse(**result)

View File

@@ -8,6 +8,7 @@ No authentication required.
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -23,8 +24,10 @@ logger = logging.getLogger(__name__)
# RESPONSE SCHEMAS
# ============================================================================
class PublicContentPageResponse(BaseModel):
"""Public content page response (no internal IDs)."""
slug: str
title: str
content: str
@@ -36,6 +39,7 @@ class PublicContentPageResponse(BaseModel):
class ContentPageListItem(BaseModel):
"""Content page list item for navigation."""
slug: str
title: str
show_in_footer: bool
@@ -47,25 +51,21 @@ class ContentPageListItem(BaseModel):
# PUBLIC ENDPOINTS
# ============================================================================
@router.get("/navigation", response_model=List[ContentPageListItem])
def get_navigation_pages(
request: Request,
db: Session = Depends(get_db)
):
def get_navigation_pages(request: Request, db: Session = Depends(get_db)):
"""
Get list of content pages for navigation (footer/header).
Uses vendor from request.state (set by middleware).
Returns vendor overrides + platform defaults.
"""
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
# Get all published pages for this vendor
pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=vendor_id,
include_unpublished=False
db, vendor_id=vendor_id, include_unpublished=False
)
return [
@@ -81,25 +81,21 @@ def get_navigation_pages(
@router.get("/{slug}", response_model=PublicContentPageResponse)
def get_content_page(
slug: str,
request: Request,
db: Session = Depends(get_db)
):
def get_content_page(slug: str, request: Request, db: Session = Depends(get_db)):
"""
Get a specific content page by slug.
Uses vendor from request.state (set by middleware).
Returns vendor override if exists, otherwise platform default.
"""
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=vendor_id,
include_unpublished=False # Only show published pages
include_unpublished=False, # Only show published pages
)
if not page:

View File

@@ -10,31 +10,23 @@ Requires customer authentication for most operations.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Path, Query, Request, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_customer_api
from app.services.order_service import order_service
from app.core.database import get_db
from app.services.customer_service import customer_service
from models.schema.order import (
OrderCreate,
OrderResponse,
OrderDetailResponse,
OrderListResponse
)
from models.database.user import User
from app.services.order_service import order_service
from models.database.customer import Customer
from models.database.user import User
from models.schema.order import (OrderCreate, OrderDetailResponse,
OrderListResponse, OrderResponse)
router = APIRouter()
logger = logging.getLogger(__name__)
def get_customer_from_user(
request: Request,
user: User,
db: Session
) -> Customer:
def get_customer_from_user(request: Request, user: User, db: Session) -> Customer:
"""
Helper to get Customer record from authenticated User.
@@ -49,25 +41,22 @@ def get_customer_from_user(
Raises:
HTTPException: If customer not found or vendor mismatch
"""
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Find customer record for this user and vendor
customer = customer_service.get_customer_by_user_id(
db=db,
vendor_id=vendor.id,
user_id=user.id
db=db, vendor_id=vendor.id, user_id=user.id
)
if not customer:
raise HTTPException(
status_code=404,
detail="Customer account not found for current vendor"
status_code=404, detail="Customer account not found for current vendor"
)
return customer
@@ -91,12 +80,12 @@ def place_order(
- Order data including shipping address, payment method, etc.
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Get customer record
@@ -109,14 +98,12 @@ def place_order(
"vendor_code": vendor.subdomain,
"customer_id": customer.id,
"user_id": current_user.id,
}
},
)
# Create order
order = order_service.create_order(
db=db,
vendor_id=vendor.id,
order_data=order_data
db=db, vendor_id=vendor.id, order_data=order_data
)
logger.info(
@@ -127,7 +114,7 @@ def place_order(
"order_number": order.order_number,
"customer_id": customer.id,
"total_amount": float(order.total_amount),
}
},
)
# TODO: Update customer stats
@@ -156,12 +143,12 @@ def get_my_orders(
- limit: Maximum number of orders to return
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Get customer record
@@ -175,23 +162,19 @@ def get_my_orders(
"customer_id": customer.id,
"skip": skip,
"limit": limit,
}
},
)
# Get orders
orders, total = order_service.get_customer_orders(
db=db,
vendor_id=vendor.id,
customer_id=customer.id,
skip=skip,
limit=limit
db=db, vendor_id=vendor.id, customer_id=customer.id, skip=skip, limit=limit
)
return OrderListResponse(
orders=[OrderResponse.model_validate(o) for o in orders],
total=total,
skip=skip,
limit=limit
limit=limit,
)
@@ -212,12 +195,12 @@ def get_order_details(
- order_id: ID of the order to retrieve
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
# Get customer record
@@ -230,19 +213,16 @@ def get_order_details(
"vendor_code": vendor.subdomain,
"customer_id": customer.id,
"order_id": order_id,
}
},
)
# Get order
order = order_service.get_order(
db=db,
vendor_id=vendor.id,
order_id=order_id
)
order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id)
# Verify order belongs to customer
if order.customer_id != customer.id:
from app.exceptions import OrderNotFoundException
raise OrderNotFoundException(str(order_id))
return OrderDetailResponse.model_validate(order)

View File

@@ -10,12 +10,13 @@ No authentication required.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Query, Path, Request, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.services.product_service import product_service
from models.schema.product import ProductResponse, ProductDetailResponse, ProductListResponse
from models.schema.product import (ProductDetailResponse, ProductListResponse,
ProductResponse)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -27,7 +28,9 @@ def get_product_catalog(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None, description="Search products by name"),
is_featured: Optional[bool] = Query(None, description="Filter by featured products"),
is_featured: Optional[bool] = Query(
None, description="Filter by featured products"
),
db: Session = Depends(get_db),
):
"""
@@ -44,12 +47,12 @@ def get_product_catalog(
- is_featured: Filter by featured products only
"""
# Get vendor from middleware (injected by VendorContextMiddleware)
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -61,7 +64,7 @@ def get_product_catalog(
"limit": limit,
"search": search,
"is_featured": is_featured,
}
},
)
# Get only active products for public view
@@ -71,14 +74,14 @@ def get_product_catalog(
skip=skip,
limit=limit,
is_active=True, # Only show active products to customers
is_featured=is_featured
is_featured=is_featured,
)
return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products],
total=total,
skip=skip,
limit=limit
limit=limit,
)
@@ -98,12 +101,12 @@ def get_product_details(
- product_id: ID of the product to retrieve
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -112,18 +115,17 @@ def get_product_details(
"vendor_id": vendor.id,
"vendor_code": vendor.subdomain,
"product_id": product_id,
}
},
)
product = product_service.get_product(
db=db,
vendor_id=vendor.id,
product_id=product_id
db=db, vendor_id=vendor.id, product_id=product_id
)
# Check if product is active
if not product.is_active:
from app.exceptions import ProductNotActiveException
raise ProductNotActiveException(str(product_id))
return ProductDetailResponse.model_validate(product)
@@ -150,12 +152,12 @@ def search_products(
- limit: Maximum number of results to return
"""
# Get vendor from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(
status_code=404,
detail="Vendor not found. Please access via vendor domain/subdomain/path."
detail="Vendor not found. Please access via vendor domain/subdomain/path.",
)
logger.debug(
@@ -166,22 +168,18 @@ def search_products(
"query": q,
"skip": skip,
"limit": limit,
}
},
)
# TODO: Implement full-text search functionality
# For now, return filtered products
products, total = product_service.get_vendor_products(
db=db,
vendor_id=vendor.id,
skip=skip,
limit=limit,
is_active=True
db=db, vendor_id=vendor.id, skip=skip, limit=limit, is_active=True
)
return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products],
total=total,
skip=skip,
limit=limit
limit=limit,
)

View File

@@ -13,25 +13,9 @@ IMPORTANT:
from fastapi import APIRouter
# Import all sub-routers (JSON API only)
from . import (
info,
auth,
dashboard,
profile,
settings,
products,
orders,
customers,
team,
inventory,
marketplace,
payments,
media,
notifications,
analytics,
content_pages,
)
from . import (analytics, auth, content_pages, customers, dashboard, info,
inventory, marketplace, media, notifications, orders, payments,
products, profile, settings, team)
# Create vendor router
router = APIRouter()
@@ -68,7 +52,11 @@ router.include_router(notifications.router, tags=["vendor-notifications"])
router.include_router(analytics.router, tags=["vendor-analytics"])
# Content pages management
router.include_router(content_pages.router, prefix="/{vendor_code}/content-pages", tags=["vendor-content-pages"])
router.include_router(
content_pages.router,
prefix="/{vendor_code}/content-pages",
tags=["vendor-content-pages"],
)
# Vendor info endpoint - MUST BE LAST! Has catch-all GET /{vendor_code}
router.include_router(info.router, tags=["vendor-info"])

View File

@@ -4,13 +4,14 @@ Vendor analytics and reporting endpoints.
"""
import logging
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.stats_service import stats_service
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor

View File

@@ -13,19 +13,20 @@ This prevents:
"""
import logging
from fastapi import APIRouter, Depends, Request, Response
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.services.auth_service import auth_service
from app.exceptions import InvalidCredentialsException
from middleware.vendor_context import get_current_vendor
from models.schema.auth import UserLogin
from models.database.vendor import Vendor, VendorUser, Role
from models.database.user import User
from pydantic import BaseModel
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from app.core.environment import should_use_secure_cookies
from app.exceptions import InvalidCredentialsException
from app.services.auth_service import auth_service
from middleware.vendor_context import get_current_vendor
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
from models.schema.auth import UserLogin
router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__)
@@ -46,7 +47,7 @@ def vendor_login(
user_credentials: UserLogin,
request: Request,
response: Response,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Vendor team member login.
@@ -64,13 +65,16 @@ def vendor_login(
vendor = get_current_vendor(request)
# If no vendor from middleware, try to get from request body
if not vendor and hasattr(user_credentials, 'vendor_code'):
vendor_code = getattr(user_credentials, 'vendor_code', None)
if not vendor and hasattr(user_credentials, "vendor_code"):
vendor_code = getattr(user_credentials, "vendor_code", None)
if vendor_code:
vendor = db.query(Vendor).filter(
Vendor.vendor_code == vendor_code.upper(),
Vendor.is_active == True
).first()
vendor = (
db.query(Vendor)
.filter(
Vendor.vendor_code == vendor_code.upper(), Vendor.is_active == True
)
.first()
)
# Authenticate user
login_result = auth_service.login_user(db=db, user_credentials=user_credentials)
@@ -79,7 +83,9 @@ def vendor_login(
# CRITICAL: Prevent admin users from using vendor login
if user.role == "admin":
logger.warning(f"Admin user attempted vendor login: {user.username}")
raise InvalidCredentialsException("Admins cannot access vendor portal. Please use admin portal.")
raise InvalidCredentialsException(
"Admins cannot access vendor portal. Please use admin portal."
)
# Determine vendor and role
vendor_role = "Member"
@@ -92,11 +98,16 @@ def vendor_login(
vendor_role = "Owner"
else:
# Check if user is team member
vendor_user = db.query(VendorUser).join(Role).filter(
vendor_user = (
db.query(VendorUser)
.join(Role)
.filter(
VendorUser.user_id == user.id,
VendorUser.vendor_id == vendor.id,
VendorUser.is_active == True
).first()
VendorUser.is_active == True,
)
.first()
)
if vendor_user:
vendor_role = vendor_user.role.name
@@ -117,17 +128,14 @@ def vendor_login(
# Check vendor memberships
elif user.vendor_memberships:
active_membership = next(
(vm for vm in user.vendor_memberships if vm.is_active),
None
(vm for vm in user.vendor_memberships if vm.is_active), None
)
if active_membership:
vendor = active_membership.vendor
vendor_role = active_membership.role.name
if not vendor:
raise InvalidCredentialsException(
"User is not associated with any vendor"
)
raise InvalidCredentialsException("User is not associated with any vendor")
logger.info(
f"Vendor team login successful: {user.username} "
@@ -161,7 +169,7 @@ def vendor_login(
"username": user.username,
"email": user.email,
"role": user.role,
"is_active": user.is_active
"is_active": user.is_active,
},
vendor={
"id": vendor.id,
@@ -169,9 +177,9 @@ def vendor_login(
"subdomain": vendor.subdomain,
"name": vendor.name,
"is_active": vendor.is_active,
"is_verified": vendor.is_verified
"is_verified": vendor.is_verified,
},
vendor_role=vendor_role
vendor_role=vendor_role,
)
@@ -198,8 +206,7 @@ def vendor_logout(response: Response):
@router.get("/me")
def get_current_vendor_user(
user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
user: User = Depends(get_current_vendor_api), db: Session = Depends(get_db)
):
"""
Get current authenticated vendor user.
@@ -212,5 +219,5 @@ def get_current_vendor_user(
"username": user.username,
"email": user.email,
"role": user.role,
"is_active": user.is_active
"is_active": user.is_active,
}

View File

@@ -10,6 +10,7 @@ Vendors can:
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -26,14 +27,26 @@ logger = logging.getLogger(__name__)
# REQUEST/RESPONSE SCHEMAS
# ============================================================================
class VendorContentPageCreate(BaseModel):
"""Schema for creating a vendor content page."""
slug: str = Field(..., max_length=100, description="URL-safe identifier (about, faq, contact, etc.)")
slug: str = Field(
...,
max_length=100,
description="URL-safe identifier (about, faq, contact, etc.)",
)
title: str = Field(..., max_length=200, description="Page title")
content: str = Field(..., description="HTML or Markdown content")
content_format: str = Field(default="html", description="Content format: html or markdown")
meta_description: Optional[str] = Field(None, max_length=300, description="SEO meta description")
meta_keywords: Optional[str] = Field(None, max_length=300, description="SEO keywords")
content_format: str = Field(
default="html", description="Content format: html or markdown"
)
meta_description: Optional[str] = Field(
None, max_length=300, description="SEO meta description"
)
meta_keywords: Optional[str] = Field(
None, max_length=300, description="SEO keywords"
)
is_published: bool = Field(default=False, description="Publish immediately")
show_in_footer: bool = Field(default=True, description="Show in footer navigation")
show_in_header: bool = Field(default=False, description="Show in header navigation")
@@ -42,6 +55,7 @@ class VendorContentPageCreate(BaseModel):
class VendorContentPageUpdate(BaseModel):
"""Schema for updating a vendor content page."""
title: Optional[str] = Field(None, max_length=200)
content: Optional[str] = None
content_format: Optional[str] = None
@@ -55,6 +69,7 @@ class VendorContentPageUpdate(BaseModel):
class ContentPageResponse(BaseModel):
"""Schema for content page response."""
id: int
vendor_id: Optional[int]
vendor_name: Optional[str]
@@ -81,11 +96,12 @@ class ContentPageResponse(BaseModel):
# VENDOR CONTENT PAGES
# ============================================================================
@router.get("/", response_model=List[ContentPageResponse])
def list_vendor_pages(
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
List all content pages available for this vendor.
@@ -93,12 +109,12 @@ def list_vendor_pages(
Returns vendor-specific overrides + platform defaults (vendor overrides take precedence).
"""
if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor")
raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=current_user.vendor_id,
include_unpublished=include_unpublished
db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished
)
return [page.to_dict() for page in pages]
@@ -108,7 +124,7 @@ def list_vendor_pages(
def list_vendor_overrides(
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
List only vendor-specific content pages (no platform defaults).
@@ -116,12 +132,12 @@ def list_vendor_overrides(
Shows what the vendor has customized.
"""
if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor")
raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
pages = content_page_service.list_all_vendor_pages(
db,
vendor_id=current_user.vendor_id,
include_unpublished=include_unpublished
db, vendor_id=current_user.vendor_id, include_unpublished=include_unpublished
)
return [page.to_dict() for page in pages]
@@ -132,7 +148,7 @@ def get_page(
slug: str,
include_unpublished: bool = Query(False, description="Include draft pages"),
current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Get a specific content page by slug.
@@ -140,13 +156,15 @@ def get_page(
Returns vendor override if exists, otherwise platform default.
"""
if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor")
raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=current_user.vendor_id,
include_unpublished=include_unpublished
include_unpublished=include_unpublished,
)
if not page:
@@ -159,7 +177,7 @@ def get_page(
def create_vendor_page(
page_data: VendorContentPageCreate,
current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Create a vendor-specific content page override.
@@ -167,7 +185,9 @@ def create_vendor_page(
This will be shown instead of the platform default for this vendor.
"""
if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor")
raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
page = content_page_service.create_page(
db,
@@ -182,7 +202,7 @@ def create_vendor_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
created_by=current_user.id
created_by=current_user.id,
)
return page.to_dict()
@@ -193,7 +213,7 @@ def update_vendor_page(
page_id: int,
page_data: VendorContentPageUpdate,
current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Update a vendor-specific content page.
@@ -201,7 +221,9 @@ def update_vendor_page(
Can only update pages owned by this vendor.
"""
if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor")
raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
# Verify ownership
existing_page = content_page_service.get_page_by_id(db, page_id)
@@ -209,7 +231,9 @@ def update_vendor_page(
raise HTTPException(status_code=404, detail="Content page not found")
if existing_page.vendor_id != current_user.vendor_id:
raise HTTPException(status_code=403, detail="Cannot edit pages from other vendors")
raise HTTPException(
status_code=403, detail="Cannot edit pages from other vendors"
)
# Update
page = content_page_service.update_page(
@@ -224,7 +248,7 @@ def update_vendor_page(
show_in_footer=page_data.show_in_footer,
show_in_header=page_data.show_in_header,
display_order=page_data.display_order,
updated_by=current_user.id
updated_by=current_user.id,
)
return page.to_dict()
@@ -234,7 +258,7 @@ def update_vendor_page(
def delete_vendor_page(
page_id: int,
current_user: User = Depends(get_current_vendor_api),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Delete a vendor-specific content page.
@@ -243,7 +267,9 @@ def delete_vendor_page(
After deletion, platform default will be shown (if exists).
"""
if not current_user.vendor_id:
raise HTTPException(status_code=403, detail="User is not associated with a vendor")
raise HTTPException(
status_code=403, detail="User is not associated with a vendor"
)
# Verify ownership
existing_page = content_page_service.get_page_by_id(db, page_id)
@@ -251,7 +277,9 @@ def delete_vendor_page(
raise HTTPException(status_code=404, detail="Content page not found")
if existing_page.vendor_id != current_user.vendor_id:
raise HTTPException(status_code=403, detail="Cannot delete pages from other vendors")
raise HTTPException(
status_code=403, detail="Cannot delete pages from other vendors"
)
# Delete
content_page_service.delete_page(db, page_id)

View File

@@ -6,6 +6,7 @@ Vendor customer management endpoints.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -43,7 +44,7 @@ def get_vendor_customers(
"total": 0,
"skip": skip,
"limit": limit,
"message": "Customer management coming in Slice 4"
"message": "Customer management coming in Slice 4",
}
@@ -63,9 +64,7 @@ def get_customer_details(
- Include order history
- Include total spent, etc.
"""
return {
"message": "Customer details coming in Slice 4"
}
return {"message": "Customer details coming in Slice 4"}
@router.get("/{customer_id}/orders")
@@ -83,10 +82,7 @@ def get_customer_orders(
- Filter by vendor_id
- Return order details
"""
return {
"orders": [],
"message": "Customer orders coming in Slice 5"
}
return {"orders": [], "message": "Customer orders coming in Slice 5"}
@router.put("/{customer_id}")
@@ -105,9 +101,7 @@ def update_customer(
- Verify customer belongs to vendor
- Update customer preferences
"""
return {
"message": "Customer update coming in Slice 4"
}
return {"message": "Customer update coming in Slice 4"}
@router.put("/{customer_id}/status")
@@ -125,9 +119,7 @@ def toggle_customer_status(
- Verify customer belongs to vendor
- Log the change
"""
return {
"message": "Customer status toggle coming in Slice 4"
}
return {"message": "Customer status toggle coming in Slice 4"}
@router.get("/{customer_id}/stats")
@@ -151,6 +143,5 @@ def get_customer_statistics(
"total_spent": 0.0,
"average_order_value": 0.0,
"last_order_date": None,
"message": "Customer statistics coming in Slice 4"
"message": "Customer statistics coming in Slice 4",
}

View File

@@ -4,13 +4,14 @@ Vendor dashboard and statistics endpoints.
"""
import logging
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.stats_service import stats_service
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
@@ -38,24 +39,23 @@ def get_vendor_dashboard_stats(
"""
# Get vendor from authenticated user's vendor_user record
from models.database.vendor import VendorUser
vendor_user = db.query(VendorUser).filter(
VendorUser.user_id == current_user.id
).first()
vendor_user = (
db.query(VendorUser).filter(VendorUser.user_id == current_user.id).first()
)
if not vendor_user:
from fastapi import HTTPException
raise HTTPException(
status_code=403,
detail="User is not associated with any vendor"
status_code=403, detail="User is not associated with any vendor"
)
vendor = vendor_user.vendor
if not vendor or not vendor.is_active:
from fastapi import HTTPException
raise HTTPException(
status_code=404,
detail="Vendor not found or inactive"
)
raise HTTPException(status_code=404, detail="Vendor not found or inactive")
# Get vendor-scoped statistics
stats_data = stats_service.get_vendor_stats(db=db, vendor_id=vendor.id)
@@ -82,5 +82,5 @@ def get_vendor_dashboard_stats(
"revenue": {
"total": stats_data.get("total_revenue", 0),
"this_month": stats_data.get("revenue_this_month", 0),
}
},
}

View File

@@ -8,14 +8,15 @@ This module provides:
"""
import logging
from fastapi import APIRouter, Path, Depends
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, Path
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.exceptions import VendorNotFoundException
from models.schema.vendor import VendorResponse, VendorDetailResponse
from models.database.vendor import Vendor
from models.schema.vendor import VendorDetailResponse, VendorResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -35,10 +36,14 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor:
Raises:
VendorNotFoundException: If vendor not found or inactive
"""
vendor = db.query(Vendor).filter(
vendor = (
db.query(Vendor)
.filter(
func.upper(Vendor.vendor_code) == vendor_code.upper(),
Vendor.is_active == True
).first()
Vendor.is_active == True,
)
.first()
)
if not vendor:
logger.warning(f"Vendor not found or inactive: {vendor_code}")
@@ -50,7 +55,7 @@ def _get_vendor_by_code(db: Session, vendor_code: str) -> Vendor:
@router.get("/{vendor_code}", response_model=VendorDetailResponse)
def get_vendor_info(
vendor_code: str = Path(..., description="Vendor code"),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Get public vendor information by vendor code.

View File

@@ -7,19 +7,14 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.inventory_service import inventory_service
from models.schema.inventory import (
InventoryCreate,
InventoryAdjust,
InventoryUpdate,
InventoryReserve,
InventoryResponse,
ProductInventorySummary,
InventoryListResponse
)
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.inventory import (InventoryAdjust, InventoryCreate,
InventoryListResponse, InventoryReserve,
InventoryResponse, InventoryUpdate,
ProductInventorySummary)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -110,10 +105,7 @@ def get_vendor_inventory(
total = len(inventories) # You might want a separate count query for large datasets
return InventoryListResponse(
inventories=inventories,
total=total,
skip=skip,
limit=limit
inventories=inventories, total=total, skip=skip, limit=limit
)
@@ -126,7 +118,9 @@ def update_inventory(
db: Session = Depends(get_db),
):
"""Update inventory entry."""
return inventory_service.update_inventory(db, vendor.id, inventory_id, inventory_update)
return inventory_service.update_inventory(
db, vendor.id, inventory_id, inventory_update
)
@router.delete("/inventory/{inventory_id}")

View File

@@ -12,16 +12,15 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context # IMPORTANT
from app.services.marketplace_import_job_service import marketplace_import_job_service
from app.services.marketplace_import_job_service import \
marketplace_import_job_service
from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit
from models.schema.marketplace_import_job import (
MarketplaceImportJobResponse,
MarketplaceImportJobRequest
)
from middleware.vendor_context import require_vendor_context # IMPORTANT
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
MarketplaceImportJobResponse)
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -88,6 +87,7 @@ def get_marketplace_import_status(
# Verify job belongs to current vendor
if job.vendor_id != vendor.id:
from app.exceptions import UnauthorizedVendorAccessException
raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id)
return marketplace_import_job_service.convert_to_response_model(job)
@@ -112,4 +112,6 @@ def get_marketplace_import_jobs(
limit=limit,
)
return [marketplace_import_job_service.convert_to_response_model(job) for job in jobs]
return [
marketplace_import_job_service.convert_to_response_model(job) for job in jobs
]

View File

@@ -6,7 +6,8 @@ Vendor media and file management endpoints.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Query, UploadFile, File
from fastapi import APIRouter, Depends, File, Query, UploadFile
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
@@ -44,7 +45,7 @@ def get_media_library(
"total": 0,
"skip": skip,
"limit": limit,
"message": "Media library coming in Slice 3"
"message": "Media library coming in Slice 3",
}
@@ -70,7 +71,7 @@ async def upload_media(
return {
"file_url": None,
"thumbnail_url": None,
"message": "Media upload coming in Slice 3"
"message": "Media upload coming in Slice 3",
}
@@ -94,7 +95,7 @@ async def upload_multiple_media(
return {
"uploaded_files": [],
"failed_files": [],
"message": "Multiple upload coming in Slice 3"
"message": "Multiple upload coming in Slice 3",
}
@@ -113,9 +114,7 @@ def get_media_details(
- Return file URL
- Return usage information (which products use this file)
"""
return {
"message": "Media details coming in Slice 3"
}
return {"message": "Media details coming in Slice 3"}
@router.put("/{media_id}")
@@ -135,9 +134,7 @@ def update_media_metadata(
- Update tags/categories
- Update description
"""
return {
"message": "Media update coming in Slice 3"
}
return {"message": "Media update coming in Slice 3"}
@router.delete("/{media_id}")
@@ -157,9 +154,7 @@ def delete_media(
- Delete database record
- Return success/error
"""
return {
"message": "Media deletion coming in Slice 3"
}
return {"message": "Media deletion coming in Slice 3"}
@router.get("/{media_id}/usage")
@@ -180,7 +175,7 @@ def get_media_usage(
return {
"products": [],
"other_usage": [],
"message": "Media usage tracking coming in Slice 3"
"message": "Media usage tracking coming in Slice 3",
}
@@ -200,7 +195,4 @@ def optimize_media(
- Keep original
- Update database with new versions
"""
return {
"message": "Media optimization coming in Slice 3"
}
return {"message": "Media optimization coming in Slice 3"}

View File

@@ -6,6 +6,7 @@ Vendor notification management endpoints.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -41,7 +42,7 @@ def get_notifications(
"notifications": [],
"total": 0,
"unread_count": 0,
"message": "Notifications coming in Slice 5"
"message": "Notifications coming in Slice 5",
}
@@ -58,10 +59,7 @@ def get_unread_count(
- Count unread notifications for vendor
- Used for notification badge
"""
return {
"unread_count": 0,
"message": "Unread count coming in Slice 5"
}
return {"unread_count": 0, "message": "Unread count coming in Slice 5"}
@router.put("/{notification_id}/read")
@@ -78,9 +76,7 @@ def mark_as_read(
- Mark single notification as read
- Update read timestamp
"""
return {
"message": "Mark as read coming in Slice 5"
}
return {"message": "Mark as read coming in Slice 5"}
@router.put("/mark-all-read")
@@ -96,9 +92,7 @@ def mark_all_as_read(
- Mark all vendor notifications as read
- Update timestamps
"""
return {
"message": "Mark all as read coming in Slice 5"
}
return {"message": "Mark all as read coming in Slice 5"}
@router.delete("/{notification_id}")
@@ -115,9 +109,7 @@ def delete_notification(
- Delete single notification
- Verify notification belongs to vendor
"""
return {
"message": "Notification deletion coming in Slice 5"
}
return {"message": "Notification deletion coming in Slice 5"}
@router.get("/settings")
@@ -138,7 +130,7 @@ def get_notification_settings(
"email_notifications": True,
"in_app_notifications": True,
"notification_types": {},
"message": "Notification settings coming in Slice 5"
"message": "Notification settings coming in Slice 5",
}
@@ -157,9 +149,7 @@ def update_notification_settings(
- Update in-app notification settings
- Enable/disable specific notification types
"""
return {
"message": "Notification settings update coming in Slice 5"
}
return {"message": "Notification settings update coming in Slice 5"}
@router.get("/templates")
@@ -176,10 +166,7 @@ def get_notification_templates(
- Include: order confirmation, shipping notification, etc.
- Return template details
"""
return {
"templates": [],
"message": "Notification templates coming in Slice 5"
}
return {"templates": [], "message": "Notification templates coming in Slice 5"}
@router.put("/templates/{template_id}")
@@ -199,9 +186,7 @@ def update_notification_template(
- Validate template variables
- Preview template
"""
return {
"message": "Template update coming in Slice 5"
}
return {"message": "Template update coming in Slice 5"}
@router.post("/test")
@@ -219,6 +204,4 @@ def send_test_notification(
- Use specified template
- Send to current user's email
"""
return {
"message": "Test notification coming in Slice 5"
}
return {"message": "Test notification coming in Slice 5"}

View File

@@ -6,21 +6,17 @@ Vendor order management endpoints.
import logging
from typing import Optional
from fastapi import APIRouter, Depends, Query, Request, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.order_service import order_service
from models.schema.order import (
OrderResponse,
OrderDetailResponse,
OrderListResponse,
OrderUpdate
)
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor, VendorUser
from models.schema.order import (OrderDetailResponse, OrderListResponse,
OrderResponse, OrderUpdate)
router = APIRouter(prefix="/orders")
logger = logging.getLogger(__name__)
@@ -51,14 +47,14 @@ def get_vendor_orders(
skip=skip,
limit=limit,
status=status,
customer_id=customer_id
customer_id=customer_id,
)
return OrderListResponse(
orders=[OrderResponse.model_validate(o) for o in orders],
total=total,
skip=skip,
limit=limit
limit=limit,
)
@@ -74,11 +70,7 @@ def get_order_details(
Requires Authorization header (API endpoint).
"""
order = order_service.get_order(
db=db,
vendor_id=vendor.id,
order_id=order_id
)
order = order_service.get_order(db=db, vendor_id=vendor.id, order_id=order_id)
return OrderDetailResponse.model_validate(order)
@@ -105,10 +97,7 @@ def update_order_status(
Requires Authorization header (API endpoint).
"""
order = order_service.update_order_status(
db=db,
vendor_id=vendor.id,
order_id=order_id,
order_update=order_update
db=db, vendor_id=vendor.id, order_id=order_id, order_update=order_update
)
logger.info(

View File

@@ -5,6 +5,7 @@ Vendor payment configuration and processing endpoints.
"""
import logging
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -38,7 +39,7 @@ def get_payment_configuration(
"accepted_methods": [],
"currency": "EUR",
"stripe_connected": False,
"message": "Payment configuration coming in Slice 5"
"message": "Payment configuration coming in Slice 5",
}
@@ -58,9 +59,7 @@ def update_payment_configuration(
- Update accepted payment methods
- Validate configuration before saving
"""
return {
"message": "Payment configuration update coming in Slice 5"
}
return {"message": "Payment configuration update coming in Slice 5"}
@router.post("/stripe/connect")
@@ -79,9 +78,7 @@ def connect_stripe_account(
- Verify Stripe account is active
- Enable payment processing
"""
return {
"message": "Stripe connection coming in Slice 5"
}
return {"message": "Stripe connection coming in Slice 5"}
@router.delete("/stripe/disconnect")
@@ -98,9 +95,7 @@ def disconnect_stripe_account(
- Disable payment processing
- Warn about pending payments
"""
return {
"message": "Stripe disconnection coming in Slice 5"
}
return {"message": "Stripe disconnection coming in Slice 5"}
@router.get("/methods")
@@ -116,10 +111,7 @@ def get_payment_methods(
- Return list of enabled payment methods
- Include: credit card, PayPal, bank transfer, etc.
"""
return {
"methods": [],
"message": "Payment methods coming in Slice 5"
}
return {"methods": [], "message": "Payment methods coming in Slice 5"}
@router.get("/transactions")
@@ -140,7 +132,7 @@ def get_payment_transactions(
return {
"transactions": [],
"total": 0,
"message": "Payment transactions coming in Slice 5"
"message": "Payment transactions coming in Slice 5",
}
@@ -164,7 +156,7 @@ def get_payment_balance(
"pending_balance": 0.0,
"currency": "EUR",
"next_payout_date": None,
"message": "Payment balance coming in Slice 5"
"message": "Payment balance coming in Slice 5",
}
@@ -185,6 +177,4 @@ def refund_payment(
- Update order status
- Send refund notification to customer
"""
return {
"message": "Payment refund coming in Slice 5"
}
return {"message": "Payment refund coming in Slice 5"}

View File

@@ -11,17 +11,13 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.product_service import product_service
from models.schema.product import (
ProductCreate,
ProductUpdate,
ProductResponse,
ProductDetailResponse,
ProductListResponse
)
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.product import (ProductCreate, ProductDetailResponse,
ProductListResponse, ProductResponse,
ProductUpdate)
router = APIRouter(prefix="/products")
logger = logging.getLogger(__name__)
@@ -50,14 +46,14 @@ def get_vendor_products(
skip=skip,
limit=limit,
is_active=is_active,
is_featured=is_featured
is_featured=is_featured,
)
return ProductListResponse(
products=[ProductResponse.model_validate(p) for p in products],
total=total,
skip=skip,
limit=limit
limit=limit,
)
@@ -70,9 +66,7 @@ def get_product_details(
):
"""Get detailed product information including inventory."""
product = product_service.get_product(
db=db,
vendor_id=vendor.id,
product_id=product_id
db=db, vendor_id=vendor.id, product_id=product_id
)
return ProductDetailResponse.model_validate(product)
@@ -91,9 +85,7 @@ def add_product_to_catalog(
This publishes a MarketplaceProduct to the vendor's public catalog.
"""
product = product_service.create_product(
db=db,
vendor_id=vendor.id,
product_data=product_data
db=db, vendor_id=vendor.id, product_data=product_data
)
logger.info(
@@ -114,10 +106,7 @@ def update_product(
):
"""Update product in vendor catalog."""
product = product_service.update_product(
db=db,
vendor_id=vendor.id,
product_id=product_id,
product_update=product_data
db=db, vendor_id=vendor.id, product_id=product_id, product_update=product_data
)
logger.info(
@@ -136,11 +125,7 @@ def remove_product_from_catalog(
db: Session = Depends(get_db),
):
"""Remove product from vendor catalog."""
product_service.delete_product(
db=db,
vendor_id=vendor.id,
product_id=product_id
)
product_service.delete_product(db=db, vendor_id=vendor.id, product_id=product_id)
logger.info(
f"Product {product_id} removed from catalog by user {current_user.username} "
@@ -163,14 +148,11 @@ def publish_from_marketplace(
Shortcut endpoint for publishing directly from marketplace import.
"""
product_data = ProductCreate(
marketplace_product_id=marketplace_product_id,
is_active=True
marketplace_product_id=marketplace_product_id, is_active=True
)
product = product_service.create_product(
db=db,
vendor_id=vendor.id,
product_data=product_data
db=db, vendor_id=vendor.id, product_data=product_data
)
logger.info(
@@ -198,10 +180,7 @@ def toggle_product_active(
status = "activated" if product.is_active else "deactivated"
logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}")
return {
"message": f"Product {status}",
"is_active": product.is_active
}
return {"message": f"Product {status}", "is_active": product.is_active}
@router.put("/{product_id}/toggle-featured")
@@ -221,7 +200,4 @@ def toggle_product_featured(
status = "featured" if product.is_featured else "unfeatured"
logger.info(f"Product {product_id} {status} for vendor {vendor.vendor_code}")
return {
"message": f"Product {status}",
"is_featured": product.is_featured
}
return {"message": f"Product {status}", "is_featured": product.is_featured}

View File

@@ -4,16 +4,17 @@ Vendor profile management endpoints.
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.vendor_service import vendor_service
from models.schema.vendor import VendorUpdate, VendorResponse
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.vendor import VendorResponse, VendorUpdate
router = APIRouter(prefix="/profile")
logger = logging.getLogger(__name__)

View File

@@ -4,13 +4,14 @@ Vendor settings and configuration endpoints.
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from middleware.vendor_context import require_vendor_context
from app.services.vendor_service import vendor_service
from middleware.vendor_context import require_vendor_context
from models.database.user import User
from models.database.vendor import Vendor

View File

@@ -12,35 +12,24 @@ Implements complete team management with:
import logging
from typing import List
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from app.api.deps import (get_current_vendor_api, get_user_permissions,
require_vendor_owner, require_vendor_permission)
from app.core.database import get_db
from app.core.permissions import VendorPermissions
from app.api.deps import (
get_current_vendor_api,
require_vendor_owner,
require_vendor_permission,
get_user_permissions
)
from app.services.vendor_team_service import vendor_team_service
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.team import (
TeamMemberInvite,
TeamMemberUpdate,
TeamMemberResponse,
TeamMemberListResponse,
InvitationAccept,
InvitationResponse,
InvitationAcceptResponse,
RoleResponse,
RoleListResponse,
UserPermissionsResponse,
TeamStatistics,
BulkRemoveRequest,
BulkRemoveResponse,
)
from models.schema.team import (BulkRemoveRequest, BulkRemoveResponse,
InvitationAccept, InvitationAcceptResponse,
InvitationResponse, RoleListResponse,
RoleResponse, TeamMemberInvite,
TeamMemberListResponse, TeamMemberResponse,
TeamMemberUpdate, TeamStatistics,
UserPermissionsResponse)
router = APIRouter(prefix="/team")
logger = logging.getLogger(__name__)
@@ -50,14 +39,15 @@ logger = logging.getLogger(__name__)
# Team Member Routes
# ============================================================================
@router.get("/members", response_model=TeamMemberListResponse)
def list_team_members(
request: Request,
include_inactive: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission(
VendorPermissions.TEAM_VIEW.value
))
current_user: User = Depends(
require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
),
):
"""
Get all team members for current vendor.
@@ -74,9 +64,7 @@ def list_team_members(
vendor = request.state.vendor
members = vendor_team_service.get_team_members(
db=db,
vendor=vendor,
include_inactive=include_inactive
db=db, vendor=vendor, include_inactive=include_inactive
)
# Calculate statistics
@@ -90,10 +78,7 @@ def list_team_members(
)
return TeamMemberListResponse(
members=members,
total=total,
active_count=active,
pending_invitations=pending
members=members, total=total, active_count=active, pending_invitations=pending
)
@@ -102,7 +87,7 @@ def invite_team_member(
invitation: TeamMemberInvite,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) # Owner only
current_user: User = Depends(require_vendor_owner), # Owner only
):
"""
Invite a new team member to the vendor.
@@ -135,7 +120,7 @@ def invite_team_member(
vendor=vendor,
inviter=current_user,
email=invitation.email,
role_id=invitation.role_id
role_id=invitation.role_id,
)
elif invitation.role_name:
# Use role name with optional custom permissions
@@ -145,7 +130,7 @@ def invite_team_member(
inviter=current_user,
email=invitation.email,
role_name=invitation.role_name,
custom_permissions=invitation.custom_permissions
custom_permissions=invitation.custom_permissions,
)
else:
# Default to Staff role
@@ -154,7 +139,7 @@ def invite_team_member(
vendor=vendor,
inviter=current_user,
email=invitation.email,
role_name="staff"
role_name="staff",
)
logger.info(
@@ -166,15 +151,12 @@ def invite_team_member(
message="Invitation sent successfully",
email=result["email"],
role=result["role"],
invitation_sent=True
invitation_sent=True,
)
@router.post("/accept-invitation", response_model=InvitationAcceptResponse)
def accept_invitation(
acceptance: InvitationAccept,
db: Session = Depends(get_db)
):
def accept_invitation(acceptance: InvitationAccept, db: Session = Depends(get_db)):
"""
Accept a team invitation and activate account.
@@ -196,7 +178,7 @@ def accept_invitation(
invitation_token=acceptance.invitation_token,
password=acceptance.password,
first_name=acceptance.first_name,
last_name=acceptance.last_name
last_name=acceptance.last_name,
)
logger.info(
@@ -210,15 +192,15 @@ def accept_invitation(
"id": result["vendor"].id,
"vendor_code": result["vendor"].vendor_code,
"name": result["vendor"].name,
"subdomain": result["vendor"].subdomain
"subdomain": result["vendor"].subdomain,
},
user={
"id": result["user"].id,
"email": result["user"].email,
"username": result["user"].username,
"full_name": result["user"].full_name
"full_name": result["user"].full_name,
},
role=result["role"]
role=result["role"],
)
@@ -227,9 +209,9 @@ def get_team_member(
user_id: int,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission(
VendorPermissions.TEAM_VIEW.value
))
current_user: User = Depends(
require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
),
):
"""
Get details of a specific team member.
@@ -239,14 +221,13 @@ def get_team_member(
vendor = request.state.vendor
members = vendor_team_service.get_team_members(
db=db,
vendor=vendor,
include_inactive=True
db=db, vendor=vendor, include_inactive=True
)
member = next((m for m in members if m["id"] == user_id), None)
if not member:
from app.exceptions import UserNotFoundException
raise UserNotFoundException(str(user_id))
return TeamMemberResponse(**member)
@@ -258,7 +239,7 @@ def update_team_member(
update_data: TeamMemberUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) # Owner only
current_user: User = Depends(require_vendor_owner), # Owner only
):
"""
Update a team member's role or status.
@@ -280,7 +261,7 @@ def update_team_member(
vendor=vendor,
user_id=user_id,
new_role_id=update_data.role_id,
is_active=update_data.is_active
is_active=update_data.is_active,
)
logger.info(
@@ -300,7 +281,7 @@ def remove_team_member(
user_id: int,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner) # Owner only
current_user: User = Depends(require_vendor_owner), # Owner only
):
"""
Remove a team member from the vendor.
@@ -316,21 +297,14 @@ def remove_team_member(
"""
vendor = request.state.vendor
vendor_team_service.remove_team_member(
db=db,
vendor=vendor,
user_id=user_id
)
vendor_team_service.remove_team_member(db=db, vendor=vendor, user_id=user_id)
logger.info(
f"Team member removed: {user_id} from {vendor.vendor_code} "
f"by {current_user.username}"
)
return {
"message": "Team member removed successfully",
"user_id": user_id
}
return {"message": "Team member removed successfully", "user_id": user_id}
@router.post("/members/bulk-remove", response_model=BulkRemoveResponse)
@@ -338,7 +312,7 @@ def bulk_remove_team_members(
bulk_remove: BulkRemoveRequest,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_owner)
current_user: User = Depends(require_vendor_owner),
):
"""
Remove multiple team members at once.
@@ -354,17 +328,12 @@ def bulk_remove_team_members(
for user_id in bulk_remove.user_ids:
try:
vendor_team_service.remove_team_member(
db=db,
vendor=vendor,
user_id=user_id
db=db, vendor=vendor, user_id=user_id
)
success_count += 1
except Exception as e:
failed_count += 1
errors.append({
"user_id": user_id,
"error": str(e)
})
errors.append({"user_id": user_id, "error": str(e)})
logger.info(
f"Bulk remove completed: {success_count} removed, {failed_count} failed "
@@ -372,9 +341,7 @@ def bulk_remove_team_members(
)
return BulkRemoveResponse(
success_count=success_count,
failed_count=failed_count,
errors=errors
success_count=success_count, failed_count=failed_count, errors=errors
)
@@ -382,13 +349,14 @@ def bulk_remove_team_members(
# Role Management Routes
# ============================================================================
@router.get("/roles", response_model=RoleListResponse)
def list_roles(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission(
VendorPermissions.TEAM_VIEW.value
))
current_user: User = Depends(
require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
),
):
"""
Get all available roles for the vendor.
@@ -403,21 +371,19 @@ def list_roles(
roles = vendor_team_service.get_vendor_roles(db=db, vendor_id=vendor.id)
return RoleListResponse(
roles=roles,
total=len(roles)
)
return RoleListResponse(roles=roles, total=len(roles))
# ============================================================================
# Permission Routes
# ============================================================================
@router.get("/me/permissions", response_model=UserPermissionsResponse)
def get_my_permissions(
request: Request,
permissions: List[str] = Depends(get_user_permissions),
current_user: User = Depends(get_current_vendor_api)
current_user: User = Depends(get_current_vendor_api),
):
"""
Get current user's permissions in this vendor.
@@ -443,7 +409,7 @@ def get_my_permissions(
permissions=permissions,
permission_count=len(permissions),
is_owner=is_owner,
role_name=role_name
role_name=role_name,
)
@@ -451,13 +417,14 @@ def get_my_permissions(
# Statistics Routes
# ============================================================================
@router.get("/statistics", response_model=TeamStatistics)
def get_team_statistics(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_vendor_permission(
VendorPermissions.TEAM_VIEW.value
))
current_user: User = Depends(
require_vendor_permission(VendorPermissions.TEAM_VIEW.value)
),
):
"""
Get team statistics for the vendor.
@@ -474,9 +441,7 @@ def get_team_statistics(
vendor = request.state.vendor
members = vendor_team_service.get_team_members(
db=db,
vendor=vendor,
include_inactive=True
db=db, vendor=vendor, include_inactive=True
)
# Calculate statistics
@@ -500,5 +465,5 @@ def get_team_statistics(
pending_invitations=pending,
owners=owners,
team_members=team_members,
roles_breakdown=roles_breakdown
roles_breakdown=roles_breakdown,
)

View File

@@ -14,6 +14,7 @@ This module focuses purely on configuration storage and validation.
"""
from typing import List, Optional
from pydantic_settings import BaseSettings
@@ -137,13 +138,9 @@ settings = Settings()
# ENVIRONMENT UTILITIES - Module-level functions
# =============================================================================
# Import environment detection utilities
from app.core.environment import (
get_environment,
is_development,
is_production,
is_staging,
should_use_secure_cookies
)
from app.core.environment import (get_environment, is_development,
is_production, is_staging,
should_use_secure_cookies)
def get_current_environment() -> str:
@@ -190,6 +187,7 @@ def is_staging_environment() -> bool:
# VALIDATION FUNCTIONS
# =============================================================================
def validate_production_settings() -> List[str]:
"""
Validate settings for production environment.
@@ -243,22 +241,19 @@ def print_environment_info():
# =============================================================================
__all__ = [
# Settings singleton
'settings',
"settings",
# Environment detection (re-exported from app.core.environment)
'get_environment',
'is_development',
'is_production',
'is_staging',
'should_use_secure_cookies',
"get_environment",
"is_development",
"is_production",
"is_staging",
"should_use_secure_cookies",
# Convenience functions
'get_current_environment',
'is_production_environment',
'is_development_environment',
'is_staging_environment',
"get_current_environment",
"is_production_environment",
"is_development_environment",
"is_staging_environment",
# Validation
'validate_production_settings',
'print_environment_info',
"validate_production_settings",
"print_environment_info",
]

View File

@@ -53,7 +53,10 @@ def get_environment() -> EnvironmentType:
# Check common development indicators
hostname = os.getenv("HOSTNAME", "").lower()
if any(dev_indicator in hostname for dev_indicator in ["local", "dev", "laptop", "desktop"]):
if any(
dev_indicator in hostname
for dev_indicator in ["local", "dev", "laptop", "desktop"]
):
return "development"
# Check for staging indicators

View File

@@ -15,11 +15,13 @@ from fastapi import FastAPI
from sqlalchemy import text
from middleware.auth import AuthManager
# Remove this import if not needed: from models.database.base import Base
from .database import SessionLocal, engine
from .logging import setup_logging
# Remove this import if not needed: from models.database.base import Base
logger = logging.getLogger(__name__)
auth_manager = AuthManager()
@@ -46,7 +48,9 @@ def check_database_ready():
try:
with engine.connect() as conn:
# Try to query a table that should exist
result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1"))
result = conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1")
)
tables = result.fetchall()
return len(tables) > 0
except Exception:
@@ -93,6 +97,7 @@ def verify_startup_requirements():
logger.info("[OK] Startup verification passed")
return True
# You can call this in your main.py if desired:
# if not verify_startup_requirements():
# raise RuntimeError("Application startup requirements not met")

View File

@@ -17,6 +17,7 @@ class VendorPermissions(str, Enum):
Naming convention: RESOURCE_ACTION
"""
# Dashboard
DASHBOARD_VIEW = "dashboard.view"
@@ -166,17 +167,23 @@ class PermissionChecker:
return required_permission in permissions
@staticmethod
def has_any_permission(permissions: List[str], required_permissions: List[str]) -> bool:
def has_any_permission(
permissions: List[str], required_permissions: List[str]
) -> bool:
"""Check if a permission list contains ANY of the required permissions."""
return any(perm in permissions for perm in required_permissions)
@staticmethod
def has_all_permissions(permissions: List[str], required_permissions: List[str]) -> bool:
def has_all_permissions(
permissions: List[str], required_permissions: List[str]
) -> bool:
"""Check if a permission list contains ALL of the required permissions."""
return all(perm in permissions for perm in required_permissions)
@staticmethod
def get_missing_permissions(permissions: List[str], required_permissions: List[str]) -> List[str]:
def get_missing_permissions(
permissions: List[str], required_permissions: List[str]
) -> List[str]:
"""Get list of missing permissions."""
return [perm for perm in required_permissions if perm not in permissions]

View File

@@ -16,19 +16,11 @@ THEME_PRESETS = {
"accent": "#ec4899", # Pink
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
"border": "#e5e7eb" # Gray-200
"border": "#e5e7eb", # Gray-200
},
"fonts": {
"heading": "Inter, sans-serif",
"body": "Inter, sans-serif"
"fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}
},
"modern": {
"colors": {
"primary": "#6366f1", # Indigo - Modern tech look
@@ -36,19 +28,11 @@ THEME_PRESETS = {
"accent": "#ec4899", # Pink
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
"border": "#e5e7eb" # Gray-200
"border": "#e5e7eb", # Gray-200
},
"fonts": {
"heading": "Inter, sans-serif",
"body": "Inter, sans-serif"
"fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}
},
"classic": {
"colors": {
"primary": "#1e40af", # Dark blue - Traditional
@@ -56,19 +40,11 @@ THEME_PRESETS = {
"accent": "#dc2626", # Red
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
"border": "#d1d5db" # Gray-300
"border": "#d1d5db", # Gray-300
},
"fonts": {
"heading": "Georgia, serif",
"body": "Arial, sans-serif"
"fonts": {"heading": "Georgia, serif", "body": "Arial, sans-serif"},
"layout": {"style": "list", "header": "static", "product_card": "classic"},
},
"layout": {
"style": "list",
"header": "static",
"product_card": "classic"
}
},
"minimal": {
"colors": {
"primary": "#000000", # Black - Ultra minimal
@@ -76,19 +52,11 @@ THEME_PRESETS = {
"accent": "#666666", # Medium gray
"background": "#ffffff", # White
"text": "#000000", # Black
"border": "#e5e7eb" # Light gray
"border": "#e5e7eb", # Light gray
},
"fonts": {
"heading": "Helvetica, sans-serif",
"body": "Helvetica, sans-serif"
"fonts": {"heading": "Helvetica, sans-serif", "body": "Helvetica, sans-serif"},
"layout": {"style": "grid", "header": "transparent", "product_card": "minimal"},
},
"layout": {
"style": "grid",
"header": "transparent",
"product_card": "minimal"
}
},
"vibrant": {
"colors": {
"primary": "#f59e0b", # Orange - Bold & energetic
@@ -96,19 +64,11 @@ THEME_PRESETS = {
"accent": "#8b5cf6", # Purple
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
"border": "#fbbf24" # Yellow
"border": "#fbbf24", # Yellow
},
"fonts": {
"heading": "Poppins, sans-serif",
"body": "Open Sans, sans-serif"
"fonts": {"heading": "Poppins, sans-serif", "body": "Open Sans, sans-serif"},
"layout": {"style": "masonry", "header": "fixed", "product_card": "modern"},
},
"layout": {
"style": "masonry",
"header": "fixed",
"product_card": "modern"
}
},
"elegant": {
"colors": {
"primary": "#6b7280", # Gray - Sophisticated
@@ -116,19 +76,11 @@ THEME_PRESETS = {
"accent": "#d97706", # Amber
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
"border": "#e5e7eb" # Gray-200
"border": "#e5e7eb", # Gray-200
},
"fonts": {
"heading": "Playfair Display, serif",
"body": "Lato, sans-serif"
"fonts": {"heading": "Playfair Display, serif", "body": "Lato, sans-serif"},
"layout": {"style": "grid", "header": "fixed", "product_card": "classic"},
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "classic"
}
},
"nature": {
"colors": {
"primary": "#059669", # Green - Natural & eco
@@ -136,18 +88,11 @@ THEME_PRESETS = {
"accent": "#f59e0b", # Amber
"background": "#ffffff", # White
"text": "#1f2937", # Gray-800
"border": "#d1fae5" # Light green
"border": "#d1fae5", # Light green
},
"fonts": {
"heading": "Montserrat, sans-serif",
"body": "Open Sans, sans-serif"
"fonts": {"heading": "Montserrat, sans-serif", "body": "Open Sans, sans-serif"},
"layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
}
}
}
@@ -243,7 +188,7 @@ def get_preset_preview(preset_name: str) -> dict:
"minimal": "Ultra-clean black and white aesthetic",
"vibrant": "Bold and energetic with bright accent colors",
"elegant": "Sophisticated gray tones with refined typography",
"nature": "Fresh and eco-friendly green color palette"
"nature": "Fresh and eco-friendly green color palette",
}
return {
@@ -259,10 +204,7 @@ def get_preset_preview(preset_name: str) -> dict:
def create_custom_preset(
colors: dict,
fonts: dict,
layout: dict,
name: str = "custom"
colors: dict, fonts: dict, layout: dict, name: str = "custom"
) -> dict:
"""
Create a custom preset from provided settings.
@@ -304,8 +246,4 @@ def create_custom_preset(
if "product_card" not in layout:
layout["product_card"] = "modern"
return {
"colors": colors,
"fonts": fonts,
"layout": layout
}
return {"colors": colors, "fonts": fonts, "layout": layout}

View File

@@ -6,179 +6,109 @@ This module provides frontend-friendly exceptions with consistent error codes,
messages, and HTTP status mappings.
"""
# Base exceptions
from .base import (
WizamartException,
ValidationException,
AuthenticationException,
AuthorizationException,
ResourceNotFoundException,
ConflictException,
BusinessLogicException,
ExternalServiceException,
RateLimitException,
ServiceUnavailableException,
)
# Authentication exceptions
from .auth import (
InvalidCredentialsException,
TokenExpiredException,
InvalidTokenException,
InsufficientPermissionsException,
UserNotActiveException,
AdminRequiredException,
UserAlreadyExistsException
)
# Admin exceptions
from .admin import (
UserNotFoundException,
UserStatusChangeException,
VendorVerificationException,
AdminOperationException,
CannotModifyAdminException,
CannotModifySelfException,
InvalidAdminActionException,
BulkOperationException,
ConfirmationRequiredException,
)
from .admin import (AdminOperationException, BulkOperationException,
CannotModifyAdminException, CannotModifySelfException,
ConfirmationRequiredException, InvalidAdminActionException,
UserNotFoundException, UserStatusChangeException,
VendorVerificationException)
# Authentication exceptions
from .auth import (AdminRequiredException, InsufficientPermissionsException,
InvalidCredentialsException, InvalidTokenException,
TokenExpiredException, UserAlreadyExistsException,
UserNotActiveException)
# Base exceptions
from .base import (AuthenticationException, AuthorizationException,
BusinessLogicException, ConflictException,
ExternalServiceException, RateLimitException,
ResourceNotFoundException, ServiceUnavailableException,
ValidationException, WizamartException)
# Cart exceptions
from .cart import (CartItemNotFoundException, CartValidationException,
EmptyCartException, InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException)
# Customer exceptions
from .customer import (CustomerAlreadyExistsException,
CustomerAuthorizationException,
CustomerNotActiveException, CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException)
# Inventory exceptions
from .inventory import (InsufficientInventoryException,
InvalidInventoryOperationException,
InvalidQuantityException, InventoryNotFoundException,
InventoryValidationException,
LocationNotFoundException, NegativeInventoryException)
# Marketplace import job exceptions
from .marketplace_import_job import (
MarketplaceImportException,
ImportJobNotFoundException,
ImportJobNotOwnedException,
InvalidImportDataException,
from .marketplace_import_job import (ImportJobAlreadyProcessingException,
ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException,
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportRateLimitException,
InvalidImportDataException,
InvalidMarketplaceException,
MarketplaceConnectionException,
MarketplaceDataParsingException,
ImportRateLimitException,
InvalidMarketplaceException,
ImportJobAlreadyProcessingException,
)
MarketplaceImportException)
# Marketplace product exceptions
from .marketplace_product import (
MarketplaceProductNotFoundException,
MarketplaceProductAlreadyExistsException,
from .marketplace_product import (InvalidGTINException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
InvalidGTINException,
MarketplaceProductAlreadyExistsException,
MarketplaceProductCSVImportException,
)
# Inventory exceptions
from .inventory import (
InventoryNotFoundException,
InsufficientInventoryException,
InvalidInventoryOperationException,
InventoryValidationException,
NegativeInventoryException,
InvalidQuantityException,
LocationNotFoundException
)
MarketplaceProductNotFoundException,
MarketplaceProductValidationException)
# Order exceptions
from .order import (InvalidOrderStatusException, OrderAlreadyExistsException,
OrderCannotBeCancelledException, OrderNotFoundException,
OrderValidationException)
# Product exceptions
from .product import (CannotDeleteProductWithInventoryException,
CannotDeleteProductWithOrdersException,
InvalidProductDataException,
ProductAlreadyExistsException, ProductNotActiveException,
ProductNotFoundException, ProductNotInCatalogException,
ProductValidationException)
# Team exceptions
from .team import (CannotModifyOwnRoleException, CannotRemoveOwnerException,
InsufficientTeamPermissionsException,
InvalidInvitationDataException,
InvalidInvitationTokenException, InvalidRoleException,
MaxTeamMembersReachedException, RoleNotFoundException,
TeamInvitationAlreadyAcceptedException,
TeamInvitationExpiredException,
TeamInvitationNotFoundException,
TeamMemberAlreadyExistsException,
TeamMemberNotFoundException, TeamValidationException,
UnauthorizedTeamActionException)
# Vendor exceptions
from .vendor import (
VendorNotFoundException,
VendorAlreadyExistsException,
VendorNotActiveException,
VendorNotVerifiedException,
from .vendor import (InvalidVendorDataException, MaxVendorsReachedException,
UnauthorizedVendorAccessException,
InvalidVendorDataException,
MaxVendorsReachedException,
VendorValidationException,
)
VendorAlreadyExistsException, VendorNotActiveException,
VendorNotFoundException, VendorNotVerifiedException,
VendorValidationException)
# Vendor domain exceptions
from .vendor_domain import (
VendorDomainNotFoundException,
VendorDomainAlreadyExistsException,
InvalidDomainFormatException,
ReservedDomainException,
from .vendor_domain import (DNSVerificationException,
DomainAlreadyVerifiedException,
DomainNotVerifiedException,
DomainVerificationFailedException,
DomainAlreadyVerifiedException,
MultiplePrimaryDomainsException,
DNSVerificationException,
InvalidDomainFormatException,
MaxDomainsReachedException,
MultiplePrimaryDomainsException,
ReservedDomainException,
UnauthorizedDomainAccessException,
)
VendorDomainAlreadyExistsException,
VendorDomainNotFoundException)
# Vendor theme exceptions
from .vendor_theme import (
VendorThemeNotFoundException,
InvalidThemeDataException,
from .vendor_theme import (InvalidColorFormatException,
InvalidFontFamilyException,
InvalidThemeDataException, ThemeOperationException,
ThemePresetAlreadyAppliedException,
ThemePresetNotFoundException,
ThemeValidationException,
ThemePresetAlreadyAppliedException,
InvalidColorFormatException,
InvalidFontFamilyException,
ThemeOperationException,
)
# Customer exceptions
from .customer import (
CustomerNotFoundException,
CustomerAlreadyExistsException,
DuplicateCustomerEmailException,
CustomerNotActiveException,
InvalidCustomerCredentialsException,
CustomerValidationException,
CustomerAuthorizationException,
)
# Team exceptions
from .team import (
TeamMemberNotFoundException,
TeamMemberAlreadyExistsException,
TeamInvitationNotFoundException,
TeamInvitationExpiredException,
TeamInvitationAlreadyAcceptedException,
UnauthorizedTeamActionException,
CannotRemoveOwnerException,
CannotModifyOwnRoleException,
RoleNotFoundException,
InvalidRoleException,
InsufficientTeamPermissionsException,
MaxTeamMembersReachedException,
TeamValidationException,
InvalidInvitationDataException,
InvalidInvitationTokenException,
)
# Product exceptions
from .product import (
ProductNotFoundException,
ProductAlreadyExistsException,
ProductNotInCatalogException,
ProductNotActiveException,
InvalidProductDataException,
ProductValidationException,
CannotDeleteProductWithInventoryException,
CannotDeleteProductWithOrdersException,
)
# Order exceptions
from .order import (
OrderNotFoundException,
OrderAlreadyExistsException,
OrderValidationException,
InvalidOrderStatusException,
OrderCannotBeCancelledException,
)
# Cart exceptions
from .cart import (
CartItemNotFoundException,
EmptyCartException,
CartValidationException,
InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException,
)
VendorThemeNotFoundException)
__all__ = [
# Base exceptions
@@ -192,7 +122,6 @@ __all__ = [
"ExternalServiceException",
"RateLimitException",
"ServiceUnavailableException",
# Auth exceptions
"InvalidCredentialsException",
"TokenExpiredException",
@@ -201,7 +130,6 @@ __all__ = [
"UserNotActiveException",
"AdminRequiredException",
"UserAlreadyExistsException",
# Customer exceptions
"CustomerNotFoundException",
"CustomerAlreadyExistsException",
@@ -210,7 +138,6 @@ __all__ = [
"InvalidCustomerCredentialsException",
"CustomerValidationException",
"CustomerAuthorizationException",
# Team exceptions
"TeamMemberNotFoundException",
"TeamMemberAlreadyExistsException",
@@ -227,7 +154,6 @@ __all__ = [
"TeamValidationException",
"InvalidInvitationDataException",
"InvalidInvitationTokenException",
# Inventory exceptions
"InventoryNotFoundException",
"InsufficientInventoryException",
@@ -236,7 +162,6 @@ __all__ = [
"NegativeInventoryException",
"InvalidQuantityException",
"LocationNotFoundException",
# Vendor exceptions
"VendorNotFoundException",
"VendorAlreadyExistsException",
@@ -246,7 +171,6 @@ __all__ = [
"InvalidVendorDataException",
"MaxVendorsReachedException",
"VendorValidationException",
# Vendor Domain
"VendorDomainNotFoundException",
"VendorDomainAlreadyExistsException",
@@ -259,7 +183,6 @@ __all__ = [
"DNSVerificationException",
"MaxDomainsReachedException",
"UnauthorizedDomainAccessException",
# Vendor Theme
"VendorThemeNotFoundException",
"InvalidThemeDataException",
@@ -269,7 +192,6 @@ __all__ = [
"InvalidColorFormatException",
"InvalidFontFamilyException",
"ThemeOperationException",
# Product exceptions
"ProductNotFoundException",
"ProductAlreadyExistsException",
@@ -279,14 +201,12 @@ __all__ = [
"ProductValidationException",
"CannotDeleteProductWithInventoryException",
"CannotDeleteProductWithOrdersException",
# Order exceptions
"OrderNotFoundException",
"OrderAlreadyExistsException",
"OrderValidationException",
"InvalidOrderStatusException",
"OrderCannotBeCancelledException",
# Cart exceptions
"CartItemNotFoundException",
"EmptyCartException",
@@ -294,7 +214,6 @@ __all__ = [
"InsufficientInventoryForCartException",
"InvalidCartQuantityException",
"ProductNotAvailableForCartException",
# MarketplaceProduct exceptions
"MarketplaceProductNotFoundException",
"MarketplaceProductAlreadyExistsException",
@@ -302,7 +221,6 @@ __all__ = [
"MarketplaceProductValidationException",
"InvalidGTINException",
"MarketplaceProductCSVImportException",
# Marketplace import exceptions
"MarketplaceImportException",
"ImportJobNotFoundException",
@@ -315,7 +233,6 @@ __all__ = [
"ImportRateLimitException",
"InvalidMarketplaceException",
"ImportJobAlreadyProcessingException",
# Admin exceptions
"UserNotFoundException",
"UserStatusChangeException",

View File

@@ -4,12 +4,9 @@ Admin operations specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
BusinessLogicException,
AuthorizationException,
ValidationException
)
from .base import (AuthorizationException, BusinessLogicException,
ResourceNotFoundException, ValidationException)
class UserNotFoundException(ResourceNotFoundException):
@@ -198,7 +195,7 @@ class ConfirmationRequiredException(BusinessLogicException):
self,
operation: str,
message: Optional[str] = None,
confirmation_param: str = "confirm"
confirmation_param: str = "confirm",
):
if not message:
message = f"Operation '{operation}' requires confirmation parameter: {confirmation_param}=true"

View File

@@ -4,7 +4,9 @@ Authentication and authorization specific exceptions.
"""
from typing import Optional
from .base import AuthenticationException, AuthorizationException, ConflictException
from .base import (AuthenticationException, AuthorizationException,
ConflictException)
class InvalidCredentialsException(AuthenticationException):

View File

@@ -39,8 +39,6 @@ class WizamartException(Exception):
return result
class ValidationException(WizamartException):
"""Raised when request validation fails."""
@@ -62,8 +60,6 @@ class ValidationException(WizamartException):
)
class AuthenticationException(WizamartException):
"""Raised when authentication fails."""
@@ -97,6 +93,7 @@ class AuthorizationException(WizamartException):
details=details,
)
class ResourceNotFoundException(WizamartException):
"""Raised when a requested resource is not found."""
@@ -122,6 +119,7 @@ class ResourceNotFoundException(WizamartException):
},
)
class ConflictException(WizamartException):
"""Raised when a resource conflict occurs."""
@@ -138,6 +136,7 @@ class ConflictException(WizamartException):
details=details,
)
class BusinessLogicException(WizamartException):
"""Raised when business logic rules are violated."""
@@ -196,6 +195,7 @@ class RateLimitException(WizamartException):
details=rate_limit_details,
)
class ServiceUnavailableException(WizamartException):
"""Raised when service is unavailable."""
@@ -206,6 +206,7 @@ class ServiceUnavailableException(WizamartException):
status_code=503,
)
# Note: Domain-specific exceptions like VendorNotFoundException, UserNotFoundException, etc.
# are defined in their respective domain modules (vendor.py, admin.py, etc.)
# to keep domain-specific logic separate from base exceptions.

View File

@@ -4,11 +4,9 @@ Shopping cart specific exceptions.
"""
from typing import Optional
from .base import (
ResourceNotFoundException,
ValidationException,
BusinessLogicException
)
from .base import (BusinessLogicException, ResourceNotFoundException,
ValidationException)
class CartItemNotFoundException(ResourceNotFoundException):
@@ -19,22 +17,16 @@ class CartItemNotFoundException(ResourceNotFoundException):
resource_type="CartItem",
identifier=str(product_id),
message=f"Product {product_id} not found in cart",
error_code="CART_ITEM_NOT_FOUND"
error_code="CART_ITEM_NOT_FOUND",
)
self.details.update({
"product_id": product_id,
"session_id": session_id
})
self.details.update({"product_id": product_id, "session_id": session_id})
class EmptyCartException(ValidationException):
"""Raised when trying to perform operations on an empty cart."""
def __init__(self, session_id: str):
super().__init__(
message="Cart is empty",
details={"session_id": session_id}
)
super().__init__(message="Cart is empty", details={"session_id": session_id})
self.error_code = "CART_EMPTY"
@@ -82,7 +74,9 @@ class InsufficientInventoryForCartException(BusinessLogicException):
class InvalidCartQuantityException(ValidationException):
"""Raised when cart quantity is invalid."""
def __init__(self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None):
def __init__(
self, quantity: int, min_quantity: int = 1, max_quantity: Optional[int] = None
):
if quantity < min_quantity:
message = f"Quantity must be at least {min_quantity}"
elif max_quantity and quantity > max_quantity:

View File

@@ -4,13 +4,10 @@ Customer management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
ConflictException,
ValidationException,
AuthenticationException,
BusinessLogicException
)
from .base import (AuthenticationException, BusinessLogicException,
ConflictException, ResourceNotFoundException,
ValidationException)
class CustomerNotFoundException(ResourceNotFoundException):
@@ -21,7 +18,7 @@ class CustomerNotFoundException(ResourceNotFoundException):
resource_type="Customer",
identifier=customer_identifier,
message=f"Customer '{customer_identifier}' not found",
error_code="CUSTOMER_NOT_FOUND"
error_code="CUSTOMER_NOT_FOUND",
)
@@ -32,7 +29,7 @@ class CustomerAlreadyExistsException(ConflictException):
super().__init__(
message=f"Customer with email '{email}' already exists",
error_code="CUSTOMER_ALREADY_EXISTS",
details={"email": email}
details={"email": email},
)
@@ -43,10 +40,7 @@ class DuplicateCustomerEmailException(ConflictException):
super().__init__(
message=f"Email '{email}' is already registered for this vendor",
error_code="DUPLICATE_CUSTOMER_EMAIL",
details={
"email": email,
"vendor_code": vendor_code
}
details={"email": email, "vendor_code": vendor_code},
)
@@ -57,7 +51,7 @@ class CustomerNotActiveException(BusinessLogicException):
super().__init__(
message=f"Customer account '{email}' is not active",
error_code="CUSTOMER_NOT_ACTIVE",
details={"email": email}
details={"email": email},
)
@@ -67,7 +61,7 @@ class InvalidCustomerCredentialsException(AuthenticationException):
def __init__(self):
super().__init__(
message="Invalid email or password",
error_code="INVALID_CUSTOMER_CREDENTIALS"
error_code="INVALID_CUSTOMER_CREDENTIALS",
)
@@ -78,13 +72,9 @@ class CustomerValidationException(ValidationException):
self,
message: str = "Customer validation failed",
field: Optional[str] = None,
details: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None,
):
super().__init__(
message=message,
field=field,
details=details
)
super().__init__(message=message, field=field, details=details)
self.error_code = "CUSTOMER_VALIDATION_FAILED"
@@ -95,8 +85,5 @@ class CustomerAuthorizationException(BusinessLogicException):
super().__init__(
message=f"Customer '{customer_email}' not authorized for: {operation}",
error_code="CUSTOMER_NOT_AUTHORIZED",
details={
"customer_email": customer_email,
"operation": operation
}
details={"customer_email": customer_email, "operation": operation},
)

View File

@@ -7,7 +7,7 @@ Handles fallback logic and context-specific customization.
"""
import logging
from pathlib import Path
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
from fastapi import Request
from fastapi.responses import HTMLResponse
@@ -114,7 +114,7 @@ class ErrorPageRenderer:
"error_code": error_code,
"context": context_type.value,
"template": template_path,
}
},
)
try:
@@ -129,8 +129,7 @@ class ErrorPageRenderer:
)
except Exception as e:
logger.error(
f"Failed to render error template {template_path}: {e}",
exc_info=True
f"Failed to render error template {template_path}: {e}", exc_info=True
)
# Return basic HTML as absolute fallback
return ErrorPageRenderer._render_basic_html_fallback(
@@ -228,7 +227,9 @@ class ErrorPageRenderer:
}
@staticmethod
def _get_context_data(request: Request, context_type: RequestContext) -> Dict[str, Any]:
def _get_context_data(
request: Request, context_type: RequestContext
) -> Dict[str, Any]:
"""Get context-specific data for error templates."""
data = {}
@@ -261,11 +262,19 @@ class ErrorPageRenderer:
# Calculate base_url for shop links
vendor_context = getattr(request.state, "vendor_context", None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
base_url = "/"
if access_method == "path" and vendor:
# Use the full_prefix from vendor_context to determine which pattern was used
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/"
data["base_url"] = base_url

View File

@@ -13,13 +13,14 @@ This module provides classes and functions for:
import logging
from typing import Union
from fastapi import Request, HTTPException
from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, RedirectResponse
from middleware.context import RequestContext, get_request_context
from .base import WizamartException
from .error_renderer import ErrorPageRenderer
from middleware.context import RequestContext, get_request_context
logger = logging.getLogger(__name__)
@@ -38,8 +39,8 @@ def setup_exception_handlers(app):
extra={
"path": request.url.path,
"accept": request.headers.get("accept", ""),
"method": request.method
}
"method": request.method,
},
)
# Redirect to appropriate login page based on context
@@ -56,15 +57,12 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": type(exc).__name__,
}
},
)
# Check if this is an API request
if _is_api_request(request):
return JSONResponse(
status_code=exc.status_code,
content=exc.to_dict()
)
return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
# Check if this is an HTML page request
if _is_html_page_request(request):
@@ -78,10 +76,7 @@ def setup_exception_handlers(app):
)
# Default to JSON for unknown request types
return JSONResponse(
status_code=exc.status_code,
content=exc.to_dict()
)
return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
@@ -96,7 +91,7 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": "HTTPException",
}
},
)
# Check if this is an API request
@@ -107,7 +102,7 @@ def setup_exception_handlers(app):
"error_code": f"HTTP_{exc.status_code}",
"message": exc.detail,
"status_code": exc.status_code,
}
},
)
# Check if this is an HTML page request
@@ -128,11 +123,13 @@ def setup_exception_handlers(app):
"error_code": f"HTTP_{exc.status_code}",
"message": exc.detail,
"status_code": exc.status_code,
}
},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
"""Handle Pydantic validation errors with consistent format."""
# Sanitize errors to remove sensitive data from logs
@@ -140,8 +137,8 @@ def setup_exception_handlers(app):
for error in exc.errors():
sanitized_error = error.copy()
# Remove 'input' field which may contain passwords
if 'input' in sanitized_error:
sanitized_error['input'] = '<redacted>'
if "input" in sanitized_error:
sanitized_error["input"] = "<redacted>"
sanitized_errors.append(sanitized_error)
logger.error(
@@ -151,7 +148,7 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": "RequestValidationError",
}
},
)
# Clean up validation errors to ensure JSON serializability
@@ -159,15 +156,17 @@ def setup_exception_handlers(app):
for error in exc.errors():
clean_error = {}
for key, value in error.items():
if key == 'input' and isinstance(value, bytes):
if key == "input" and isinstance(value, bytes):
# Convert bytes to string representation for JSON serialization
clean_error[key] = f"<bytes: {len(value)} bytes>"
elif key == 'ctx' and isinstance(value, dict):
elif key == "ctx" and isinstance(value, dict):
# Handle the 'ctx' field that contains ValueError objects
clean_ctx = {}
for ctx_key, ctx_value in value.items():
if isinstance(ctx_value, Exception):
clean_ctx[ctx_key] = str(ctx_value) # Convert exception to string
clean_ctx[ctx_key] = str(
ctx_value
) # Convert exception to string
else:
clean_ctx[ctx_key] = ctx_value
clean_error[key] = clean_ctx
@@ -186,10 +185,8 @@ def setup_exception_handlers(app):
"error_code": "VALIDATION_ERROR",
"message": "Request validation failed",
"status_code": 422,
"details": {
"validation_errors": clean_errors
}
}
"details": {"validation_errors": clean_errors},
},
)
# Check if this is an HTML page request
@@ -210,10 +207,8 @@ def setup_exception_handlers(app):
"error_code": "VALIDATION_ERROR",
"message": "Request validation failed",
"status_code": 422,
"details": {
"validation_errors": clean_errors
}
}
"details": {"validation_errors": clean_errors},
},
)
@app.exception_handler(Exception)
@@ -227,7 +222,7 @@ def setup_exception_handlers(app):
"url": str(request.url),
"method": request.method,
"exception_type": type(exc).__name__,
}
},
)
# Check if this is an API request
@@ -238,7 +233,7 @@ def setup_exception_handlers(app):
"error_code": "INTERNAL_SERVER_ERROR",
"message": "Internal server error",
"status_code": 500,
}
},
)
# Check if this is an HTML page request
@@ -259,7 +254,7 @@ def setup_exception_handlers(app):
"error_code": "INTERNAL_SERVER_ERROR",
"message": "Internal server error",
"status_code": 500,
}
},
)
@app.exception_handler(404)
@@ -275,11 +270,8 @@ def setup_exception_handlers(app):
"error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}",
"status_code": 404,
"details": {
"path": request.url.path,
"method": request.method
}
}
"details": {"path": request.url.path, "method": request.method},
},
)
# Check if this is an HTML page request
@@ -300,11 +292,8 @@ def setup_exception_handlers(app):
"error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}",
"status_code": 404,
"details": {
"path": request.url.path,
"method": request.method
}
}
"details": {"path": request.url.path, "method": request.method},
},
)
@@ -332,8 +321,8 @@ def _is_html_page_request(request: Request) -> bool:
extra={
"path": request.url.path,
"method": request.method,
"accept": request.headers.get("accept", "")
}
"accept": request.headers.get("accept", ""),
},
)
# Don't redirect API calls
@@ -354,7 +343,9 @@ def _is_html_page_request(request: Request) -> bool:
# MUST explicitly accept HTML (strict check)
accept_header = request.headers.get("accept", "")
if "text/html" not in accept_header:
logger.debug(f"Not HTML page: Accept header doesn't include text/html: {accept_header}")
logger.debug(
f"Not HTML page: Accept header doesn't include text/html: {accept_header}"
)
return False
logger.debug("IS HTML page request")
@@ -379,13 +370,21 @@ def _redirect_to_login(request: Request) -> RedirectResponse:
elif context_type == RequestContext.SHOP:
# For shop context, redirect to shop login (customer login)
# Calculate base_url for proper routing (supports domain, subdomain, and path-based access)
vendor = getattr(request.state, 'vendor', None)
vendor_context = getattr(request.state, 'vendor_context', None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
vendor = getattr(request.state, "vendor", None)
vendor_context = getattr(request.state, "vendor_context", None)
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
base_url = "/"
if access_method == "path" and vendor:
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/"
login_url = f"{base_url}shop/account/login"
@@ -401,22 +400,28 @@ def _redirect_to_login(request: Request) -> RedirectResponse:
def raise_not_found(resource_type: str, identifier: str) -> None:
"""Convenience function to raise ResourceNotFoundException."""
from .base import ResourceNotFoundException
raise ResourceNotFoundException(resource_type, identifier)
def raise_validation_error(message: str, field: str = None, details: dict = None) -> None:
def raise_validation_error(
message: str, field: str = None, details: dict = None
) -> None:
"""Convenience function to raise ValidationException."""
from .base import ValidationException
raise ValidationException(message, field, details)
def raise_auth_error(message: str = "Authentication failed") -> None:
"""Convenience function to raise AuthenticationException."""
from .base import AuthenticationException
raise AuthenticationException(message)
def raise_permission_error(message: str = "Access denied") -> None:
"""Convenience function to raise AuthorizationException."""
from .base import AuthorizationException
raise AuthorizationException(message)

View File

@@ -4,7 +4,9 @@ Inventory management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import ResourceNotFoundException, ValidationException, BusinessLogicException
from .base import (BusinessLogicException, ResourceNotFoundException,
ValidationException)
class InventoryNotFoundException(ResourceNotFoundException):
@@ -14,7 +16,9 @@ class InventoryNotFoundException(ResourceNotFoundException):
if identifier_type.lower() == "gtin":
message = f"No inventory found for GTIN '{identifier}'"
else:
message = f"Inventory record with {identifier_type} '{identifier}' not found"
message = (
f"Inventory record with {identifier_type} '{identifier}' not found"
)
super().__init__(
resource_type="Inventory",

View File

@@ -4,13 +4,10 @@ Marketplace import specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
ValidationException,
BusinessLogicException,
AuthorizationException,
ExternalServiceException
)
from .base import (AuthorizationException, BusinessLogicException,
ExternalServiceException, ResourceNotFoundException,
ValidationException)
class MarketplaceImportException(BusinessLogicException):
@@ -118,7 +115,9 @@ class ImportJobCannotBeDeletedException(BusinessLogicException):
class MarketplaceConnectionException(ExternalServiceException):
"""Raised when marketplace connection fails."""
def __init__(self, marketplace: str, message: str = "Failed to connect to marketplace"):
def __init__(
self, marketplace: str, message: str = "Failed to connect to marketplace"
):
super().__init__(
service=marketplace,
message=f"{message}: {marketplace}",

View File

@@ -4,7 +4,9 @@ MarketplaceProduct management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import ResourceNotFoundException, ConflictException, ValidationException, BusinessLogicException
from .base import (BusinessLogicException, ConflictException,
ResourceNotFoundException, ValidationException)
class MarketplaceProductNotFoundException(ResourceNotFoundException):

View File

@@ -4,11 +4,9 @@ Order management specific exceptions.
"""
from typing import Optional
from .base import (
ResourceNotFoundException,
ValidationException,
BusinessLogicException
)
from .base import (BusinessLogicException, ResourceNotFoundException,
ValidationException)
class OrderNotFoundException(ResourceNotFoundException):
@@ -19,7 +17,7 @@ class OrderNotFoundException(ResourceNotFoundException):
resource_type="Order",
identifier=order_identifier,
message=f"Order '{order_identifier}' not found",
error_code="ORDER_NOT_FOUND"
error_code="ORDER_NOT_FOUND",
)
@@ -30,7 +28,7 @@ class OrderAlreadyExistsException(ValidationException):
super().__init__(
message=f"Order with number '{order_number}' already exists",
error_code="ORDER_ALREADY_EXISTS",
details={"order_number": order_number}
details={"order_number": order_number},
)
@@ -39,9 +37,7 @@ class OrderValidationException(ValidationException):
def __init__(self, message: str, details: Optional[dict] = None):
super().__init__(
message=message,
error_code="ORDER_VALIDATION_FAILED",
details=details
message=message, error_code="ORDER_VALIDATION_FAILED", details=details
)
@@ -52,10 +48,7 @@ class InvalidOrderStatusException(BusinessLogicException):
super().__init__(
message=f"Cannot change order status from '{current_status}' to '{new_status}'",
error_code="INVALID_ORDER_STATUS_CHANGE",
details={
"current_status": current_status,
"new_status": new_status
}
details={"current_status": current_status, "new_status": new_status},
)
@@ -66,8 +59,5 @@ class OrderCannotBeCancelledException(BusinessLogicException):
super().__init__(
message=f"Order '{order_number}' cannot be cancelled: {reason}",
error_code="ORDER_CANNOT_BE_CANCELLED",
details={
"order_number": order_number,
"reason": reason
}
details={"order_number": order_number, "reason": reason},
)

View File

@@ -4,12 +4,9 @@ Product (vendor catalog) specific exceptions.
"""
from typing import Optional
from .base import (
ResourceNotFoundException,
ConflictException,
ValidationException,
BusinessLogicException
)
from .base import (BusinessLogicException, ConflictException,
ResourceNotFoundException, ValidationException)
class ProductNotFoundException(ResourceNotFoundException):

View File

@@ -4,13 +4,10 @@ Team management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
ConflictException,
ValidationException,
AuthorizationException,
BusinessLogicException
)
from .base import (AuthorizationException, BusinessLogicException,
ConflictException, ResourceNotFoundException,
ValidationException)
class TeamMemberNotFoundException(ResourceNotFoundException):
@@ -20,7 +17,9 @@ class TeamMemberNotFoundException(ResourceNotFoundException):
details = {"user_id": user_id}
if vendor_id:
details["vendor_id"] = vendor_id
message = f"Team member with user ID '{user_id}' not found in vendor {vendor_id}"
message = (
f"Team member with user ID '{user_id}' not found in vendor {vendor_id}"
)
else:
message = f"Team member with user ID '{user_id}' not found"
@@ -84,7 +83,12 @@ class TeamInvitationAlreadyAcceptedException(ConflictException):
class UnauthorizedTeamActionException(AuthorizationException):
"""Raised when user tries to perform team action without permission."""
def __init__(self, action: str, user_id: Optional[int] = None, required_permission: Optional[str] = None):
def __init__(
self,
action: str,
user_id: Optional[int] = None,
required_permission: Optional[str] = None,
):
details = {"action": action}
if user_id:
details["user_id"] = user_id
@@ -240,6 +244,7 @@ class InvalidInvitationDataException(ValidationException):
# NEW: Add InvalidInvitationTokenException
# ============================================================================
class InvalidInvitationTokenException(ValidationException):
"""Raised when invitation token is invalid, expired, or already used.
@@ -250,7 +255,7 @@ class InvalidInvitationTokenException(ValidationException):
def __init__(
self,
message: str = "Invalid or expired invitation token",
invitation_token: Optional[str] = None
invitation_token: Optional[str] = None,
):
details = {}
if invitation_token:

View File

@@ -4,13 +4,10 @@ Vendor management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
ConflictException,
ValidationException,
AuthorizationException,
BusinessLogicException
)
from .base import (AuthorizationException, BusinessLogicException,
ConflictException, ResourceNotFoundException,
ValidationException)
class VendorNotFoundException(ResourceNotFoundException):

View File

@@ -4,13 +4,10 @@ Vendor domain management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
ConflictException,
ValidationException,
BusinessLogicException,
ExternalServiceException
)
from .base import (BusinessLogicException, ConflictException,
ExternalServiceException, ResourceNotFoundException,
ValidationException)
class VendorDomainNotFoundException(ResourceNotFoundException):
@@ -64,10 +61,7 @@ class ReservedDomainException(ValidationException):
super().__init__(
message=f"Domain cannot use reserved subdomain: {reserved_part}",
field="domain",
details={
"domain": domain,
"reserved_part": reserved_part
},
details={"domain": domain, "reserved_part": reserved_part},
)
self.error_code = "RESERVED_DOMAIN"
@@ -79,10 +73,7 @@ class DomainNotVerifiedException(BusinessLogicException):
super().__init__(
message=f"Domain '{domain}' must be verified before activation",
error_code="DOMAIN_NOT_VERIFIED",
details={
"domain_id": domain_id,
"domain": domain
},
details={"domain_id": domain_id, "domain": domain},
)
@@ -93,10 +84,7 @@ class DomainVerificationFailedException(BusinessLogicException):
super().__init__(
message=f"Domain verification failed for '{domain}': {reason}",
error_code="DOMAIN_VERIFICATION_FAILED",
details={
"domain": domain,
"reason": reason
},
details={"domain": domain, "reason": reason},
)
@@ -107,10 +95,7 @@ class DomainAlreadyVerifiedException(BusinessLogicException):
super().__init__(
message=f"Domain '{domain}' is already verified",
error_code="DOMAIN_ALREADY_VERIFIED",
details={
"domain_id": domain_id,
"domain": domain
},
details={"domain_id": domain_id, "domain": domain},
)
@@ -133,10 +118,7 @@ class DNSVerificationException(ExternalServiceException):
service_name="DNS",
message=f"DNS verification failed for '{domain}': {reason}",
error_code="DNS_VERIFICATION_ERROR",
details={
"domain": domain,
"reason": reason
},
details={"domain": domain, "reason": reason},
)
@@ -147,10 +129,7 @@ class MaxDomainsReachedException(BusinessLogicException):
super().__init__(
message=f"Maximum number of domains reached ({max_domains})",
error_code="MAX_DOMAINS_REACHED",
details={
"vendor_id": vendor_id,
"max_domains": max_domains
},
details={"vendor_id": vendor_id, "max_domains": max_domains},
)
@@ -161,8 +140,5 @@ class UnauthorizedDomainAccessException(BusinessLogicException):
super().__init__(
message=f"Unauthorized access to domain {domain_id}",
error_code="UNAUTHORIZED_DOMAIN_ACCESS",
details={
"domain_id": domain_id,
"vendor_id": vendor_id
},
details={"domain_id": domain_id, "vendor_id": vendor_id},
)

View File

@@ -4,12 +4,9 @@ Vendor theme management specific exceptions.
"""
from typing import Any, Dict, Optional
from .base import (
ResourceNotFoundException,
ConflictException,
ValidationException,
BusinessLogicException
)
from .base import (BusinessLogicException, ConflictException,
ResourceNotFoundException, ValidationException)
class VendorThemeNotFoundException(ResourceNotFoundException):
@@ -86,10 +83,7 @@ class ThemePresetAlreadyAppliedException(BusinessLogicException):
super().__init__(
message=f"Preset '{preset_name}' is already applied to vendor '{vendor_code}'",
error_code="THEME_PRESET_ALREADY_APPLIED",
details={
"preset_name": preset_name,
"vendor_code": vendor_code
},
details={"preset_name": preset_name, "vendor_code": vendor_code},
)
@@ -120,18 +114,13 @@ class InvalidFontFamilyException(ValidationException):
class ThemeOperationException(BusinessLogicException):
"""Raised when theme operation fails."""
def __init__(
self,
operation: str,
vendor_code: str,
reason: str
):
def __init__(self, operation: str, vendor_code: str, reason: str):
super().__init__(
message=f"Theme operation '{operation}' failed for vendor '{vendor_code}': {reason}",
error_code="THEME_OPERATION_FAILED",
details={
"operation": operation,
"vendor_code": vendor_code,
"reason": reason
"reason": reason,
},
)

View File

@@ -3,18 +3,23 @@ Architecture Scan Models
Database models for tracking code quality scans and violations
"""
from sqlalchemy import Column, Integer, String, Float, DateTime, Text, Boolean, ForeignKey, JSON
from sqlalchemy import (JSON, Boolean, Column, DateTime, Float, ForeignKey,
Integer, String, Text)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.core.database import Base
class ArchitectureScan(Base):
"""Represents a single run of the architecture validator"""
__tablename__ = "architecture_scans"
id = Column(Integer, primary_key=True, index=True)
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
timestamp = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
)
total_files = Column(Integer, default=0)
total_violations = Column(Integer, default=0)
errors = Column(Integer, default=0)
@@ -24,7 +29,9 @@ class ArchitectureScan(Base):
git_commit_hash = Column(String(40))
# Relationship to violations
violations = relationship("ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan")
violations = relationship(
"ArchitectureViolation", back_populates="scan", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<ArchitectureScan(id={self.id}, violations={self.total_violations}, errors={self.errors})>"
@@ -32,31 +39,48 @@ class ArchitectureScan(Base):
class ArchitectureViolation(Base):
"""Represents a single architectural violation found during a scan"""
__tablename__ = "architecture_violations"
id = Column(Integer, primary_key=True, index=True)
scan_id = Column(Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True)
scan_id = Column(
Integer, ForeignKey("architecture_scans.id"), nullable=False, index=True
)
rule_id = Column(String(20), nullable=False, index=True) # e.g., 'API-001'
rule_name = Column(String(200), nullable=False)
severity = Column(String(10), nullable=False, index=True) # 'error', 'warning', 'info'
severity = Column(
String(10), nullable=False, index=True
) # 'error', 'warning', 'info'
file_path = Column(String(500), nullable=False, index=True)
line_number = Column(Integer, nullable=False)
message = Column(Text, nullable=False)
context = Column(Text) # Code snippet
suggestion = Column(Text)
status = Column(String(20), default='open', index=True) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt'
status = Column(
String(20), default="open", index=True
) # 'open', 'assigned', 'resolved', 'ignored', 'technical_debt'
assigned_to = Column(Integer, ForeignKey("users.id"))
resolved_at = Column(DateTime(timezone=True))
resolved_by = Column(Integer, ForeignKey("users.id"))
resolution_note = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationships
scan = relationship("ArchitectureScan", back_populates="violations")
assigned_user = relationship("User", foreign_keys=[assigned_to], backref="assigned_violations")
resolver = relationship("User", foreign_keys=[resolved_by], backref="resolved_violations")
assignments = relationship("ViolationAssignment", back_populates="violation", cascade="all, delete-orphan")
comments = relationship("ViolationComment", back_populates="violation", cascade="all, delete-orphan")
assigned_user = relationship(
"User", foreign_keys=[assigned_to], backref="assigned_violations"
)
resolver = relationship(
"User", foreign_keys=[resolved_by], backref="resolved_violations"
)
assignments = relationship(
"ViolationAssignment", back_populates="violation", cascade="all, delete-orphan"
)
comments = relationship(
"ViolationComment", back_populates="violation", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<ArchitectureViolation(id={self.id}, rule={self.rule_id}, file={self.file_path}:{self.line_number})>"
@@ -64,18 +88,30 @@ class ArchitectureViolation(Base):
class ArchitectureRule(Base):
"""Architecture rules configuration (from YAML with database overrides)"""
__tablename__ = "architecture_rules"
id = Column(Integer, primary_key=True, index=True)
rule_id = Column(String(20), unique=True, nullable=False, index=True) # e.g., 'API-001'
category = Column(String(50), nullable=False) # 'api_endpoint', 'service_layer', etc.
rule_id = Column(
String(20), unique=True, nullable=False, index=True
) # e.g., 'API-001'
category = Column(
String(50), nullable=False
) # 'api_endpoint', 'service_layer', etc.
name = Column(String(200), nullable=False)
description = Column(Text)
severity = Column(String(10), nullable=False) # Can override default from YAML
enabled = Column(Boolean, default=True, nullable=False)
custom_config = Column(JSON) # For rule-specific settings
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
def __repr__(self):
return f"<ArchitectureRule(id={self.rule_id}, name={self.name}, enabled={self.enabled})>"
@@ -83,20 +119,29 @@ class ArchitectureRule(Base):
class ViolationAssignment(Base):
"""Tracks assignment of violations to developers"""
__tablename__ = "violation_assignments"
id = Column(Integer, primary_key=True, index=True)
violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True)
violation_id = Column(
Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True
)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
assigned_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
assigned_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
assigned_by = Column(Integer, ForeignKey("users.id"))
due_date = Column(DateTime(timezone=True))
priority = Column(String(10), default='medium') # 'low', 'medium', 'high', 'critical'
priority = Column(
String(10), default="medium"
) # 'low', 'medium', 'high', 'critical'
# Relationships
violation = relationship("ArchitectureViolation", back_populates="assignments")
user = relationship("User", foreign_keys=[user_id], backref="violation_assignments")
assigner = relationship("User", foreign_keys=[assigned_by], backref="assigned_by_me")
assigner = relationship(
"User", foreign_keys=[assigned_by], backref="assigned_by_me"
)
def __repr__(self):
return f"<ViolationAssignment(id={self.id}, violation_id={self.violation_id}, user_id={self.user_id})>"
@@ -104,13 +149,18 @@ class ViolationAssignment(Base):
class ViolationComment(Base):
"""Comments on violations for collaboration"""
__tablename__ = "violation_comments"
id = Column(Integer, primary_key=True, index=True)
violation_id = Column(Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True)
violation_id = Column(
Integer, ForeignKey("architecture_violations.id"), nullable=False, index=True
)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
comment = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationships
violation = relationship("ArchitectureViolation", back_populates="comments")

View File

@@ -30,17 +30,15 @@ Routes:
- GET /code-quality/violations/{violation_id} → Violation details (auth required)
"""
from fastapi import APIRouter, Request, Depends, Path
from typing import Optional
from fastapi import APIRouter, Depends, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Optional
from app.api.deps import (
get_current_admin_from_cookie_or_header,
get_current_admin_optional,
get_db
)
from app.api.deps import (get_current_admin_from_cookie_or_header,
get_current_admin_optional, get_db)
from models.database.user import User
router = APIRouter()
@@ -51,9 +49,10 @@ templates = Jinja2Templates(directory="app/templates")
# PUBLIC ROUTES (No Authentication Required)
# ============================================================================
@router.get("/", response_class=RedirectResponse, include_in_schema=False)
async def admin_root(
current_user: Optional[User] = Depends(get_current_admin_optional)
current_user: Optional[User] = Depends(get_current_admin_optional),
):
"""
Redirect /admin/ based on authentication status.
@@ -70,8 +69,7 @@ async def admin_root(
@router.get("/login", response_class=HTMLResponse, include_in_schema=False)
async def admin_login_page(
request: Request,
current_user: Optional[User] = Depends(get_current_admin_optional)
request: Request, current_user: Optional[User] = Depends(get_current_admin_optional)
):
"""
Render admin login page.
@@ -83,21 +81,19 @@ async def admin_login_page(
# User is already logged in as admin, redirect to dashboard
return RedirectResponse(url="/admin/dashboard", status_code=302)
return templates.TemplateResponse(
"admin/login.html",
{"request": request}
)
return templates.TemplateResponse("admin/login.html", {"request": request})
# ============================================================================
# AUTHENTICATED ROUTES (Admin Only)
# ============================================================================
@router.get("/dashboard", response_class=HTMLResponse, include_in_schema=False)
async def admin_dashboard_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render admin dashboard page.
@@ -108,7 +104,7 @@ async def admin_dashboard_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -116,11 +112,12 @@ async def admin_dashboard_page(
# VENDOR MANAGEMENT ROUTES
# ============================================================================
@router.get("/vendors", response_class=HTMLResponse, include_in_schema=False)
async def admin_vendors_list_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendors management page.
@@ -131,7 +128,7 @@ async def admin_vendors_list_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -139,7 +136,7 @@ async def admin_vendors_list_page(
async def admin_vendor_create_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendor creation form.
@@ -149,16 +146,18 @@ async def admin_vendor_create_page(
{
"request": request,
"user": current_user,
}
},
)
@router.get("/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/vendors/{vendor_code}", response_class=HTMLResponse, include_in_schema=False
)
async def admin_vendor_detail_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendor detail page.
@@ -170,16 +169,18 @@ async def admin_vendor_detail_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@router.get("/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/vendors/{vendor_code}/edit", response_class=HTMLResponse, include_in_schema=False
)
async def admin_vendor_edit_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendor edit form.
@@ -190,7 +191,7 @@ async def admin_vendor_edit_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -198,12 +199,17 @@ async def admin_vendor_edit_page(
# VENDOR DOMAINS ROUTES
# ============================================================================
@router.get("/vendors/{vendor_code}/domains", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/vendors/{vendor_code}/domains",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_vendor_domains_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendor domains management page.
@@ -215,7 +221,7 @@ async def admin_vendor_domains_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -223,12 +229,15 @@ async def admin_vendor_domains_page(
# VENDOR THEMES ROUTES
# ============================================================================
@router.get("/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/vendors/{vendor_code}/theme", response_class=HTMLResponse, include_in_schema=False
)
async def admin_vendor_theme_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendor theme customization page.
@@ -240,7 +249,7 @@ async def admin_vendor_theme_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -248,11 +257,12 @@ async def admin_vendor_theme_page(
# USER MANAGEMENT ROUTES
# ============================================================================
@router.get("/users", response_class=HTMLResponse, include_in_schema=False)
async def admin_users_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render users management page.
@@ -263,7 +273,7 @@ async def admin_users_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -271,11 +281,12 @@ async def admin_users_page(
# IMPORT MANAGEMENT ROUTES
# ============================================================================
@router.get("/imports", response_class=HTMLResponse, include_in_schema=False)
async def admin_imports_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render imports management page.
@@ -286,7 +297,7 @@ async def admin_imports_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -294,11 +305,12 @@ async def admin_imports_page(
# SETTINGS ROUTES
# ============================================================================
@router.get("/settings", response_class=HTMLResponse, include_in_schema=False)
async def admin_settings_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render admin settings page.
@@ -309,7 +321,7 @@ async def admin_settings_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -317,11 +329,12 @@ async def admin_settings_page(
# CONTENT MANAGEMENT SYSTEM (CMS) ROUTES
# ============================================================================
@router.get("/platform-homepage", response_class=HTMLResponse, include_in_schema=False)
async def admin_platform_homepage_manager(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render platform homepage manager.
@@ -332,7 +345,7 @@ async def admin_platform_homepage_manager(
{
"request": request,
"user": current_user,
}
},
)
@@ -340,7 +353,7 @@ async def admin_platform_homepage_manager(
async def admin_content_pages_list(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render content pages list.
@@ -351,15 +364,17 @@ async def admin_content_pages_list(
{
"request": request,
"user": current_user,
}
},
)
@router.get("/content-pages/create", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/content-pages/create", response_class=HTMLResponse, include_in_schema=False
)
async def admin_content_page_create(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render create content page form.
@@ -371,16 +386,20 @@ async def admin_content_page_create(
"request": request,
"user": current_user,
"page_id": None, # Indicates this is a create operation
}
},
)
@router.get("/content-pages/{page_id}/edit", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/content-pages/{page_id}/edit",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_content_page_edit(
request: Request,
page_id: int = Path(..., description="Content page ID"),
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render edit content page form.
@@ -392,7 +411,7 @@ async def admin_content_page_edit(
"request": request,
"user": current_user,
"page_id": page_id,
}
},
)
@@ -400,11 +419,12 @@ async def admin_content_page_edit(
# DEVELOPER TOOLS - COMPONENTS & TESTING
# ============================================================================
@router.get("/components", response_class=HTMLResponse, include_in_schema=False)
async def admin_components_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render UI components library page.
@@ -415,7 +435,7 @@ async def admin_components_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -423,7 +443,7 @@ async def admin_components_page(
async def admin_icons_page(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render icons browser page.
@@ -434,7 +454,7 @@ async def admin_icons_page(
{
"request": request,
"user": current_user,
}
},
)
@@ -442,7 +462,7 @@ async def admin_icons_page(
async def admin_testing_hub(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render testing hub page.
@@ -453,7 +473,7 @@ async def admin_testing_hub(
{
"request": request,
"user": current_user,
}
},
)
@@ -461,7 +481,7 @@ async def admin_testing_hub(
async def admin_test_auth_flow(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render authentication flow testing page.
@@ -472,15 +492,19 @@ async def admin_test_auth_flow(
{
"request": request,
"user": current_user,
}
},
)
@router.get("/test/vendors-users-migration", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/test/vendors-users-migration",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_test_vendors_users_migration(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render vendors and users migration testing page.
@@ -491,7 +515,7 @@ async def admin_test_vendors_users_migration(
{
"request": request,
"user": current_user,
}
},
)
@@ -499,11 +523,12 @@ async def admin_test_vendors_users_migration(
# CODE QUALITY & ARCHITECTURE ROUTES
# ============================================================================
@router.get("/code-quality", response_class=HTMLResponse, include_in_schema=False)
async def admin_code_quality_dashboard(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render code quality dashboard.
@@ -514,15 +539,17 @@ async def admin_code_quality_dashboard(
{
"request": request,
"user": current_user,
}
},
)
@router.get("/code-quality/violations", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/code-quality/violations", response_class=HTMLResponse, include_in_schema=False
)
async def admin_code_quality_violations(
request: Request,
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render violations list page.
@@ -533,16 +560,20 @@ async def admin_code_quality_violations(
{
"request": request,
"user": current_user,
}
},
)
@router.get("/code-quality/violations/{violation_id}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/code-quality/violations/{violation_id}",
response_class=HTMLResponse,
include_in_schema=False,
)
async def admin_code_quality_violation_detail(
request: Request,
violation_id: int = Path(..., description="Violation ID"),
current_user: User = Depends(get_current_admin_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render violation detail page.
@@ -554,5 +585,5 @@ async def admin_code_quality_violation_detail(
"request": request,
"user": current_user,
"violation_id": violation_id,
}
},
)

View File

@@ -31,7 +31,8 @@ Routes (all mounted at /shop/* or /vendors/{code}/shop/* prefix):
"""
import logging
from fastapi import APIRouter, Request, Depends, Path
from fastapi import APIRouter, Depends, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
@@ -50,6 +51,7 @@ logger = logging.getLogger(__name__)
# HELPER: Build Shop Template Context
# ============================================================================
def get_shop_context(request: Request, db: Session = None, **extra_context) -> dict:
"""
Build template context for shop pages.
@@ -76,13 +78,17 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
get_shop_context(request, db=db, user=current_user, product_id=123)
"""
# Extract from middleware state
vendor = getattr(request.state, 'vendor', None)
theme = getattr(request.state, 'theme', None)
clean_path = getattr(request.state, 'clean_path', request.url.path)
vendor_context = getattr(request.state, 'vendor_context', None)
vendor = getattr(request.state, "vendor", None)
theme = getattr(request.state, "theme", None)
clean_path = getattr(request.state, "clean_path", request.url.path)
vendor_context = getattr(request.state, "vendor_context", None)
# Get detection method from vendor_context
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
if vendor is None:
logger.warning(
@@ -91,7 +97,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
"path": request.url.path,
"host": request.headers.get("host", ""),
"has_vendor": False,
}
},
)
# Calculate base URL for links
@@ -100,7 +106,11 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
base_url = "/"
if access_method == "path" and vendor:
# Use the full_prefix from vendor_context to determine which pattern was used
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/"
# Load footer navigation pages from CMS if db session provided
@@ -111,22 +121,16 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
vendor_id = vendor.id
# Get pages configured to show in footer
footer_pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=vendor_id,
footer_only=True,
include_unpublished=False
db, vendor_id=vendor_id, footer_only=True, include_unpublished=False
)
# Get pages configured to show in header
header_pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=vendor_id,
header_only=True,
include_unpublished=False
db, vendor_id=vendor_id, header_only=True, include_unpublished=False
)
except Exception as e:
logger.error(
f"[SHOP_CONTEXT] Failed to load navigation pages",
extra={"error": str(e), "vendor_id": vendor.id if vendor else None}
extra={"error": str(e), "vendor_id": vendor.id if vendor else None},
)
context = {
@@ -156,7 +160,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
"footer_pages_count": len(footer_pages),
"header_pages_count": len(header_pages),
"extra_keys": list(extra_context.keys()) if extra_context else [],
}
},
)
return context
@@ -166,6 +170,7 @@ def get_shop_context(request: Request, db: Session = None, **extra_context) -> d
# PUBLIC SHOP ROUTES (No Authentication Required)
# ============================================================================
@router.get("/", response_class=HTMLResponse, include_in_schema=False)
@router.get("/products", response_class=HTMLResponse, include_in_schema=False)
async def shop_products_page(request: Request, db: Session = Depends(get_db)):
@@ -177,21 +182,21 @@ async def shop_products_page(request: Request, db: Session = Depends(get_db)):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/products.html",
get_shop_context(request, db=db)
"shop/products.html", get_shop_context(request, db=db)
)
@router.get("/products/{product_id}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/products/{product_id}", response_class=HTMLResponse, include_in_schema=False
)
async def shop_product_detail_page(
request: Request,
product_id: int = Path(..., description="Product ID")
request: Request, product_id: int = Path(..., description="Product ID")
):
"""
Render product detail page.
@@ -201,21 +206,21 @@ async def shop_product_detail_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/product.html",
get_shop_context(request, product_id=product_id)
"shop/product.html", get_shop_context(request, product_id=product_id)
)
@router.get("/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/categories/{category_slug}", response_class=HTMLResponse, include_in_schema=False
)
async def shop_category_page(
request: Request,
category_slug: str = Path(..., description="Category slug")
request: Request, category_slug: str = Path(..., description="Category slug")
):
"""
Render category products page.
@@ -225,14 +230,13 @@ async def shop_category_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/category.html",
get_shop_context(request, category_slug=category_slug)
"shop/category.html", get_shop_context(request, category_slug=category_slug)
)
@@ -246,15 +250,12 @@ async def shop_cart_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/cart.html",
get_shop_context(request)
)
return templates.TemplateResponse("shop/cart.html", get_shop_context(request))
@router.get("/checkout", response_class=HTMLResponse, include_in_schema=False)
@@ -267,15 +268,12 @@ async def shop_checkout_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/checkout.html",
get_shop_context(request)
)
return templates.TemplateResponse("shop/checkout.html", get_shop_context(request))
@router.get("/search", response_class=HTMLResponse, include_in_schema=False)
@@ -288,21 +286,19 @@ async def shop_search_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/search.html",
get_shop_context(request)
)
return templates.TemplateResponse("shop/search.html", get_shop_context(request))
# ============================================================================
# CUSTOMER ACCOUNT - PUBLIC ROUTES (No Authentication)
# ============================================================================
@router.get("/account/register", response_class=HTMLResponse, include_in_schema=False)
async def shop_register_page(request: Request):
"""
@@ -313,14 +309,13 @@ async def shop_register_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/register.html",
get_shop_context(request)
"shop/account/register.html", get_shop_context(request)
)
@@ -334,18 +329,19 @@ async def shop_login_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/login.html",
get_shop_context(request)
"shop/account/login.html", get_shop_context(request)
)
@router.get("/account/forgot-password", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/account/forgot-password", response_class=HTMLResponse, include_in_schema=False
)
async def shop_forgot_password_page(request: Request):
"""
Render forgot password page.
@@ -355,14 +351,13 @@ async def shop_forgot_password_page(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/forgot-password.html",
get_shop_context(request)
"shop/account/forgot-password.html", get_shop_context(request)
)
@@ -370,6 +365,7 @@ async def shop_forgot_password_page(request: Request):
# CUSTOMER ACCOUNT - AUTHENTICATED ROUTES
# ============================================================================
@router.get("/account", response_class=RedirectResponse, include_in_schema=False)
@router.get("/account/", response_class=RedirectResponse, include_in_schema=False)
async def shop_account_root(request: Request):
@@ -380,19 +376,27 @@ async def shop_account_root(request: Request):
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
# Get base_url from context for proper redirect
vendor = getattr(request.state, 'vendor', None)
vendor_context = getattr(request.state, 'vendor_context', None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
vendor = getattr(request.state, "vendor", None)
vendor_context = getattr(request.state, "vendor_context", None)
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
base_url = "/"
if access_method == "path" and vendor:
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
base_url = f"{full_prefix}{vendor.subdomain}/"
return RedirectResponse(url=f"{base_url}shop/account/dashboard", status_code=302)
@@ -402,7 +406,7 @@ async def shop_account_root(request: Request):
async def shop_account_dashboard_page(
request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer account dashboard.
@@ -413,14 +417,13 @@ async def shop_account_dashboard_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/dashboard.html",
get_shop_context(request, user=current_customer)
"shop/account/dashboard.html", get_shop_context(request, user=current_customer)
)
@@ -428,7 +431,7 @@ async def shop_account_dashboard_page(
async def shop_orders_page(
request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer orders history page.
@@ -439,23 +442,24 @@ async def shop_orders_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/orders.html",
get_shop_context(request, user=current_customer)
"shop/account/orders.html", get_shop_context(request, user=current_customer)
)
@router.get("/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/account/orders/{order_id}", response_class=HTMLResponse, include_in_schema=False
)
async def shop_order_detail_page(
request: Request,
order_id: int = Path(..., description="Order ID"),
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer order detail page.
@@ -466,14 +470,14 @@ async def shop_order_detail_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/order-detail.html",
get_shop_context(request, user=current_customer, order_id=order_id)
get_shop_context(request, user=current_customer, order_id=order_id),
)
@@ -481,7 +485,7 @@ async def shop_order_detail_page(
async def shop_profile_page(
request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer profile page.
@@ -492,14 +496,13 @@ async def shop_profile_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/profile.html",
get_shop_context(request, user=current_customer)
"shop/account/profile.html", get_shop_context(request, user=current_customer)
)
@@ -507,7 +510,7 @@ async def shop_profile_page(
async def shop_addresses_page(
request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer addresses management page.
@@ -518,14 +521,13 @@ async def shop_addresses_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/addresses.html",
get_shop_context(request, user=current_customer)
"shop/account/addresses.html", get_shop_context(request, user=current_customer)
)
@@ -533,7 +535,7 @@ async def shop_addresses_page(
async def shop_wishlist_page(
request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer wishlist page.
@@ -544,14 +546,13 @@ async def shop_wishlist_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/wishlist.html",
get_shop_context(request, user=current_customer)
"shop/account/wishlist.html", get_shop_context(request, user=current_customer)
)
@@ -559,7 +560,7 @@ async def shop_wishlist_page(
async def shop_settings_page(
request: Request,
current_customer: Customer = Depends(get_current_customer_from_cookie_or_header),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Render customer account settings page.
@@ -570,14 +571,13 @@ async def shop_settings_page(
f"[SHOP_HANDLER] shop_products_page REACHED",
extra={
"path": request.url.path,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
return templates.TemplateResponse(
"shop/account/settings.html",
get_shop_context(request, user=current_customer)
"shop/account/settings.html", get_shop_context(request, user=current_customer)
)
@@ -585,11 +585,12 @@ async def shop_settings_page(
# DYNAMIC CONTENT PAGES (CMS)
# ============================================================================
@router.get("/{slug}", response_class=HTMLResponse, include_in_schema=False)
async def generic_content_page(
request: Request,
slug: str = Path(..., description="Content page slug"),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Generic content page handler (CMS).
@@ -612,20 +613,17 @@ async def generic_content_page(
extra={
"path": request.url.path,
"slug": slug,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
# Load content page from database (vendor override → platform default)
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=vendor_id,
include_unpublished=False
db, slug=slug, vendor_id=vendor_id, include_unpublished=False
)
if not page:
@@ -635,7 +633,7 @@ async def generic_content_page(
"slug": slug,
"vendor_id": vendor_id,
"vendor_name": vendor.name if vendor else None,
}
},
)
raise HTTPException(status_code=404, detail=f"Page not found: {slug}")
@@ -647,12 +645,11 @@ async def generic_content_page(
"page_title": page.title,
"is_vendor_override": page.vendor_id is not None,
"vendor_id": vendor_id,
}
},
)
return templates.TemplateResponse(
"shop/content-page.html",
get_shop_context(request, page=page)
"shop/content-page.html", get_shop_context(request, page=page)
)
@@ -660,6 +657,7 @@ async def generic_content_page(
# DEBUG ENDPOINTS - For troubleshooting context issues
# ============================================================================
@router.get("/debug/context", response_class=HTMLResponse, include_in_schema=False)
async def debug_context(request: Request):
"""
@@ -670,8 +668,8 @@ async def debug_context(request: Request):
URL: /shop/debug/context
"""
vendor = getattr(request.state, 'vendor', None)
theme = getattr(request.state, 'theme', None)
vendor = getattr(request.state, "vendor", None)
theme = getattr(request.state, "theme", None)
debug_info = {
"path": request.url.path,
@@ -687,12 +685,13 @@ async def debug_context(request: Request):
"found": theme is not None,
"name": theme.get("theme_name") if theme else None,
},
"clean_path": getattr(request.state, 'clean_path', 'NOT SET'),
"context_type": str(getattr(request.state, 'context_type', 'NOT SET')),
"clean_path": getattr(request.state, "clean_path", "NOT SET"),
"context_type": str(getattr(request.state, "context_type", "NOT SET")),
}
# Return as JSON-like HTML for easy reading
import json
html_content = f"""
<!DOCTYPE html>
<html>

View File

@@ -21,18 +21,16 @@ Routes:
- GET /vendor/{vendor_code}/settings → Vendor settings
"""
from fastapi import APIRouter, Request, Depends, Path, HTTPException
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Optional
import logging
from app.api.deps import (
get_current_vendor_from_cookie_or_header,
get_current_vendor_optional,
get_db
)
from app.api.deps import (get_current_vendor_from_cookie_or_header,
get_current_vendor_optional, get_db)
from app.services.content_page_service import content_page_service
from models.database.user import User
@@ -46,6 +44,7 @@ templates = Jinja2Templates(directory="app/templates")
# PUBLIC ROUTES (No Authentication Required)
# ============================================================================
@router.get("/{vendor_code}", response_class=RedirectResponse, include_in_schema=False)
async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor code")):
"""
@@ -58,7 +57,7 @@ async def vendor_root_no_slash(vendor_code: str = Path(..., description="Vendor
@router.get("/{vendor_code}/", response_class=RedirectResponse, include_in_schema=False)
async def vendor_root(
vendor_code: str = Path(..., description="Vendor code"),
current_user: Optional[User] = Depends(get_current_vendor_optional)
current_user: Optional[User] = Depends(get_current_vendor_optional),
):
"""
Redirect /vendor/{code}/ based on authentication status.
@@ -73,11 +72,13 @@ async def vendor_root(
return RedirectResponse(url=f"/vendor/{vendor_code}/login", status_code=302)
@router.get("/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/login", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_login_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: Optional[User] = Depends(get_current_vendor_optional)
current_user: Optional[User] = Depends(get_current_vendor_optional),
):
"""
Render vendor login page.
@@ -99,7 +100,7 @@ async def vendor_login_page(
{
"request": request,
"vendor_code": vendor_code,
}
},
)
@@ -107,11 +108,14 @@ async def vendor_login_page(
# AUTHENTICATED ROUTES (Vendor Users Only)
# ============================================================================
@router.get("/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/dashboard", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_dashboard_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render vendor dashboard.
@@ -128,7 +132,7 @@ async def vendor_dashboard_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -136,11 +140,14 @@ async def vendor_dashboard_page(
# PRODUCT MANAGEMENT
# ============================================================================
@router.get("/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/products", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_products_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render products management page.
@@ -152,7 +159,7 @@ async def vendor_products_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -160,11 +167,14 @@ async def vendor_products_page(
# ORDER MANAGEMENT
# ============================================================================
@router.get("/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/orders", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_orders_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render orders management page.
@@ -176,7 +186,7 @@ async def vendor_orders_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -184,11 +194,14 @@ async def vendor_orders_page(
# CUSTOMER MANAGEMENT
# ============================================================================
@router.get("/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/customers", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_customers_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render customers management page.
@@ -200,7 +213,7 @@ async def vendor_customers_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -208,11 +221,14 @@ async def vendor_customers_page(
# INVENTORY MANAGEMENT
# ============================================================================
@router.get("/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/inventory", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_inventory_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render inventory management page.
@@ -224,7 +240,7 @@ async def vendor_inventory_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -232,11 +248,14 @@ async def vendor_inventory_page(
# MARKETPLACE IMPORTS
# ============================================================================
@router.get("/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/marketplace", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_marketplace_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render marketplace import page.
@@ -248,7 +267,7 @@ async def vendor_marketplace_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -256,11 +275,12 @@ async def vendor_marketplace_page(
# TEAM MANAGEMENT
# ============================================================================
@router.get("/{vendor_code}/team", response_class=HTMLResponse, include_in_schema=False)
async def vendor_team_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render team management page.
@@ -272,7 +292,7 @@ async def vendor_team_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -280,11 +300,14 @@ async def vendor_team_page(
# PROFILE & SETTINGS
# ============================================================================
@router.get("/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/profile", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_profile_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render vendor profile page.
@@ -296,15 +319,17 @@ async def vendor_profile_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@router.get("/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/settings", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_settings_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
current_user: User = Depends(get_current_vendor_from_cookie_or_header)
current_user: User = Depends(get_current_vendor_from_cookie_or_header),
):
"""
Render vendor settings page.
@@ -316,7 +341,7 @@ async def vendor_settings_page(
"request": request,
"user": current_user,
"vendor_code": vendor_code,
}
},
)
@@ -324,12 +349,15 @@ async def vendor_settings_page(
# DYNAMIC CONTENT PAGES (CMS)
# ============================================================================
@router.get("/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False)
@router.get(
"/{vendor_code}/{slug}", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_content_page(
request: Request,
vendor_code: str = Path(..., description="Vendor code"),
slug: str = Path(..., description="Content page slug"),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""
Generic content page handler for vendor shop (CMS).
@@ -351,20 +379,17 @@ async def vendor_content_page(
"path": request.url.path,
"vendor_code": vendor_code,
"slug": slug,
"vendor": getattr(request.state, 'vendor', 'NOT SET'),
"context": getattr(request.state, 'context_type', 'NOT SET'),
}
"vendor": getattr(request.state, "vendor", "NOT SET"),
"context": getattr(request.state, "context_type", "NOT SET"),
},
)
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
vendor_id = vendor.id if vendor else None
# Load content page from database (vendor override → platform default)
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=vendor_id,
include_unpublished=False
db, slug=slug, vendor_id=vendor_id, include_unpublished=False
)
if not page:
@@ -374,7 +399,7 @@ async def vendor_content_page(
"slug": slug,
"vendor_code": vendor_code,
"vendor_id": vendor_id,
}
},
)
raise HTTPException(status_code=404, detail="Page not found")
@@ -385,7 +410,7 @@ async def vendor_content_page(
"page_id": page.id,
"is_vendor_override": page.vendor_id is not None,
"vendor_id": vendor_id,
}
},
)
return templates.TemplateResponse(
@@ -394,5 +419,5 @@ async def vendor_content_page(
"request": request,
"page": page,
"vendor_code": vendor_code,
}
},
)

View File

@@ -10,15 +10,15 @@ This module provides functions for:
import logging
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from app.exceptions import AdminOperationException
from models.database.admin import AdminAuditLog
from models.database.user import User
from models.schema.admin import AdminAuditLogFilters, AdminAuditLogResponse
from app.exceptions import AdminOperationException
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class AdminAuditService:
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_id: Optional[str] = None
request_id: Optional[str] = None,
) -> AdminAuditLog:
"""
Log an admin action to the audit trail.
@@ -63,7 +63,7 @@ class AdminAuditService:
details=details or {},
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id
request_id=request_id,
)
db.add(audit_log)
@@ -84,9 +84,7 @@ class AdminAuditService:
return None
def get_audit_logs(
self,
db: Session,
filters: AdminAuditLogFilters
self, db: Session, filters: AdminAuditLogFilters
) -> List[AdminAuditLogResponse]:
"""
Get filtered admin audit logs with pagination.
@@ -98,7 +96,9 @@ class AdminAuditService:
List of audit log responses
"""
try:
query = db.query(AdminAuditLog).join(User, AdminAuditLog.admin_user_id == User.id)
query = db.query(AdminAuditLog).join(
User, AdminAuditLog.admin_user_id == User.id
)
# Apply filters
conditions = []
@@ -123,8 +123,7 @@ class AdminAuditService:
# Execute query with pagination
logs = (
query
.order_by(AdminAuditLog.created_at.desc())
query.order_by(AdminAuditLog.created_at.desc())
.offset(filters.skip)
.limit(filters.limit)
.all()
@@ -143,7 +142,7 @@ class AdminAuditService:
ip_address=log.ip_address,
user_agent=log.user_agent,
request_id=log.request_id,
created_at=log.created_at
created_at=log.created_at,
)
for log in logs
]
@@ -151,15 +150,10 @@ class AdminAuditService:
except Exception as e:
logger.error(f"Failed to retrieve audit logs: {str(e)}")
raise AdminOperationException(
operation="get_audit_logs",
reason="Database query failed"
operation="get_audit_logs", reason="Database query failed"
)
def get_audit_logs_count(
self,
db: Session,
filters: AdminAuditLogFilters
) -> int:
def get_audit_logs_count(self, db: Session, filters: AdminAuditLogFilters) -> int:
"""Get total count of audit logs matching filters."""
try:
query = db.query(AdminAuditLog)
@@ -192,24 +186,14 @@ class AdminAuditService:
return 0
def get_recent_actions_by_admin(
self,
db: Session,
admin_user_id: int,
limit: int = 10
self, db: Session, admin_user_id: int, limit: int = 10
) -> List[AdminAuditLogResponse]:
"""Get recent actions by a specific admin."""
filters = AdminAuditLogFilters(
admin_user_id=admin_user_id,
limit=limit
)
filters = AdminAuditLogFilters(admin_user_id=admin_user_id, limit=limit)
return self.get_audit_logs(db, filters)
def get_actions_by_target(
self,
db: Session,
target_type: str,
target_id: str,
limit: int = 50
self, db: Session, target_type: str, target_id: str, limit: int = 50
) -> List[AdminAuditLogResponse]:
"""Get all actions performed on a specific target."""
try:
@@ -218,7 +202,7 @@ class AdminAuditService:
.filter(
and_(
AdminAuditLog.target_type == target_type,
AdminAuditLog.target_id == str(target_id)
AdminAuditLog.target_id == str(target_id),
)
)
.order_by(AdminAuditLog.created_at.desc())
@@ -236,7 +220,7 @@ class AdminAuditService:
target_id=log.target_id,
details=log.details,
ip_address=log.ip_address,
created_at=log.created_at
created_at=log.created_at,
)
for log in logs
]

View File

@@ -16,24 +16,19 @@ import string
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from app.exceptions import (
UserNotFoundException,
UserStatusChangeException,
CannotModifySelfException,
from app.exceptions import (AdminOperationException, CannotModifySelfException,
UserNotFoundException, UserStatusChangeException,
ValidationException, VendorAlreadyExistsException,
VendorNotFoundException,
VendorAlreadyExistsException,
VendorVerificationException,
AdminOperationException,
ValidationException,
)
VendorVerificationException)
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.schema.vendor import VendorCreate
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor, Role, VendorUser
from models.database.user import User
logger = logging.getLogger(__name__)
@@ -52,8 +47,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve users: {str(e)}")
raise AdminOperationException(
operation="get_all_users",
reason="Database query failed"
operation="get_all_users", reason="Database query failed"
)
def toggle_user_status(
@@ -72,7 +66,7 @@ class AdminService:
user_id=user_id,
current_status="admin",
attempted_action="toggle status",
reason="Cannot modify another admin user"
reason="Cannot modify another admin user",
)
try:
@@ -95,7 +89,7 @@ class AdminService:
user_id=user_id,
current_status="active" if original_status else "inactive",
attempted_action="toggle status",
reason="Database update failed"
reason="Database update failed",
)
# ============================================================================
@@ -118,17 +112,23 @@ class AdminService:
"""
try:
# Check if vendor code already exists
existing_vendor = db.query(Vendor).filter(
existing_vendor = (
db.query(Vendor)
.filter(
func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
).first()
)
.first()
)
if existing_vendor:
raise VendorAlreadyExistsException(vendor_data.vendor_code)
# Check if subdomain already exists
existing_subdomain = db.query(Vendor).filter(
func.lower(Vendor.subdomain) == vendor_data.subdomain.lower()
).first()
existing_subdomain = (
db.query(Vendor)
.filter(func.lower(Vendor.subdomain) == vendor_data.subdomain.lower())
.first()
)
if existing_subdomain:
raise ValidationException(
@@ -140,15 +140,14 @@ class AdminService:
# Create owner user with owner_email
from middleware.auth import AuthManager
auth_manager = AuthManager()
owner_username = f"{vendor_data.subdomain}_owner"
owner_email = vendor_data.owner_email # ✅ For User authentication
# Check if user with this email already exists
existing_user = db.query(User).filter(
User.email == owner_email
).first()
existing_user = db.query(User).filter(User.email == owner_email).first()
if existing_user:
# Use existing user as owner
@@ -215,7 +214,7 @@ class AdminService:
logger.error(f"Failed to create vendor: {str(e)}")
raise AdminOperationException(
operation="create_vendor_with_owner",
reason=f"Failed to create vendor: {str(e)}"
reason=f"Failed to create vendor: {str(e)}",
)
def get_all_vendors(
@@ -225,7 +224,7 @@ class AdminService:
limit: int = 100,
search: Optional[str] = None,
is_active: Optional[bool] = None,
is_verified: Optional[bool] = None
is_verified: Optional[bool] = None,
) -> Tuple[List[Vendor], int]:
"""Get paginated list of all vendors with filtering."""
try:
@@ -238,7 +237,7 @@ class AdminService:
or_(
Vendor.name.ilike(search_term),
Vendor.vendor_code.ilike(search_term),
Vendor.subdomain.ilike(search_term)
Vendor.subdomain.ilike(search_term),
)
)
@@ -255,8 +254,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve vendors: {str(e)}")
raise AdminOperationException(
operation="get_all_vendors",
reason="Database query failed"
operation="get_all_vendors", reason="Database query failed"
)
def get_vendor_by_id(self, db: Session, vendor_id: int) -> Vendor:
@@ -290,7 +288,7 @@ class AdminService:
raise VendorVerificationException(
vendor_id=vendor_id,
reason="Database update failed",
current_verification_status=original_status
current_verification_status=original_status,
)
def toggle_vendor_status(self, db: Session, vendor_id: int) -> Tuple[Vendor, str]:
@@ -317,7 +315,7 @@ class AdminService:
operation="toggle_vendor_status",
reason="Database update failed",
target_type="vendor",
target_id=str(vendor_id)
target_id=str(vendor_id),
)
def delete_vendor(self, db: Session, vendor_id: int) -> str:
@@ -345,15 +343,11 @@ class AdminService:
db.rollback()
logger.error(f"Failed to delete vendor {vendor_id}: {str(e)}")
raise AdminOperationException(
operation="delete_vendor",
reason="Database deletion failed"
operation="delete_vendor", reason="Database deletion failed"
)
def update_vendor(
self,
db: Session,
vendor_id: int,
vendor_update # VendorUpdate schema
self, db: Session, vendor_id: int, vendor_update # VendorUpdate schema
) -> Vendor:
"""
Update vendor information (Admin only).
@@ -387,11 +381,18 @@ class AdminService:
update_data = vendor_update.model_dump(exclude_unset=True)
# Check subdomain uniqueness if changing
if 'subdomain' in update_data and update_data['subdomain'] != vendor.subdomain:
existing = db.query(Vendor).filter(
Vendor.subdomain == update_data['subdomain'],
Vendor.id != vendor_id
).first()
if (
"subdomain" in update_data
and update_data["subdomain"] != vendor.subdomain
):
existing = (
db.query(Vendor)
.filter(
Vendor.subdomain == update_data["subdomain"],
Vendor.id != vendor_id,
)
.first()
)
if existing:
raise ValidationException(
f"Subdomain '{update_data['subdomain']}' is already taken"
@@ -419,8 +420,7 @@ class AdminService:
db.rollback()
logger.error(f"Failed to update vendor {vendor_id}: {str(e)}")
raise AdminOperationException(
operation="update_vendor",
reason=f"Database update failed: {str(e)}"
operation="update_vendor", reason=f"Database update failed: {str(e)}"
)
# Add this NEW method for transferring ownership:
@@ -429,7 +429,7 @@ class AdminService:
self,
db: Session,
vendor_id: int,
transfer_data # VendorTransferOwnership schema
transfer_data, # VendorTransferOwnership schema
) -> Tuple[Vendor, User, User]:
"""
Transfer vendor ownership to another user.
@@ -466,9 +466,9 @@ class AdminService:
old_owner = vendor.owner
# Get new owner
new_owner = db.query(User).filter(
User.id == transfer_data.new_owner_user_id
).first()
new_owner = (
db.query(User).filter(User.id == transfer_data.new_owner_user_id).first()
)
if not new_owner:
raise UserNotFoundException(str(transfer_data.new_owner_user_id))
@@ -487,26 +487,32 @@ class AdminService:
try:
# Get Owner role for this vendor
owner_role = db.query(Role).filter(
Role.vendor_id == vendor_id,
Role.name == "Owner"
).first()
owner_role = (
db.query(Role)
.filter(Role.vendor_id == vendor_id, Role.name == "Owner")
.first()
)
if not owner_role:
raise ValidationException("Owner role not found for vendor")
# Get Manager role (to demote old owner)
manager_role = db.query(Role).filter(
Role.vendor_id == vendor_id,
Role.name == "Manager"
).first()
manager_role = (
db.query(Role)
.filter(Role.vendor_id == vendor_id, Role.name == "Manager")
.first()
)
# Remove old owner from Owner role
old_owner_link = db.query(VendorUser).filter(
old_owner_link = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == old_owner.id,
VendorUser.role_id == owner_role.id
).first()
VendorUser.role_id == owner_role.id,
)
.first()
)
if old_owner_link:
if manager_role:
@@ -525,10 +531,14 @@ class AdminService:
)
# Check if new owner already has a vendor_user link
new_owner_link = db.query(VendorUser).filter(
new_owner_link = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == new_owner.id
).first()
VendorUser.user_id == new_owner.id,
)
.first()
)
if new_owner_link:
# Update existing link to Owner role
@@ -540,7 +550,7 @@ class AdminService:
vendor_id=vendor_id,
user_id=new_owner.id,
role_id=owner_role.id,
is_active=True
is_active=True,
)
db.add(new_owner_link)
@@ -568,10 +578,12 @@ class AdminService:
raise
except Exception as e:
db.rollback()
logger.error(f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}")
logger.error(
f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException(
operation="transfer_vendor_ownership",
reason=f"Ownership transfer failed: {str(e)}"
reason=f"Ownership transfer failed: {str(e)}",
)
# ============================================================================
@@ -596,7 +608,9 @@ class AdminService:
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
query = query.filter(MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%")
)
if status:
query = query.filter(MarketplaceImportJob.status == status)
@@ -612,8 +626,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve marketplace import jobs: {str(e)}")
raise AdminOperationException(
operation="get_marketplace_import_jobs",
reason="Database query failed"
operation="get_marketplace_import_jobs", reason="Database query failed"
)
# ============================================================================
@@ -624,10 +637,7 @@ class AdminService:
"""Get recently created vendors."""
try:
vendors = (
db.query(Vendor)
.order_by(Vendor.created_at.desc())
.limit(limit)
.all()
db.query(Vendor).order_by(Vendor.created_at.desc()).limit(limit).all()
)
return [
@@ -638,7 +648,7 @@ class AdminService:
"subdomain": v.subdomain,
"is_active": v.is_active,
"is_verified": v.is_verified,
"created_at": v.created_at
"created_at": v.created_at,
}
for v in vendors
]
@@ -663,7 +673,7 @@ class AdminService:
"vendor_name": j.vendor_name,
"status": j.status,
"total_processed": j.total_processed or 0,
"created_at": j.created_at
"created_at": j.created_at,
}
for j in jobs
]
@@ -692,47 +702,53 @@ class AdminService:
def _generate_temp_password(self, length: int = 12) -> str:
"""Generate secure temporary password."""
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(length))
return "".join(secrets.choice(alphabet) for _ in range(length))
def _create_default_roles(self, db: Session, vendor_id: int):
"""Create default roles for a new vendor."""
default_roles = [
{
"name": "Owner",
"permissions": ["*"] # Full access
},
{"name": "Owner", "permissions": ["*"]}, # Full access
{
"name": "Manager",
"permissions": [
"products.*", "orders.*", "customers.view",
"inventory.*", "team.view"
]
"products.*",
"orders.*",
"customers.view",
"inventory.*",
"team.view",
],
},
{
"name": "Editor",
"permissions": [
"products.view", "products.edit",
"orders.view", "inventory.view"
]
"products.view",
"products.edit",
"orders.view",
"inventory.view",
],
},
{
"name": "Viewer",
"permissions": [
"products.view", "orders.view",
"customers.view", "inventory.view"
]
}
"products.view",
"orders.view",
"customers.view",
"inventory.view",
],
},
]
for role_data in default_roles:
role = Role(
vendor_id=vendor_id,
name=role_data["name"],
permissions=role_data["permissions"]
permissions=role_data["permissions"],
)
db.add(role)
def _convert_job_to_response(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse:
def _convert_job_to_response(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to response schema."""
return MarketplaceImportJobResponse(
job_id=job.id,

View File

@@ -8,25 +8,19 @@ This module provides functions for:
- Encrypting sensitive settings
"""
import logging
import json
from typing import Optional, List, Any, Dict
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (AdminOperationException, ResourceNotFoundException,
ValidationException)
from models.database.admin import AdminSetting
from models.schema.admin import (
AdminSettingCreate,
AdminSettingResponse,
AdminSettingUpdate
)
from app.exceptions import (
AdminOperationException,
ValidationException,
ResourceNotFoundException
)
from models.schema.admin import (AdminSettingCreate, AdminSettingResponse,
AdminSettingUpdate)
logger = logging.getLogger(__name__)
@@ -34,26 +28,19 @@ logger = logging.getLogger(__name__)
class AdminSettingsService:
"""Service for managing platform-wide settings."""
def get_setting_by_key(
self,
db: Session,
key: str
) -> Optional[AdminSetting]:
def get_setting_by_key(self, db: Session, key: str) -> Optional[AdminSetting]:
"""Get setting by key."""
try:
return db.query(AdminSetting).filter(
func.lower(AdminSetting.key) == key.lower()
).first()
return (
db.query(AdminSetting)
.filter(func.lower(AdminSetting.key) == key.lower())
.first()
)
except Exception as e:
logger.error(f"Failed to get setting {key}: {str(e)}")
return None
def get_setting_value(
self,
db: Session,
key: str,
default: Any = None
) -> Any:
def get_setting_value(self, db: Session, key: str, default: Any = None) -> Any:
"""
Get setting value with type conversion.
@@ -76,7 +63,7 @@ class AdminSettingsService:
elif setting.value_type == "float":
return float(setting.value)
elif setting.value_type == "boolean":
return setting.value.lower() in ('true', '1', 'yes')
return setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json":
return json.loads(setting.value)
else:
@@ -89,7 +76,7 @@ class AdminSettingsService:
self,
db: Session,
category: Optional[str] = None,
is_public: Optional[bool] = None
is_public: Optional[bool] = None,
) -> List[AdminSettingResponse]:
"""Get all settings with optional filtering."""
try:
@@ -104,22 +91,16 @@ class AdminSettingsService:
settings = query.order_by(AdminSetting.category, AdminSetting.key).all()
return [
AdminSettingResponse.model_validate(setting)
for setting in settings
AdminSettingResponse.model_validate(setting) for setting in settings
]
except Exception as e:
logger.error(f"Failed to get settings: {str(e)}")
raise AdminOperationException(
operation="get_all_settings",
reason="Database query failed"
operation="get_all_settings", reason="Database query failed"
)
def get_settings_by_category(
self,
db: Session,
category: str
) -> Dict[str, Any]:
def get_settings_by_category(self, db: Session, category: str) -> Dict[str, Any]:
"""
Get all settings in a category as a dictionary.
@@ -136,7 +117,7 @@ class AdminSettingsService:
elif setting.value_type == "float":
result[setting.key] = float(setting.value)
elif setting.value_type == "boolean":
result[setting.key] = setting.value.lower() in ('true', '1', 'yes')
result[setting.key] = setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json":
result[setting.key] = json.loads(setting.value)
else:
@@ -145,10 +126,7 @@ class AdminSettingsService:
return result
def create_setting(
self,
db: Session,
setting_data: AdminSettingCreate,
admin_user_id: int
self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
) -> AdminSettingResponse:
"""Create new setting."""
try:
@@ -176,7 +154,7 @@ class AdminSettingsService:
description=setting_data.description,
is_encrypted=setting_data.is_encrypted,
is_public=setting_data.is_public,
last_modified_by_user_id=admin_user_id
last_modified_by_user_id=admin_user_id,
)
db.add(setting)
@@ -194,25 +172,17 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to create setting: {str(e)}")
raise AdminOperationException(
operation="create_setting",
reason="Database operation failed"
operation="create_setting", reason="Database operation failed"
)
def update_setting(
self,
db: Session,
key: str,
update_data: AdminSettingUpdate,
admin_user_id: int
self, db: Session, key: str, update_data: AdminSettingUpdate, admin_user_id: int
) -> AdminSettingResponse:
"""Update existing setting."""
setting = self.get_setting_by_key(db, key)
if not setting:
raise ResourceNotFoundException(
resource_type="setting",
identifier=key
)
raise ResourceNotFoundException(resource_type="setting", identifier=key)
try:
# Validate new value
@@ -244,42 +214,29 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to update setting {key}: {str(e)}")
raise AdminOperationException(
operation="update_setting",
reason="Database operation failed"
operation="update_setting", reason="Database operation failed"
)
def upsert_setting(
self,
db: Session,
setting_data: AdminSettingCreate,
admin_user_id: int
self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
) -> AdminSettingResponse:
"""Create or update setting (upsert)."""
existing = self.get_setting_by_key(db, setting_data.key)
if existing:
update_data = AdminSettingUpdate(
value=setting_data.value,
description=setting_data.description
value=setting_data.value, description=setting_data.description
)
return self.update_setting(db, setting_data.key, update_data, admin_user_id)
else:
return self.create_setting(db, setting_data, admin_user_id)
def delete_setting(
self,
db: Session,
key: str,
admin_user_id: int
) -> str:
def delete_setting(self, db: Session, key: str, admin_user_id: int) -> str:
"""Delete setting."""
setting = self.get_setting_by_key(db, key)
if not setting:
raise ResourceNotFoundException(
resource_type="setting",
identifier=key
)
raise ResourceNotFoundException(resource_type="setting", identifier=key)
try:
db.delete(setting)
@@ -293,8 +250,7 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to delete setting {key}: {str(e)}")
raise AdminOperationException(
operation="delete_setting",
reason="Database operation failed"
operation="delete_setting", reason="Database operation failed"
)
# ============================================================================
@@ -309,7 +265,7 @@ class AdminSettingsService:
elif value_type == "float":
float(value)
elif value_type == "boolean":
if value.lower() not in ('true', 'false', '1', '0', 'yes', 'no'):
if value.lower() not in ("true", "false", "1", "0", "yes", "no"):
raise ValueError("Invalid boolean value")
elif value_type == "json":
json.loads(value)

View File

@@ -13,15 +13,12 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from app.exceptions import (
UserAlreadyExistsException,
InvalidCredentialsException,
UserNotActiveException,
ValidationException,
)
from app.exceptions import (InvalidCredentialsException,
UserAlreadyExistsException, UserNotActiveException,
ValidationException)
from middleware.auth import AuthManager
from models.schema.auth import UserLogin, UserRegister
from models.database.user import User
from models.schema.auth import UserLogin, UserRegister
logger = logging.getLogger(__name__)
@@ -51,11 +48,15 @@ class AuthService:
try:
# Check if email already exists
if self._email_exists(db, user_data.email):
raise UserAlreadyExistsException("Email already registered", field="email")
raise UserAlreadyExistsException(
"Email already registered", field="email"
)
# Check if username already exists
if self._username_exists(db, user_data.username):
raise UserAlreadyExistsException("Username already taken", field="username")
raise UserAlreadyExistsException(
"Username already taken", field="username"
)
# Hash password and create user
hashed_password = self.auth_manager.hash_password(user_data.password)
@@ -182,7 +183,9 @@ class AuthService:
Dictionary with access_token, token_type, and expires_in
"""
from datetime import datetime, timedelta, timezone
from jose import jwt
from app.core.config import settings
try:
@@ -217,6 +220,5 @@ class AuthService:
return db.query(User).filter(User.username == username).first() is not None
# Create service instance following the same pattern as other services
auth_service = AuthService()

View File

@@ -9,23 +9,20 @@ This module provides:
"""
import logging
from typing import Dict, List, Optional
from datetime import datetime, timezone
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy.orm import Session
from models.database.product import Product
from models.database.vendor import Vendor
from models.database.cart import CartItem
from app.exceptions import (
ProductNotFoundException,
CartItemNotFoundException,
CartValidationException,
from app.exceptions import (CartItemNotFoundException, CartValidationException,
InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException,
)
ProductNotFoundException)
from models.database.cart import CartItem
from models.database.product import Product
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
@@ -33,12 +30,7 @@ logger = logging.getLogger(__name__)
class CartService:
"""Service for managing shopping carts."""
def get_cart(
self,
db: Session,
vendor_id: int,
session_id: str
) -> Dict:
def get_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
"""
Get cart contents for a session.
@@ -55,20 +47,21 @@ class CartService:
extra={
"vendor_id": vendor_id,
"session_id": session_id,
}
},
)
# Fetch cart items from database
cart_items = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id
cart_items = (
db.query(CartItem)
.filter(
and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
.all()
)
).all()
logger.info(
f"[CART_SERVICE] Found {len(cart_items)} items in database",
extra={"item_count": len(cart_items)}
extra={"item_count": len(cart_items)},
)
# Build response
@@ -79,14 +72,20 @@ class CartService:
product = cart_item.product
line_total = cart_item.line_total
items.append({
items.append(
{
"product_id": product.id,
"product_name": product.marketplace_product.title,
"quantity": cart_item.quantity,
"price": cart_item.price_at_add,
"line_total": line_total,
"image_url": product.marketplace_product.image_link if product.marketplace_product else None,
})
"image_url": (
product.marketplace_product.image_link
if product.marketplace_product
else None
),
}
)
subtotal += line_total
@@ -95,12 +94,12 @@ class CartService:
"session_id": session_id,
"items": items,
"subtotal": subtotal,
"total": subtotal # Could add tax/shipping later
"total": subtotal, # Could add tax/shipping later
}
logger.info(
f"[CART_SERVICE] get_cart returning: {len(cart_data['items'])} items, total: {cart_data['total']}",
extra={"cart": cart_data}
extra={"cart": cart_data},
)
return cart_data
@@ -111,7 +110,7 @@ class CartService:
vendor_id: int,
session_id: str,
product_id: int,
quantity: int = 1
quantity: int = 1,
) -> Dict:
"""
Add product to cart.
@@ -136,23 +135,27 @@ class CartService:
"vendor_id": vendor_id,
"session_id": session_id,
"product_id": product_id,
"quantity": quantity
}
"quantity": quantity,
},
)
# Verify product exists and belongs to vendor
product = db.query(Product).filter(
product = (
db.query(Product)
.filter(
and_(
Product.id == product_id,
Product.vendor_id == vendor_id,
Product.is_active == True
Product.is_active == True,
)
)
.first()
)
).first()
if not product:
logger.error(
f"[CART_SERVICE] Product not found",
extra={"product_id": product_id, "vendor_id": vendor_id}
extra={"product_id": product_id, "vendor_id": vendor_id},
)
raise ProductNotFoundException(product_id=product_id, vendor_id=vendor_id)
@@ -161,21 +164,25 @@ class CartService:
extra={
"product_id": product_id,
"product_name": product.marketplace_product.title,
"available_inventory": product.available_inventory
}
"available_inventory": product.available_inventory,
},
)
# Get current price (use sale_price if available, otherwise regular price)
current_price = product.sale_price if product.sale_price else product.price
# Check if item already exists in cart
existing_item = db.query(CartItem).filter(
existing_item = (
db.query(CartItem)
.filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id
CartItem.product_id == product_id,
)
)
.first()
)
).first()
if existing_item:
# Update quantity
@@ -190,14 +197,14 @@ class CartService:
"current_in_cart": existing_item.quantity,
"adding": quantity,
"requested_total": new_quantity,
"available": product.available_inventory
}
"available": product.available_inventory,
},
)
raise InsufficientInventoryForCartException(
product_id=product_id,
product_name=product.marketplace_product.title,
requested=new_quantity,
available=product.available_inventory
available=product.available_inventory,
)
existing_item.quantity = new_quantity
@@ -206,16 +213,13 @@ class CartService:
logger.info(
f"[CART_SERVICE] Updated existing cart item",
extra={
"cart_item_id": existing_item.id,
"new_quantity": new_quantity
}
extra={"cart_item_id": existing_item.id, "new_quantity": new_quantity},
)
return {
"message": "Product quantity updated in cart",
"product_id": product_id,
"quantity": new_quantity
"quantity": new_quantity,
}
else:
# Check inventory for new item
@@ -225,14 +229,14 @@ class CartService:
extra={
"product_id": product_id,
"requested": quantity,
"available": product.available_inventory
}
"available": product.available_inventory,
},
)
raise InsufficientInventoryForCartException(
product_id=product_id,
product_name=product.marketplace_product.title,
requested=quantity,
available=product.available_inventory
available=product.available_inventory,
)
# Create new cart item
@@ -241,7 +245,7 @@ class CartService:
session_id=session_id,
product_id=product_id,
quantity=quantity,
price_at_add=current_price
price_at_add=current_price,
)
db.add(cart_item)
db.commit()
@@ -252,14 +256,14 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"quantity": quantity,
"price": current_price
}
"price": current_price,
},
)
return {
"message": "Product added to cart",
"product_id": product_id,
"quantity": quantity
"quantity": quantity,
}
def update_cart_item(
@@ -268,7 +272,7 @@ class CartService:
vendor_id: int,
session_id: str,
product_id: int,
quantity: int
quantity: int,
) -> Dict:
"""
Update quantity of item in cart.
@@ -292,25 +296,35 @@ class CartService:
raise InvalidCartQuantityException(quantity=quantity, min_quantity=1)
# Find cart item
cart_item = db.query(CartItem).filter(
cart_item = (
db.query(CartItem)
.filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id
CartItem.product_id == product_id,
)
)
.first()
)
).first()
if not cart_item:
raise CartItemNotFoundException(product_id=product_id, session_id=session_id)
raise CartItemNotFoundException(
product_id=product_id, session_id=session_id
)
# Verify product still exists and is active
product = db.query(Product).filter(
product = (
db.query(Product)
.filter(
and_(
Product.id == product_id,
Product.vendor_id == vendor_id,
Product.is_active == True
Product.is_active == True,
)
)
.first()
)
).first()
if not product:
raise ProductNotFoundException(str(product_id))
@@ -321,7 +335,7 @@ class CartService:
product_id=product_id,
product_name=product.marketplace_product.title,
requested=quantity,
available=product.available_inventory
available=product.available_inventory,
)
# Update quantity
@@ -334,22 +348,18 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"product_id": product_id,
"new_quantity": quantity
}
"new_quantity": quantity,
},
)
return {
"message": "Cart updated",
"product_id": product_id,
"quantity": quantity
"quantity": quantity,
}
def remove_from_cart(
self,
db: Session,
vendor_id: int,
session_id: str,
product_id: int
self, db: Session, vendor_id: int, session_id: str, product_id: int
) -> Dict:
"""
Remove item from cart.
@@ -367,16 +377,22 @@ class CartService:
ProductNotFoundException: If product not in cart
"""
# Find and delete cart item
cart_item = db.query(CartItem).filter(
cart_item = (
db.query(CartItem)
.filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id
CartItem.product_id == product_id,
)
)
.first()
)
).first()
if not cart_item:
raise CartItemNotFoundException(product_id=product_id, session_id=session_id)
raise CartItemNotFoundException(
product_id=product_id, session_id=session_id
)
db.delete(cart_item)
db.commit()
@@ -386,21 +402,13 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"product_id": product_id,
"session_id": session_id
}
"session_id": session_id,
},
)
return {
"message": "Item removed from cart",
"product_id": product_id
}
return {"message": "Item removed from cart", "product_id": product_id}
def clear_cart(
self,
db: Session,
vendor_id: int,
session_id: str
) -> Dict:
def clear_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
"""
Clear all items from cart.
@@ -413,12 +421,13 @@ class CartService:
Success message with count of items removed
"""
# Delete all cart items for this session
deleted_count = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id
deleted_count = (
db.query(CartItem)
.filter(
and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
.delete()
)
).delete()
db.commit()
@@ -427,14 +436,11 @@ class CartService:
extra={
"session_id": session_id,
"vendor_id": vendor_id,
"items_removed": deleted_count
}
"items_removed": deleted_count,
},
)
return {
"message": "Cart cleared",
"items_removed": deleted_count
}
return {"message": "Cart cleared", "items_removed": deleted_count}
# Create service instance

View File

@@ -3,22 +3,20 @@ Code Quality Service
Business logic for managing architecture scans and violations
"""
import subprocess
import json
import logging
import subprocess
from datetime import datetime
from typing import List, Tuple, Optional, Dict
from pathlib import Path
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from typing import Dict, List, Optional, Tuple
from app.models.architecture_scan import (
ArchitectureScan,
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from app.models.architecture_scan import (ArchitectureRule, ArchitectureScan,
ArchitectureViolation,
ArchitectureRule,
ViolationAssignment,
ViolationComment
)
ViolationComment)
logger = logging.getLogger(__name__)
@@ -26,7 +24,7 @@ logger = logging.getLogger(__name__)
class CodeQualityService:
"""Service for managing code quality scans and violations"""
def run_scan(self, db: Session, triggered_by: str = 'manual') -> ArchitectureScan:
def run_scan(self, db: Session, triggered_by: str = "manual") -> ArchitectureScan:
"""
Run architecture validator and store results in database
@@ -49,10 +47,10 @@ class CodeQualityService:
start_time = datetime.now()
try:
result = subprocess.run(
['python', 'scripts/validate_architecture.py', '--json'],
["python", "scripts/validate_architecture.py", "--json"],
capture_output=True,
text=True,
timeout=300 # 5 minute timeout
timeout=300, # 5 minute timeout
)
except subprocess.TimeoutExpired:
logger.error("Architecture scan timed out after 5 minutes")
@@ -63,17 +61,17 @@ class CodeQualityService:
# Parse JSON output (get only the JSON part, skip progress messages)
try:
# Find the JSON part in stdout
lines = result.stdout.strip().split('\n')
lines = result.stdout.strip().split("\n")
json_start = -1
for i, line in enumerate(lines):
if line.strip().startswith('{'):
if line.strip().startswith("{"):
json_start = i
break
if json_start == -1:
raise ValueError("No JSON output found")
json_output = '\n'.join(lines[json_start:])
json_output = "\n".join(lines[json_start:])
data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse validator output: {e}")
@@ -84,33 +82,33 @@ class CodeQualityService:
# Create scan record
scan = ArchitectureScan(
timestamp=datetime.now(),
total_files=data.get('files_checked', 0),
total_violations=data.get('total_violations', 0),
errors=data.get('errors', 0),
warnings=data.get('warnings', 0),
total_files=data.get("files_checked", 0),
total_violations=data.get("total_violations", 0),
errors=data.get("errors", 0),
warnings=data.get("warnings", 0),
duration_seconds=duration,
triggered_by=triggered_by,
git_commit_hash=git_commit
git_commit_hash=git_commit,
)
db.add(scan)
db.flush() # Get scan.id
# Create violation records
violations_data = data.get('violations', [])
violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} violation records")
for v in violations_data:
violation = ArchitectureViolation(
scan_id=scan.id,
rule_id=v['rule_id'],
rule_name=v['rule_name'],
severity=v['severity'],
file_path=v['file_path'],
line_number=v['line_number'],
message=v['message'],
context=v.get('context', ''),
suggestion=v.get('suggestion', ''),
status='open'
rule_id=v["rule_id"],
rule_name=v["rule_name"],
severity=v["severity"],
file_path=v["file_path"],
line_number=v["line_number"],
message=v["message"],
context=v.get("context", ""),
suggestion=v.get("suggestion", ""),
status="open",
)
db.add(violation)
@@ -122,7 +120,11 @@ class CodeQualityService:
def get_latest_scan(self, db: Session) -> Optional[ArchitectureScan]:
"""Get the most recent scan"""
return db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)).first()
return (
db.query(ArchitectureScan)
.order_by(desc(ArchitectureScan.timestamp))
.first()
)
def get_scan_by_id(self, db: Session, scan_id: int) -> Optional[ArchitectureScan]:
"""Get scan by ID"""
@@ -139,10 +141,12 @@ class CodeQualityService:
Returns:
List of ArchitectureScan objects, newest first
"""
return db.query(ArchitectureScan)\
.order_by(desc(ArchitectureScan.timestamp))\
.limit(limit)\
return (
db.query(ArchitectureScan)
.order_by(desc(ArchitectureScan.timestamp))
.limit(limit)
.all()
)
def get_violations(
self,
@@ -153,7 +157,7 @@ class CodeQualityService:
rule_id: str = None,
file_path: str = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
) -> Tuple[List[ArchitectureViolation], int]:
"""
Get violations with filtering and pagination
@@ -194,24 +198,32 @@ class CodeQualityService:
query = query.filter(ArchitectureViolation.rule_id == rule_id)
if file_path:
query = query.filter(ArchitectureViolation.file_path.like(f'%{file_path}%'))
query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
# Get total count
total = query.count()
# Get page of results
violations = query.order_by(
ArchitectureViolation.severity.desc(),
ArchitectureViolation.file_path
).limit(limit).offset(offset).all()
violations = (
query.order_by(
ArchitectureViolation.severity.desc(), ArchitectureViolation.file_path
)
.limit(limit)
.offset(offset)
.all()
)
return violations, total
def get_violation_by_id(self, db: Session, violation_id: int) -> Optional[ArchitectureViolation]:
def get_violation_by_id(
self, db: Session, violation_id: int
) -> Optional[ArchitectureViolation]:
"""Get single violation with details"""
return db.query(ArchitectureViolation).filter(
ArchitectureViolation.id == violation_id
).first()
return (
db.query(ArchitectureViolation)
.filter(ArchitectureViolation.id == violation_id)
.first()
)
def assign_violation(
self,
@@ -220,7 +232,7 @@ class CodeQualityService:
user_id: int,
assigned_by: int,
due_date: datetime = None,
priority: str = 'medium'
priority: str = "medium",
) -> ViolationAssignment:
"""
Assign violation to a developer
@@ -239,7 +251,7 @@ class CodeQualityService:
# Update violation status
violation = self.get_violation_by_id(db, violation_id)
if violation:
violation.status = 'assigned'
violation.status = "assigned"
violation.assigned_to = user_id
# Create assignment record
@@ -248,7 +260,7 @@ class CodeQualityService:
user_id=user_id,
assigned_by=assigned_by,
due_date=due_date,
priority=priority
priority=priority,
)
db.add(assignment)
db.commit()
@@ -257,11 +269,7 @@ class CodeQualityService:
return assignment
def resolve_violation(
self,
db: Session,
violation_id: int,
resolved_by: int,
resolution_note: str
self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
) -> ArchitectureViolation:
"""
Mark violation as resolved
@@ -279,7 +287,7 @@ class CodeQualityService:
if not violation:
raise ValueError(f"Violation {violation_id} not found")
violation.status = 'resolved'
violation.status = "resolved"
violation.resolved_at = datetime.now()
violation.resolved_by = resolved_by
violation.resolution_note = resolution_note
@@ -289,11 +297,7 @@ class CodeQualityService:
return violation
def ignore_violation(
self,
db: Session,
violation_id: int,
ignored_by: int,
reason: str
self, db: Session, violation_id: int, ignored_by: int, reason: str
) -> ArchitectureViolation:
"""
Mark violation as ignored/won't fix
@@ -311,7 +315,7 @@ class CodeQualityService:
if not violation:
raise ValueError(f"Violation {violation_id} not found")
violation.status = 'ignored'
violation.status = "ignored"
violation.resolved_at = datetime.now()
violation.resolved_by = ignored_by
violation.resolution_note = f"Ignored: {reason}"
@@ -321,11 +325,7 @@ class CodeQualityService:
return violation
def add_comment(
self,
db: Session,
violation_id: int,
user_id: int,
comment: str
self, db: Session, violation_id: int, user_id: int, comment: str
) -> ViolationComment:
"""
Add comment to violation
@@ -340,9 +340,7 @@ class CodeQualityService:
ViolationComment object
"""
comment_obj = ViolationComment(
violation_id=violation_id,
user_id=user_id,
comment=comment
violation_id=violation_id, user_id=user_id, comment=comment
)
db.add(comment_obj)
db.commit()
@@ -360,79 +358,95 @@ class CodeQualityService:
latest_scan = self.get_latest_scan(db)
if not latest_scan:
return {
'total_violations': 0,
'errors': 0,
'warnings': 0,
'open': 0,
'assigned': 0,
'resolved': 0,
'ignored': 0,
'technical_debt_score': 100,
'trend': [],
'by_severity': {},
'by_rule': {},
'by_module': {},
'top_files': []
"total_violations": 0,
"errors": 0,
"warnings": 0,
"open": 0,
"assigned": 0,
"resolved": 0,
"ignored": 0,
"technical_debt_score": 100,
"trend": [],
"by_severity": {},
"by_rule": {},
"by_module": {},
"top_files": [],
}
# Get violation counts by status
status_counts = db.query(
ArchitectureViolation.status,
func.count(ArchitectureViolation.id)
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.status).all()
status_counts = (
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.status)
.all()
)
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
severity_counts = db.query(
ArchitectureViolation.severity,
func.count(ArchitectureViolation.id)
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.severity).all()
severity_counts = (
db.query(
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule
rule_counts = db.query(
ArchitectureViolation.rule_id,
func.count(ArchitectureViolation.id)
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.rule_id).all()
rule_counts = (
db.query(
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {rule: count for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]}
by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[
:10
]
}
# Get top violating files
file_counts = db.query(
file_counts = (
db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label('count')
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.file_path)\
.order_by(desc('count'))\
.limit(10).all()
func.count(ArchitectureViolation.id).label("count"),
)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.file_path)
.order_by(desc("count"))
.limit(10)
.all()
)
top_files = [{'file': file, 'count': count} for file, count in file_counts]
top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module (extract module from file path)
by_module = {}
violations = db.query(ArchitectureViolation.file_path).filter(
ArchitectureViolation.scan_id == latest_scan.id
).all()
violations = (
db.query(ArchitectureViolation.file_path)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.all()
)
for v in violations:
path_parts = v.file_path.split('/')
path_parts = v.file_path.split("/")
if len(path_parts) >= 2:
module = '/'.join(path_parts[:2]) # e.g., 'app/api'
module = "/".join(path_parts[:2]) # e.g., 'app/api'
else:
module = path_parts[0]
by_module[module] = by_module.get(module, 0) + 1
# Sort by count and take top 10
by_module = dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
by_module = dict(
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
)
# Calculate technical debt score
tech_debt_score = self.calculate_technical_debt_score(db, latest_scan.id)
@@ -441,29 +455,29 @@ class CodeQualityService:
trend_scans = self.get_scan_history(db, limit=7)
trend = [
{
'timestamp': scan.timestamp.isoformat(),
'violations': scan.total_violations,
'errors': scan.errors,
'warnings': scan.warnings
"timestamp": scan.timestamp.isoformat(),
"violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
}
for scan in reversed(trend_scans) # Oldest first for chart
]
return {
'total_violations': latest_scan.total_violations,
'errors': latest_scan.errors,
'warnings': latest_scan.warnings,
'open': status_dict.get('open', 0),
'assigned': status_dict.get('assigned', 0),
'resolved': status_dict.get('resolved', 0),
'ignored': status_dict.get('ignored', 0),
'technical_debt_score': tech_debt_score,
'trend': trend,
'by_severity': by_severity,
'by_rule': by_rule,
'by_module': by_module,
'top_files': top_files,
'last_scan': latest_scan.timestamp.isoformat() if latest_scan else None
"total_violations": latest_scan.total_violations,
"errors": latest_scan.errors,
"warnings": latest_scan.warnings,
"open": status_dict.get("open", 0),
"assigned": status_dict.get("assigned", 0),
"resolved": status_dict.get("resolved", 0),
"ignored": status_dict.get("ignored", 0),
"technical_debt_score": tech_debt_score,
"trend": trend,
"by_severity": by_severity,
"by_rule": by_rule,
"by_module": by_module,
"top_files": top_files,
"last_scan": latest_scan.timestamp.isoformat() if latest_scan else None,
}
def calculate_technical_debt_score(self, db: Session, scan_id: int = None) -> int:
@@ -497,10 +511,7 @@ class CodeQualityService:
"""Get current git commit hash"""
try:
result = subprocess.run(
['git', 'rev-parse', 'HEAD'],
capture_output=True,
text=True,
timeout=5
["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return result.stdout.strip()[:40]

View File

@@ -19,8 +19,9 @@ This allows:
import logging
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from models.database.content_page import ContentPage
@@ -35,7 +36,7 @@ class ContentPageService:
db: Session,
slug: str,
vendor_id: Optional[int] = None,
include_unpublished: bool = False
include_unpublished: bool = False,
) -> Optional[ContentPage]:
"""
Get content page for a vendor with fallback to platform default.
@@ -62,28 +63,20 @@ class ContentPageService:
if vendor_id:
vendor_page = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == vendor_id,
*filters
)
)
.filter(and_(ContentPage.vendor_id == vendor_id, *filters))
.first()
)
if vendor_page:
logger.debug(f"Found vendor-specific page: {slug} for vendor_id={vendor_id}")
logger.debug(
f"Found vendor-specific page: {slug} for vendor_id={vendor_id}"
)
return vendor_page
# Fallback to platform default
platform_page = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == None,
*filters
)
)
.filter(and_(ContentPage.vendor_id == None, *filters))
.first()
)
@@ -100,7 +93,7 @@ class ContentPageService:
vendor_id: Optional[int] = None,
include_unpublished: bool = False,
footer_only: bool = False,
header_only: bool = False
header_only: bool = False,
) -> List[ContentPage]:
"""
List all available pages for a vendor (includes vendor overrides + platform defaults).
@@ -133,12 +126,7 @@ class ContentPageService:
if vendor_id:
vendor_pages = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == vendor_id,
*filters
)
)
.filter(and_(ContentPage.vendor_id == vendor_id, *filters))
.order_by(ContentPage.display_order, ContentPage.title)
.all()
)
@@ -146,12 +134,7 @@ class ContentPageService:
# Get platform defaults
platform_pages = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == None,
*filters
)
)
.filter(and_(ContentPage.vendor_id == None, *filters))
.order_by(ContentPage.display_order, ContentPage.title)
.all()
)
@@ -159,8 +142,7 @@ class ContentPageService:
# Merge: vendor overrides take precedence
vendor_slugs = {page.slug for page in vendor_pages}
all_pages = vendor_pages + [
page for page in platform_pages
if page.slug not in vendor_slugs
page for page in platform_pages if page.slug not in vendor_slugs
]
# Sort by display_order
@@ -183,7 +165,7 @@ class ContentPageService:
show_in_footer: bool = True,
show_in_header: bool = False,
display_order: int = 0,
created_by: Optional[int] = None
created_by: Optional[int] = None,
) -> ContentPage:
"""
Create a new content page.
@@ -229,7 +211,9 @@ class ContentPageService:
db.commit()
db.refresh(page)
logger.info(f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})")
logger.info(
f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})"
)
return page
@staticmethod
@@ -246,7 +230,7 @@ class ContentPageService:
show_in_footer: Optional[bool] = None,
show_in_header: Optional[bool] = None,
display_order: Optional[int] = None,
updated_by: Optional[int] = None
updated_by: Optional[int] = None,
) -> Optional[ContentPage]:
"""
Update an existing content page.
@@ -338,9 +322,7 @@ class ContentPageService:
@staticmethod
def list_all_vendor_pages(
db: Session,
vendor_id: int,
include_unpublished: bool = False
db: Session, vendor_id: int, include_unpublished: bool = False
) -> List[ContentPage]:
"""
List only vendor-specific pages (no platform defaults).
@@ -367,8 +349,7 @@ class ContentPageService:
@staticmethod
def list_all_platform_pages(
db: Session,
include_unpublished: bool = False
db: Session, include_unpublished: bool = False
) -> List[ContentPage]:
"""
List only platform default pages.

View File

@@ -8,24 +8,24 @@ with complete vendor isolation.
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import Any, Dict, Optional
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions.customer import (CustomerAlreadyExistsException,
CustomerNotActiveException,
CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException)
from app.exceptions.vendor import (VendorNotActiveException,
VendorNotFoundException)
from app.services.auth_service import AuthService
from models.database.customer import Customer, CustomerAddress
from models.database.vendor import Vendor
from models.schema.customer import CustomerRegister, CustomerUpdate
from models.schema.auth import UserLogin
from app.exceptions.customer import (
CustomerNotFoundException,
CustomerAlreadyExistsException,
CustomerNotActiveException,
InvalidCustomerCredentialsException,
CustomerValidationException,
DuplicateCustomerEmailException
)
from app.exceptions.vendor import VendorNotFoundException, VendorNotActiveException
from app.services.auth_service import AuthService
from models.schema.customer import CustomerRegister, CustomerUpdate
logger = logging.getLogger(__name__)
@@ -37,10 +37,7 @@ class CustomerService:
self.auth_service = AuthService()
def register_customer(
self,
db: Session,
vendor_id: int,
customer_data: CustomerRegister
self, db: Session, vendor_id: int, customer_data: CustomerRegister
) -> Customer:
"""
Register a new customer for a specific vendor.
@@ -68,18 +65,26 @@ class CustomerService:
raise VendorNotActiveException(vendor.vendor_code)
# Check if email already exists for this vendor
existing_customer = db.query(Customer).filter(
existing_customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == customer_data.email.lower()
Customer.email == customer_data.email.lower(),
)
)
.first()
)
).first()
if existing_customer:
raise DuplicateCustomerEmailException(customer_data.email, vendor.vendor_code)
raise DuplicateCustomerEmailException(
customer_data.email, vendor.vendor_code
)
# Generate unique customer number for this vendor
customer_number = self._generate_customer_number(db, vendor_id, vendor.vendor_code)
customer_number = self._generate_customer_number(
db, vendor_id, vendor.vendor_code
)
# Hash password
hashed_password = self.auth_service.hash_password(customer_data.password)
@@ -93,8 +98,12 @@ class CustomerService:
last_name=customer_data.last_name,
phone=customer_data.phone,
customer_number=customer_number,
marketing_consent=customer_data.marketing_consent if hasattr(customer_data, 'marketing_consent') else False,
is_active=True
marketing_consent=(
customer_data.marketing_consent
if hasattr(customer_data, "marketing_consent")
else False
),
is_active=True,
)
try:
@@ -114,15 +123,11 @@ class CustomerService:
db.rollback()
logger.error(f"Error registering customer: {str(e)}")
raise CustomerValidationException(
message="Failed to register customer",
details={"error": str(e)}
message="Failed to register customer", details={"error": str(e)}
)
def login_customer(
self,
db: Session,
vendor_id: int,
credentials: UserLogin
self, db: Session, vendor_id: int, credentials: UserLogin
) -> Dict[str, Any]:
"""
Authenticate customer and generate JWT token.
@@ -146,20 +151,23 @@ class CustomerService:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
# Find customer by email (vendor-scoped)
customer = db.query(Customer).filter(
customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == credentials.email_or_username.lower()
Customer.email == credentials.email_or_username.lower(),
)
)
.first()
)
).first()
if not customer:
raise InvalidCustomerCredentialsException()
# Verify password using auth_manager directly
if not self.auth_service.auth_manager.verify_password(
credentials.password,
customer.hashed_password
credentials.password, customer.hashed_password
):
raise InvalidCustomerCredentialsException()
@@ -170,6 +178,7 @@ class CustomerService:
# Generate JWT token with customer context
# Use auth_manager directly since Customer is not a User model
from datetime import datetime, timedelta, timezone
from jose import jwt
auth_manager = self.auth_service.auth_manager
@@ -185,7 +194,9 @@ class CustomerService:
"iat": datetime.now(timezone.utc),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
token_data = {
"access_token": token,
@@ -198,17 +209,9 @@ class CustomerService:
f"for vendor {vendor.vendor_code}"
)
return {
"customer": customer,
"token_data": token_data
}
return {"customer": customer, "token_data": token_data}
def get_customer(
self,
db: Session,
vendor_id: int,
customer_id: int
) -> Customer:
def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
"""
Get customer by ID with vendor isolation.
@@ -223,12 +226,11 @@ class CustomerService:
Raises:
CustomerNotFoundException: If customer not found
"""
customer = db.query(Customer).filter(
and_(
Customer.id == customer_id,
Customer.vendor_id == vendor_id
customer = (
db.query(Customer)
.filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
.first()
)
).first()
if not customer:
raise CustomerNotFoundException(str(customer_id))
@@ -236,10 +238,7 @@ class CustomerService:
return customer
def get_customer_by_email(
self,
db: Session,
vendor_id: int,
email: str
self, db: Session, vendor_id: int, email: str
) -> Optional[Customer]:
"""
Get customer by email (vendor-scoped).
@@ -252,19 +251,20 @@ class CustomerService:
Returns:
Optional[Customer]: Customer object or None
"""
return db.query(Customer).filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == email.lower()
return (
db.query(Customer)
.filter(
and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
)
.first()
)
).first()
def update_customer(
self,
db: Session,
vendor_id: int,
customer_id: int,
customer_data: CustomerUpdate
customer_data: CustomerUpdate,
) -> Customer:
"""
Update customer profile.
@@ -290,13 +290,17 @@ class CustomerService:
for field, value in update_data.items():
if field == "email" and value:
# Check if new email already exists for this vendor
existing = db.query(Customer).filter(
existing = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == value.lower(),
Customer.id != customer_id
Customer.id != customer_id,
)
)
.first()
)
).first()
if existing:
raise DuplicateCustomerEmailException(value, "vendor")
@@ -317,15 +321,11 @@ class CustomerService:
db.rollback()
logger.error(f"Error updating customer: {str(e)}")
raise CustomerValidationException(
message="Failed to update customer",
details={"error": str(e)}
message="Failed to update customer", details={"error": str(e)}
)
def deactivate_customer(
self,
db: Session,
vendor_id: int,
customer_id: int
self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Deactivate customer account.
@@ -352,10 +352,7 @@ class CustomerService:
return customer
def update_customer_stats(
self,
db: Session,
customer_id: int,
order_total: float
self, db: Session, customer_id: int, order_total: float
) -> None:
"""
Update customer statistics after order.
@@ -377,10 +374,7 @@ class CustomerService:
logger.debug(f"Updated stats for customer {customer.email}")
def _generate_customer_number(
self,
db: Session,
vendor_id: int,
vendor_code: str
self, db: Session, vendor_id: int, vendor_code: str
) -> str:
"""
Generate unique customer number for vendor.
@@ -397,21 +391,23 @@ class CustomerService:
str: Unique customer number
"""
# Get count of customers for this vendor
count = db.query(Customer).filter(
Customer.vendor_id == vendor_id
).count()
count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
# Generate number with padding
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
# Ensure uniqueness (in case of deletions)
while db.query(Customer).filter(
while (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.customer_number == customer_number
Customer.customer_number == customer_number,
)
).first():
)
.first()
):
count += 1
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"

View File

@@ -5,27 +5,20 @@ from typing import List, Optional
from sqlalchemy.orm import Session
from app.exceptions import (
InventoryNotFoundException,
InsufficientInventoryException,
from app.exceptions import (InsufficientInventoryException,
InvalidInventoryOperationException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException,
NegativeInventoryException,
InvalidQuantityException,
ValidationException,
ProductNotFoundException,
)
from models.schema.inventory import (
InventoryCreate,
InventoryAdjust,
InventoryUpdate,
InventoryReserve,
InventoryLocationResponse,
ProductInventorySummary
)
ProductNotFoundException, ValidationException)
from models.database.inventory import Inventory
from models.database.product import Product
from models.database.vendor import Vendor
from models.schema.inventory import (InventoryAdjust, InventoryCreate,
InventoryLocationResponse,
InventoryReserve, InventoryUpdate,
ProductInventorySummary)
logger = logging.getLogger(__name__)
@@ -93,7 +86,11 @@ class InventoryService:
)
return new_inventory
except (ProductNotFoundException, InvalidQuantityException, InventoryValidationException):
except (
ProductNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
@@ -124,7 +121,9 @@ class InventoryService:
location = self._validate_location(inventory_data.location)
# Check if inventory exists
existing = self._get_inventory_entry(db, inventory_data.product_id, location)
existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if not existing:
# Create new if adding, error if removing
@@ -173,8 +172,12 @@ class InventoryService:
)
return existing
except (ProductNotFoundException, InventoryNotFoundException,
InsufficientInventoryException, InventoryValidationException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
@@ -231,8 +234,12 @@ class InventoryService:
)
return inventory
except (ProductNotFoundException, InventoryNotFoundException,
InsufficientInventoryException, InvalidQuantityException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
@@ -287,7 +294,11 @@ class InventoryService:
)
return inventory
except (ProductNotFoundException, InventoryNotFoundException, InvalidQuantityException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
@@ -349,8 +360,12 @@ class InventoryService:
)
return inventory
except (ProductNotFoundException, InventoryNotFoundException,
InsufficientInventoryException, InvalidQuantityException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
@@ -376,9 +391,7 @@ class InventoryService:
product = self._get_vendor_product(db, vendor_id, product_id)
inventory_entries = (
db.query(Inventory)
.filter(Inventory.product_id == product_id)
.all()
db.query(Inventory).filter(Inventory.product_id == product_id).all()
)
if not inventory_entries:
@@ -425,8 +438,13 @@ class InventoryService:
raise ValidationException("Failed to retrieve product inventory")
def get_vendor_inventory(
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 100,
location: Optional[str] = None, low_stock_threshold: Optional[int] = None
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
location: Optional[str] = None,
low_stock_threshold: Optional[int] = None,
) -> List[Inventory]:
"""
Get all inventory for a vendor with filtering.
@@ -458,8 +476,11 @@ class InventoryService:
raise ValidationException("Failed to retrieve vendor inventory")
def update_inventory(
self, db: Session, vendor_id: int, inventory_id: int,
inventory_update: InventoryUpdate
self,
db: Session,
vendor_id: int,
inventory_id: int,
inventory_update: InventoryUpdate,
) -> Inventory:
"""Update inventory entry."""
try:
@@ -475,7 +496,9 @@ class InventoryService:
inventory.quantity = inventory_update.quantity
if inventory_update.reserved_quantity is not None:
self._validate_quantity(inventory_update.reserved_quantity, allow_zero=True)
self._validate_quantity(
inventory_update.reserved_quantity, allow_zero=True
)
inventory.reserved_quantity = inventory_update.reserved_quantity
if inventory_update.location:
@@ -488,7 +511,11 @@ class InventoryService:
logger.info(f"Updated inventory {inventory_id}")
return inventory
except (InventoryNotFoundException, InvalidQuantityException, InventoryValidationException):
except (
InventoryNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
@@ -496,9 +523,7 @@ class InventoryService:
logger.error(f"Error updating inventory: {str(e)}")
raise ValidationException("Failed to update inventory")
def delete_inventory(
self, db: Session, vendor_id: int, inventory_id: int
) -> bool:
def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
"""Delete inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
@@ -521,15 +546,20 @@ class InventoryService:
raise ValidationException("Failed to delete inventory")
# Private helper methods
def _get_vendor_product(self, db: Session, vendor_id: int, product_id: int) -> Product:
def _get_vendor_product(
self, db: Session, vendor_id: int, product_id: int
) -> Product:
"""Get product and verify it belongs to vendor."""
product = db.query(Product).filter(
Product.id == product_id,
Product.vendor_id == vendor_id
).first()
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(f"Product {product_id} not found in your catalog")
raise ProductNotFoundException(
f"Product {product_id} not found in your catalog"
)
return product
@@ -539,10 +569,7 @@ class InventoryService:
"""Get inventory entry by product and location."""
return (
db.query(Inventory)
.filter(
Inventory.product_id == product_id,
Inventory.location == location
)
.filter(Inventory.product_id == product_id, Inventory.location == location)
.first()
)

View File

@@ -5,20 +5,15 @@ from typing import List, Optional
from sqlalchemy.orm import Session
from app.exceptions import (
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportJobCannotBeCancelledException,
from app.exceptions import (ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException,
ValidationException,
)
from models.schema.marketplace_import_job import (
MarketplaceImportJobResponse,
MarketplaceImportJobRequest
)
ImportJobNotFoundException,
ImportJobNotOwnedException, ValidationException)
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
MarketplaceImportJobResponse)
logger = logging.getLogger(__name__)
@@ -31,7 +26,7 @@ class MarketplaceImportJobService:
db: Session,
request: MarketplaceImportJobRequest,
vendor: Vendor, # CHANGED: Vendor object from middleware
user: User
user: User,
) -> MarketplaceImportJob:
"""
Create a new marketplace import job.
@@ -147,7 +142,9 @@ class MarketplaceImportJobService:
marketplace=job.marketplace,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None, # FIXED
vendor_name=job.vendor.name if job.vendor else None, # FIXED: from relationship
vendor_name=(
job.vendor.name if job.vendor else None
), # FIXED: from relationship
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,

View File

@@ -17,19 +17,20 @@ from typing import Generator, List, Optional, Tuple
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.exceptions import (
MarketplaceProductNotFoundException,
from app.exceptions import (InvalidMarketplaceProductDataException,
MarketplaceProductAlreadyExistsException,
InvalidMarketplaceProductDataException,
MarketplaceProductNotFoundException,
MarketplaceProductValidationException,
ValidationException,
)
from app.services.marketplace_import_job_service import marketplace_import_job_service
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
ValidationException)
from app.services.marketplace_import_job_service import \
marketplace_import_job_service
from app.utils.data_processing import GTINProcessor, PriceProcessor
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
from models.schema.inventory import (InventoryLocationResponse,
InventorySummaryResponse)
from models.schema.marketplace_product import (MarketplaceProductCreate,
MarketplaceProductUpdate)
logger = logging.getLogger(__name__)
@@ -42,14 +43,18 @@ class MarketplaceProductService:
self.gtin_processor = GTINProcessor()
self.price_processor = PriceProcessor()
def create_product(self, db: Session, product_data: MarketplaceProductCreate) -> MarketplaceProduct:
def create_product(
self, db: Session, product_data: MarketplaceProductCreate
) -> MarketplaceProduct:
"""Create a new product with validation."""
try:
# Process and validate GTIN if provided
if product_data.gtin:
normalized_gtin = self.gtin_processor.normalize(product_data.gtin)
if not normalized_gtin:
raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin")
raise InvalidMarketplaceProductDataException(
"Invalid GTIN format", field="gtin"
)
product_data.gtin = normalized_gtin
# Process price if provided
@@ -70,11 +75,18 @@ class MarketplaceProductService:
product_data.marketplace = "Letzshop"
# Validate required fields
if not product_data.marketplace_product_id or not product_data.marketplace_product_id.strip():
raise MarketplaceProductValidationException("MarketplaceProduct ID is required", field="marketplace_product_id")
if (
not product_data.marketplace_product_id
or not product_data.marketplace_product_id.strip()
):
raise MarketplaceProductValidationException(
"MarketplaceProduct ID is required", field="marketplace_product_id"
)
if not product_data.title or not product_data.title.strip():
raise MarketplaceProductValidationException("MarketplaceProduct title is required", field="title")
raise MarketplaceProductValidationException(
"MarketplaceProduct title is required", field="title"
)
db_product = MarketplaceProduct(**product_data.model_dump())
db.add(db_product)
@@ -84,30 +96,47 @@ class MarketplaceProductService:
logger.info(f"Created product {db_product.marketplace_product_id}")
return db_product
except (InvalidMarketplaceProductDataException, MarketplaceProductValidationException):
except (
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
):
db.rollback()
raise # Re-raise custom exceptions
except IntegrityError as e:
db.rollback()
logger.error(f"Database integrity error: {str(e)}")
if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower():
raise MarketplaceProductAlreadyExistsException(product_data.marketplace_product_id)
raise MarketplaceProductAlreadyExistsException(
product_data.marketplace_product_id
)
else:
raise MarketplaceProductValidationException("Data integrity constraint violation")
raise MarketplaceProductValidationException(
"Data integrity constraint violation"
)
except Exception as e:
db.rollback()
logger.error(f"Error creating product: {str(e)}")
raise ValidationException("Failed to create product")
def get_product_by_id(self, db: Session, marketplace_product_id: str) -> Optional[MarketplaceProduct]:
def get_product_by_id(
self, db: Session, marketplace_product_id: str
) -> Optional[MarketplaceProduct]:
"""Get a product by its ID."""
try:
return db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
return (
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
)
.first()
)
except Exception as e:
logger.error(f"Error getting product {marketplace_product_id}: {str(e)}")
return None
def get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct:
def get_product_by_id_or_raise(
self, db: Session, marketplace_product_id: str
) -> MarketplaceProduct:
"""
Get a product by its ID or raise exception.
@@ -162,13 +191,19 @@ class MarketplaceProductService:
if brand:
query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%"))
if category:
query = query.filter(MarketplaceProduct.google_product_category.ilike(f"%{category}%"))
query = query.filter(
MarketplaceProduct.google_product_category.ilike(f"%{category}%")
)
if availability:
query = query.filter(MarketplaceProduct.availability == availability)
if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
if search:
# Search in title, description, marketplace, and name
search_term = f"%{search}%"
@@ -188,7 +223,12 @@ class MarketplaceProductService:
logger.error(f"Error getting products with filters: {str(e)}")
raise ValidationException("Failed to retrieve products")
def update_product(self, db: Session, marketplace_product_id: str, product_update: MarketplaceProductUpdate) -> MarketplaceProduct:
def update_product(
self,
db: Session,
marketplace_product_id: str,
product_update: MarketplaceProductUpdate,
) -> MarketplaceProduct:
"""Update product with validation."""
try:
product = self.get_product_by_id_or_raise(db, marketplace_product_id)
@@ -200,7 +240,9 @@ class MarketplaceProductService:
if "gtin" in update_data and update_data["gtin"]:
normalized_gtin = self.gtin_processor.normalize(update_data["gtin"])
if not normalized_gtin:
raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin")
raise InvalidMarketplaceProductDataException(
"Invalid GTIN format", field="gtin"
)
update_data["gtin"] = normalized_gtin
# Process price if being updated
@@ -217,8 +259,12 @@ class MarketplaceProductService:
raise InvalidMarketplaceProductDataException(str(e), field="price")
# Validate required fields if being updated
if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()):
raise MarketplaceProductValidationException("MarketplaceProduct title cannot be empty", field="title")
if "title" in update_data and (
not update_data["title"] or not update_data["title"].strip()
):
raise MarketplaceProductValidationException(
"MarketplaceProduct title cannot be empty", field="title"
)
for key, value in update_data.items():
setattr(product, key, value)
@@ -230,7 +276,11 @@ class MarketplaceProductService:
logger.info(f"Updated product {marketplace_product_id}")
return product
except (MarketplaceProductNotFoundException, InvalidMarketplaceProductDataException, MarketplaceProductValidationException):
except (
MarketplaceProductNotFoundException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -272,7 +322,9 @@ class MarketplaceProductService:
logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}")
raise ValidationException("Failed to delete product")
def get_inventory_info(self, db: Session, gtin: str) -> Optional[InventorySummaryResponse]:
def get_inventory_info(
self, db: Session, gtin: str
) -> Optional[InventorySummaryResponse]:
"""
Get inventory information for a product by GTIN.
@@ -290,7 +342,9 @@ class MarketplaceProductService:
total_quantity = sum(entry.quantity for entry in inventory_entries)
locations = [
InventoryLocationResponse(location=entry.location, quantity=entry.quantity)
InventoryLocationResponse(
location=entry.location, quantity=entry.quantity
)
for entry in inventory_entries
]
@@ -305,6 +359,7 @@ class MarketplaceProductService:
import csv
from io import StringIO
from typing import Generator, Optional
from sqlalchemy.orm import Session
def generate_csv_export(
@@ -331,9 +386,18 @@ class MarketplaceProductService:
# Write header row
headers = [
"marketplace_product_id", "title", "description", "link", "image_link",
"availability", "price", "currency", "brand", "gtin",
"marketplace", "name"
"marketplace_product_id",
"title",
"description",
"link",
"image_link",
"availability",
"price",
"currency",
"brand",
"gtin",
"marketplace",
"name",
]
writer.writerow(headers)
yield output.getvalue()
@@ -350,9 +414,13 @@ class MarketplaceProductService:
# Apply marketplace filters
if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
products = query.offset(offset).limit(batch_size).all()
if not products:
@@ -392,7 +460,11 @@ class MarketplaceProductService:
"""Check if product exists by ID."""
try:
return (
db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
)
.first()
is not None
)
except Exception as e:
@@ -402,18 +474,27 @@ class MarketplaceProductService:
# Private helper methods
def _validate_product_data(self, product_data: dict) -> None:
"""Validate product data structure."""
required_fields = ['marketplace_product_id', 'title']
required_fields = ["marketplace_product_id", "title"]
for field in required_fields:
if field not in product_data or not product_data[field]:
raise MarketplaceProductValidationException(f"{field} is required", field=field)
raise MarketplaceProductValidationException(
f"{field} is required", field=field
)
def _normalize_product_data(self, product_data: dict) -> dict:
"""Normalize and clean product data."""
normalized = product_data.copy()
# Trim whitespace from string fields
string_fields = ['marketplace_product_id', 'title', 'description', 'brand', 'marketplace', 'name']
string_fields = [
"marketplace_product_id",
"title",
"description",
"brand",
"marketplace",
"name",
]
for field in string_fields:
if field in normalized and normalized[field]:
normalized[field] = normalized[field].strip()

View File

@@ -9,24 +9,21 @@ This module provides:
"""
import logging
from datetime import datetime, timezone
from typing import List, Optional, Tuple
import random
import string
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from models.database.order import Order, OrderItem
from models.database.customer import Customer, CustomerAddress
from models.database.product import Product
from models.schema.order import OrderCreate, OrderUpdate, OrderAddressCreate
from app.exceptions import (
OrderNotFoundException,
ValidationException,
from app.exceptions import (CustomerNotFoundException,
InsufficientInventoryException,
CustomerNotFoundException
)
OrderNotFoundException, ValidationException)
from models.database.customer import Customer, CustomerAddress
from models.database.order import Order, OrderItem
from models.database.product import Product
from models.schema.order import OrderAddressCreate, OrderCreate, OrderUpdate
logger = logging.getLogger(__name__)
@@ -42,12 +39,16 @@ class OrderService:
Example: ORD-1-20250110-A1B2C3
"""
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
random_suffix = "".join(
random.choices(string.ascii_uppercase + string.digits, k=6)
)
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
# Ensure uniqueness
while db.query(Order).filter(Order.order_number == order_number).first():
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
random_suffix = "".join(
random.choices(string.ascii_uppercase + string.digits, k=6)
)
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
return order_number
@@ -58,7 +59,7 @@ class OrderService:
vendor_id: int,
customer_id: int,
address_data: OrderAddressCreate,
address_type: str
address_type: str,
) -> CustomerAddress:
"""Create a customer address for order."""
address = CustomerAddress(
@@ -73,17 +74,14 @@ class OrderService:
city=address_data.city,
postal_code=address_data.postal_code,
country=address_data.country,
is_default=False
is_default=False,
)
db.add(address)
db.flush() # Get ID without committing
return address
def create_order(
self,
db: Session,
vendor_id: int,
order_data: OrderCreate
self, db: Session, vendor_id: int, order_data: OrderCreate
) -> Order:
"""
Create a new order.
@@ -104,12 +102,15 @@ class OrderService:
# Validate customer exists if provided
customer_id = order_data.customer_id
if customer_id:
customer = db.query(Customer).filter(
customer = (
db.query(Customer)
.filter(
and_(
Customer.id == customer_id,
Customer.vendor_id == vendor_id
Customer.id == customer_id, Customer.vendor_id == vendor_id
)
)
.first()
)
).first()
if not customer:
raise CustomerNotFoundException(str(customer_id))
@@ -124,7 +125,7 @@ class OrderService:
vendor_id=vendor_id,
customer_id=customer_id,
address_data=order_data.shipping_address,
address_type="shipping"
address_type="shipping",
)
# Create billing address (use shipping if not provided)
@@ -134,7 +135,7 @@ class OrderService:
vendor_id=vendor_id,
customer_id=customer_id,
address_data=order_data.billing_address,
address_type="billing"
address_type="billing",
)
else:
billing_address = shipping_address
@@ -145,23 +146,29 @@ class OrderService:
for item_data in order_data.items:
# Get product
product = db.query(Product).filter(
product = (
db.query(Product)
.filter(
and_(
Product.id == item_data.product_id,
Product.vendor_id == vendor_id,
Product.is_active == True
Product.is_active == True,
)
)
.first()
)
).first()
if not product:
raise ValidationException(f"Product {item_data.product_id} not found")
raise ValidationException(
f"Product {item_data.product_id} not found"
)
# Check inventory
if product.available_inventory < item_data.quantity:
raise InsufficientInventoryException(
product_id=product.id,
requested=item_data.quantity,
available=product.available_inventory
available=product.available_inventory,
)
# Calculate item total
@@ -172,14 +179,16 @@ class OrderService:
item_total = unit_price * item_data.quantity
subtotal += item_total
order_items_data.append({
order_items_data.append(
{
"product_id": product.id,
"product_name": product.marketplace_product.title,
"product_sku": product.product_id,
"quantity": item_data.quantity,
"unit_price": unit_price,
"total_price": item_total
})
"total_price": item_total,
}
)
# Calculate tax and shipping (simple implementation)
tax_amount = 0.0 # TODO: Implement tax calculation
@@ -205,7 +214,7 @@ class OrderService:
shipping_address_id=shipping_address.id,
billing_address_id=billing_address.id,
shipping_method=order_data.shipping_method,
customer_notes=order_data.customer_notes
customer_notes=order_data.customer_notes,
)
db.add(order)
@@ -213,10 +222,7 @@ class OrderService:
# Create order items
for item_data in order_items_data:
order_item = OrderItem(
order_id=order.id,
**item_data
)
order_item = OrderItem(order_id=order.id, **item_data)
db.add(order_item)
db.commit()
@@ -229,7 +235,11 @@ class OrderService:
return order
except (ValidationException, InsufficientInventoryException, CustomerNotFoundException):
except (
ValidationException,
InsufficientInventoryException,
CustomerNotFoundException,
):
db.rollback()
raise
except Exception as e:
@@ -237,19 +247,13 @@ class OrderService:
logger.error(f"Error creating order: {str(e)}")
raise ValidationException(f"Failed to create order: {str(e)}")
def get_order(
self,
db: Session,
vendor_id: int,
order_id: int
) -> Order:
def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order:
"""Get order by ID."""
order = db.query(Order).filter(
and_(
Order.id == order_id,
Order.vendor_id == vendor_id
order = (
db.query(Order)
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
.first()
)
).first()
if not order:
raise OrderNotFoundException(str(order_id))
@@ -263,7 +267,7 @@ class OrderService:
skip: int = 0,
limit: int = 100,
status: Optional[str] = None,
customer_id: Optional[int] = None
customer_id: Optional[int] = None,
) -> Tuple[List[Order], int]:
"""
Get orders for vendor with filtering.
@@ -301,23 +305,15 @@ class OrderService:
vendor_id: int,
customer_id: int,
skip: int = 0,
limit: int = 100
limit: int = 100,
) -> Tuple[List[Order], int]:
"""Get orders for a specific customer."""
return self.get_vendor_orders(
db=db,
vendor_id=vendor_id,
skip=skip,
limit=limit,
customer_id=customer_id
db=db, vendor_id=vendor_id, skip=skip, limit=limit, customer_id=customer_id
)
def update_order_status(
self,
db: Session,
vendor_id: int,
order_id: int,
order_update: OrderUpdate
self, db: Session, vendor_id: int, order_id: int, order_update: OrderUpdate
) -> Order:
"""
Update order status and tracking information.

View File

@@ -14,14 +14,11 @@ from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from app.exceptions import (
ProductNotFoundException,
ProductAlreadyExistsException,
ValidationException,
)
from models.schema.product import ProductCreate, ProductUpdate
from models.database.product import Product
from app.exceptions import (ProductAlreadyExistsException,
ProductNotFoundException, ValidationException)
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.schema.product import ProductCreate, ProductUpdate
logger = logging.getLogger(__name__)
@@ -45,10 +42,11 @@ class ProductService:
ProductNotFoundException: If product not found
"""
try:
product = db.query(Product).filter(
Product.id == product_id,
Product.vendor_id == vendor_id
).first()
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(f"Product {product_id} not found")
@@ -81,10 +79,14 @@ class ProductService:
"""
try:
# Verify marketplace product exists and belongs to vendor
marketplace_product = db.query(MarketplaceProduct).filter(
marketplace_product = (
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.id == product_data.marketplace_product_id,
MarketplaceProduct.vendor_id == vendor_id
).first()
MarketplaceProduct.vendor_id == vendor_id,
)
.first()
)
if not marketplace_product:
raise ValidationException(
@@ -92,10 +94,15 @@ class ProductService:
)
# Check if already in catalog
existing = db.query(Product).filter(
existing = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.marketplace_product_id == product_data.marketplace_product_id
).first()
Product.marketplace_product_id
== product_data.marketplace_product_id,
)
.first()
)
if existing:
raise ProductAlreadyExistsException(
@@ -122,9 +129,7 @@ class ProductService:
db.commit()
db.refresh(product)
logger.info(
f"Added product {product.id} to vendor {vendor_id} catalog"
)
logger.info(f"Added product {product.id} to vendor {vendor_id} catalog")
return product
except (ProductAlreadyExistsException, ValidationException):
@@ -136,7 +141,11 @@ class ProductService:
raise ValidationException("Failed to create product")
def update_product(
self, db: Session, vendor_id: int, product_id: int, product_update: ProductUpdate
self,
db: Session,
vendor_id: int,
product_id: int,
product_update: ProductUpdate,
) -> Product:
"""
Update product in vendor catalog.

View File

@@ -10,25 +10,21 @@ This module provides:
"""
import logging
from typing import Any, Dict, List
from datetime import datetime, timedelta
from typing import Any, Dict, List
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
VendorNotFoundException,
AdminOperationException,
)
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.database.inventory import Inventory
from models.database.vendor import Vendor
from models.database.order import Order
from models.database.user import User
from app.exceptions import AdminOperationException, VendorNotFoundException
from models.database.customer import Customer
from models.database.inventory import Inventory
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
@@ -62,63 +58,77 @@ class StatsService:
try:
# Catalog statistics
total_catalog_products = db.query(Product).filter(
Product.vendor_id == vendor_id,
Product.is_active == True
).count()
total_catalog_products = (
db.query(Product)
.filter(Product.vendor_id == vendor_id, Product.is_active == True)
.count()
)
featured_products = db.query(Product).filter(
featured_products = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.is_featured == True,
Product.is_active == True
).count()
Product.is_active == True,
)
.count()
)
# Staging statistics
# TODO: This is fragile - MarketplaceProduct uses vendor_name (string) not vendor_id
# Should add vendor_id foreign key to MarketplaceProduct for robust querying
# For now, matching by vendor name which could fail if names don't match exactly
staging_products = db.query(MarketplaceProduct).filter(
MarketplaceProduct.vendor_name == vendor.name
).count()
staging_products = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.vendor_name == vendor.name)
.count()
)
# Inventory statistics
total_inventory = db.query(
func.sum(Inventory.quantity)
).filter(
Inventory.vendor_id == vendor_id
).scalar() or 0
total_inventory = (
db.query(func.sum(Inventory.quantity))
.filter(Inventory.vendor_id == vendor_id)
.scalar()
or 0
)
reserved_inventory = db.query(
func.sum(Inventory.reserved_quantity)
).filter(
Inventory.vendor_id == vendor_id
).scalar() or 0
reserved_inventory = (
db.query(func.sum(Inventory.reserved_quantity))
.filter(Inventory.vendor_id == vendor_id)
.scalar()
or 0
)
inventory_locations = db.query(
func.count(func.distinct(Inventory.location))
).filter(
Inventory.vendor_id == vendor_id
).scalar() or 0
inventory_locations = (
db.query(func.count(func.distinct(Inventory.location)))
.filter(Inventory.vendor_id == vendor_id)
.scalar()
or 0
)
# Import statistics
total_imports = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.vendor_id == vendor_id
).count()
total_imports = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.vendor_id == vendor_id)
.count()
)
successful_imports = db.query(MarketplaceImportJob).filter(
successful_imports = (
db.query(MarketplaceImportJob)
.filter(
MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.status == "completed"
).count()
MarketplaceImportJob.status == "completed",
)
.count()
)
# Orders
total_orders = db.query(Order).filter(
Order.vendor_id == vendor_id
).count()
total_orders = db.query(Order).filter(Order.vendor_id == vendor_id).count()
# Customers
total_customers = db.query(Customer).filter(
Customer.vendor_id == vendor_id
).count()
total_customers = (
db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
)
return {
"catalog": {
@@ -138,7 +148,11 @@ class StatsService:
"imports": {
"total_imports": total_imports,
"successful_imports": successful_imports,
"success_rate": (successful_imports / total_imports * 100) if total_imports > 0 else 0,
"success_rate": (
(successful_imports / total_imports * 100)
if total_imports > 0
else 0
),
},
"orders": {
"total_orders": total_orders,
@@ -151,12 +165,14 @@ class StatsService:
except VendorNotFoundException:
raise
except Exception as e:
logger.error(f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}")
logger.error(
f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException(
operation="get_vendor_stats",
reason=f"Database query failed: {str(e)}",
target_type="vendor",
target_id=str(vendor_id)
target_id=str(vendor_id),
)
def get_vendor_analytics(
@@ -188,21 +204,28 @@ class StatsService:
start_date = datetime.utcnow() - timedelta(days=days)
# Import activity
recent_imports = db.query(MarketplaceImportJob).filter(
recent_imports = (
db.query(MarketplaceImportJob)
.filter(
MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.created_at >= start_date
).count()
MarketplaceImportJob.created_at >= start_date,
)
.count()
)
# Products added to catalog
products_added = db.query(Product).filter(
Product.vendor_id == vendor_id,
Product.created_at >= start_date
).count()
products_added = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id, Product.created_at >= start_date
)
.count()
)
# Inventory changes
inventory_entries = db.query(Inventory).filter(
Inventory.vendor_id == vendor_id
).count()
inventory_entries = (
db.query(Inventory).filter(Inventory.vendor_id == vendor_id).count()
)
return {
"period": period,
@@ -221,12 +244,14 @@ class StatsService:
except VendorNotFoundException:
raise
except Exception as e:
logger.error(f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}")
logger.error(
f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException(
operation="get_vendor_analytics",
reason=f"Database query failed: {str(e)}",
target_type="vendor",
target_id=str(vendor_id)
target_id=str(vendor_id),
)
def get_vendor_statistics(self, db: Session) -> dict:
@@ -234,7 +259,9 @@ class StatsService:
try:
total_vendors = db.query(Vendor).count()
active_vendors = db.query(Vendor).filter(Vendor.is_active == True).count()
verified_vendors = db.query(Vendor).filter(Vendor.is_verified == True).count()
verified_vendors = (
db.query(Vendor).filter(Vendor.is_verified == True).count()
)
inactive_vendors = total_vendors - active_vendors
return {
@@ -242,13 +269,14 @@ class StatsService:
"active_vendors": active_vendors,
"inactive_vendors": inactive_vendors,
"verified_vendors": verified_vendors,
"verification_rate": (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
"verification_rate": (
(verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
),
}
except Exception as e:
logger.error(f"Failed to get vendor statistics: {str(e)}")
raise AdminOperationException(
operation="get_vendor_statistics",
reason="Database query failed"
operation="get_vendor_statistics", reason="Database query failed"
)
# ========================================================================
@@ -302,7 +330,7 @@ class StatsService:
logger.error(f"Failed to retrieve comprehensive statistics: {str(e)}")
raise AdminOperationException(
operation="get_comprehensive_stats",
reason=f"Database query failed: {str(e)}"
reason=f"Database query failed: {str(e)}",
)
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
@@ -323,8 +351,12 @@ class StatsService:
db.query(
MarketplaceProduct.marketplace,
func.count(MarketplaceProduct.id).label("total_products"),
func.count(func.distinct(MarketplaceProduct.vendor_name)).label("unique_vendors"),
func.count(func.distinct(MarketplaceProduct.brand)).label("unique_brands"),
func.count(func.distinct(MarketplaceProduct.vendor_name)).label(
"unique_vendors"
),
func.count(func.distinct(MarketplaceProduct.brand)).label(
"unique_brands"
),
)
.filter(MarketplaceProduct.marketplace.isnot(None))
.group_by(MarketplaceProduct.marketplace)
@@ -342,10 +374,12 @@ class StatsService:
]
except Exception as e:
logger.error(f"Failed to retrieve marketplace breakdown statistics: {str(e)}")
logger.error(
f"Failed to retrieve marketplace breakdown statistics: {str(e)}"
)
raise AdminOperationException(
operation="get_marketplace_breakdown_stats",
reason=f"Database query failed: {str(e)}"
reason=f"Database query failed: {str(e)}",
)
def get_user_statistics(self, db: Session) -> Dict[str, Any]:
@@ -372,13 +406,14 @@ class StatsService:
"active_users": active_users,
"inactive_users": inactive_users,
"admin_users": admin_users,
"activation_rate": (active_users / total_users * 100) if total_users > 0 else 0
"activation_rate": (
(active_users / total_users * 100) if total_users > 0 else 0
),
}
except Exception as e:
logger.error(f"Failed to get user statistics: {str(e)}")
raise AdminOperationException(
operation="get_user_statistics",
reason="Database query failed"
operation="get_user_statistics", reason="Database query failed"
)
def get_import_statistics(self, db: Session) -> Dict[str, Any]:
@@ -396,18 +431,22 @@ class StatsService:
"""
try:
total = db.query(MarketplaceImportJob).count()
completed = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.status == "completed"
).count()
failed = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.status == "failed"
).count()
completed = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "completed")
.count()
)
failed = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "failed")
.count()
)
return {
"total_imports": total,
"completed_imports": completed,
"failed_imports": failed,
"success_rate": (completed / total * 100) if total > 0 else 0
"success_rate": (completed / total * 100) if total > 0 else 0,
}
except Exception as e:
logger.error(f"Failed to get import statistics: {str(e)}")
@@ -415,7 +454,7 @@ class StatsService:
"total_imports": 0,
"completed_imports": 0,
"failed_imports": 0,
"success_rate": 0
"success_rate": 0,
}
def get_order_statistics(self, db: Session) -> Dict[str, Any]:
@@ -431,11 +470,7 @@ class StatsService:
Note:
TODO: Implement when Order model is fully available
"""
return {
"total_orders": 0,
"pending_orders": 0,
"completed_orders": 0
}
return {"total_orders": 0, "pending_orders": 0, "completed_orders": 0}
def get_product_statistics(self, db: Session) -> Dict[str, Any]:
"""
@@ -450,11 +485,7 @@ class StatsService:
Note:
TODO: Implement when Product model is fully available
"""
return {
"total_products": 0,
"active_products": 0,
"out_of_stock": 0
}
return {"total_products": 0, "active_products": 0, "out_of_stock": 0}
# ========================================================================
# PRIVATE HELPER METHODS
@@ -491,8 +522,7 @@ class StatsService:
return (
db.query(MarketplaceProduct.brand)
.filter(
MarketplaceProduct.brand.isnot(None),
MarketplaceProduct.brand != ""
MarketplaceProduct.brand.isnot(None), MarketplaceProduct.brand != ""
)
.distinct()
.count()

View File

@@ -9,17 +9,15 @@ This module provides:
"""
import logging
from typing import List, Dict, Any
from datetime import datetime, timezone
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.exceptions import (
ValidationException,
UnauthorizedVendorAccessException,
)
from models.database.vendor import VendorUser, Role
from app.exceptions import (UnauthorizedVendorAccessException,
ValidationException)
from models.database.user import User
from models.database.vendor import Role, VendorUser
logger = logging.getLogger(__name__)
@@ -42,14 +40,16 @@ class TeamService:
List of team members
"""
try:
vendor_users = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.is_active == True
).all()
vendor_users = (
db.query(VendorUser)
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.all()
)
members = []
for vu in vendor_users:
members.append({
members.append(
{
"id": vu.user_id,
"email": vu.user.email,
"first_name": vu.user.first_name,
@@ -58,7 +58,8 @@ class TeamService:
"role_id": vu.role_id,
"is_active": vu.is_active,
"joined_at": vu.created_at,
})
}
)
return members
@@ -100,7 +101,7 @@ class TeamService:
vendor_id: int,
user_id: int,
update_data: dict,
current_user: User
current_user: User,
) -> Dict[str, Any]:
"""
Update team member role or status.
@@ -116,10 +117,13 @@ class TeamService:
Updated member info
"""
try:
vendor_user = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == user_id
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
)
.first()
)
if not vendor_user:
raise ValidationException("Team member not found")
@@ -161,10 +165,13 @@ class TeamService:
True if removed
"""
try:
vendor_user = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == user_id
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
)
.first()
)
if not vendor_user:
raise ValidationException("Team member not found")

View File

@@ -12,30 +12,28 @@ This module provides classes and functions for:
import logging
import secrets
from typing import List, Tuple, Optional
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions import (
VendorNotFoundException,
VendorDomainNotFoundException,
VendorDomainAlreadyExistsException,
InvalidDomainFormatException,
ReservedDomainException,
from app.exceptions import (DNSVerificationException,
DomainAlreadyVerifiedException,
DomainNotVerifiedException,
DomainVerificationFailedException,
DomainAlreadyVerifiedException,
MultiplePrimaryDomainsException,
DNSVerificationException,
InvalidDomainFormatException,
MaxDomainsReachedException,
MultiplePrimaryDomainsException,
ReservedDomainException,
UnauthorizedDomainAccessException,
ValidationException,
)
from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
VendorDomainAlreadyExistsException,
VendorDomainNotFoundException,
VendorNotFoundException)
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
logger = logging.getLogger(__name__)
@@ -45,13 +43,19 @@ class VendorDomainService:
def __init__(self):
self.max_domains_per_vendor = 10 # Configure as needed
self.reserved_subdomains = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail']
self.reserved_subdomains = [
"www",
"admin",
"api",
"mail",
"smtp",
"ftp",
"cpanel",
"webmail",
]
def add_domain(
self,
db: Session,
vendor_id: int,
domain_data: VendorDomainCreate
self, db: Session, vendor_id: int, domain_data: VendorDomainCreate
) -> VendorDomain:
"""
Add a custom domain to vendor.
@@ -85,12 +89,14 @@ class VendorDomainService:
# Check if domain already exists
if self._domain_exists(db, normalized_domain):
existing_domain = db.query(VendorDomain).filter(
VendorDomain.domain == normalized_domain
).first()
existing_domain = (
db.query(VendorDomain)
.filter(VendorDomain.domain == normalized_domain)
.first()
)
raise VendorDomainAlreadyExistsException(
normalized_domain,
existing_domain.vendor_id if existing_domain else None
existing_domain.vendor_id if existing_domain else None,
)
# If setting as primary, unset other primary domains
@@ -105,7 +111,7 @@ class VendorDomainService:
verification_token=secrets.token_urlsafe(32),
is_verified=False, # Requires DNS verification
is_active=False, # Cannot be active until verified
ssl_status="pending"
ssl_status="pending",
)
db.add(new_domain)
@@ -120,7 +126,7 @@ class VendorDomainService:
VendorDomainAlreadyExistsException,
MaxDomainsReachedException,
InvalidDomainFormatException,
ReservedDomainException
ReservedDomainException,
):
db.rollback()
raise
@@ -129,11 +135,7 @@ class VendorDomainService:
logger.error(f"Error adding domain: {str(e)}")
raise ValidationException("Failed to add domain")
def get_vendor_domains(
self,
db: Session,
vendor_id: int
) -> List[VendorDomain]:
def get_vendor_domains(self, db: Session, vendor_id: int) -> List[VendorDomain]:
"""
Get all domains for a vendor.
@@ -151,12 +153,14 @@ class VendorDomainService:
# Verify vendor exists
self._get_vendor_by_id_or_raise(db, vendor_id)
domains = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id
).order_by(
VendorDomain.is_primary.desc(),
VendorDomain.created_at.desc()
).all()
domains = (
db.query(VendorDomain)
.filter(VendorDomain.vendor_id == vendor_id)
.order_by(
VendorDomain.is_primary.desc(), VendorDomain.created_at.desc()
)
.all()
)
return domains
@@ -166,11 +170,7 @@ class VendorDomainService:
logger.error(f"Error getting vendor domains: {str(e)}")
raise ValidationException("Failed to retrieve domains")
def get_domain_by_id(
self,
db: Session,
domain_id: int
) -> VendorDomain:
def get_domain_by_id(self, db: Session, domain_id: int) -> VendorDomain:
"""
Get domain by ID.
@@ -190,10 +190,7 @@ class VendorDomainService:
return domain
def update_domain(
self,
db: Session,
domain_id: int,
domain_update: VendorDomainUpdate
self, db: Session, domain_id: int, domain_update: VendorDomainUpdate
) -> VendorDomain:
"""
Update domain settings.
@@ -215,7 +212,9 @@ class VendorDomainService:
# If setting as primary, unset other primary domains
if domain_update.is_primary:
self._unset_primary_domains(db, domain.vendor_id, exclude_domain_id=domain_id)
self._unset_primary_domains(
db, domain.vendor_id, exclude_domain_id=domain_id
)
domain.is_primary = True
# If activating, check verification
@@ -240,11 +239,7 @@ class VendorDomainService:
logger.error(f"Error updating domain: {str(e)}")
raise ValidationException("Failed to update domain")
def delete_domain(
self,
db: Session,
domain_id: int
) -> str:
def delete_domain(self, db: Session, domain_id: int) -> str:
"""
Delete a custom domain.
@@ -277,11 +272,7 @@ class VendorDomainService:
logger.error(f"Error deleting domain: {str(e)}")
raise ValidationException("Failed to delete domain")
def verify_domain(
self,
db: Session,
domain_id: int
) -> Tuple[VendorDomain, str]:
def verify_domain(self, db: Session, domain_id: int) -> Tuple[VendorDomain, str]:
"""
Verify domain ownership via DNS TXT record.
@@ -313,8 +304,7 @@ class VendorDomainService:
# Query DNS TXT records
try:
txt_records = dns.resolver.resolve(
f"_wizamart-verify.{domain.domain}",
'TXT'
f"_wizamart-verify.{domain.domain}", "TXT"
)
# Check if verification token is present
@@ -332,42 +322,33 @@ class VendorDomainService:
# Token not found
raise DomainVerificationFailedException(
domain.domain,
"Verification token not found in DNS records"
domain.domain, "Verification token not found in DNS records"
)
except dns.resolver.NXDOMAIN:
raise DomainVerificationFailedException(
domain.domain,
f"DNS record _wizamart-verify.{domain.domain} not found"
f"DNS record _wizamart-verify.{domain.domain} not found",
)
except dns.resolver.NoAnswer:
raise DomainVerificationFailedException(
domain.domain,
"No TXT records found for verification"
domain.domain, "No TXT records found for verification"
)
except Exception as dns_error:
raise DNSVerificationException(
domain.domain,
str(dns_error)
)
raise DNSVerificationException(domain.domain, str(dns_error))
except (
VendorDomainNotFoundException,
DomainAlreadyVerifiedException,
DomainVerificationFailedException,
DNSVerificationException
DNSVerificationException,
):
raise
except Exception as e:
logger.error(f"Error verifying domain: {str(e)}")
raise ValidationException("Failed to verify domain")
def get_verification_instructions(
self,
db: Session,
domain_id: int
) -> dict:
def get_verification_instructions(self, db: Session, domain_id: int) -> dict:
"""
Get DNS verification instructions for domain.
@@ -390,20 +371,20 @@ class VendorDomainService:
"step1": "Go to your domain's DNS settings (at your domain registrar)",
"step2": "Add a new TXT record with the following values:",
"step3": "Wait for DNS propagation (5-15 minutes)",
"step4": "Click 'Verify Domain' button in admin panel"
"step4": "Click 'Verify Domain' button in admin panel",
},
"txt_record": {
"type": "TXT",
"name": "_wizamart-verify",
"value": domain.verification_token,
"ttl": 3600
"ttl": 3600,
},
"common_registrars": {
"Cloudflare": "https://dash.cloudflare.com",
"GoDaddy": "https://dcc.godaddy.com/manage/dns",
"Namecheap": "https://www.namecheap.com/myaccount/domain-list/",
"Google Domains": "https://domains.google.com"
}
"Google Domains": "https://domains.google.com",
},
}
# Private helper methods
@@ -416,36 +397,33 @@ class VendorDomainService:
def _check_domain_limit(self, db: Session, vendor_id: int) -> None:
"""Check if vendor has reached maximum domain limit."""
domain_count = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id
).count()
domain_count = (
db.query(VendorDomain).filter(VendorDomain.vendor_id == vendor_id).count()
)
if domain_count >= self.max_domains_per_vendor:
raise MaxDomainsReachedException(vendor_id, self.max_domains_per_vendor)
def _domain_exists(self, db: Session, domain: str) -> bool:
"""Check if domain already exists in system."""
return db.query(VendorDomain).filter(
VendorDomain.domain == domain
).first() is not None
return (
db.query(VendorDomain).filter(VendorDomain.domain == domain).first()
is not None
)
def _validate_domain_format(self, domain: str) -> None:
"""Validate domain format and check for reserved subdomains."""
# Check for reserved subdomains
first_part = domain.split('.')[0]
first_part = domain.split(".")[0]
if first_part in self.reserved_subdomains:
raise ReservedDomainException(domain, first_part)
def _unset_primary_domains(
self,
db: Session,
vendor_id: int,
exclude_domain_id: Optional[int] = None
self, db: Session, vendor_id: int, exclude_domain_id: Optional[int] = None
) -> None:
"""Unset all primary domains for vendor."""
query = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id,
VendorDomain.is_primary == True
VendorDomain.vendor_id == vendor_id, VendorDomain.is_primary == True
)
if exclude_domain_id:

View File

@@ -15,22 +15,19 @@ from typing import List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
VendorNotFoundException,
VendorAlreadyExistsException,
UnauthorizedVendorAccessException,
InvalidVendorDataException,
from app.exceptions import (InvalidVendorDataException,
MarketplaceProductNotFoundException,
ProductAlreadyExistsException,
MaxVendorsReachedException,
ValidationException,
)
from models.schema.vendor import VendorCreate
from models.schema.product import ProductCreate
ProductAlreadyExistsException,
UnauthorizedVendorAccessException,
ValidationException, VendorAlreadyExistsException,
VendorNotFoundException)
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.product import ProductCreate
from models.schema.vendor import VendorCreate
logger = logging.getLogger(__name__)
@@ -91,7 +88,11 @@ class VendorService:
)
return new_vendor
except (VendorAlreadyExistsException, MaxVendorsReachedException, InvalidVendorDataException):
except (
VendorAlreadyExistsException,
MaxVendorsReachedException,
InvalidVendorDataException,
):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -129,7 +130,10 @@ class VendorService:
if current_user.role != "admin":
query = query.filter(
(Vendor.is_active == True)
& ((Vendor.is_verified == True) | (Vendor.owner_user_id == current_user.id))
& (
(Vendor.is_verified == True)
| (Vendor.owner_user_id == current_user.id)
)
)
else:
# Admin can apply filters
@@ -147,7 +151,9 @@ class VendorService:
logger.error(f"Error getting vendors: {str(e)}")
raise ValidationException("Failed to retrieve vendors")
def get_vendor_by_code(self, db: Session, vendor_code: str, current_user: User) -> Vendor:
def get_vendor_by_code(
self, db: Session, vendor_code: str, current_user: User
) -> Vendor:
"""
Get vendor by vendor code with access control.
@@ -205,11 +211,15 @@ class VendorService:
"""
try:
# Check if product exists
marketplace_product = self._get_product_by_id_or_raise(db, product.marketplace_product_id)
marketplace_product = self._get_product_by_id_or_raise(
db, product.marketplace_product_id
)
# Check if product already in vendor
if self._product_in_catalog(db, vendor.id, marketplace_product.id):
raise ProductAlreadyExistsException(vendor.vendor_code, product.marketplace_product_id)
raise ProductAlreadyExistsException(
vendor.vendor_code, product.marketplace_product_id
)
# Create vendor -product association
new_product = Product(
@@ -225,7 +235,9 @@ class VendorService:
# Load the product relationship
db.refresh(new_product)
logger.info(f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}")
logger.info(
f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}"
)
return new_product
except (MarketplaceProductNotFoundException, ProductAlreadyExistsException):
@@ -267,7 +279,9 @@ class VendorService:
try:
# Check access permissions
if not self._can_access_vendor(vendor, current_user):
raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id)
raise UnauthorizedVendorAccessException(
vendor.vendor_code, current_user.id
)
# Query vendor products
query = db.query(Product).filter(Product.vendor_id == vendor.id)
@@ -292,17 +306,20 @@ class VendorService:
def _validate_vendor_data(self, vendor_data: VendorCreate) -> None:
"""Validate vendor creation data."""
if not vendor_data.vendor_code or not vendor_data.vendor_code.strip():
raise InvalidVendorDataException("Vendor code is required", field="vendor_code")
raise InvalidVendorDataException(
"Vendor code is required", field="vendor_code"
)
if not vendor_data.vendor_name or not vendor_data.vendor_name.strip():
raise InvalidVendorDataException("Vendor name is required", field="name")
# Validate vendor code format (alphanumeric, underscores, hyphens)
import re
if not re.match(r'^[A-Za-z0-9_-]+$', vendor_data.vendor_code):
if not re.match(r"^[A-Za-z0-9_-]+$", vendor_data.vendor_code):
raise InvalidVendorDataException(
"Vendor code can only contain letters, numbers, underscores, and hyphens",
field="vendor_code"
field="vendor_code",
)
def _check_vendor_limit(self, db: Session, user: User) -> None:
@@ -310,7 +327,9 @@ class VendorService:
if user.role == "admin":
return # Admins have no limit
user_vendor_count = db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
user_vendor_count = (
db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
)
max_vendors = 5 # Configure this as needed
if user_vendor_count >= max_vendors:
@@ -321,25 +340,35 @@ class VendorService:
return (
db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == vendor_code.upper())
.first() is not None
.first()
is not None
)
def _get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct:
def _get_product_by_id_or_raise(
self, db: Session, marketplace_product_id: str
) -> MarketplaceProduct:
"""Get product by ID or raise exception."""
product = db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
product = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
.first()
)
if not product:
raise MarketplaceProductNotFoundException(marketplace_product_id)
return product
def _product_in_catalog(self, db: Session, vendor_id: int, marketplace_product_id: int) -> bool:
def _product_in_catalog(
self, db: Session, vendor_id: int, marketplace_product_id: int
) -> bool:
"""Check if product is already in vendor."""
return (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.marketplace_product_id == marketplace_product_id
Product.marketplace_product_id == marketplace_product_id,
)
.first() is not None
.first()
is not None
)
def _can_access_vendor(self, vendor: Vendor, user: User) -> bool:
@@ -355,5 +384,6 @@ class VendorService:
"""Check if user is vendor owner."""
return vendor.owner_user_id == user.id
# Create service instance following the same pattern as other services
vendor_service = VendorService()

View File

@@ -11,23 +11,20 @@ Handles:
import logging
import secrets
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.permissions import get_preset_permissions
from app.exceptions import (
TeamMemberAlreadyExistsException,
from app.exceptions import (CannotRemoveOwnerException,
InvalidInvitationTokenException,
TeamInvitationAlreadyAcceptedException,
MaxTeamMembersReachedException,
UserNotFoundException,
VendorNotFoundException,
CannotRemoveOwnerException,
)
from models.database.user import User
from models.database.vendor import Vendor, VendorUser, VendorUserType, Role
TeamInvitationAlreadyAcceptedException,
TeamMemberAlreadyExistsException,
UserNotFoundException, VendorNotFoundException)
from middleware.auth import AuthManager
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser, VendorUserType
logger = logging.getLogger(__name__)
@@ -69,10 +66,14 @@ class VendorTeamService:
"""
try:
# Check team size limit
current_team_size = db.query(VendorUser).filter(
current_team_size = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.is_active == True,
).count()
)
.count()
)
if current_team_size >= self.max_team_members:
raise MaxTeamMembersReachedException(
@@ -85,22 +86,34 @@ class VendorTeamService:
if user:
# Check if already a member
existing_membership = db.query(VendorUser).filter(
existing_membership = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user.id,
).first()
)
.first()
)
if existing_membership:
if existing_membership.is_active:
raise TeamMemberAlreadyExistsException(email, vendor.vendor_code)
raise TeamMemberAlreadyExistsException(
email, vendor.vendor_code
)
# Reactivate old membership
existing_membership.is_active = False # Will be activated on acceptance
existing_membership.invitation_token = self._generate_invitation_token()
existing_membership.is_active = (
False # Will be activated on acceptance
)
existing_membership.invitation_token = (
self._generate_invitation_token()
)
existing_membership.invitation_sent_at = datetime.utcnow()
existing_membership.invitation_accepted_at = None
db.commit()
logger.info(f"Re-invited user {email} to vendor {vendor.vendor_code}")
logger.info(
f"Re-invited user {email} to vendor {vendor.vendor_code}"
)
return {
"invitation_token": existing_membership.invitation_token,
"email": email,
@@ -108,7 +121,7 @@ class VendorTeamService:
}
else:
# Create new user account (inactive until invitation accepted)
username = email.split('@')[0]
username = email.split("@")[0]
# Ensure unique username
base_username = username
counter = 1
@@ -201,9 +214,13 @@ class VendorTeamService:
"""
try:
# Find invitation
vendor_user = db.query(VendorUser).filter(
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.invitation_token == invitation_token,
).first()
)
.first()
)
if not vendor_user:
raise InvalidInvitationTokenException()
@@ -247,7 +264,10 @@ class VendorTeamService:
"role": vendor_user.role.name if vendor_user.role else "member",
}
except (InvalidInvitationTokenException, TeamInvitationAlreadyAcceptedException):
except (
InvalidInvitationTokenException,
TeamInvitationAlreadyAcceptedException,
):
raise
except Exception as e:
db.rollback()
@@ -274,10 +294,14 @@ class VendorTeamService:
True if removed
"""
try:
vendor_user = db.query(VendorUser).filter(
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id,
).first()
)
.first()
)
if not vendor_user:
raise UserNotFoundException(str(user_id))
@@ -322,10 +346,14 @@ class VendorTeamService:
Updated VendorUser
"""
try:
vendor_user = db.query(VendorUser).filter(
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id,
).first()
)
.first()
)
if not vendor_user:
raise UserNotFoundException(str(user_id))
@@ -387,7 +415,8 @@ class VendorTeamService:
members = []
for vu in vendor_users:
members.append({
members.append(
{
"id": vu.user.id,
"email": vu.user.email,
"username": vu.user.username,
@@ -400,7 +429,8 @@ class VendorTeamService:
"invitation_pending": vu.is_invitation_pending,
"invited_at": vu.invitation_sent_at,
"accepted_at": vu.invitation_accepted_at,
})
}
)
return members
@@ -419,10 +449,14 @@ class VendorTeamService:
) -> Role:
"""Get existing role or create new one with preset/custom permissions."""
# Try to find existing role with same name
role = db.query(Role).filter(
role = (
db.query(Role)
.filter(
Role.vendor_id == vendor.id,
Role.name == role_name,
).first()
)
.first()
)
if role and custom_permissions is None:
# Use existing role

View File

@@ -8,32 +8,24 @@ Handles theme CRUD operations, preset application, and validation.
import logging
import re
from typing import Optional, Dict, List
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from models.database.vendor import Vendor
from models.database.vendor_theme import VendorTheme
from models.schema.vendor_theme import (
VendorThemeUpdate,
ThemePresetPreview
)
from app.core.theme_presets import (THEME_PRESETS, apply_preset,
get_available_presets, get_preset_preview)
from app.exceptions.vendor import VendorNotFoundException
from app.exceptions.vendor_theme import (
VendorThemeNotFoundException,
from app.exceptions.vendor_theme import (InvalidColorFormatException,
InvalidFontFamilyException,
InvalidThemeDataException,
ThemeOperationException,
ThemePresetAlreadyAppliedException,
ThemePresetNotFoundException,
ThemeValidationException,
InvalidColorFormatException,
InvalidFontFamilyException,
ThemePresetAlreadyAppliedException,
ThemeOperationException
)
from app.core.theme_presets import (
apply_preset,
get_available_presets,
get_preset_preview,
THEME_PRESETS
)
VendorThemeNotFoundException)
from models.database.vendor import Vendor
from models.database.vendor_theme import VendorTheme
from models.schema.vendor_theme import ThemePresetPreview, VendorThemeUpdate
logger = logging.getLogger(__name__)
@@ -71,9 +63,9 @@ class VendorThemeService:
Raises:
VendorNotFoundException: If vendor not found
"""
vendor = db.query(Vendor).filter(
Vendor.vendor_code == vendor_code.upper()
).first()
vendor = (
db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
)
if not vendor:
self.logger.warning(f"Vendor not found: {vendor_code}")
@@ -105,12 +97,12 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
if not theme:
self.logger.info(f"No custom theme for vendor {vendor_code}, returning default")
self.logger.info(
f"No custom theme for vendor {vendor_code}, returning default"
)
return self._get_default_theme()
return theme.to_dict()
@@ -130,23 +122,16 @@ class VendorThemeService:
"accent": "#ec4899",
"background": "#ffffff",
"text": "#1f2937",
"border": "#e5e7eb"
},
"fonts": {
"heading": "Inter, sans-serif",
"body": "Inter, sans-serif"
"border": "#e5e7eb",
},
"fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"branding": {
"logo": None,
"logo_dark": None,
"favicon": None,
"banner": None
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
"banner": None,
},
"layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"social_links": {},
"custom_css": None,
"css_variables": {
@@ -158,7 +143,7 @@ class VendorThemeService:
"--color-border": "#e5e7eb",
"--font-heading": "Inter, sans-serif",
"--font-body": "Inter, sans-serif",
}
},
}
# ============================================================================
@@ -166,10 +151,7 @@ class VendorThemeService:
# ============================================================================
def update_theme(
self,
db: Session,
vendor_code: str,
theme_data: VendorThemeUpdate
self, db: Session, vendor_code: str, theme_data: VendorThemeUpdate
) -> VendorTheme:
"""
Update or create theme for vendor.
@@ -194,9 +176,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = (
db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
)
if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -224,15 +206,11 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to update theme for vendor {vendor_code}: {e}")
raise ThemeOperationException(
operation="update",
vendor_code=vendor_code,
reason=str(e)
operation="update", vendor_code=vendor_code, reason=str(e)
)
def _apply_theme_updates(
self,
theme: VendorTheme,
theme_data: VendorThemeUpdate
self, theme: VendorTheme, theme_data: VendorThemeUpdate
) -> None:
"""
Apply theme updates to theme object.
@@ -251,30 +229,30 @@ class VendorThemeService:
# Update fonts
if theme_data.fonts:
if theme_data.fonts.get('heading'):
theme.font_family_heading = theme_data.fonts['heading']
if theme_data.fonts.get('body'):
theme.font_family_body = theme_data.fonts['body']
if theme_data.fonts.get("heading"):
theme.font_family_heading = theme_data.fonts["heading"]
if theme_data.fonts.get("body"):
theme.font_family_body = theme_data.fonts["body"]
# Update branding
if theme_data.branding:
if theme_data.branding.get('logo') is not None:
theme.logo_url = theme_data.branding['logo']
if theme_data.branding.get('logo_dark') is not None:
theme.logo_dark_url = theme_data.branding['logo_dark']
if theme_data.branding.get('favicon') is not None:
theme.favicon_url = theme_data.branding['favicon']
if theme_data.branding.get('banner') is not None:
theme.banner_url = theme_data.branding['banner']
if theme_data.branding.get("logo") is not None:
theme.logo_url = theme_data.branding["logo"]
if theme_data.branding.get("logo_dark") is not None:
theme.logo_dark_url = theme_data.branding["logo_dark"]
if theme_data.branding.get("favicon") is not None:
theme.favicon_url = theme_data.branding["favicon"]
if theme_data.branding.get("banner") is not None:
theme.banner_url = theme_data.branding["banner"]
# Update layout
if theme_data.layout:
if theme_data.layout.get('style'):
theme.layout_style = theme_data.layout['style']
if theme_data.layout.get('header'):
theme.header_style = theme_data.layout['header']
if theme_data.layout.get('product_card'):
theme.product_card_style = theme_data.layout['product_card']
if theme_data.layout.get("style"):
theme.layout_style = theme_data.layout["style"]
if theme_data.layout.get("header"):
theme.header_style = theme_data.layout["header"]
if theme_data.layout.get("product_card"):
theme.product_card_style = theme_data.layout["product_card"]
# Update custom CSS
if theme_data.custom_css is not None:
@@ -289,10 +267,7 @@ class VendorThemeService:
# ============================================================================
def apply_theme_preset(
self,
db: Session,
vendor_code: str,
preset_name: str
self, db: Session, vendor_code: str, preset_name: str
) -> VendorTheme:
"""
Apply a theme preset to vendor.
@@ -322,9 +297,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = (
db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
)
if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -338,7 +313,9 @@ class VendorThemeService:
db.commit()
db.refresh(theme)
self.logger.info(f"Preset '{preset_name}' applied successfully to vendor {vendor_code}")
self.logger.info(
f"Preset '{preset_name}' applied successfully to vendor {vendor_code}"
)
return theme
except (VendorNotFoundException, ThemePresetNotFoundException):
@@ -349,9 +326,7 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to apply preset to vendor {vendor_code}: {e}")
raise ThemeOperationException(
operation="apply_preset",
vendor_code=vendor_code,
reason=str(e)
operation="apply_preset", vendor_code=vendor_code, reason=str(e)
)
def get_available_presets(self) -> List[ThemePresetPreview]:
@@ -399,9 +374,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = (
db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
)
if not theme:
raise VendorThemeNotFoundException(vendor_code)
@@ -423,9 +398,7 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to delete theme for vendor {vendor_code}: {e}")
raise ThemeOperationException(
operation="delete",
vendor_code=vendor_code,
reason=str(e)
operation="delete", vendor_code=vendor_code, reason=str(e)
)
# ============================================================================
@@ -459,9 +432,9 @@ class VendorThemeService:
# Validate layout values
if theme_data.layout:
valid_layouts = {
'style': ['grid', 'list', 'masonry'],
'header': ['fixed', 'static', 'transparent'],
'product_card': ['modern', 'classic', 'minimal']
"style": ["grid", "list", "masonry"],
"header": ["fixed", "static", "transparent"],
"product_card": ["modern", "classic", "minimal"],
}
for layout_key, layout_value in theme_data.layout.items():
@@ -472,7 +445,7 @@ class VendorThemeService:
field=layout_key,
validation_errors={
layout_key: f"Must be one of: {', '.join(valid_layouts[layout_key])}"
}
},
)
def _is_valid_color(self, color: str) -> bool:
@@ -489,7 +462,7 @@ class VendorThemeService:
return False
# Check for hex color format (#RGB or #RRGGBB)
hex_pattern = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$'
hex_pattern = r"^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
return bool(re.match(hex_pattern, color))
def _is_valid_font(self, font: str) -> bool:

View File

@@ -3,9 +3,9 @@ import logging
from datetime import datetime, timezone
from app.core.database import SessionLocal
from app.utils.csv_processor import CSVProcessor
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
from app.utils.csv_processor import CSVProcessor
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ async def process_marketplace_import(
url: str,
marketplace: str,
vendor_id: int, # FIXED: Changed from vendor_name to vendor_id
batch_size: int = 1000
batch_size: int = 1000,
):
"""Background task to process marketplace CSV import."""
db = SessionLocal()
@@ -59,7 +59,7 @@ async def process_marketplace_import(
marketplace,
vendor_id, # FIXED: Pass vendor_id instead of vendor_name
batch_size,
db
db,
)
# Update job with results

View File

@@ -267,7 +267,9 @@ class CSVProcessor:
# Validate required fields
if not product_data.get("marketplace_product_id"):
logger.warning(f"Row {index}: Missing marketplace_product_id, skipping")
logger.warning(
f"Row {index}: Missing marketplace_product_id, skipping"
)
errors += 1
continue
@@ -279,7 +281,10 @@ class CSVProcessor:
# Check if product exists
existing_product = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.marketplace_product_id == literal(product_data["marketplace_product_id"]))
.filter(
MarketplaceProduct.marketplace_product_id
== literal(product_data["marketplace_product_id"])
)
.first()
)

View File

@@ -109,7 +109,9 @@ class PriceProcessor:
r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)),
}
def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]:
def parse_price_currency(
self, price_str: any
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse a price string to extract the numeric value and currency.

View File

@@ -7,6 +7,7 @@ This module provides utility functions and classes to interact with a database u
"""
import logging
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool

142
main.py
View File

@@ -9,13 +9,13 @@ Multi-tenant e-commerce marketplace platform with:
- Middleware stack for context injection
"""
import sys
import io
import sys
# Fix Windows console encoding issues (must be at the very top)
if sys.platform == 'win32':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
if sys.platform == "win32":
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
import logging
from datetime import datetime, timezone
@@ -23,27 +23,25 @@ from pathlib import Path
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse, FileResponse
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.api.main import api_router
# Import page routers
from app.routes import admin_pages, vendor_pages, shop_pages
from app.core.config import settings
from app.core.database import get_db
from app.core.lifespan import lifespan
from app.exceptions.handler import setup_exception_handlers
from app.exceptions import ServiceUnavailableException
from app.exceptions.handler import setup_exception_handlers
# Import page routers
from app.routes import admin_pages, shop_pages, vendor_pages
from middleware.context import ContextMiddleware
from middleware.logging import LoggingMiddleware
from middleware.theme_context import ThemeContextMiddleware
# Import REFACTORED class-based middleware
from middleware.vendor_context import VendorContextMiddleware
from middleware.context import ContextMiddleware
from middleware.theme_context import ThemeContextMiddleware
from middleware.logging import LoggingMiddleware
logger = logging.getLogger(__name__)
@@ -146,6 +144,7 @@ app.include_router(api_router, prefix="/api")
# FAVICON ROUTES (Must be registered BEFORE page routers)
# ============================================================================
def serve_favicon() -> Response:
"""
Serve favicon with caching headers.
@@ -164,7 +163,7 @@ def serve_favicon() -> Response:
media_type="image/x-icon",
headers={
"Cache-Control": "public, max-age=86400", # Cache for 1 day
}
},
)
return Response(status_code=204)
@@ -194,10 +193,7 @@ logger.info("=" * 80)
# Admin pages
logger.info("Registering admin page routes: /admin/*")
app.include_router(
admin_pages.router,
prefix="/admin",
tags=["admin-pages"],
include_in_schema=False
admin_pages.router, prefix="/admin", tags=["admin-pages"], include_in_schema=False
)
# Vendor management pages (dashboard, products, orders, etc.)
@@ -206,7 +202,7 @@ app.include_router(
vendor_pages.router,
prefix="/vendor",
tags=["vendor-pages"],
include_in_schema=False
include_in_schema=False,
)
# Customer shop pages - Register at TWO prefixes:
@@ -217,46 +213,42 @@ logger.info(" - /shop/* (subdomain/custom domain mode)")
logger.info(" - /vendors/{code}/shop/* (path-based development mode)")
app.include_router(
shop_pages.router,
prefix="/shop",
tags=["shop-pages"],
include_in_schema=False
shop_pages.router, prefix="/shop", tags=["shop-pages"], include_in_schema=False
)
app.include_router(
shop_pages.router,
prefix="/vendors/{vendor_code}/shop",
tags=["shop-pages"],
include_in_schema=False
include_in_schema=False,
)
# Add handler for /vendors/{vendor_code}/ root path
@app.get("/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False)
async def vendor_root_path(vendor_code: str, request: Request, db: Session = Depends(get_db)):
@app.get(
"/vendors/{vendor_code}/", response_class=HTMLResponse, include_in_schema=False
)
async def vendor_root_path(
vendor_code: str, request: Request, db: Session = Depends(get_db)
):
"""Handle vendor root path (e.g., /vendors/wizamart/)"""
# Vendor should already be in request.state from middleware
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if not vendor:
raise HTTPException(status_code=404, detail=f"Vendor '{vendor_code}' not found")
from app.services.content_page_service import content_page_service
from app.routes.shop_pages import get_shop_context
from app.services.content_page_service import content_page_service
# Try to find landing page
landing_page = content_page_service.get_page_for_vendor(
db,
slug="landing",
vendor_id=vendor.id,
include_unpublished=False
db, slug="landing", vendor_id=vendor.id, include_unpublished=False
)
if not landing_page:
landing_page = content_page_service.get_page_for_vendor(
db,
slug="home",
vendor_id=vendor.id,
include_unpublished=False
db, slug="home", vendor_id=vendor.id, include_unpublished=False
)
if landing_page:
@@ -265,8 +257,7 @@ async def vendor_root_path(vendor_code: str, request: Request, db: Session = Dep
template_path = f"vendor/landing-{template_name}.html"
return templates.TemplateResponse(
template_path,
get_shop_context(request, db=db, page=landing_page)
template_path, get_shop_context(request, db=db, page=landing_page)
)
else:
# No landing page - redirect to shop
@@ -298,22 +289,16 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
db,
slug="platform_homepage",
vendor_id=None, # Platform-level page
include_unpublished=False
include_unpublished=False,
)
# Load header and footer navigation
header_pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=None,
header_only=True,
include_unpublished=False
db, vendor_id=None, header_only=True, include_unpublished=False
)
footer_pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=None,
footer_only=True,
include_unpublished=False
db, vendor_id=None, footer_only=True, include_unpublished=False
)
if homepage:
@@ -330,7 +315,7 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
"page": homepage,
"header_pages": header_pages,
"footer_pages": footer_pages,
}
},
)
else:
# Fallback to default static template
@@ -342,15 +327,13 @@ async def platform_homepage(request: Request, db: Session = Depends(get_db)):
"request": request,
"header_pages": header_pages,
"footer_pages": footer_pages,
}
},
)
@app.get("/{slug}", response_class=HTMLResponse, include_in_schema=False)
async def platform_content_page(
request: Request,
slug: str,
db: Session = Depends(get_db)
request: Request, slug: str, db: Session = Depends(get_db)
):
"""
Platform content pages: /about, /faq, /terms, /contact, etc.
@@ -366,10 +349,7 @@ async def platform_content_page(
# Load page from CMS
page = content_page_service.get_page_for_vendor(
db,
slug=slug,
vendor_id=None, # Platform pages only
include_unpublished=False
db, slug=slug, vendor_id=None, include_unpublished=False # Platform pages only
)
if not page:
@@ -378,17 +358,11 @@ async def platform_content_page(
# Load header and footer navigation
header_pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=None,
header_only=True,
include_unpublished=False
db, vendor_id=None, header_only=True, include_unpublished=False
)
footer_pages = content_page_service.list_pages_for_vendor(
db,
vendor_id=None,
footer_only=True,
include_unpublished=False
db, vendor_id=None, footer_only=True, include_unpublished=False
)
logger.info(f"[PLATFORM] Rendering content page: {page.title} (/{slug})")
@@ -400,7 +374,7 @@ async def platform_content_page(
"page": page,
"header_pages": header_pages,
"footer_pages": footer_pages,
}
},
)
@@ -411,8 +385,8 @@ logger.info("=" * 80)
logger.info("REGISTERED ROUTES SUMMARY")
logger.info("=" * 80)
for route in app.routes:
if hasattr(route, 'methods') and hasattr(route, 'path'):
methods = ', '.join(route.methods) if route.methods else 'N/A'
if hasattr(route, "methods") and hasattr(route, "path"):
methods = ", ".join(route.methods) if route.methods else "N/A"
logger.info(f" {methods:<10} {route.path:<60}")
logger.info("=" * 80)
@@ -420,6 +394,7 @@ logger.info("=" * 80)
# API ROUTES (JSON Responses)
# ============================================================================
# Public Routes (no authentication required)
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def root(request: Request, db: Session = Depends(get_db)):
@@ -428,7 +403,7 @@ async def root(request: Request, db: Session = Depends(get_db)):
- If vendor detected (domain/subdomain): Show vendor landing page or redirect to shop
- If no vendor (platform root): Redirect to documentation
"""
vendor = getattr(request.state, 'vendor', None)
vendor = getattr(request.state, "vendor", None)
if vendor:
# Vendor context detected - serve landing page
@@ -436,19 +411,13 @@ async def root(request: Request, db: Session = Depends(get_db)):
# Try to find landing page (slug='landing' or 'home')
landing_page = content_page_service.get_page_for_vendor(
db,
slug="landing",
vendor_id=vendor.id,
include_unpublished=False
db, slug="landing", vendor_id=vendor.id, include_unpublished=False
)
if not landing_page:
# Try 'home' slug as fallback
landing_page = content_page_service.get_page_for_vendor(
db,
slug="home",
vendor_id=vendor.id,
include_unpublished=False
db, slug="home", vendor_id=vendor.id, include_unpublished=False
)
if landing_page:
@@ -459,17 +428,26 @@ async def root(request: Request, db: Session = Depends(get_db)):
template_path = f"vendor/landing-{template_name}.html"
return templates.TemplateResponse(
template_path,
get_shop_context(request, db=db, page=landing_page)
template_path, get_shop_context(request, db=db, page=landing_page)
)
else:
# No landing page - redirect to shop
vendor_context = getattr(request.state, 'vendor_context', None)
access_method = vendor_context.get('detection_method', 'unknown') if vendor_context else 'unknown'
vendor_context = getattr(request.state, "vendor_context", None)
access_method = (
vendor_context.get("detection_method", "unknown")
if vendor_context
else "unknown"
)
if access_method == "path":
full_prefix = vendor_context.get('full_prefix', '/vendor/') if vendor_context else '/vendor/'
return RedirectResponse(url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302)
full_prefix = (
vendor_context.get("full_prefix", "/vendor/")
if vendor_context
else "/vendor/"
)
return RedirectResponse(
url=f"{full_prefix}{vendor.subdomain}/shop/", status_code=302
)
else:
# Domain/subdomain
return RedirectResponse(url="/shop/", status_code=302)

View File

@@ -26,14 +26,10 @@ from jose import jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from app.exceptions import (
AdminRequiredException,
InvalidTokenException,
TokenExpiredException,
UserNotActiveException,
InvalidCredentialsException,
InsufficientPermissionsException
)
from app.exceptions import (AdminRequiredException,
InsufficientPermissionsException,
InvalidCredentialsException, InvalidTokenException,
TokenExpiredException, UserNotActiveException)
from models.database.user import User
logger = logging.getLogger(__name__)
@@ -99,7 +95,9 @@ class AuthManager:
"""
return pwd_context.verify(plain_password, hashed_password)
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
def authenticate_user(
self, db: Session, username: str, password: str
) -> Optional[User]:
"""Authenticate user credentials against the database.
Supports authentication using either username or email address.
@@ -201,7 +199,9 @@ class AuthManager:
raise InvalidTokenException("Token missing expiration")
# Check if token has expired (additional check beyond jwt.decode)
if datetime.now(timezone.utc) > datetime.fromtimestamp(exp, tz=timezone.utc):
if datetime.now(timezone.utc) > datetime.fromtimestamp(
exp, tz=timezone.utc
):
raise TokenExpiredException()
# Validate user identifier claim exists
@@ -214,7 +214,9 @@ class AuthManager:
"user_id": int(user_id),
"username": payload.get("username"),
"email": payload.get("email"),
"role": payload.get("role", "user"), # Default to "user" role if not specified
"role": payload.get(
"role", "user"
), # Default to "user" role if not specified
}
except jwt.ExpiredSignatureError:
@@ -232,7 +234,9 @@ class AuthManager:
logger.error(f"Token verification error: {e}")
raise InvalidTokenException("Authentication failed")
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User:
def get_current_user(
self, db: Session, credentials: HTTPAuthorizationCredentials
) -> User:
"""Extract and validate the current authenticated user from request credentials.
Verifies the JWT token from the Authorization header, looks up the user
@@ -286,8 +290,10 @@ class AuthManager:
# This will only execute if user has "admin" role
pass
"""
def decorator(func):
"""Decorator that wraps the function with role checking."""
def wrapper(current_user: User, *args, **kwargs):
# Check if current user has the required role
if current_user.role != required_role:
@@ -339,8 +345,7 @@ class AuthManager:
# Check if user has vendor or admin role (admins have full access)
if current_user.role not in ["vendor", "admin"]:
raise InsufficientPermissionsException(
message="Vendor access required",
required_permission="vendor"
message="Vendor access required", required_permission="vendor"
)
return current_user
@@ -363,7 +368,7 @@ class AuthManager:
if current_user.role not in ["customer", "admin"]:
raise InsufficientPermissionsException(
message="Customer account access required",
required_permission="customer"
required_permission="customer",
)
return current_user

View File

@@ -17,14 +17,16 @@ Class-based middleware provides:
import logging
from enum import Enum
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class RequestContext(str, Enum):
"""Request context types for the application."""
API = "api"
ADMIN = "admin"
VENDOR_DASHBOARD = "vendor"
@@ -59,7 +61,7 @@ class ContextManager:
# Use clean_path if available (extracted by vendor_context_middleware)
# Falls back to original path if clean_path not set
# This is critical for correct context detection with path-based routing
path = getattr(request.state, 'clean_path', request.url.path)
path = getattr(request.state, "clean_path", request.url.path)
host = request.headers.get("host", "")
@@ -71,10 +73,10 @@ class ContextManager:
f"[CONTEXT] Detecting context",
extra={
"original_path": request.url.path,
"clean_path": getattr(request.state, 'clean_path', 'NOT SET'),
"clean_path": getattr(request.state, "clean_path", "NOT SET"),
"path_to_check": path,
"host": host,
}
},
)
# 1. API context (highest priority)
@@ -84,24 +86,30 @@ class ContextManager:
# 2. Admin context
if ContextManager._is_admin_context(request, host, path):
logger.debug("[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host})
logger.debug(
"[CONTEXT] Detected as ADMIN", extra={"path": path, "host": host}
)
return RequestContext.ADMIN
# 3. Vendor Dashboard context (vendor management area)
# Check both clean_path and original path for vendor dashboard
original_path = request.url.path
if ContextManager._is_vendor_dashboard_context(path) or \
ContextManager._is_vendor_dashboard_context(original_path):
logger.debug("[CONTEXT] Detected as VENDOR_DASHBOARD", extra={"path": path, "original_path": original_path})
if ContextManager._is_vendor_dashboard_context(
path
) or ContextManager._is_vendor_dashboard_context(original_path):
logger.debug(
"[CONTEXT] Detected as VENDOR_DASHBOARD",
extra={"path": path, "original_path": original_path},
)
return RequestContext.VENDOR_DASHBOARD
# 4. Shop context (vendor storefront)
# Check if vendor context exists (set by vendor_context_middleware)
if hasattr(request.state, 'vendor') and request.state.vendor:
if hasattr(request.state, "vendor") and request.state.vendor:
# If we have a vendor and it's not admin or vendor dashboard, it's shop
logger.debug(
"[CONTEXT] Detected as SHOP (has vendor context)",
extra={"vendor": request.state.vendor.name}
extra={"vendor": request.state.vendor.name},
)
return RequestContext.SHOP
@@ -173,11 +181,12 @@ class ContextMiddleware(BaseHTTPMiddleware):
f"[CONTEXT_MIDDLEWARE] Context detected: {context_type.value}",
extra={
"path": request.url.path,
"clean_path": getattr(request.state, 'clean_path', 'NOT SET'),
"clean_path": getattr(request.state, "clean_path", "NOT SET"),
"host": request.headers.get("host", ""),
"context": context_type.value,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
}
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
},
)
# Continue processing

Some files were not shown because too many files have changed in this diff Show More