style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -10,15 +10,15 @@ This module provides functions for:
import logging
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from app.exceptions import AdminOperationException
from models.database.admin import AdminAuditLog
from models.database.user import User
from models.schema.admin import AdminAuditLogFilters, AdminAuditLogResponse
from app.exceptions import AdminOperationException
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class AdminAuditService:
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_id: Optional[str] = None
request_id: Optional[str] = None,
) -> AdminAuditLog:
"""
Log an admin action to the audit trail.
@@ -63,7 +63,7 @@ class AdminAuditService:
details=details or {},
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id
request_id=request_id,
)
db.add(audit_log)
@@ -84,9 +84,7 @@ class AdminAuditService:
return None
def get_audit_logs(
self,
db: Session,
filters: AdminAuditLogFilters
self, db: Session, filters: AdminAuditLogFilters
) -> List[AdminAuditLogResponse]:
"""
Get filtered admin audit logs with pagination.
@@ -98,7 +96,9 @@ class AdminAuditService:
List of audit log responses
"""
try:
query = db.query(AdminAuditLog).join(User, AdminAuditLog.admin_user_id == User.id)
query = db.query(AdminAuditLog).join(
User, AdminAuditLog.admin_user_id == User.id
)
# Apply filters
conditions = []
@@ -123,8 +123,7 @@ class AdminAuditService:
# Execute query with pagination
logs = (
query
.order_by(AdminAuditLog.created_at.desc())
query.order_by(AdminAuditLog.created_at.desc())
.offset(filters.skip)
.limit(filters.limit)
.all()
@@ -143,7 +142,7 @@ class AdminAuditService:
ip_address=log.ip_address,
user_agent=log.user_agent,
request_id=log.request_id,
created_at=log.created_at
created_at=log.created_at,
)
for log in logs
]
@@ -151,15 +150,10 @@ class AdminAuditService:
except Exception as e:
logger.error(f"Failed to retrieve audit logs: {str(e)}")
raise AdminOperationException(
operation="get_audit_logs",
reason="Database query failed"
operation="get_audit_logs", reason="Database query failed"
)
def get_audit_logs_count(
self,
db: Session,
filters: AdminAuditLogFilters
) -> int:
def get_audit_logs_count(self, db: Session, filters: AdminAuditLogFilters) -> int:
"""Get total count of audit logs matching filters."""
try:
query = db.query(AdminAuditLog)
@@ -192,24 +186,14 @@ class AdminAuditService:
return 0
def get_recent_actions_by_admin(
self,
db: Session,
admin_user_id: int,
limit: int = 10
self, db: Session, admin_user_id: int, limit: int = 10
) -> List[AdminAuditLogResponse]:
"""Get recent actions by a specific admin."""
filters = AdminAuditLogFilters(
admin_user_id=admin_user_id,
limit=limit
)
filters = AdminAuditLogFilters(admin_user_id=admin_user_id, limit=limit)
return self.get_audit_logs(db, filters)
def get_actions_by_target(
self,
db: Session,
target_type: str,
target_id: str,
limit: int = 50
self, db: Session, target_type: str, target_id: str, limit: int = 50
) -> List[AdminAuditLogResponse]:
"""Get all actions performed on a specific target."""
try:
@@ -218,7 +202,7 @@ class AdminAuditService:
.filter(
and_(
AdminAuditLog.target_type == target_type,
AdminAuditLog.target_id == str(target_id)
AdminAuditLog.target_id == str(target_id),
)
)
.order_by(AdminAuditLog.created_at.desc())
@@ -236,7 +220,7 @@ class AdminAuditService:
target_id=log.target_id,
details=log.details,
ip_address=log.ip_address,
created_at=log.created_at
created_at=log.created_at,
)
for log in logs
]
@@ -247,4 +231,4 @@ class AdminAuditService:
# Create service instance
admin_audit_service = AdminAuditService()
admin_audit_service = AdminAuditService()

View File

@@ -16,24 +16,19 @@ import string
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from app.exceptions import (
UserNotFoundException,
UserStatusChangeException,
CannotModifySelfException,
VendorNotFoundException,
VendorAlreadyExistsException,
VendorVerificationException,
AdminOperationException,
ValidationException,
)
from app.exceptions import (AdminOperationException, CannotModifySelfException,
UserNotFoundException, UserStatusChangeException,
ValidationException, VendorAlreadyExistsException,
VendorNotFoundException,
VendorVerificationException)
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.schema.vendor import VendorCreate
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor, Role, VendorUser
from models.database.user import User
logger = logging.getLogger(__name__)
@@ -52,12 +47,11 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve users: {str(e)}")
raise AdminOperationException(
operation="get_all_users",
reason="Database query failed"
operation="get_all_users", reason="Database query failed"
)
def toggle_user_status(
self, db: Session, user_id: int, current_admin_id: int
self, db: Session, user_id: int, current_admin_id: int
) -> Tuple[User, str]:
"""Toggle user active status."""
user = self._get_user_by_id_or_raise(db, user_id)
@@ -72,7 +66,7 @@ class AdminService:
user_id=user_id,
current_status="admin",
attempted_action="toggle status",
reason="Cannot modify another admin user"
reason="Cannot modify another admin user",
)
try:
@@ -95,7 +89,7 @@ class AdminService:
user_id=user_id,
current_status="active" if original_status else "inactive",
attempted_action="toggle status",
reason="Database update failed"
reason="Database update failed",
)
# ============================================================================
@@ -103,7 +97,7 @@ class AdminService:
# ============================================================================
def create_vendor_with_owner(
self, db: Session, vendor_data: VendorCreate
self, db: Session, vendor_data: VendorCreate
) -> Tuple[Vendor, User, str]:
"""
Create vendor with owner user account.
@@ -118,17 +112,23 @@ class AdminService:
"""
try:
# Check if vendor code already exists
existing_vendor = db.query(Vendor).filter(
func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
).first()
existing_vendor = (
db.query(Vendor)
.filter(
func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
)
.first()
)
if existing_vendor:
raise VendorAlreadyExistsException(vendor_data.vendor_code)
# Check if subdomain already exists
existing_subdomain = db.query(Vendor).filter(
func.lower(Vendor.subdomain) == vendor_data.subdomain.lower()
).first()
existing_subdomain = (
db.query(Vendor)
.filter(func.lower(Vendor.subdomain) == vendor_data.subdomain.lower())
.first()
)
if existing_subdomain:
raise ValidationException(
@@ -140,15 +140,14 @@ class AdminService:
# Create owner user with owner_email
from middleware.auth import AuthManager
auth_manager = AuthManager()
owner_username = f"{vendor_data.subdomain}_owner"
owner_email = vendor_data.owner_email # ✅ For User authentication
# Check if user with this email already exists
existing_user = db.query(User).filter(
User.email == owner_email
).first()
existing_user = db.query(User).filter(User.email == owner_email).first()
if existing_user:
# Use existing user as owner
@@ -215,17 +214,17 @@ class AdminService:
logger.error(f"Failed to create vendor: {str(e)}")
raise AdminOperationException(
operation="create_vendor_with_owner",
reason=f"Failed to create vendor: {str(e)}"
reason=f"Failed to create vendor: {str(e)}",
)
def get_all_vendors(
self,
db: Session,
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
is_active: Optional[bool] = None,
is_verified: Optional[bool] = None
self,
db: Session,
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
is_active: Optional[bool] = None,
is_verified: Optional[bool] = None,
) -> Tuple[List[Vendor], int]:
"""Get paginated list of all vendors with filtering."""
try:
@@ -238,7 +237,7 @@ class AdminService:
or_(
Vendor.name.ilike(search_term),
Vendor.vendor_code.ilike(search_term),
Vendor.subdomain.ilike(search_term)
Vendor.subdomain.ilike(search_term),
)
)
@@ -255,8 +254,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve vendors: {str(e)}")
raise AdminOperationException(
operation="get_all_vendors",
reason="Database query failed"
operation="get_all_vendors", reason="Database query failed"
)
def get_vendor_by_id(self, db: Session, vendor_id: int) -> Vendor:
@@ -290,7 +288,7 @@ class AdminService:
raise VendorVerificationException(
vendor_id=vendor_id,
reason="Database update failed",
current_verification_status=original_status
current_verification_status=original_status,
)
def toggle_vendor_status(self, db: Session, vendor_id: int) -> Tuple[Vendor, str]:
@@ -317,7 +315,7 @@ class AdminService:
operation="toggle_vendor_status",
reason="Database update failed",
target_type="vendor",
target_id=str(vendor_id)
target_id=str(vendor_id),
)
def delete_vendor(self, db: Session, vendor_id: int) -> str:
@@ -345,15 +343,11 @@ class AdminService:
db.rollback()
logger.error(f"Failed to delete vendor {vendor_id}: {str(e)}")
raise AdminOperationException(
operation="delete_vendor",
reason="Database deletion failed"
operation="delete_vendor", reason="Database deletion failed"
)
def update_vendor(
self,
db: Session,
vendor_id: int,
vendor_update # VendorUpdate schema
self, db: Session, vendor_id: int, vendor_update # VendorUpdate schema
) -> Vendor:
"""
Update vendor information (Admin only).
@@ -387,11 +381,18 @@ class AdminService:
update_data = vendor_update.model_dump(exclude_unset=True)
# Check subdomain uniqueness if changing
if 'subdomain' in update_data and update_data['subdomain'] != vendor.subdomain:
existing = db.query(Vendor).filter(
Vendor.subdomain == update_data['subdomain'],
Vendor.id != vendor_id
).first()
if (
"subdomain" in update_data
and update_data["subdomain"] != vendor.subdomain
):
existing = (
db.query(Vendor)
.filter(
Vendor.subdomain == update_data["subdomain"],
Vendor.id != vendor_id,
)
.first()
)
if existing:
raise ValidationException(
f"Subdomain '{update_data['subdomain']}' is already taken"
@@ -419,17 +420,16 @@ class AdminService:
db.rollback()
logger.error(f"Failed to update vendor {vendor_id}: {str(e)}")
raise AdminOperationException(
operation="update_vendor",
reason=f"Database update failed: {str(e)}"
operation="update_vendor", reason=f"Database update failed: {str(e)}"
)
# Add this NEW method for transferring ownership:
def transfer_vendor_ownership(
self,
db: Session,
vendor_id: int,
transfer_data # VendorTransferOwnership schema
self,
db: Session,
vendor_id: int,
transfer_data, # VendorTransferOwnership schema
) -> Tuple[Vendor, User, User]:
"""
Transfer vendor ownership to another user.
@@ -466,9 +466,9 @@ class AdminService:
old_owner = vendor.owner
# Get new owner
new_owner = db.query(User).filter(
User.id == transfer_data.new_owner_user_id
).first()
new_owner = (
db.query(User).filter(User.id == transfer_data.new_owner_user_id).first()
)
if not new_owner:
raise UserNotFoundException(str(transfer_data.new_owner_user_id))
@@ -487,26 +487,32 @@ class AdminService:
try:
# Get Owner role for this vendor
owner_role = db.query(Role).filter(
Role.vendor_id == vendor_id,
Role.name == "Owner"
).first()
owner_role = (
db.query(Role)
.filter(Role.vendor_id == vendor_id, Role.name == "Owner")
.first()
)
if not owner_role:
raise ValidationException("Owner role not found for vendor")
# Get Manager role (to demote old owner)
manager_role = db.query(Role).filter(
Role.vendor_id == vendor_id,
Role.name == "Manager"
).first()
manager_role = (
db.query(Role)
.filter(Role.vendor_id == vendor_id, Role.name == "Manager")
.first()
)
# Remove old owner from Owner role
old_owner_link = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == old_owner.id,
VendorUser.role_id == owner_role.id
).first()
old_owner_link = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == old_owner.id,
VendorUser.role_id == owner_role.id,
)
.first()
)
if old_owner_link:
if manager_role:
@@ -525,10 +531,14 @@ class AdminService:
)
# Check if new owner already has a vendor_user link
new_owner_link = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == new_owner.id
).first()
new_owner_link = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == new_owner.id,
)
.first()
)
if new_owner_link:
# Update existing link to Owner role
@@ -540,7 +550,7 @@ class AdminService:
vendor_id=vendor_id,
user_id=new_owner.id,
role_id=owner_role.id,
is_active=True
is_active=True,
)
db.add(new_owner_link)
@@ -568,10 +578,12 @@ class AdminService:
raise
except Exception as e:
db.rollback()
logger.error(f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}")
logger.error(
f"Failed to transfer ownership for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException(
operation="transfer_vendor_ownership",
reason=f"Ownership transfer failed: {str(e)}"
reason=f"Ownership transfer failed: {str(e)}",
)
# ============================================================================
@@ -579,13 +591,13 @@ class AdminService:
# ============================================================================
def get_marketplace_import_jobs(
self,
db: Session,
marketplace: Optional[str] = None,
vendor_name: Optional[str] = None,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100,
self,
db: Session,
marketplace: Optional[str] = None,
vendor_name: Optional[str] = None,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100,
) -> List[MarketplaceImportJobResponse]:
"""Get filtered and paginated marketplace import jobs."""
try:
@@ -596,7 +608,9 @@ class AdminService:
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
query = query.filter(MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceImportJob.vendor_name.ilike(f"%{vendor_name}%")
)
if status:
query = query.filter(MarketplaceImportJob.status == status)
@@ -612,8 +626,7 @@ class AdminService:
except Exception as e:
logger.error(f"Failed to retrieve marketplace import jobs: {str(e)}")
raise AdminOperationException(
operation="get_marketplace_import_jobs",
reason="Database query failed"
operation="get_marketplace_import_jobs", reason="Database query failed"
)
# ============================================================================
@@ -624,10 +637,7 @@ class AdminService:
"""Get recently created vendors."""
try:
vendors = (
db.query(Vendor)
.order_by(Vendor.created_at.desc())
.limit(limit)
.all()
db.query(Vendor).order_by(Vendor.created_at.desc()).limit(limit).all()
)
return [
@@ -638,7 +648,7 @@ class AdminService:
"subdomain": v.subdomain,
"is_active": v.is_active,
"is_verified": v.is_verified,
"created_at": v.created_at
"created_at": v.created_at,
}
for v in vendors
]
@@ -663,7 +673,7 @@ class AdminService:
"vendor_name": j.vendor_name,
"status": j.status,
"total_processed": j.total_processed or 0,
"created_at": j.created_at
"created_at": j.created_at,
}
for j in jobs
]
@@ -692,47 +702,53 @@ class AdminService:
def _generate_temp_password(self, length: int = 12) -> str:
"""Generate secure temporary password."""
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(length))
return "".join(secrets.choice(alphabet) for _ in range(length))
def _create_default_roles(self, db: Session, vendor_id: int):
"""Create default roles for a new vendor."""
default_roles = [
{
"name": "Owner",
"permissions": ["*"] # Full access
},
{"name": "Owner", "permissions": ["*"]}, # Full access
{
"name": "Manager",
"permissions": [
"products.*", "orders.*", "customers.view",
"inventory.*", "team.view"
]
"products.*",
"orders.*",
"customers.view",
"inventory.*",
"team.view",
],
},
{
"name": "Editor",
"permissions": [
"products.view", "products.edit",
"orders.view", "inventory.view"
]
"products.view",
"products.edit",
"orders.view",
"inventory.view",
],
},
{
"name": "Viewer",
"permissions": [
"products.view", "orders.view",
"customers.view", "inventory.view"
]
}
"products.view",
"orders.view",
"customers.view",
"inventory.view",
],
},
]
for role_data in default_roles:
role = Role(
vendor_id=vendor_id,
name=role_data["name"],
permissions=role_data["permissions"]
permissions=role_data["permissions"],
)
db.add(role)
def _convert_job_to_response(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse:
def _convert_job_to_response(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to response schema."""
return MarketplaceImportJobResponse(
job_id=job.id,

View File

@@ -8,25 +8,19 @@ This module provides functions for:
- Encrypting sensitive settings
"""
import logging
import json
from typing import Optional, List, Any, Dict
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (AdminOperationException, ResourceNotFoundException,
ValidationException)
from models.database.admin import AdminSetting
from models.schema.admin import (
AdminSettingCreate,
AdminSettingResponse,
AdminSettingUpdate
)
from app.exceptions import (
AdminOperationException,
ValidationException,
ResourceNotFoundException
)
from models.schema.admin import (AdminSettingCreate, AdminSettingResponse,
AdminSettingUpdate)
logger = logging.getLogger(__name__)
@@ -34,26 +28,19 @@ logger = logging.getLogger(__name__)
class AdminSettingsService:
"""Service for managing platform-wide settings."""
def get_setting_by_key(
self,
db: Session,
key: str
) -> Optional[AdminSetting]:
def get_setting_by_key(self, db: Session, key: str) -> Optional[AdminSetting]:
"""Get setting by key."""
try:
return db.query(AdminSetting).filter(
func.lower(AdminSetting.key) == key.lower()
).first()
return (
db.query(AdminSetting)
.filter(func.lower(AdminSetting.key) == key.lower())
.first()
)
except Exception as e:
logger.error(f"Failed to get setting {key}: {str(e)}")
return None
def get_setting_value(
self,
db: Session,
key: str,
default: Any = None
) -> Any:
def get_setting_value(self, db: Session, key: str, default: Any = None) -> Any:
"""
Get setting value with type conversion.
@@ -76,7 +63,7 @@ class AdminSettingsService:
elif setting.value_type == "float":
return float(setting.value)
elif setting.value_type == "boolean":
return setting.value.lower() in ('true', '1', 'yes')
return setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json":
return json.loads(setting.value)
else:
@@ -86,10 +73,10 @@ class AdminSettingsService:
return default
def get_all_settings(
self,
db: Session,
category: Optional[str] = None,
is_public: Optional[bool] = None
self,
db: Session,
category: Optional[str] = None,
is_public: Optional[bool] = None,
) -> List[AdminSettingResponse]:
"""Get all settings with optional filtering."""
try:
@@ -104,22 +91,16 @@ class AdminSettingsService:
settings = query.order_by(AdminSetting.category, AdminSetting.key).all()
return [
AdminSettingResponse.model_validate(setting)
for setting in settings
AdminSettingResponse.model_validate(setting) for setting in settings
]
except Exception as e:
logger.error(f"Failed to get settings: {str(e)}")
raise AdminOperationException(
operation="get_all_settings",
reason="Database query failed"
operation="get_all_settings", reason="Database query failed"
)
def get_settings_by_category(
self,
db: Session,
category: str
) -> Dict[str, Any]:
def get_settings_by_category(self, db: Session, category: str) -> Dict[str, Any]:
"""
Get all settings in a category as a dictionary.
@@ -136,7 +117,7 @@ class AdminSettingsService:
elif setting.value_type == "float":
result[setting.key] = float(setting.value)
elif setting.value_type == "boolean":
result[setting.key] = setting.value.lower() in ('true', '1', 'yes')
result[setting.key] = setting.value.lower() in ("true", "1", "yes")
elif setting.value_type == "json":
result[setting.key] = json.loads(setting.value)
else:
@@ -145,10 +126,7 @@ class AdminSettingsService:
return result
def create_setting(
self,
db: Session,
setting_data: AdminSettingCreate,
admin_user_id: int
self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
) -> AdminSettingResponse:
"""Create new setting."""
try:
@@ -176,7 +154,7 @@ class AdminSettingsService:
description=setting_data.description,
is_encrypted=setting_data.is_encrypted,
is_public=setting_data.is_public,
last_modified_by_user_id=admin_user_id
last_modified_by_user_id=admin_user_id,
)
db.add(setting)
@@ -194,25 +172,17 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to create setting: {str(e)}")
raise AdminOperationException(
operation="create_setting",
reason="Database operation failed"
operation="create_setting", reason="Database operation failed"
)
def update_setting(
self,
db: Session,
key: str,
update_data: AdminSettingUpdate,
admin_user_id: int
self, db: Session, key: str, update_data: AdminSettingUpdate, admin_user_id: int
) -> AdminSettingResponse:
"""Update existing setting."""
setting = self.get_setting_by_key(db, key)
if not setting:
raise ResourceNotFoundException(
resource_type="setting",
identifier=key
)
raise ResourceNotFoundException(resource_type="setting", identifier=key)
try:
# Validate new value
@@ -244,42 +214,29 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to update setting {key}: {str(e)}")
raise AdminOperationException(
operation="update_setting",
reason="Database operation failed"
operation="update_setting", reason="Database operation failed"
)
def upsert_setting(
self,
db: Session,
setting_data: AdminSettingCreate,
admin_user_id: int
self, db: Session, setting_data: AdminSettingCreate, admin_user_id: int
) -> AdminSettingResponse:
"""Create or update setting (upsert)."""
existing = self.get_setting_by_key(db, setting_data.key)
if existing:
update_data = AdminSettingUpdate(
value=setting_data.value,
description=setting_data.description
value=setting_data.value, description=setting_data.description
)
return self.update_setting(db, setting_data.key, update_data, admin_user_id)
else:
return self.create_setting(db, setting_data, admin_user_id)
def delete_setting(
self,
db: Session,
key: str,
admin_user_id: int
) -> str:
def delete_setting(self, db: Session, key: str, admin_user_id: int) -> str:
"""Delete setting."""
setting = self.get_setting_by_key(db, key)
if not setting:
raise ResourceNotFoundException(
resource_type="setting",
identifier=key
)
raise ResourceNotFoundException(resource_type="setting", identifier=key)
try:
db.delete(setting)
@@ -293,8 +250,7 @@ class AdminSettingsService:
db.rollback()
logger.error(f"Failed to delete setting {key}: {str(e)}")
raise AdminOperationException(
operation="delete_setting",
reason="Database operation failed"
operation="delete_setting", reason="Database operation failed"
)
# ============================================================================
@@ -309,7 +265,7 @@ class AdminSettingsService:
elif value_type == "float":
float(value)
elif value_type == "boolean":
if value.lower() not in ('true', 'false', '1', '0', 'yes', 'no'):
if value.lower() not in ("true", "false", "1", "0", "yes", "no"):
raise ValueError("Invalid boolean value")
elif value_type == "json":
json.loads(value)

View File

@@ -1 +1 @@
# Audit logging services
# Audit logging services

View File

@@ -13,15 +13,12 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from app.exceptions import (
UserAlreadyExistsException,
InvalidCredentialsException,
UserNotActiveException,
ValidationException,
)
from app.exceptions import (InvalidCredentialsException,
UserAlreadyExistsException, UserNotActiveException,
ValidationException)
from middleware.auth import AuthManager
from models.schema.auth import UserLogin, UserRegister
from models.database.user import User
from models.schema.auth import UserLogin, UserRegister
logger = logging.getLogger(__name__)
@@ -51,11 +48,15 @@ class AuthService:
try:
# Check if email already exists
if self._email_exists(db, user_data.email):
raise UserAlreadyExistsException("Email already registered", field="email")
raise UserAlreadyExistsException(
"Email already registered", field="email"
)
# Check if username already exists
if self._username_exists(db, user_data.username):
raise UserAlreadyExistsException("Username already taken", field="username")
raise UserAlreadyExistsException(
"Username already taken", field="username"
)
# Hash password and create user
hashed_password = self.auth_manager.hash_password(user_data.password)
@@ -182,7 +183,9 @@ class AuthService:
Dictionary with access_token, token_type, and expires_in
"""
from datetime import datetime, timedelta, timezone
from jose import jwt
from app.core.config import settings
try:
@@ -217,6 +220,5 @@ class AuthService:
return db.query(User).filter(User.username == username).first() is not None
# Create service instance following the same pattern as other services
auth_service = AuthService()

View File

@@ -1 +1 @@
# Backup and recovery services
# Backup and recovery services

View File

@@ -1 +1 @@
# Caching services
# Caching services

View File

@@ -9,23 +9,20 @@ This module provides:
"""
import logging
from typing import Dict, List, Optional
from datetime import datetime, timezone
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions import (CartItemNotFoundException, CartValidationException,
InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException,
ProductNotFoundException)
from models.database.cart import CartItem
from models.database.product import Product
from models.database.vendor import Vendor
from models.database.cart import CartItem
from app.exceptions import (
ProductNotFoundException,
CartItemNotFoundException,
CartValidationException,
InsufficientInventoryForCartException,
InvalidCartQuantityException,
ProductNotAvailableForCartException,
)
logger = logging.getLogger(__name__)
@@ -33,12 +30,7 @@ logger = logging.getLogger(__name__)
class CartService:
"""Service for managing shopping carts."""
def get_cart(
self,
db: Session,
vendor_id: int,
session_id: str
) -> Dict:
def get_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
"""
Get cart contents for a session.
@@ -55,20 +47,21 @@ class CartService:
extra={
"vendor_id": vendor_id,
"session_id": session_id,
}
},
)
# Fetch cart items from database
cart_items = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id
cart_items = (
db.query(CartItem)
.filter(
and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
).all()
.all()
)
logger.info(
f"[CART_SERVICE] Found {len(cart_items)} items in database",
extra={"item_count": len(cart_items)}
extra={"item_count": len(cart_items)},
)
# Build response
@@ -79,14 +72,20 @@ class CartService:
product = cart_item.product
line_total = cart_item.line_total
items.append({
"product_id": product.id,
"product_name": product.marketplace_product.title,
"quantity": cart_item.quantity,
"price": cart_item.price_at_add,
"line_total": line_total,
"image_url": product.marketplace_product.image_link if product.marketplace_product else None,
})
items.append(
{
"product_id": product.id,
"product_name": product.marketplace_product.title,
"quantity": cart_item.quantity,
"price": cart_item.price_at_add,
"line_total": line_total,
"image_url": (
product.marketplace_product.image_link
if product.marketplace_product
else None
),
}
)
subtotal += line_total
@@ -95,23 +94,23 @@ class CartService:
"session_id": session_id,
"items": items,
"subtotal": subtotal,
"total": subtotal # Could add tax/shipping later
"total": subtotal, # Could add tax/shipping later
}
logger.info(
f"[CART_SERVICE] get_cart returning: {len(cart_data['items'])} items, total: {cart_data['total']}",
extra={"cart": cart_data}
extra={"cart": cart_data},
)
return cart_data
def add_to_cart(
self,
db: Session,
vendor_id: int,
session_id: str,
product_id: int,
quantity: int = 1
self,
db: Session,
vendor_id: int,
session_id: str,
product_id: int,
quantity: int = 1,
) -> Dict:
"""
Add product to cart.
@@ -136,23 +135,27 @@ class CartService:
"vendor_id": vendor_id,
"session_id": session_id,
"product_id": product_id,
"quantity": quantity
}
"quantity": quantity,
},
)
# Verify product exists and belongs to vendor
product = db.query(Product).filter(
and_(
Product.id == product_id,
Product.vendor_id == vendor_id,
Product.is_active == True
product = (
db.query(Product)
.filter(
and_(
Product.id == product_id,
Product.vendor_id == vendor_id,
Product.is_active == True,
)
)
).first()
.first()
)
if not product:
logger.error(
f"[CART_SERVICE] Product not found",
extra={"product_id": product_id, "vendor_id": vendor_id}
extra={"product_id": product_id, "vendor_id": vendor_id},
)
raise ProductNotFoundException(product_id=product_id, vendor_id=vendor_id)
@@ -161,21 +164,25 @@ class CartService:
extra={
"product_id": product_id,
"product_name": product.marketplace_product.title,
"available_inventory": product.available_inventory
}
"available_inventory": product.available_inventory,
},
)
# Get current price (use sale_price if available, otherwise regular price)
current_price = product.sale_price if product.sale_price else product.price
# Check if item already exists in cart
existing_item = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id
existing_item = (
db.query(CartItem)
.filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id,
)
)
).first()
.first()
)
if existing_item:
# Update quantity
@@ -190,14 +197,14 @@ class CartService:
"current_in_cart": existing_item.quantity,
"adding": quantity,
"requested_total": new_quantity,
"available": product.available_inventory
}
"available": product.available_inventory,
},
)
raise InsufficientInventoryForCartException(
product_id=product_id,
product_name=product.marketplace_product.title,
requested=new_quantity,
available=product.available_inventory
available=product.available_inventory,
)
existing_item.quantity = new_quantity
@@ -206,16 +213,13 @@ class CartService:
logger.info(
f"[CART_SERVICE] Updated existing cart item",
extra={
"cart_item_id": existing_item.id,
"new_quantity": new_quantity
}
extra={"cart_item_id": existing_item.id, "new_quantity": new_quantity},
)
return {
"message": "Product quantity updated in cart",
"product_id": product_id,
"quantity": new_quantity
"quantity": new_quantity,
}
else:
# Check inventory for new item
@@ -225,14 +229,14 @@ class CartService:
extra={
"product_id": product_id,
"requested": quantity,
"available": product.available_inventory
}
"available": product.available_inventory,
},
)
raise InsufficientInventoryForCartException(
product_id=product_id,
product_name=product.marketplace_product.title,
requested=quantity,
available=product.available_inventory
available=product.available_inventory,
)
# Create new cart item
@@ -241,7 +245,7 @@ class CartService:
session_id=session_id,
product_id=product_id,
quantity=quantity,
price_at_add=current_price
price_at_add=current_price,
)
db.add(cart_item)
db.commit()
@@ -252,23 +256,23 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"quantity": quantity,
"price": current_price
}
"price": current_price,
},
)
return {
"message": "Product added to cart",
"product_id": product_id,
"quantity": quantity
"quantity": quantity,
}
def update_cart_item(
self,
db: Session,
vendor_id: int,
session_id: str,
product_id: int,
quantity: int
self,
db: Session,
vendor_id: int,
session_id: str,
product_id: int,
quantity: int,
) -> Dict:
"""
Update quantity of item in cart.
@@ -292,25 +296,35 @@ class CartService:
raise InvalidCartQuantityException(quantity=quantity, min_quantity=1)
# Find cart item
cart_item = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id
cart_item = (
db.query(CartItem)
.filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id,
)
)
).first()
.first()
)
if not cart_item:
raise CartItemNotFoundException(product_id=product_id, session_id=session_id)
raise CartItemNotFoundException(
product_id=product_id, session_id=session_id
)
# Verify product still exists and is active
product = db.query(Product).filter(
and_(
Product.id == product_id,
Product.vendor_id == vendor_id,
Product.is_active == True
product = (
db.query(Product)
.filter(
and_(
Product.id == product_id,
Product.vendor_id == vendor_id,
Product.is_active == True,
)
)
).first()
.first()
)
if not product:
raise ProductNotFoundException(str(product_id))
@@ -321,7 +335,7 @@ class CartService:
product_id=product_id,
product_name=product.marketplace_product.title,
requested=quantity,
available=product.available_inventory
available=product.available_inventory,
)
# Update quantity
@@ -334,22 +348,18 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"product_id": product_id,
"new_quantity": quantity
}
"new_quantity": quantity,
},
)
return {
"message": "Cart updated",
"product_id": product_id,
"quantity": quantity
"quantity": quantity,
}
def remove_from_cart(
self,
db: Session,
vendor_id: int,
session_id: str,
product_id: int
self, db: Session, vendor_id: int, session_id: str, product_id: int
) -> Dict:
"""
Remove item from cart.
@@ -367,16 +377,22 @@ class CartService:
ProductNotFoundException: If product not in cart
"""
# Find and delete cart item
cart_item = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id
cart_item = (
db.query(CartItem)
.filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id,
CartItem.product_id == product_id,
)
)
).first()
.first()
)
if not cart_item:
raise CartItemNotFoundException(product_id=product_id, session_id=session_id)
raise CartItemNotFoundException(
product_id=product_id, session_id=session_id
)
db.delete(cart_item)
db.commit()
@@ -386,21 +402,13 @@ class CartService:
extra={
"cart_item_id": cart_item.id,
"product_id": product_id,
"session_id": session_id
}
"session_id": session_id,
},
)
return {
"message": "Item removed from cart",
"product_id": product_id
}
return {"message": "Item removed from cart", "product_id": product_id}
def clear_cart(
self,
db: Session,
vendor_id: int,
session_id: str
) -> Dict:
def clear_cart(self, db: Session, vendor_id: int, session_id: str) -> Dict:
"""
Clear all items from cart.
@@ -413,12 +421,13 @@ class CartService:
Success message with count of items removed
"""
# Delete all cart items for this session
deleted_count = db.query(CartItem).filter(
and_(
CartItem.vendor_id == vendor_id,
CartItem.session_id == session_id
deleted_count = (
db.query(CartItem)
.filter(
and_(CartItem.vendor_id == vendor_id, CartItem.session_id == session_id)
)
).delete()
.delete()
)
db.commit()
@@ -427,14 +436,11 @@ class CartService:
extra={
"session_id": session_id,
"vendor_id": vendor_id,
"items_removed": deleted_count
}
"items_removed": deleted_count,
},
)
return {
"message": "Cart cleared",
"items_removed": deleted_count
}
return {"message": "Cart cleared", "items_removed": deleted_count}
# Create service instance

View File

@@ -3,22 +3,20 @@ Code Quality Service
Business logic for managing architecture scans and violations
"""
import subprocess
import json
import logging
import subprocess
from datetime import datetime
from typing import List, Tuple, Optional, Dict
from pathlib import Path
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from typing import Dict, List, Optional, Tuple
from app.models.architecture_scan import (
ArchitectureScan,
ArchitectureViolation,
ArchitectureRule,
ViolationAssignment,
ViolationComment
)
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from app.models.architecture_scan import (ArchitectureRule, ArchitectureScan,
ArchitectureViolation,
ViolationAssignment,
ViolationComment)
logger = logging.getLogger(__name__)
@@ -26,7 +24,7 @@ logger = logging.getLogger(__name__)
class CodeQualityService:
"""Service for managing code quality scans and violations"""
def run_scan(self, db: Session, triggered_by: str = 'manual') -> ArchitectureScan:
def run_scan(self, db: Session, triggered_by: str = "manual") -> ArchitectureScan:
"""
Run architecture validator and store results in database
@@ -49,10 +47,10 @@ class CodeQualityService:
start_time = datetime.now()
try:
result = subprocess.run(
['python', 'scripts/validate_architecture.py', '--json'],
["python", "scripts/validate_architecture.py", "--json"],
capture_output=True,
text=True,
timeout=300 # 5 minute timeout
timeout=300, # 5 minute timeout
)
except subprocess.TimeoutExpired:
logger.error("Architecture scan timed out after 5 minutes")
@@ -63,17 +61,17 @@ class CodeQualityService:
# Parse JSON output (get only the JSON part, skip progress messages)
try:
# Find the JSON part in stdout
lines = result.stdout.strip().split('\n')
lines = result.stdout.strip().split("\n")
json_start = -1
for i, line in enumerate(lines):
if line.strip().startswith('{'):
if line.strip().startswith("{"):
json_start = i
break
if json_start == -1:
raise ValueError("No JSON output found")
json_output = '\n'.join(lines[json_start:])
json_output = "\n".join(lines[json_start:])
data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse validator output: {e}")
@@ -84,33 +82,33 @@ class CodeQualityService:
# Create scan record
scan = ArchitectureScan(
timestamp=datetime.now(),
total_files=data.get('files_checked', 0),
total_violations=data.get('total_violations', 0),
errors=data.get('errors', 0),
warnings=data.get('warnings', 0),
total_files=data.get("files_checked", 0),
total_violations=data.get("total_violations", 0),
errors=data.get("errors", 0),
warnings=data.get("warnings", 0),
duration_seconds=duration,
triggered_by=triggered_by,
git_commit_hash=git_commit
git_commit_hash=git_commit,
)
db.add(scan)
db.flush() # Get scan.id
# Create violation records
violations_data = data.get('violations', [])
violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} violation records")
for v in violations_data:
violation = ArchitectureViolation(
scan_id=scan.id,
rule_id=v['rule_id'],
rule_name=v['rule_name'],
severity=v['severity'],
file_path=v['file_path'],
line_number=v['line_number'],
message=v['message'],
context=v.get('context', ''),
suggestion=v.get('suggestion', ''),
status='open'
rule_id=v["rule_id"],
rule_name=v["rule_name"],
severity=v["severity"],
file_path=v["file_path"],
line_number=v["line_number"],
message=v["message"],
context=v.get("context", ""),
suggestion=v.get("suggestion", ""),
status="open",
)
db.add(violation)
@@ -122,7 +120,11 @@ class CodeQualityService:
def get_latest_scan(self, db: Session) -> Optional[ArchitectureScan]:
"""Get the most recent scan"""
return db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)).first()
return (
db.query(ArchitectureScan)
.order_by(desc(ArchitectureScan.timestamp))
.first()
)
def get_scan_by_id(self, db: Session, scan_id: int) -> Optional[ArchitectureScan]:
"""Get scan by ID"""
@@ -139,10 +141,12 @@ class CodeQualityService:
Returns:
List of ArchitectureScan objects, newest first
"""
return db.query(ArchitectureScan)\
.order_by(desc(ArchitectureScan.timestamp))\
.limit(limit)\
return (
db.query(ArchitectureScan)
.order_by(desc(ArchitectureScan.timestamp))
.limit(limit)
.all()
)
def get_violations(
self,
@@ -153,7 +157,7 @@ class CodeQualityService:
rule_id: str = None,
file_path: str = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
) -> Tuple[List[ArchitectureViolation], int]:
"""
Get violations with filtering and pagination
@@ -194,24 +198,32 @@ class CodeQualityService:
query = query.filter(ArchitectureViolation.rule_id == rule_id)
if file_path:
query = query.filter(ArchitectureViolation.file_path.like(f'%{file_path}%'))
query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
# Get total count
total = query.count()
# Get page of results
violations = query.order_by(
ArchitectureViolation.severity.desc(),
ArchitectureViolation.file_path
).limit(limit).offset(offset).all()
violations = (
query.order_by(
ArchitectureViolation.severity.desc(), ArchitectureViolation.file_path
)
.limit(limit)
.offset(offset)
.all()
)
return violations, total
def get_violation_by_id(self, db: Session, violation_id: int) -> Optional[ArchitectureViolation]:
def get_violation_by_id(
self, db: Session, violation_id: int
) -> Optional[ArchitectureViolation]:
"""Get single violation with details"""
return db.query(ArchitectureViolation).filter(
ArchitectureViolation.id == violation_id
).first()
return (
db.query(ArchitectureViolation)
.filter(ArchitectureViolation.id == violation_id)
.first()
)
def assign_violation(
self,
@@ -220,7 +232,7 @@ class CodeQualityService:
user_id: int,
assigned_by: int,
due_date: datetime = None,
priority: str = 'medium'
priority: str = "medium",
) -> ViolationAssignment:
"""
Assign violation to a developer
@@ -239,7 +251,7 @@ class CodeQualityService:
# Update violation status
violation = self.get_violation_by_id(db, violation_id)
if violation:
violation.status = 'assigned'
violation.status = "assigned"
violation.assigned_to = user_id
# Create assignment record
@@ -248,7 +260,7 @@ class CodeQualityService:
user_id=user_id,
assigned_by=assigned_by,
due_date=due_date,
priority=priority
priority=priority,
)
db.add(assignment)
db.commit()
@@ -257,11 +269,7 @@ class CodeQualityService:
return assignment
def resolve_violation(
self,
db: Session,
violation_id: int,
resolved_by: int,
resolution_note: str
self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
) -> ArchitectureViolation:
"""
Mark violation as resolved
@@ -279,7 +287,7 @@ class CodeQualityService:
if not violation:
raise ValueError(f"Violation {violation_id} not found")
violation.status = 'resolved'
violation.status = "resolved"
violation.resolved_at = datetime.now()
violation.resolved_by = resolved_by
violation.resolution_note = resolution_note
@@ -289,11 +297,7 @@ class CodeQualityService:
return violation
def ignore_violation(
self,
db: Session,
violation_id: int,
ignored_by: int,
reason: str
self, db: Session, violation_id: int, ignored_by: int, reason: str
) -> ArchitectureViolation:
"""
Mark violation as ignored/won't fix
@@ -311,7 +315,7 @@ class CodeQualityService:
if not violation:
raise ValueError(f"Violation {violation_id} not found")
violation.status = 'ignored'
violation.status = "ignored"
violation.resolved_at = datetime.now()
violation.resolved_by = ignored_by
violation.resolution_note = f"Ignored: {reason}"
@@ -321,11 +325,7 @@ class CodeQualityService:
return violation
def add_comment(
self,
db: Session,
violation_id: int,
user_id: int,
comment: str
self, db: Session, violation_id: int, user_id: int, comment: str
) -> ViolationComment:
"""
Add comment to violation
@@ -340,9 +340,7 @@ class CodeQualityService:
ViolationComment object
"""
comment_obj = ViolationComment(
violation_id=violation_id,
user_id=user_id,
comment=comment
violation_id=violation_id, user_id=user_id, comment=comment
)
db.add(comment_obj)
db.commit()
@@ -360,79 +358,95 @@ class CodeQualityService:
latest_scan = self.get_latest_scan(db)
if not latest_scan:
return {
'total_violations': 0,
'errors': 0,
'warnings': 0,
'open': 0,
'assigned': 0,
'resolved': 0,
'ignored': 0,
'technical_debt_score': 100,
'trend': [],
'by_severity': {},
'by_rule': {},
'by_module': {},
'top_files': []
"total_violations": 0,
"errors": 0,
"warnings": 0,
"open": 0,
"assigned": 0,
"resolved": 0,
"ignored": 0,
"technical_debt_score": 100,
"trend": [],
"by_severity": {},
"by_rule": {},
"by_module": {},
"top_files": [],
}
# Get violation counts by status
status_counts = db.query(
ArchitectureViolation.status,
func.count(ArchitectureViolation.id)
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.status).all()
status_counts = (
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.status)
.all()
)
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
severity_counts = db.query(
ArchitectureViolation.severity,
func.count(ArchitectureViolation.id)
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.severity).all()
severity_counts = (
db.query(
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule
rule_counts = db.query(
ArchitectureViolation.rule_id,
func.count(ArchitectureViolation.id)
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.rule_id).all()
rule_counts = (
db.query(
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {rule: count for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]}
by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[
:10
]
}
# Get top violating files
file_counts = db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label('count')
).filter(
ArchitectureViolation.scan_id == latest_scan.id
).group_by(ArchitectureViolation.file_path)\
.order_by(desc('count'))\
.limit(10).all()
file_counts = (
db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label("count"),
)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.group_by(ArchitectureViolation.file_path)
.order_by(desc("count"))
.limit(10)
.all()
)
top_files = [{'file': file, 'count': count} for file, count in file_counts]
top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module (extract module from file path)
by_module = {}
violations = db.query(ArchitectureViolation.file_path).filter(
ArchitectureViolation.scan_id == latest_scan.id
).all()
violations = (
db.query(ArchitectureViolation.file_path)
.filter(ArchitectureViolation.scan_id == latest_scan.id)
.all()
)
for v in violations:
path_parts = v.file_path.split('/')
path_parts = v.file_path.split("/")
if len(path_parts) >= 2:
module = '/'.join(path_parts[:2]) # e.g., 'app/api'
module = "/".join(path_parts[:2]) # e.g., 'app/api'
else:
module = path_parts[0]
by_module[module] = by_module.get(module, 0) + 1
# Sort by count and take top 10
by_module = dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
by_module = dict(
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
)
# Calculate technical debt score
tech_debt_score = self.calculate_technical_debt_score(db, latest_scan.id)
@@ -441,29 +455,29 @@ class CodeQualityService:
trend_scans = self.get_scan_history(db, limit=7)
trend = [
{
'timestamp': scan.timestamp.isoformat(),
'violations': scan.total_violations,
'errors': scan.errors,
'warnings': scan.warnings
"timestamp": scan.timestamp.isoformat(),
"violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
}
for scan in reversed(trend_scans) # Oldest first for chart
]
return {
'total_violations': latest_scan.total_violations,
'errors': latest_scan.errors,
'warnings': latest_scan.warnings,
'open': status_dict.get('open', 0),
'assigned': status_dict.get('assigned', 0),
'resolved': status_dict.get('resolved', 0),
'ignored': status_dict.get('ignored', 0),
'technical_debt_score': tech_debt_score,
'trend': trend,
'by_severity': by_severity,
'by_rule': by_rule,
'by_module': by_module,
'top_files': top_files,
'last_scan': latest_scan.timestamp.isoformat() if latest_scan else None
"total_violations": latest_scan.total_violations,
"errors": latest_scan.errors,
"warnings": latest_scan.warnings,
"open": status_dict.get("open", 0),
"assigned": status_dict.get("assigned", 0),
"resolved": status_dict.get("resolved", 0),
"ignored": status_dict.get("ignored", 0),
"technical_debt_score": tech_debt_score,
"trend": trend,
"by_severity": by_severity,
"by_rule": by_rule,
"by_module": by_module,
"top_files": top_files,
"last_scan": latest_scan.timestamp.isoformat() if latest_scan else None,
}
def calculate_technical_debt_score(self, db: Session, scan_id: int = None) -> int:
@@ -497,10 +511,7 @@ class CodeQualityService:
"""Get current git commit hash"""
try:
result = subprocess.run(
['git', 'rev-parse', 'HEAD'],
capture_output=True,
text=True,
timeout=5
["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return result.stdout.strip()[:40]

View File

@@ -1 +1 @@
# Configuration management services
# Configuration management services

View File

@@ -19,8 +19,9 @@ This allows:
import logging
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from models.database.content_page import ContentPage
@@ -35,7 +36,7 @@ class ContentPageService:
db: Session,
slug: str,
vendor_id: Optional[int] = None,
include_unpublished: bool = False
include_unpublished: bool = False,
) -> Optional[ContentPage]:
"""
Get content page for a vendor with fallback to platform default.
@@ -62,28 +63,20 @@ class ContentPageService:
if vendor_id:
vendor_page = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == vendor_id,
*filters
)
)
.filter(and_(ContentPage.vendor_id == vendor_id, *filters))
.first()
)
if vendor_page:
logger.debug(f"Found vendor-specific page: {slug} for vendor_id={vendor_id}")
logger.debug(
f"Found vendor-specific page: {slug} for vendor_id={vendor_id}"
)
return vendor_page
# Fallback to platform default
platform_page = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == None,
*filters
)
)
.filter(and_(ContentPage.vendor_id == None, *filters))
.first()
)
@@ -100,7 +93,7 @@ class ContentPageService:
vendor_id: Optional[int] = None,
include_unpublished: bool = False,
footer_only: bool = False,
header_only: bool = False
header_only: bool = False,
) -> List[ContentPage]:
"""
List all available pages for a vendor (includes vendor overrides + platform defaults).
@@ -133,12 +126,7 @@ class ContentPageService:
if vendor_id:
vendor_pages = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == vendor_id,
*filters
)
)
.filter(and_(ContentPage.vendor_id == vendor_id, *filters))
.order_by(ContentPage.display_order, ContentPage.title)
.all()
)
@@ -146,12 +134,7 @@ class ContentPageService:
# Get platform defaults
platform_pages = (
db.query(ContentPage)
.filter(
and_(
ContentPage.vendor_id == None,
*filters
)
)
.filter(and_(ContentPage.vendor_id == None, *filters))
.order_by(ContentPage.display_order, ContentPage.title)
.all()
)
@@ -159,8 +142,7 @@ class ContentPageService:
# Merge: vendor overrides take precedence
vendor_slugs = {page.slug for page in vendor_pages}
all_pages = vendor_pages + [
page for page in platform_pages
if page.slug not in vendor_slugs
page for page in platform_pages if page.slug not in vendor_slugs
]
# Sort by display_order
@@ -183,7 +165,7 @@ class ContentPageService:
show_in_footer: bool = True,
show_in_header: bool = False,
display_order: int = 0,
created_by: Optional[int] = None
created_by: Optional[int] = None,
) -> ContentPage:
"""
Create a new content page.
@@ -229,7 +211,9 @@ class ContentPageService:
db.commit()
db.refresh(page)
logger.info(f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})")
logger.info(
f"Created content page: {slug} (vendor_id={vendor_id}, id={page.id})"
)
return page
@staticmethod
@@ -246,7 +230,7 @@ class ContentPageService:
show_in_footer: Optional[bool] = None,
show_in_header: Optional[bool] = None,
display_order: Optional[int] = None,
updated_by: Optional[int] = None
updated_by: Optional[int] = None,
) -> Optional[ContentPage]:
"""
Update an existing content page.
@@ -338,9 +322,7 @@ class ContentPageService:
@staticmethod
def list_all_vendor_pages(
db: Session,
vendor_id: int,
include_unpublished: bool = False
db: Session, vendor_id: int, include_unpublished: bool = False
) -> List[ContentPage]:
"""
List only vendor-specific pages (no platform defaults).
@@ -367,8 +349,7 @@ class ContentPageService:
@staticmethod
def list_all_platform_pages(
db: Session,
include_unpublished: bool = False
db: Session, include_unpublished: bool = False
) -> List[ContentPage]:
"""
List only platform default pages.

View File

@@ -8,24 +8,24 @@ with complete vendor isolation.
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import Any, Dict, Optional
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions.customer import (CustomerAlreadyExistsException,
CustomerNotActiveException,
CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException)
from app.exceptions.vendor import (VendorNotActiveException,
VendorNotFoundException)
from app.services.auth_service import AuthService
from models.database.customer import Customer, CustomerAddress
from models.database.vendor import Vendor
from models.schema.customer import CustomerRegister, CustomerUpdate
from models.schema.auth import UserLogin
from app.exceptions.customer import (
CustomerNotFoundException,
CustomerAlreadyExistsException,
CustomerNotActiveException,
InvalidCustomerCredentialsException,
CustomerValidationException,
DuplicateCustomerEmailException
)
from app.exceptions.vendor import VendorNotFoundException, VendorNotActiveException
from app.services.auth_service import AuthService
from models.schema.customer import CustomerRegister, CustomerUpdate
logger = logging.getLogger(__name__)
@@ -37,10 +37,7 @@ class CustomerService:
self.auth_service = AuthService()
def register_customer(
self,
db: Session,
vendor_id: int,
customer_data: CustomerRegister
self, db: Session, vendor_id: int, customer_data: CustomerRegister
) -> Customer:
"""
Register a new customer for a specific vendor.
@@ -68,18 +65,26 @@ class CustomerService:
raise VendorNotActiveException(vendor.vendor_code)
# Check if email already exists for this vendor
existing_customer = db.query(Customer).filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == customer_data.email.lower()
existing_customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == customer_data.email.lower(),
)
)
).first()
.first()
)
if existing_customer:
raise DuplicateCustomerEmailException(customer_data.email, vendor.vendor_code)
raise DuplicateCustomerEmailException(
customer_data.email, vendor.vendor_code
)
# Generate unique customer number for this vendor
customer_number = self._generate_customer_number(db, vendor_id, vendor.vendor_code)
customer_number = self._generate_customer_number(
db, vendor_id, vendor.vendor_code
)
# Hash password
hashed_password = self.auth_service.hash_password(customer_data.password)
@@ -93,8 +98,12 @@ class CustomerService:
last_name=customer_data.last_name,
phone=customer_data.phone,
customer_number=customer_number,
marketing_consent=customer_data.marketing_consent if hasattr(customer_data, 'marketing_consent') else False,
is_active=True
marketing_consent=(
customer_data.marketing_consent
if hasattr(customer_data, "marketing_consent")
else False
),
is_active=True,
)
try:
@@ -114,15 +123,11 @@ class CustomerService:
db.rollback()
logger.error(f"Error registering customer: {str(e)}")
raise CustomerValidationException(
message="Failed to register customer",
details={"error": str(e)}
message="Failed to register customer", details={"error": str(e)}
)
def login_customer(
self,
db: Session,
vendor_id: int,
credentials: UserLogin
self, db: Session, vendor_id: int, credentials: UserLogin
) -> Dict[str, Any]:
"""
Authenticate customer and generate JWT token.
@@ -146,20 +151,23 @@ class CustomerService:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
# Find customer by email (vendor-scoped)
customer = db.query(Customer).filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == credentials.email_or_username.lower()
customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == credentials.email_or_username.lower(),
)
)
).first()
.first()
)
if not customer:
raise InvalidCustomerCredentialsException()
# Verify password using auth_manager directly
if not self.auth_service.auth_manager.verify_password(
credentials.password,
customer.hashed_password
credentials.password, customer.hashed_password
):
raise InvalidCustomerCredentialsException()
@@ -170,6 +178,7 @@ class CustomerService:
# Generate JWT token with customer context
# Use auth_manager directly since Customer is not a User model
from datetime import datetime, timedelta, timezone
from jose import jwt
auth_manager = self.auth_service.auth_manager
@@ -185,7 +194,9 @@ class CustomerService:
"iat": datetime.now(timezone.utc),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
token_data = {
"access_token": token,
@@ -198,17 +209,9 @@ class CustomerService:
f"for vendor {vendor.vendor_code}"
)
return {
"customer": customer,
"token_data": token_data
}
return {"customer": customer, "token_data": token_data}
def get_customer(
self,
db: Session,
vendor_id: int,
customer_id: int
) -> Customer:
def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
"""
Get customer by ID with vendor isolation.
@@ -223,12 +226,11 @@ class CustomerService:
Raises:
CustomerNotFoundException: If customer not found
"""
customer = db.query(Customer).filter(
and_(
Customer.id == customer_id,
Customer.vendor_id == vendor_id
)
).first()
customer = (
db.query(Customer)
.filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
.first()
)
if not customer:
raise CustomerNotFoundException(str(customer_id))
@@ -236,10 +238,7 @@ class CustomerService:
return customer
def get_customer_by_email(
self,
db: Session,
vendor_id: int,
email: str
self, db: Session, vendor_id: int, email: str
) -> Optional[Customer]:
"""
Get customer by email (vendor-scoped).
@@ -252,19 +251,20 @@ class CustomerService:
Returns:
Optional[Customer]: Customer object or None
"""
return db.query(Customer).filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == email.lower()
return (
db.query(Customer)
.filter(
and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
)
).first()
.first()
)
def update_customer(
self,
db: Session,
vendor_id: int,
customer_id: int,
customer_data: CustomerUpdate
self,
db: Session,
vendor_id: int,
customer_id: int,
customer_data: CustomerUpdate,
) -> Customer:
"""
Update customer profile.
@@ -290,13 +290,17 @@ class CustomerService:
for field, value in update_data.items():
if field == "email" and value:
# Check if new email already exists for this vendor
existing = db.query(Customer).filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == value.lower(),
Customer.id != customer_id
existing = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == value.lower(),
Customer.id != customer_id,
)
)
).first()
.first()
)
if existing:
raise DuplicateCustomerEmailException(value, "vendor")
@@ -317,15 +321,11 @@ class CustomerService:
db.rollback()
logger.error(f"Error updating customer: {str(e)}")
raise CustomerValidationException(
message="Failed to update customer",
details={"error": str(e)}
message="Failed to update customer", details={"error": str(e)}
)
def deactivate_customer(
self,
db: Session,
vendor_id: int,
customer_id: int
self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Deactivate customer account.
@@ -352,10 +352,7 @@ class CustomerService:
return customer
def update_customer_stats(
self,
db: Session,
customer_id: int,
order_total: float
self, db: Session, customer_id: int, order_total: float
) -> None:
"""
Update customer statistics after order.
@@ -377,10 +374,7 @@ class CustomerService:
logger.debug(f"Updated stats for customer {customer.email}")
def _generate_customer_number(
self,
db: Session,
vendor_id: int,
vendor_code: str
self, db: Session, vendor_id: int, vendor_code: str
) -> str:
"""
Generate unique customer number for vendor.
@@ -397,21 +391,23 @@ class CustomerService:
str: Unique customer number
"""
# Get count of customers for this vendor
count = db.query(Customer).filter(
Customer.vendor_id == vendor_id
).count()
count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
# Generate number with padding
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
# Ensure uniqueness (in case of deletions)
while db.query(Customer).filter(
while (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.customer_number == customer_number
Customer.customer_number == customer_number,
)
).first():
)
.first()
):
count += 1
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"

View File

@@ -5,27 +5,20 @@ from typing import List, Optional
from sqlalchemy.orm import Session
from app.exceptions import (
InventoryNotFoundException,
InsufficientInventoryException,
InvalidInventoryOperationException,
InventoryValidationException,
NegativeInventoryException,
InvalidQuantityException,
ValidationException,
ProductNotFoundException,
)
from models.schema.inventory import (
InventoryCreate,
InventoryAdjust,
InventoryUpdate,
InventoryReserve,
InventoryLocationResponse,
ProductInventorySummary
)
from app.exceptions import (InsufficientInventoryException,
InvalidInventoryOperationException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException,
NegativeInventoryException,
ProductNotFoundException, ValidationException)
from models.database.inventory import Inventory
from models.database.product import Product
from models.database.vendor import Vendor
from models.schema.inventory import (InventoryAdjust, InventoryCreate,
InventoryLocationResponse,
InventoryReserve, InventoryUpdate,
ProductInventorySummary)
logger = logging.getLogger(__name__)
@@ -34,7 +27,7 @@ class InventoryService:
"""Service for inventory operations with vendor isolation."""
def set_inventory(
self, db: Session, vendor_id: int, inventory_data: InventoryCreate
self, db: Session, vendor_id: int, inventory_data: InventoryCreate
) -> Inventory:
"""
Set exact inventory quantity for a product at a location (replaces existing).
@@ -93,7 +86,11 @@ class InventoryService:
)
return new_inventory
except (ProductNotFoundException, InvalidQuantityException, InventoryValidationException):
except (
ProductNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
@@ -102,7 +99,7 @@ class InventoryService:
raise ValidationException("Failed to set inventory")
def adjust_inventory(
self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
) -> Inventory:
"""
Adjust inventory by adding or removing quantity.
@@ -124,7 +121,9 @@ class InventoryService:
location = self._validate_location(inventory_data.location)
# Check if inventory exists
existing = self._get_inventory_entry(db, inventory_data.product_id, location)
existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if not existing:
# Create new if adding, error if removing
@@ -173,8 +172,12 @@ class InventoryService:
)
return existing
except (ProductNotFoundException, InventoryNotFoundException,
InsufficientInventoryException, InventoryValidationException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
@@ -183,7 +186,7 @@ class InventoryService:
raise ValidationException("Failed to adjust inventory")
def reserve_inventory(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Reserve inventory for an order (increases reserved_quantity).
@@ -231,8 +234,12 @@ class InventoryService:
)
return inventory
except (ProductNotFoundException, InventoryNotFoundException,
InsufficientInventoryException, InvalidQuantityException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
@@ -241,7 +248,7 @@ class InventoryService:
raise ValidationException("Failed to reserve inventory")
def release_reservation(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Release reserved inventory (decreases reserved_quantity).
@@ -287,7 +294,11 @@ class InventoryService:
)
return inventory
except (ProductNotFoundException, InventoryNotFoundException, InvalidQuantityException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
@@ -296,7 +307,7 @@ class InventoryService:
raise ValidationException("Failed to release reservation")
def fulfill_reservation(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Fulfill a reservation (decreases both quantity and reserved_quantity).
@@ -349,8 +360,12 @@ class InventoryService:
)
return inventory
except (ProductNotFoundException, InventoryNotFoundException,
InsufficientInventoryException, InvalidQuantityException):
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
@@ -359,7 +374,7 @@ class InventoryService:
raise ValidationException("Failed to fulfill reservation")
def get_product_inventory(
self, db: Session, vendor_id: int, product_id: int
self, db: Session, vendor_id: int, product_id: int
) -> ProductInventorySummary:
"""
Get inventory summary for a product across all locations.
@@ -376,9 +391,7 @@ class InventoryService:
product = self._get_vendor_product(db, vendor_id, product_id)
inventory_entries = (
db.query(Inventory)
.filter(Inventory.product_id == product_id)
.all()
db.query(Inventory).filter(Inventory.product_id == product_id).all()
)
if not inventory_entries:
@@ -425,8 +438,13 @@ class InventoryService:
raise ValidationException("Failed to retrieve product inventory")
def get_vendor_inventory(
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 100,
location: Optional[str] = None, low_stock_threshold: Optional[int] = None
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
location: Optional[str] = None,
low_stock_threshold: Optional[int] = None,
) -> List[Inventory]:
"""
Get all inventory for a vendor with filtering.
@@ -458,8 +476,11 @@ class InventoryService:
raise ValidationException("Failed to retrieve vendor inventory")
def update_inventory(
self, db: Session, vendor_id: int, inventory_id: int,
inventory_update: InventoryUpdate
self,
db: Session,
vendor_id: int,
inventory_id: int,
inventory_update: InventoryUpdate,
) -> Inventory:
"""Update inventory entry."""
try:
@@ -475,7 +496,9 @@ class InventoryService:
inventory.quantity = inventory_update.quantity
if inventory_update.reserved_quantity is not None:
self._validate_quantity(inventory_update.reserved_quantity, allow_zero=True)
self._validate_quantity(
inventory_update.reserved_quantity, allow_zero=True
)
inventory.reserved_quantity = inventory_update.reserved_quantity
if inventory_update.location:
@@ -488,7 +511,11 @@ class InventoryService:
logger.info(f"Updated inventory {inventory_id}")
return inventory
except (InventoryNotFoundException, InvalidQuantityException, InventoryValidationException):
except (
InventoryNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
@@ -496,9 +523,7 @@ class InventoryService:
logger.error(f"Error updating inventory: {str(e)}")
raise ValidationException("Failed to update inventory")
def delete_inventory(
self, db: Session, vendor_id: int, inventory_id: int
) -> bool:
def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
"""Delete inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
@@ -521,28 +546,30 @@ class InventoryService:
raise ValidationException("Failed to delete inventory")
# Private helper methods
def _get_vendor_product(self, db: Session, vendor_id: int, product_id: int) -> Product:
def _get_vendor_product(
self, db: Session, vendor_id: int, product_id: int
) -> Product:
"""Get product and verify it belongs to vendor."""
product = db.query(Product).filter(
Product.id == product_id,
Product.vendor_id == vendor_id
).first()
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(f"Product {product_id} not found in your catalog")
raise ProductNotFoundException(
f"Product {product_id} not found in your catalog"
)
return product
def _get_inventory_entry(
self, db: Session, product_id: int, location: str
self, db: Session, product_id: int, location: str
) -> Optional[Inventory]:
"""Get inventory entry by product and location."""
return (
db.query(Inventory)
.filter(
Inventory.product_id == product_id,
Inventory.location == location
)
.filter(Inventory.product_id == product_id, Inventory.location == location)
.first()
)

View File

@@ -5,20 +5,15 @@ from typing import List, Optional
from sqlalchemy.orm import Session
from app.exceptions import (
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException,
ValidationException,
)
from models.schema.marketplace_import_job import (
MarketplaceImportJobResponse,
MarketplaceImportJobRequest
)
from app.exceptions import (ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException,
ImportJobNotFoundException,
ImportJobNotOwnedException, ValidationException)
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (MarketplaceImportJobRequest,
MarketplaceImportJobResponse)
logger = logging.getLogger(__name__)
@@ -31,7 +26,7 @@ class MarketplaceImportJobService:
db: Session,
request: MarketplaceImportJobRequest,
vendor: Vendor, # CHANGED: Vendor object from middleware
user: User
user: User,
) -> MarketplaceImportJob:
"""
Create a new marketplace import job.
@@ -147,7 +142,9 @@ class MarketplaceImportJobService:
marketplace=job.marketplace,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None, # FIXED
vendor_name=job.vendor.name if job.vendor else None, # FIXED: from relationship
vendor_name=(
job.vendor.name if job.vendor else None
), # FIXED: from relationship
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,

View File

@@ -17,19 +17,20 @@ from typing import Generator, List, Optional, Tuple
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.exceptions import (
MarketplaceProductNotFoundException,
MarketplaceProductAlreadyExistsException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
ValidationException,
)
from app.services.marketplace_import_job_service import marketplace_import_job_service
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from app.exceptions import (InvalidMarketplaceProductDataException,
MarketplaceProductAlreadyExistsException,
MarketplaceProductNotFoundException,
MarketplaceProductValidationException,
ValidationException)
from app.services.marketplace_import_job_service import \
marketplace_import_job_service
from app.utils.data_processing import GTINProcessor, PriceProcessor
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
from models.schema.inventory import (InventoryLocationResponse,
InventorySummaryResponse)
from models.schema.marketplace_product import (MarketplaceProductCreate,
MarketplaceProductUpdate)
logger = logging.getLogger(__name__)
@@ -42,14 +43,18 @@ class MarketplaceProductService:
self.gtin_processor = GTINProcessor()
self.price_processor = PriceProcessor()
def create_product(self, db: Session, product_data: MarketplaceProductCreate) -> MarketplaceProduct:
def create_product(
self, db: Session, product_data: MarketplaceProductCreate
) -> MarketplaceProduct:
"""Create a new product with validation."""
try:
# Process and validate GTIN if provided
if product_data.gtin:
normalized_gtin = self.gtin_processor.normalize(product_data.gtin)
if not normalized_gtin:
raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin")
raise InvalidMarketplaceProductDataException(
"Invalid GTIN format", field="gtin"
)
product_data.gtin = normalized_gtin
# Process price if provided
@@ -70,11 +75,18 @@ class MarketplaceProductService:
product_data.marketplace = "Letzshop"
# Validate required fields
if not product_data.marketplace_product_id or not product_data.marketplace_product_id.strip():
raise MarketplaceProductValidationException("MarketplaceProduct ID is required", field="marketplace_product_id")
if (
not product_data.marketplace_product_id
or not product_data.marketplace_product_id.strip()
):
raise MarketplaceProductValidationException(
"MarketplaceProduct ID is required", field="marketplace_product_id"
)
if not product_data.title or not product_data.title.strip():
raise MarketplaceProductValidationException("MarketplaceProduct title is required", field="title")
raise MarketplaceProductValidationException(
"MarketplaceProduct title is required", field="title"
)
db_product = MarketplaceProduct(**product_data.model_dump())
db.add(db_product)
@@ -84,30 +96,47 @@ class MarketplaceProductService:
logger.info(f"Created product {db_product.marketplace_product_id}")
return db_product
except (InvalidMarketplaceProductDataException, MarketplaceProductValidationException):
except (
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
):
db.rollback()
raise # Re-raise custom exceptions
except IntegrityError as e:
db.rollback()
logger.error(f"Database integrity error: {str(e)}")
if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower():
raise MarketplaceProductAlreadyExistsException(product_data.marketplace_product_id)
raise MarketplaceProductAlreadyExistsException(
product_data.marketplace_product_id
)
else:
raise MarketplaceProductValidationException("Data integrity constraint violation")
raise MarketplaceProductValidationException(
"Data integrity constraint violation"
)
except Exception as e:
db.rollback()
logger.error(f"Error creating product: {str(e)}")
raise ValidationException("Failed to create product")
def get_product_by_id(self, db: Session, marketplace_product_id: str) -> Optional[MarketplaceProduct]:
def get_product_by_id(
self, db: Session, marketplace_product_id: str
) -> Optional[MarketplaceProduct]:
"""Get a product by its ID."""
try:
return db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
return (
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
)
.first()
)
except Exception as e:
logger.error(f"Error getting product {marketplace_product_id}: {str(e)}")
return None
def get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct:
def get_product_by_id_or_raise(
self, db: Session, marketplace_product_id: str
) -> MarketplaceProduct:
"""
Get a product by its ID or raise exception.
@@ -127,16 +156,16 @@ class MarketplaceProductService:
return product
def get_products_with_filters(
self,
db: Session,
skip: int = 0,
limit: int = 100,
brand: Optional[str] = None,
category: Optional[str] = None,
availability: Optional[str] = None,
marketplace: Optional[str] = None,
vendor_name: Optional[str] = None,
search: Optional[str] = None,
self,
db: Session,
skip: int = 0,
limit: int = 100,
brand: Optional[str] = None,
category: Optional[str] = None,
availability: Optional[str] = None,
marketplace: Optional[str] = None,
vendor_name: Optional[str] = None,
search: Optional[str] = None,
) -> Tuple[List[MarketplaceProduct], int]:
"""
Get products with filtering and pagination.
@@ -162,13 +191,19 @@ class MarketplaceProductService:
if brand:
query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%"))
if category:
query = query.filter(MarketplaceProduct.google_product_category.ilike(f"%{category}%"))
query = query.filter(
MarketplaceProduct.google_product_category.ilike(f"%{category}%")
)
if availability:
query = query.filter(MarketplaceProduct.availability == availability)
if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
if search:
# Search in title, description, marketplace, and name
search_term = f"%{search}%"
@@ -188,7 +223,12 @@ class MarketplaceProductService:
logger.error(f"Error getting products with filters: {str(e)}")
raise ValidationException("Failed to retrieve products")
def update_product(self, db: Session, marketplace_product_id: str, product_update: MarketplaceProductUpdate) -> MarketplaceProduct:
def update_product(
self,
db: Session,
marketplace_product_id: str,
product_update: MarketplaceProductUpdate,
) -> MarketplaceProduct:
"""Update product with validation."""
try:
product = self.get_product_by_id_or_raise(db, marketplace_product_id)
@@ -200,7 +240,9 @@ class MarketplaceProductService:
if "gtin" in update_data and update_data["gtin"]:
normalized_gtin = self.gtin_processor.normalize(update_data["gtin"])
if not normalized_gtin:
raise InvalidMarketplaceProductDataException("Invalid GTIN format", field="gtin")
raise InvalidMarketplaceProductDataException(
"Invalid GTIN format", field="gtin"
)
update_data["gtin"] = normalized_gtin
# Process price if being updated
@@ -217,8 +259,12 @@ class MarketplaceProductService:
raise InvalidMarketplaceProductDataException(str(e), field="price")
# Validate required fields if being updated
if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()):
raise MarketplaceProductValidationException("MarketplaceProduct title cannot be empty", field="title")
if "title" in update_data and (
not update_data["title"] or not update_data["title"].strip()
):
raise MarketplaceProductValidationException(
"MarketplaceProduct title cannot be empty", field="title"
)
for key, value in update_data.items():
setattr(product, key, value)
@@ -230,7 +276,11 @@ class MarketplaceProductService:
logger.info(f"Updated product {marketplace_product_id}")
return product
except (MarketplaceProductNotFoundException, InvalidMarketplaceProductDataException, MarketplaceProductValidationException):
except (
MarketplaceProductNotFoundException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -272,7 +322,9 @@ class MarketplaceProductService:
logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}")
raise ValidationException("Failed to delete product")
def get_inventory_info(self, db: Session, gtin: str) -> Optional[InventorySummaryResponse]:
def get_inventory_info(
self, db: Session, gtin: str
) -> Optional[InventorySummaryResponse]:
"""
Get inventory information for a product by GTIN.
@@ -290,7 +342,9 @@ class MarketplaceProductService:
total_quantity = sum(entry.quantity for entry in inventory_entries)
locations = [
InventoryLocationResponse(location=entry.location, quantity=entry.quantity)
InventoryLocationResponse(
location=entry.location, quantity=entry.quantity
)
for entry in inventory_entries
]
@@ -305,13 +359,14 @@ class MarketplaceProductService:
import csv
from io import StringIO
from typing import Generator, Optional
from sqlalchemy.orm import Session
def generate_csv_export(
self,
db: Session,
marketplace: Optional[str] = None,
vendor_name: Optional[str] = None,
self,
db: Session,
marketplace: Optional[str] = None,
vendor_name: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generate CSV export with streaming for memory efficiency and proper CSV escaping.
@@ -331,9 +386,18 @@ class MarketplaceProductService:
# Write header row
headers = [
"marketplace_product_id", "title", "description", "link", "image_link",
"availability", "price", "currency", "brand", "gtin",
"marketplace", "name"
"marketplace_product_id",
"title",
"description",
"link",
"image_link",
"availability",
"price",
"currency",
"brand",
"gtin",
"marketplace",
"name",
]
writer.writerow(headers)
yield output.getvalue()
@@ -350,9 +414,13 @@ class MarketplaceProductService:
# Apply marketplace filters
if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
products = query.offset(offset).limit(batch_size).all()
if not products:
@@ -392,8 +460,12 @@ class MarketplaceProductService:
"""Check if product exists by ID."""
try:
return (
db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
is not None
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.marketplace_product_id == marketplace_product_id
)
.first()
is not None
)
except Exception as e:
logger.error(f"Error checking if product exists: {str(e)}")
@@ -402,18 +474,27 @@ class MarketplaceProductService:
# Private helper methods
def _validate_product_data(self, product_data: dict) -> None:
"""Validate product data structure."""
required_fields = ['marketplace_product_id', 'title']
required_fields = ["marketplace_product_id", "title"]
for field in required_fields:
if field not in product_data or not product_data[field]:
raise MarketplaceProductValidationException(f"{field} is required", field=field)
raise MarketplaceProductValidationException(
f"{field} is required", field=field
)
def _normalize_product_data(self, product_data: dict) -> dict:
"""Normalize and clean product data."""
normalized = product_data.copy()
# Trim whitespace from string fields
string_fields = ['marketplace_product_id', 'title', 'description', 'brand', 'marketplace', 'name']
string_fields = [
"marketplace_product_id",
"title",
"description",
"brand",
"marketplace",
"name",
]
for field in string_fields:
if field in normalized and normalized[field]:
normalized[field] = normalized[field].strip()

View File

@@ -1 +1 @@
# File and media management services
# File and media management services

View File

@@ -1 +1 @@
# Application monitoring services
# Application monitoring services

View File

@@ -1 +1 @@
# Email/notification services
# Email/notification services

View File

@@ -9,24 +9,21 @@ This module provides:
"""
import logging
from datetime import datetime, timezone
from typing import List, Optional, Tuple
import random
import string
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from models.database.order import Order, OrderItem
from app.exceptions import (CustomerNotFoundException,
InsufficientInventoryException,
OrderNotFoundException, ValidationException)
from models.database.customer import Customer, CustomerAddress
from models.database.order import Order, OrderItem
from models.database.product import Product
from models.schema.order import OrderCreate, OrderUpdate, OrderAddressCreate
from app.exceptions import (
OrderNotFoundException,
ValidationException,
InsufficientInventoryException,
CustomerNotFoundException
)
from models.schema.order import OrderAddressCreate, OrderCreate, OrderUpdate
logger = logging.getLogger(__name__)
@@ -42,23 +39,27 @@ class OrderService:
Example: ORD-1-20250110-A1B2C3
"""
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
random_suffix = "".join(
random.choices(string.ascii_uppercase + string.digits, k=6)
)
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
# Ensure uniqueness
while db.query(Order).filter(Order.order_number == order_number).first():
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
random_suffix = "".join(
random.choices(string.ascii_uppercase + string.digits, k=6)
)
order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}"
return order_number
def _create_customer_address(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_data: OrderAddressCreate,
address_type: str
self,
db: Session,
vendor_id: int,
customer_id: int,
address_data: OrderAddressCreate,
address_type: str,
) -> CustomerAddress:
"""Create a customer address for order."""
address = CustomerAddress(
@@ -73,17 +74,14 @@ class OrderService:
city=address_data.city,
postal_code=address_data.postal_code,
country=address_data.country,
is_default=False
is_default=False,
)
db.add(address)
db.flush() # Get ID without committing
return address
def create_order(
self,
db: Session,
vendor_id: int,
order_data: OrderCreate
self, db: Session, vendor_id: int, order_data: OrderCreate
) -> Order:
"""
Create a new order.
@@ -104,12 +102,15 @@ class OrderService:
# Validate customer exists if provided
customer_id = order_data.customer_id
if customer_id:
customer = db.query(Customer).filter(
and_(
Customer.id == customer_id,
Customer.vendor_id == vendor_id
customer = (
db.query(Customer)
.filter(
and_(
Customer.id == customer_id, Customer.vendor_id == vendor_id
)
)
).first()
.first()
)
if not customer:
raise CustomerNotFoundException(str(customer_id))
@@ -124,7 +125,7 @@ class OrderService:
vendor_id=vendor_id,
customer_id=customer_id,
address_data=order_data.shipping_address,
address_type="shipping"
address_type="shipping",
)
# Create billing address (use shipping if not provided)
@@ -134,7 +135,7 @@ class OrderService:
vendor_id=vendor_id,
customer_id=customer_id,
address_data=order_data.billing_address,
address_type="billing"
address_type="billing",
)
else:
billing_address = shipping_address
@@ -145,23 +146,29 @@ class OrderService:
for item_data in order_data.items:
# Get product
product = db.query(Product).filter(
and_(
Product.id == item_data.product_id,
Product.vendor_id == vendor_id,
Product.is_active == True
product = (
db.query(Product)
.filter(
and_(
Product.id == item_data.product_id,
Product.vendor_id == vendor_id,
Product.is_active == True,
)
)
).first()
.first()
)
if not product:
raise ValidationException(f"Product {item_data.product_id} not found")
raise ValidationException(
f"Product {item_data.product_id} not found"
)
# Check inventory
if product.available_inventory < item_data.quantity:
raise InsufficientInventoryException(
product_id=product.id,
requested=item_data.quantity,
available=product.available_inventory
available=product.available_inventory,
)
# Calculate item total
@@ -172,14 +179,16 @@ class OrderService:
item_total = unit_price * item_data.quantity
subtotal += item_total
order_items_data.append({
"product_id": product.id,
"product_name": product.marketplace_product.title,
"product_sku": product.product_id,
"quantity": item_data.quantity,
"unit_price": unit_price,
"total_price": item_total
})
order_items_data.append(
{
"product_id": product.id,
"product_name": product.marketplace_product.title,
"product_sku": product.product_id,
"quantity": item_data.quantity,
"unit_price": unit_price,
"total_price": item_total,
}
)
# Calculate tax and shipping (simple implementation)
tax_amount = 0.0 # TODO: Implement tax calculation
@@ -205,7 +214,7 @@ class OrderService:
shipping_address_id=shipping_address.id,
billing_address_id=billing_address.id,
shipping_method=order_data.shipping_method,
customer_notes=order_data.customer_notes
customer_notes=order_data.customer_notes,
)
db.add(order)
@@ -213,10 +222,7 @@ class OrderService:
# Create order items
for item_data in order_items_data:
order_item = OrderItem(
order_id=order.id,
**item_data
)
order_item = OrderItem(order_id=order.id, **item_data)
db.add(order_item)
db.commit()
@@ -229,7 +235,11 @@ class OrderService:
return order
except (ValidationException, InsufficientInventoryException, CustomerNotFoundException):
except (
ValidationException,
InsufficientInventoryException,
CustomerNotFoundException,
):
db.rollback()
raise
except Exception as e:
@@ -237,19 +247,13 @@ class OrderService:
logger.error(f"Error creating order: {str(e)}")
raise ValidationException(f"Failed to create order: {str(e)}")
def get_order(
self,
db: Session,
vendor_id: int,
order_id: int
) -> Order:
def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order:
"""Get order by ID."""
order = db.query(Order).filter(
and_(
Order.id == order_id,
Order.vendor_id == vendor_id
)
).first()
order = (
db.query(Order)
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
.first()
)
if not order:
raise OrderNotFoundException(str(order_id))
@@ -257,13 +261,13 @@ class OrderService:
return order
def get_vendor_orders(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
status: Optional[str] = None,
customer_id: Optional[int] = None
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
status: Optional[str] = None,
customer_id: Optional[int] = None,
) -> Tuple[List[Order], int]:
"""
Get orders for vendor with filtering.
@@ -296,28 +300,20 @@ class OrderService:
return orders, total
def get_customer_orders(
self,
db: Session,
vendor_id: int,
customer_id: int,
skip: int = 0,
limit: int = 100
self,
db: Session,
vendor_id: int,
customer_id: int,
skip: int = 0,
limit: int = 100,
) -> Tuple[List[Order], int]:
"""Get orders for a specific customer."""
return self.get_vendor_orders(
db=db,
vendor_id=vendor_id,
skip=skip,
limit=limit,
customer_id=customer_id
db=db, vendor_id=vendor_id, skip=skip, limit=limit, customer_id=customer_id
)
def update_order_status(
self,
db: Session,
vendor_id: int,
order_id: int,
order_update: OrderUpdate
self, db: Session, vendor_id: int, order_id: int, order_update: OrderUpdate
) -> Order:
"""
Update order status and tracking information.

View File

@@ -1 +1 @@
# Payment processing services
# Payment processing services

View File

@@ -14,14 +14,11 @@ from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from app.exceptions import (
ProductNotFoundException,
ProductAlreadyExistsException,
ValidationException,
)
from models.schema.product import ProductCreate, ProductUpdate
from models.database.product import Product
from app.exceptions import (ProductAlreadyExistsException,
ProductNotFoundException, ValidationException)
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.schema.product import ProductCreate, ProductUpdate
logger = logging.getLogger(__name__)
@@ -45,10 +42,11 @@ class ProductService:
ProductNotFoundException: If product not found
"""
try:
product = db.query(Product).filter(
Product.id == product_id,
Product.vendor_id == vendor_id
).first()
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(f"Product {product_id} not found")
@@ -62,7 +60,7 @@ class ProductService:
raise ValidationException("Failed to retrieve product")
def create_product(
self, db: Session, vendor_id: int, product_data: ProductCreate
self, db: Session, vendor_id: int, product_data: ProductCreate
) -> Product:
"""
Add a product from marketplace to vendor catalog.
@@ -81,10 +79,14 @@ class ProductService:
"""
try:
# Verify marketplace product exists and belongs to vendor
marketplace_product = db.query(MarketplaceProduct).filter(
MarketplaceProduct.id == product_data.marketplace_product_id,
MarketplaceProduct.vendor_id == vendor_id
).first()
marketplace_product = (
db.query(MarketplaceProduct)
.filter(
MarketplaceProduct.id == product_data.marketplace_product_id,
MarketplaceProduct.vendor_id == vendor_id,
)
.first()
)
if not marketplace_product:
raise ValidationException(
@@ -92,10 +94,15 @@ class ProductService:
)
# Check if already in catalog
existing = db.query(Product).filter(
Product.vendor_id == vendor_id,
Product.marketplace_product_id == product_data.marketplace_product_id
).first()
existing = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.marketplace_product_id
== product_data.marketplace_product_id,
)
.first()
)
if existing:
raise ProductAlreadyExistsException(
@@ -122,9 +129,7 @@ class ProductService:
db.commit()
db.refresh(product)
logger.info(
f"Added product {product.id} to vendor {vendor_id} catalog"
)
logger.info(f"Added product {product.id} to vendor {vendor_id} catalog")
return product
except (ProductAlreadyExistsException, ValidationException):
@@ -136,7 +141,11 @@ class ProductService:
raise ValidationException("Failed to create product")
def update_product(
self, db: Session, vendor_id: int, product_id: int, product_update: ProductUpdate
self,
db: Session,
vendor_id: int,
product_id: int,
product_update: ProductUpdate,
) -> Product:
"""
Update product in vendor catalog.
@@ -202,13 +211,13 @@ class ProductService:
raise ValidationException("Failed to delete product")
def get_vendor_products(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
is_featured: Optional[bool] = None,
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
is_featured: Optional[bool] = None,
) -> Tuple[List[Product], int]:
"""
Get products in vendor catalog with filtering.

View File

@@ -1 +1 @@
# Search and indexing services
# Search and indexing services

View File

@@ -10,25 +10,21 @@ This module provides:
"""
import logging
from typing import Any, Dict, List
from datetime import datetime, timedelta
from typing import Any, Dict, List
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
VendorNotFoundException,
AdminOperationException,
)
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.database.inventory import Inventory
from models.database.vendor import Vendor
from models.database.order import Order
from models.database.user import User
from app.exceptions import AdminOperationException, VendorNotFoundException
from models.database.customer import Customer
from models.database.inventory import Inventory
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
@@ -62,63 +58,77 @@ class StatsService:
try:
# Catalog statistics
total_catalog_products = db.query(Product).filter(
Product.vendor_id == vendor_id,
Product.is_active == True
).count()
total_catalog_products = (
db.query(Product)
.filter(Product.vendor_id == vendor_id, Product.is_active == True)
.count()
)
featured_products = db.query(Product).filter(
Product.vendor_id == vendor_id,
Product.is_featured == True,
Product.is_active == True
).count()
featured_products = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.is_featured == True,
Product.is_active == True,
)
.count()
)
# Staging statistics
# TODO: This is fragile - MarketplaceProduct uses vendor_name (string) not vendor_id
# Should add vendor_id foreign key to MarketplaceProduct for robust querying
# For now, matching by vendor name which could fail if names don't match exactly
staging_products = db.query(MarketplaceProduct).filter(
MarketplaceProduct.vendor_name == vendor.name
).count()
staging_products = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.vendor_name == vendor.name)
.count()
)
# Inventory statistics
total_inventory = db.query(
func.sum(Inventory.quantity)
).filter(
Inventory.vendor_id == vendor_id
).scalar() or 0
total_inventory = (
db.query(func.sum(Inventory.quantity))
.filter(Inventory.vendor_id == vendor_id)
.scalar()
or 0
)
reserved_inventory = db.query(
func.sum(Inventory.reserved_quantity)
).filter(
Inventory.vendor_id == vendor_id
).scalar() or 0
reserved_inventory = (
db.query(func.sum(Inventory.reserved_quantity))
.filter(Inventory.vendor_id == vendor_id)
.scalar()
or 0
)
inventory_locations = db.query(
func.count(func.distinct(Inventory.location))
).filter(
Inventory.vendor_id == vendor_id
).scalar() or 0
inventory_locations = (
db.query(func.count(func.distinct(Inventory.location)))
.filter(Inventory.vendor_id == vendor_id)
.scalar()
or 0
)
# Import statistics
total_imports = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.vendor_id == vendor_id
).count()
total_imports = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.vendor_id == vendor_id)
.count()
)
successful_imports = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.status == "completed"
).count()
successful_imports = (
db.query(MarketplaceImportJob)
.filter(
MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.status == "completed",
)
.count()
)
# Orders
total_orders = db.query(Order).filter(
Order.vendor_id == vendor_id
).count()
total_orders = db.query(Order).filter(Order.vendor_id == vendor_id).count()
# Customers
total_customers = db.query(Customer).filter(
Customer.vendor_id == vendor_id
).count()
total_customers = (
db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
)
return {
"catalog": {
@@ -138,7 +148,11 @@ class StatsService:
"imports": {
"total_imports": total_imports,
"successful_imports": successful_imports,
"success_rate": (successful_imports / total_imports * 100) if total_imports > 0 else 0,
"success_rate": (
(successful_imports / total_imports * 100)
if total_imports > 0
else 0
),
},
"orders": {
"total_orders": total_orders,
@@ -151,16 +165,18 @@ class StatsService:
except VendorNotFoundException:
raise
except Exception as e:
logger.error(f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}")
logger.error(
f"Failed to retrieve vendor statistics for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException(
operation="get_vendor_stats",
reason=f"Database query failed: {str(e)}",
target_type="vendor",
target_id=str(vendor_id)
target_id=str(vendor_id),
)
def get_vendor_analytics(
self, db: Session, vendor_id: int, period: str = "30d"
self, db: Session, vendor_id: int, period: str = "30d"
) -> Dict[str, Any]:
"""
Get a specific vendor analytics for a time period.
@@ -188,21 +204,28 @@ class StatsService:
start_date = datetime.utcnow() - timedelta(days=days)
# Import activity
recent_imports = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.created_at >= start_date
).count()
recent_imports = (
db.query(MarketplaceImportJob)
.filter(
MarketplaceImportJob.vendor_id == vendor_id,
MarketplaceImportJob.created_at >= start_date,
)
.count()
)
# Products added to catalog
products_added = db.query(Product).filter(
Product.vendor_id == vendor_id,
Product.created_at >= start_date
).count()
products_added = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id, Product.created_at >= start_date
)
.count()
)
# Inventory changes
inventory_entries = db.query(Inventory).filter(
Inventory.vendor_id == vendor_id
).count()
inventory_entries = (
db.query(Inventory).filter(Inventory.vendor_id == vendor_id).count()
)
return {
"period": period,
@@ -221,12 +244,14 @@ class StatsService:
except VendorNotFoundException:
raise
except Exception as e:
logger.error(f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}")
logger.error(
f"Failed to retrieve vendor analytics for vendor {vendor_id}: {str(e)}"
)
raise AdminOperationException(
operation="get_vendor_analytics",
reason=f"Database query failed: {str(e)}",
target_type="vendor",
target_id=str(vendor_id)
target_id=str(vendor_id),
)
def get_vendor_statistics(self, db: Session) -> dict:
@@ -234,7 +259,9 @@ class StatsService:
try:
total_vendors = db.query(Vendor).count()
active_vendors = db.query(Vendor).filter(Vendor.is_active == True).count()
verified_vendors = db.query(Vendor).filter(Vendor.is_verified == True).count()
verified_vendors = (
db.query(Vendor).filter(Vendor.is_verified == True).count()
)
inactive_vendors = total_vendors - active_vendors
return {
@@ -242,13 +269,14 @@ class StatsService:
"active_vendors": active_vendors,
"inactive_vendors": inactive_vendors,
"verified_vendors": verified_vendors,
"verification_rate": (verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
"verification_rate": (
(verified_vendors / total_vendors * 100) if total_vendors > 0 else 0
),
}
except Exception as e:
logger.error(f"Failed to get vendor statistics: {str(e)}")
raise AdminOperationException(
operation="get_vendor_statistics",
reason="Database query failed"
operation="get_vendor_statistics", reason="Database query failed"
)
# ========================================================================
@@ -302,7 +330,7 @@ class StatsService:
logger.error(f"Failed to retrieve comprehensive statistics: {str(e)}")
raise AdminOperationException(
operation="get_comprehensive_stats",
reason=f"Database query failed: {str(e)}"
reason=f"Database query failed: {str(e)}",
)
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
@@ -323,8 +351,12 @@ class StatsService:
db.query(
MarketplaceProduct.marketplace,
func.count(MarketplaceProduct.id).label("total_products"),
func.count(func.distinct(MarketplaceProduct.vendor_name)).label("unique_vendors"),
func.count(func.distinct(MarketplaceProduct.brand)).label("unique_brands"),
func.count(func.distinct(MarketplaceProduct.vendor_name)).label(
"unique_vendors"
),
func.count(func.distinct(MarketplaceProduct.brand)).label(
"unique_brands"
),
)
.filter(MarketplaceProduct.marketplace.isnot(None))
.group_by(MarketplaceProduct.marketplace)
@@ -342,10 +374,12 @@ class StatsService:
]
except Exception as e:
logger.error(f"Failed to retrieve marketplace breakdown statistics: {str(e)}")
logger.error(
f"Failed to retrieve marketplace breakdown statistics: {str(e)}"
)
raise AdminOperationException(
operation="get_marketplace_breakdown_stats",
reason=f"Database query failed: {str(e)}"
reason=f"Database query failed: {str(e)}",
)
def get_user_statistics(self, db: Session) -> Dict[str, Any]:
@@ -372,13 +406,14 @@ class StatsService:
"active_users": active_users,
"inactive_users": inactive_users,
"admin_users": admin_users,
"activation_rate": (active_users / total_users * 100) if total_users > 0 else 0
"activation_rate": (
(active_users / total_users * 100) if total_users > 0 else 0
),
}
except Exception as e:
logger.error(f"Failed to get user statistics: {str(e)}")
raise AdminOperationException(
operation="get_user_statistics",
reason="Database query failed"
operation="get_user_statistics", reason="Database query failed"
)
def get_import_statistics(self, db: Session) -> Dict[str, Any]:
@@ -396,18 +431,22 @@ class StatsService:
"""
try:
total = db.query(MarketplaceImportJob).count()
completed = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.status == "completed"
).count()
failed = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.status == "failed"
).count()
completed = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "completed")
.count()
)
failed = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "failed")
.count()
)
return {
"total_imports": total,
"completed_imports": completed,
"failed_imports": failed,
"success_rate": (completed / total * 100) if total > 0 else 0
"success_rate": (completed / total * 100) if total > 0 else 0,
}
except Exception as e:
logger.error(f"Failed to get import statistics: {str(e)}")
@@ -415,7 +454,7 @@ class StatsService:
"total_imports": 0,
"completed_imports": 0,
"failed_imports": 0,
"success_rate": 0
"success_rate": 0,
}
def get_order_statistics(self, db: Session) -> Dict[str, Any]:
@@ -431,11 +470,7 @@ class StatsService:
Note:
TODO: Implement when Order model is fully available
"""
return {
"total_orders": 0,
"pending_orders": 0,
"completed_orders": 0
}
return {"total_orders": 0, "pending_orders": 0, "completed_orders": 0}
def get_product_statistics(self, db: Session) -> Dict[str, Any]:
"""
@@ -450,11 +485,7 @@ class StatsService:
Note:
TODO: Implement when Product model is fully available
"""
return {
"total_products": 0,
"active_products": 0,
"out_of_stock": 0
}
return {"total_products": 0, "active_products": 0, "out_of_stock": 0}
# ========================================================================
# PRIVATE HELPER METHODS
@@ -491,8 +522,7 @@ class StatsService:
return (
db.query(MarketplaceProduct.brand)
.filter(
MarketplaceProduct.brand.isnot(None),
MarketplaceProduct.brand != ""
MarketplaceProduct.brand.isnot(None), MarketplaceProduct.brand != ""
)
.distinct()
.count()

View File

@@ -9,17 +9,15 @@ This module provides:
"""
import logging
from typing import List, Dict, Any
from datetime import datetime, timezone
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.exceptions import (
ValidationException,
UnauthorizedVendorAccessException,
)
from models.database.vendor import VendorUser, Role
from app.exceptions import (UnauthorizedVendorAccessException,
ValidationException)
from models.database.user import User
from models.database.vendor import Role, VendorUser
logger = logging.getLogger(__name__)
@@ -28,7 +26,7 @@ class TeamService:
"""Service for team management operations."""
def get_team_members(
self, db: Session, vendor_id: int, current_user: User
self, db: Session, vendor_id: int, current_user: User
) -> List[Dict[str, Any]]:
"""
Get all team members for vendor.
@@ -42,23 +40,26 @@ class TeamService:
List of team members
"""
try:
vendor_users = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.is_active == True
).all()
vendor_users = (
db.query(VendorUser)
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.all()
)
members = []
for vu in vendor_users:
members.append({
"id": vu.user_id,
"email": vu.user.email,
"first_name": vu.user.first_name,
"last_name": vu.user.last_name,
"role": vu.role.name,
"role_id": vu.role_id,
"is_active": vu.is_active,
"joined_at": vu.created_at,
})
members.append(
{
"id": vu.user_id,
"email": vu.user.email,
"first_name": vu.user.first_name,
"last_name": vu.user.last_name,
"role": vu.role.name,
"role_id": vu.role_id,
"is_active": vu.is_active,
"joined_at": vu.created_at,
}
)
return members
@@ -67,7 +68,7 @@ class TeamService:
raise ValidationException("Failed to retrieve team members")
def invite_team_member(
self, db: Session, vendor_id: int, invitation_data: dict, current_user: User
self, db: Session, vendor_id: int, invitation_data: dict, current_user: User
) -> Dict[str, Any]:
"""
Invite a new team member.
@@ -95,12 +96,12 @@ class TeamService:
raise ValidationException("Failed to invite team member")
def update_team_member(
self,
db: Session,
vendor_id: int,
user_id: int,
update_data: dict,
current_user: User
self,
db: Session,
vendor_id: int,
user_id: int,
update_data: dict,
current_user: User,
) -> Dict[str, Any]:
"""
Update team member role or status.
@@ -116,10 +117,13 @@ class TeamService:
Updated member info
"""
try:
vendor_user = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == user_id
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
)
.first()
)
if not vendor_user:
raise ValidationException("Team member not found")
@@ -146,7 +150,7 @@ class TeamService:
raise ValidationException("Failed to update team member")
def remove_team_member(
self, db: Session, vendor_id: int, user_id: int, current_user: User
self, db: Session, vendor_id: int, user_id: int, current_user: User
) -> bool:
"""
Remove team member from vendor.
@@ -161,10 +165,13 @@ class TeamService:
True if removed
"""
try:
vendor_user = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor_id,
VendorUser.user_id == user_id
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor_id, VendorUser.user_id == user_id
)
.first()
)
if not vendor_user:
raise ValidationException("Team member not found")

View File

@@ -12,30 +12,28 @@ This module provides classes and functions for:
import logging
import secrets
from typing import List, Tuple, Optional
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions import (
VendorNotFoundException,
VendorDomainNotFoundException,
VendorDomainAlreadyExistsException,
InvalidDomainFormatException,
ReservedDomainException,
DomainNotVerifiedException,
DomainVerificationFailedException,
DomainAlreadyVerifiedException,
MultiplePrimaryDomainsException,
DNSVerificationException,
MaxDomainsReachedException,
UnauthorizedDomainAccessException,
ValidationException,
)
from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
from app.exceptions import (DNSVerificationException,
DomainAlreadyVerifiedException,
DomainNotVerifiedException,
DomainVerificationFailedException,
InvalidDomainFormatException,
MaxDomainsReachedException,
MultiplePrimaryDomainsException,
ReservedDomainException,
UnauthorizedDomainAccessException,
ValidationException,
VendorDomainAlreadyExistsException,
VendorDomainNotFoundException,
VendorNotFoundException)
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
from models.schema.vendor_domain import VendorDomainCreate, VendorDomainUpdate
logger = logging.getLogger(__name__)
@@ -45,13 +43,19 @@ class VendorDomainService:
def __init__(self):
self.max_domains_per_vendor = 10 # Configure as needed
self.reserved_subdomains = ['www', 'admin', 'api', 'mail', 'smtp', 'ftp', 'cpanel', 'webmail']
self.reserved_subdomains = [
"www",
"admin",
"api",
"mail",
"smtp",
"ftp",
"cpanel",
"webmail",
]
def add_domain(
self,
db: Session,
vendor_id: int,
domain_data: VendorDomainCreate
self, db: Session, vendor_id: int, domain_data: VendorDomainCreate
) -> VendorDomain:
"""
Add a custom domain to vendor.
@@ -85,12 +89,14 @@ class VendorDomainService:
# Check if domain already exists
if self._domain_exists(db, normalized_domain):
existing_domain = db.query(VendorDomain).filter(
VendorDomain.domain == normalized_domain
).first()
existing_domain = (
db.query(VendorDomain)
.filter(VendorDomain.domain == normalized_domain)
.first()
)
raise VendorDomainAlreadyExistsException(
normalized_domain,
existing_domain.vendor_id if existing_domain else None
existing_domain.vendor_id if existing_domain else None,
)
# If setting as primary, unset other primary domains
@@ -104,8 +110,8 @@ class VendorDomainService:
is_primary=domain_data.is_primary,
verification_token=secrets.token_urlsafe(32),
is_verified=False, # Requires DNS verification
is_active=False, # Cannot be active until verified
ssl_status="pending"
is_active=False, # Cannot be active until verified
ssl_status="pending",
)
db.add(new_domain)
@@ -120,7 +126,7 @@ class VendorDomainService:
VendorDomainAlreadyExistsException,
MaxDomainsReachedException,
InvalidDomainFormatException,
ReservedDomainException
ReservedDomainException,
):
db.rollback()
raise
@@ -129,11 +135,7 @@ class VendorDomainService:
logger.error(f"Error adding domain: {str(e)}")
raise ValidationException("Failed to add domain")
def get_vendor_domains(
self,
db: Session,
vendor_id: int
) -> List[VendorDomain]:
def get_vendor_domains(self, db: Session, vendor_id: int) -> List[VendorDomain]:
"""
Get all domains for a vendor.
@@ -151,12 +153,14 @@ class VendorDomainService:
# Verify vendor exists
self._get_vendor_by_id_or_raise(db, vendor_id)
domains = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id
).order_by(
VendorDomain.is_primary.desc(),
VendorDomain.created_at.desc()
).all()
domains = (
db.query(VendorDomain)
.filter(VendorDomain.vendor_id == vendor_id)
.order_by(
VendorDomain.is_primary.desc(), VendorDomain.created_at.desc()
)
.all()
)
return domains
@@ -166,11 +170,7 @@ class VendorDomainService:
logger.error(f"Error getting vendor domains: {str(e)}")
raise ValidationException("Failed to retrieve domains")
def get_domain_by_id(
self,
db: Session,
domain_id: int
) -> VendorDomain:
def get_domain_by_id(self, db: Session, domain_id: int) -> VendorDomain:
"""
Get domain by ID.
@@ -190,10 +190,7 @@ class VendorDomainService:
return domain
def update_domain(
self,
db: Session,
domain_id: int,
domain_update: VendorDomainUpdate
self, db: Session, domain_id: int, domain_update: VendorDomainUpdate
) -> VendorDomain:
"""
Update domain settings.
@@ -215,7 +212,9 @@ class VendorDomainService:
# If setting as primary, unset other primary domains
if domain_update.is_primary:
self._unset_primary_domains(db, domain.vendor_id, exclude_domain_id=domain_id)
self._unset_primary_domains(
db, domain.vendor_id, exclude_domain_id=domain_id
)
domain.is_primary = True
# If activating, check verification
@@ -240,11 +239,7 @@ class VendorDomainService:
logger.error(f"Error updating domain: {str(e)}")
raise ValidationException("Failed to update domain")
def delete_domain(
self,
db: Session,
domain_id: int
) -> str:
def delete_domain(self, db: Session, domain_id: int) -> str:
"""
Delete a custom domain.
@@ -277,11 +272,7 @@ class VendorDomainService:
logger.error(f"Error deleting domain: {str(e)}")
raise ValidationException("Failed to delete domain")
def verify_domain(
self,
db: Session,
domain_id: int
) -> Tuple[VendorDomain, str]:
def verify_domain(self, db: Session, domain_id: int) -> Tuple[VendorDomain, str]:
"""
Verify domain ownership via DNS TXT record.
@@ -313,8 +304,7 @@ class VendorDomainService:
# Query DNS TXT records
try:
txt_records = dns.resolver.resolve(
f"_wizamart-verify.{domain.domain}",
'TXT'
f"_wizamart-verify.{domain.domain}", "TXT"
)
# Check if verification token is present
@@ -332,42 +322,33 @@ class VendorDomainService:
# Token not found
raise DomainVerificationFailedException(
domain.domain,
"Verification token not found in DNS records"
domain.domain, "Verification token not found in DNS records"
)
except dns.resolver.NXDOMAIN:
raise DomainVerificationFailedException(
domain.domain,
f"DNS record _wizamart-verify.{domain.domain} not found"
f"DNS record _wizamart-verify.{domain.domain} not found",
)
except dns.resolver.NoAnswer:
raise DomainVerificationFailedException(
domain.domain,
"No TXT records found for verification"
domain.domain, "No TXT records found for verification"
)
except Exception as dns_error:
raise DNSVerificationException(
domain.domain,
str(dns_error)
)
raise DNSVerificationException(domain.domain, str(dns_error))
except (
VendorDomainNotFoundException,
DomainAlreadyVerifiedException,
DomainVerificationFailedException,
DNSVerificationException
DNSVerificationException,
):
raise
except Exception as e:
logger.error(f"Error verifying domain: {str(e)}")
raise ValidationException("Failed to verify domain")
def get_verification_instructions(
self,
db: Session,
domain_id: int
) -> dict:
def get_verification_instructions(self, db: Session, domain_id: int) -> dict:
"""
Get DNS verification instructions for domain.
@@ -390,20 +371,20 @@ class VendorDomainService:
"step1": "Go to your domain's DNS settings (at your domain registrar)",
"step2": "Add a new TXT record with the following values:",
"step3": "Wait for DNS propagation (5-15 minutes)",
"step4": "Click 'Verify Domain' button in admin panel"
"step4": "Click 'Verify Domain' button in admin panel",
},
"txt_record": {
"type": "TXT",
"name": "_wizamart-verify",
"value": domain.verification_token,
"ttl": 3600
"ttl": 3600,
},
"common_registrars": {
"Cloudflare": "https://dash.cloudflare.com",
"GoDaddy": "https://dcc.godaddy.com/manage/dns",
"Namecheap": "https://www.namecheap.com/myaccount/domain-list/",
"Google Domains": "https://domains.google.com"
}
"Google Domains": "https://domains.google.com",
},
}
# Private helper methods
@@ -416,36 +397,33 @@ class VendorDomainService:
def _check_domain_limit(self, db: Session, vendor_id: int) -> None:
"""Check if vendor has reached maximum domain limit."""
domain_count = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id
).count()
domain_count = (
db.query(VendorDomain).filter(VendorDomain.vendor_id == vendor_id).count()
)
if domain_count >= self.max_domains_per_vendor:
raise MaxDomainsReachedException(vendor_id, self.max_domains_per_vendor)
def _domain_exists(self, db: Session, domain: str) -> bool:
"""Check if domain already exists in system."""
return db.query(VendorDomain).filter(
VendorDomain.domain == domain
).first() is not None
return (
db.query(VendorDomain).filter(VendorDomain.domain == domain).first()
is not None
)
def _validate_domain_format(self, domain: str) -> None:
"""Validate domain format and check for reserved subdomains."""
# Check for reserved subdomains
first_part = domain.split('.')[0]
first_part = domain.split(".")[0]
if first_part in self.reserved_subdomains:
raise ReservedDomainException(domain, first_part)
def _unset_primary_domains(
self,
db: Session,
vendor_id: int,
exclude_domain_id: Optional[int] = None
self, db: Session, vendor_id: int, exclude_domain_id: Optional[int] = None
) -> None:
"""Unset all primary domains for vendor."""
query = db.query(VendorDomain).filter(
VendorDomain.vendor_id == vendor_id,
VendorDomain.is_primary == True
VendorDomain.vendor_id == vendor_id, VendorDomain.is_primary == True
)
if exclude_domain_id:

View File

@@ -15,22 +15,19 @@ from typing import List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
VendorNotFoundException,
VendorAlreadyExistsException,
UnauthorizedVendorAccessException,
InvalidVendorDataException,
MarketplaceProductNotFoundException,
ProductAlreadyExistsException,
MaxVendorsReachedException,
ValidationException,
)
from models.schema.vendor import VendorCreate
from models.schema.product import ProductCreate
from app.exceptions import (InvalidVendorDataException,
MarketplaceProductNotFoundException,
MaxVendorsReachedException,
ProductAlreadyExistsException,
UnauthorizedVendorAccessException,
ValidationException, VendorAlreadyExistsException,
VendorNotFoundException)
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.product import ProductCreate
from models.schema.vendor import VendorCreate
logger = logging.getLogger(__name__)
@@ -39,7 +36,7 @@ class VendorService:
"""Service class for vendor operations following the application's service pattern."""
def create_vendor(
self, db: Session, vendor_data: VendorCreate, current_user: User
self, db: Session, vendor_data: VendorCreate, current_user: User
) -> Vendor:
"""
Create a new vendor.
@@ -47,7 +44,7 @@ class VendorService:
Args:
db: Database session
vendor_data: Vendor creation data
current_user: User creating the vendor
current_user: User creating the vendor
Returns:
Created vendor object
@@ -91,7 +88,11 @@ class VendorService:
)
return new_vendor
except (VendorAlreadyExistsException, MaxVendorsReachedException, InvalidVendorDataException):
except (
VendorAlreadyExistsException,
MaxVendorsReachedException,
InvalidVendorDataException,
):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -100,13 +101,13 @@ class VendorService:
raise ValidationException("Failed to create vendor ")
def get_vendors(
self,
db: Session,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
verified_only: bool = False,
self,
db: Session,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
verified_only: bool = False,
) -> Tuple[List[Vendor], int]:
"""
Get vendors with filtering.
@@ -129,7 +130,10 @@ class VendorService:
if current_user.role != "admin":
query = query.filter(
(Vendor.is_active == True)
& ((Vendor.is_verified == True) | (Vendor.owner_user_id == current_user.id))
& (
(Vendor.is_verified == True)
| (Vendor.owner_user_id == current_user.id)
)
)
else:
# Admin can apply filters
@@ -147,14 +151,16 @@ class VendorService:
logger.error(f"Error getting vendors: {str(e)}")
raise ValidationException("Failed to retrieve vendors")
def get_vendor_by_code(self, db: Session, vendor_code: str, current_user: User) -> Vendor:
def get_vendor_by_code(
self, db: Session, vendor_code: str, current_user: User
) -> Vendor:
"""
Get vendor by vendor code with access control.
Args:
db: Database session
vendor_code: Vendor code to find
current_user: Current user requesting the vendor
current_user: Current user requesting the vendor
Returns:
Vendor object
@@ -170,14 +176,14 @@ class VendorService:
.first()
)
if not vendor :
if not vendor:
raise VendorNotFoundException(vendor_code)
# Check access permissions
if not self._can_access_vendor(vendor, current_user):
raise UnauthorizedVendorAccessException(vendor_code, current_user.id)
return vendor
return vendor
except (VendorNotFoundException, UnauthorizedVendorAccessException):
raise # Re-raise custom exceptions
@@ -186,7 +192,7 @@ class VendorService:
raise ValidationException("Failed to retrieve vendor ")
def add_product_to_catalog(
self, db: Session, vendor : Vendor, product: ProductCreate
self, db: Session, vendor: Vendor, product: ProductCreate
) -> Product:
"""
Add existing product to vendor catalog with vendor -specific settings.
@@ -201,15 +207,19 @@ class VendorService:
Raises:
MarketplaceProductNotFoundException: If product not found
ProductAlreadyExistsException: If product already in vendor
ProductAlreadyExistsException: If product already in vendor
"""
try:
# Check if product exists
marketplace_product = self._get_product_by_id_or_raise(db, product.marketplace_product_id)
marketplace_product = self._get_product_by_id_or_raise(
db, product.marketplace_product_id
)
# Check if product already in vendor
# Check if product already in vendor
if self._product_in_catalog(db, vendor.id, marketplace_product.id):
raise ProductAlreadyExistsException(vendor.vendor_code, product.marketplace_product_id)
raise ProductAlreadyExistsException(
vendor.vendor_code, product.marketplace_product_id
)
# Create vendor -product association
new_product = Product(
@@ -225,7 +235,9 @@ class VendorService:
# Load the product relationship
db.refresh(new_product)
logger.info(f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}")
logger.info(
f"MarketplaceProduct {product.marketplace_product_id} added to vendor {vendor.vendor_code}"
)
return new_product
except (MarketplaceProductNotFoundException, ProductAlreadyExistsException):
@@ -237,14 +249,14 @@ class VendorService:
raise ValidationException("Failed to add product to vendor ")
def get_products(
self,
db: Session,
vendor : Vendor,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
featured_only: bool = False,
self,
db: Session,
vendor: Vendor,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
featured_only: bool = False,
) -> Tuple[List[Product], int]:
"""
Get products in vendor catalog with filtering.
@@ -267,7 +279,9 @@ class VendorService:
try:
# Check access permissions
if not self._can_access_vendor(vendor, current_user):
raise UnauthorizedVendorAccessException(vendor.vendor_code, current_user.id)
raise UnauthorizedVendorAccessException(
vendor.vendor_code, current_user.id
)
# Query vendor products
query = db.query(Product).filter(Product.vendor_id == vendor.id)
@@ -292,17 +306,20 @@ class VendorService:
def _validate_vendor_data(self, vendor_data: VendorCreate) -> None:
"""Validate vendor creation data."""
if not vendor_data.vendor_code or not vendor_data.vendor_code.strip():
raise InvalidVendorDataException("Vendor code is required", field="vendor_code")
raise InvalidVendorDataException(
"Vendor code is required", field="vendor_code"
)
if not vendor_data.vendor_name or not vendor_data.vendor_name.strip():
raise InvalidVendorDataException("Vendor name is required", field="name")
# Validate vendor code format (alphanumeric, underscores, hyphens)
import re
if not re.match(r'^[A-Za-z0-9_-]+$', vendor_data.vendor_code):
if not re.match(r"^[A-Za-z0-9_-]+$", vendor_data.vendor_code):
raise InvalidVendorDataException(
"Vendor code can only contain letters, numbers, underscores, and hyphens",
field="vendor_code"
field="vendor_code",
)
def _check_vendor_limit(self, db: Session, user: User) -> None:
@@ -310,7 +327,9 @@ class VendorService:
if user.role == "admin":
return # Admins have no limit
user_vendor_count = db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
user_vendor_count = (
db.query(Vendor).filter(Vendor.owner_user_id == user.id).count()
)
max_vendors = 5 # Configure this as needed
if user_vendor_count >= max_vendors:
@@ -319,30 +338,40 @@ class VendorService:
def _vendor_code_exists(self, db: Session, vendor_code: str) -> bool:
"""Check if vendor code already exists (case-insensitive)."""
return (
db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == vendor_code.upper())
.first() is not None
db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == vendor_code.upper())
.first()
is not None
)
def _get_product_by_id_or_raise(self, db: Session, marketplace_product_id: str) -> MarketplaceProduct:
def _get_product_by_id_or_raise(
self, db: Session, marketplace_product_id: str
) -> MarketplaceProduct:
"""Get product by ID or raise exception."""
product = db.query(MarketplaceProduct).filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id).first()
product = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.marketplace_product_id == marketplace_product_id)
.first()
)
if not product:
raise MarketplaceProductNotFoundException(marketplace_product_id)
return product
def _product_in_catalog(self, db: Session, vendor_id: int, marketplace_product_id: int) -> bool:
def _product_in_catalog(
self, db: Session, vendor_id: int, marketplace_product_id: int
) -> bool:
"""Check if product is already in vendor."""
return (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.marketplace_product_id == marketplace_product_id
)
.first() is not None
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.marketplace_product_id == marketplace_product_id,
)
.first()
is not None
)
def _can_access_vendor(self, vendor : Vendor, user: User) -> bool:
def _can_access_vendor(self, vendor: Vendor, user: User) -> bool:
"""Check if user can access vendor."""
# Admins and owners can always access
if user.role == "admin" or vendor.owner_user_id == user.id:
@@ -351,9 +380,10 @@ class VendorService:
# Others can only access active and verified vendors
return vendor.is_active and vendor.is_verified
def _is_vendor_owner(self, vendor : Vendor, user: User) -> bool:
def _is_vendor_owner(self, vendor: Vendor, user: User) -> bool:
"""Check if user is vendor owner."""
return vendor.owner_user_id == user.id
# Create service instance following the same pattern as other services
vendor_service = VendorService()

View File

@@ -11,23 +11,20 @@ Handles:
import logging
import secrets
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.permissions import get_preset_permissions
from app.exceptions import (
TeamMemberAlreadyExistsException,
InvalidInvitationTokenException,
TeamInvitationAlreadyAcceptedException,
MaxTeamMembersReachedException,
UserNotFoundException,
VendorNotFoundException,
CannotRemoveOwnerException,
)
from models.database.user import User
from models.database.vendor import Vendor, VendorUser, VendorUserType, Role
from app.exceptions import (CannotRemoveOwnerException,
InvalidInvitationTokenException,
MaxTeamMembersReachedException,
TeamInvitationAlreadyAcceptedException,
TeamMemberAlreadyExistsException,
UserNotFoundException, VendorNotFoundException)
from middleware.auth import AuthManager
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser, VendorUserType
logger = logging.getLogger(__name__)
@@ -40,13 +37,13 @@ class VendorTeamService:
self.max_team_members = 50 # Configure as needed
def invite_team_member(
self,
db: Session,
vendor: Vendor,
inviter: User,
email: str,
role_name: str,
custom_permissions: Optional[List[str]] = None,
self,
db: Session,
vendor: Vendor,
inviter: User,
email: str,
role_name: str,
custom_permissions: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Invite a new team member to a vendor.
@@ -69,10 +66,14 @@ class VendorTeamService:
"""
try:
# Check team size limit
current_team_size = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor.id,
VendorUser.is_active == True,
).count()
current_team_size = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.is_active == True,
)
.count()
)
if current_team_size >= self.max_team_members:
raise MaxTeamMembersReachedException(
@@ -85,22 +86,34 @@ class VendorTeamService:
if user:
# Check if already a member
existing_membership = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user.id,
).first()
existing_membership = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user.id,
)
.first()
)
if existing_membership:
if existing_membership.is_active:
raise TeamMemberAlreadyExistsException(email, vendor.vendor_code)
raise TeamMemberAlreadyExistsException(
email, vendor.vendor_code
)
# Reactivate old membership
existing_membership.is_active = False # Will be activated on acceptance
existing_membership.invitation_token = self._generate_invitation_token()
existing_membership.is_active = (
False # Will be activated on acceptance
)
existing_membership.invitation_token = (
self._generate_invitation_token()
)
existing_membership.invitation_sent_at = datetime.utcnow()
existing_membership.invitation_accepted_at = None
db.commit()
logger.info(f"Re-invited user {email} to vendor {vendor.vendor_code}")
logger.info(
f"Re-invited user {email} to vendor {vendor.vendor_code}"
)
return {
"invitation_token": existing_membership.invitation_token,
"email": email,
@@ -108,7 +121,7 @@ class VendorTeamService:
}
else:
# Create new user account (inactive until invitation accepted)
username = email.split('@')[0]
username = email.split("@")[0]
# Ensure unique username
base_username = username
counter = 1
@@ -179,12 +192,12 @@ class VendorTeamService:
raise
def accept_invitation(
self,
db: Session,
invitation_token: str,
password: str,
first_name: Optional[str] = None,
last_name: Optional[str] = None,
self,
db: Session,
invitation_token: str,
password: str,
first_name: Optional[str] = None,
last_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
Accept a team invitation and activate account.
@@ -201,9 +214,13 @@ class VendorTeamService:
"""
try:
# Find invitation
vendor_user = db.query(VendorUser).filter(
VendorUser.invitation_token == invitation_token,
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.invitation_token == invitation_token,
)
.first()
)
if not vendor_user:
raise InvalidInvitationTokenException()
@@ -247,7 +264,10 @@ class VendorTeamService:
"role": vendor_user.role.name if vendor_user.role else "member",
}
except (InvalidInvitationTokenException, TeamInvitationAlreadyAcceptedException):
except (
InvalidInvitationTokenException,
TeamInvitationAlreadyAcceptedException,
):
raise
except Exception as e:
db.rollback()
@@ -255,10 +275,10 @@ class VendorTeamService:
raise
def remove_team_member(
self,
db: Session,
vendor: Vendor,
user_id: int,
self,
db: Session,
vendor: Vendor,
user_id: int,
) -> bool:
"""
Remove a team member from a vendor.
@@ -274,10 +294,14 @@ class VendorTeamService:
True if removed
"""
try:
vendor_user = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id,
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id,
)
.first()
)
if not vendor_user:
raise UserNotFoundException(str(user_id))
@@ -301,12 +325,12 @@ class VendorTeamService:
raise
def update_member_role(
self,
db: Session,
vendor: Vendor,
user_id: int,
new_role_name: str,
custom_permissions: Optional[List[str]] = None,
self,
db: Session,
vendor: Vendor,
user_id: int,
new_role_name: str,
custom_permissions: Optional[List[str]] = None,
) -> VendorUser:
"""
Update a team member's role.
@@ -322,10 +346,14 @@ class VendorTeamService:
Updated VendorUser
"""
try:
vendor_user = db.query(VendorUser).filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id,
).first()
vendor_user = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.user_id == user_id,
)
.first()
)
if not vendor_user:
raise UserNotFoundException(str(user_id))
@@ -360,10 +388,10 @@ class VendorTeamService:
raise
def get_team_members(
self,
db: Session,
vendor: Vendor,
include_inactive: bool = False,
self,
db: Session,
vendor: Vendor,
include_inactive: bool = False,
) -> List[Dict[str, Any]]:
"""
Get all team members for a vendor.
@@ -387,20 +415,22 @@ class VendorTeamService:
members = []
for vu in vendor_users:
members.append({
"id": vu.user.id,
"email": vu.user.email,
"username": vu.user.username,
"full_name": vu.user.full_name,
"user_type": vu.user_type,
"role": vu.role.name if vu.role else "owner",
"permissions": vu.get_all_permissions(),
"is_active": vu.is_active,
"is_owner": vu.is_owner,
"invitation_pending": vu.is_invitation_pending,
"invited_at": vu.invitation_sent_at,
"accepted_at": vu.invitation_accepted_at,
})
members.append(
{
"id": vu.user.id,
"email": vu.user.email,
"username": vu.user.username,
"full_name": vu.user.full_name,
"user_type": vu.user_type,
"role": vu.role.name if vu.role else "owner",
"permissions": vu.get_all_permissions(),
"is_active": vu.is_active,
"is_owner": vu.is_owner,
"invitation_pending": vu.is_invitation_pending,
"invited_at": vu.invitation_sent_at,
"accepted_at": vu.invitation_accepted_at,
}
)
return members
@@ -411,18 +441,22 @@ class VendorTeamService:
return secrets.token_urlsafe(32)
def _get_or_create_role(
self,
db: Session,
vendor: Vendor,
role_name: str,
custom_permissions: Optional[List[str]] = None,
self,
db: Session,
vendor: Vendor,
role_name: str,
custom_permissions: Optional[List[str]] = None,
) -> Role:
"""Get existing role or create new one with preset/custom permissions."""
# Try to find existing role with same name
role = db.query(Role).filter(
Role.vendor_id == vendor.id,
Role.name == role_name,
).first()
role = (
db.query(Role)
.filter(
Role.vendor_id == vendor.id,
Role.name == role_name,
)
.first()
)
if role and custom_permissions is None:
# Use existing role

View File

@@ -8,32 +8,24 @@ Handles theme CRUD operations, preset application, and validation.
import logging
import re
from typing import Optional, Dict, List
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.theme_presets import (THEME_PRESETS, apply_preset,
get_available_presets, get_preset_preview)
from app.exceptions.vendor import VendorNotFoundException
from app.exceptions.vendor_theme import (InvalidColorFormatException,
InvalidFontFamilyException,
InvalidThemeDataException,
ThemeOperationException,
ThemePresetAlreadyAppliedException,
ThemePresetNotFoundException,
ThemeValidationException,
VendorThemeNotFoundException)
from models.database.vendor import Vendor
from models.database.vendor_theme import VendorTheme
from models.schema.vendor_theme import (
VendorThemeUpdate,
ThemePresetPreview
)
from app.exceptions.vendor import VendorNotFoundException
from app.exceptions.vendor_theme import (
VendorThemeNotFoundException,
InvalidThemeDataException,
ThemePresetNotFoundException,
ThemeValidationException,
InvalidColorFormatException,
InvalidFontFamilyException,
ThemePresetAlreadyAppliedException,
ThemeOperationException
)
from app.core.theme_presets import (
apply_preset,
get_available_presets,
get_preset_preview,
THEME_PRESETS
)
from models.schema.vendor_theme import ThemePresetPreview, VendorThemeUpdate
logger = logging.getLogger(__name__)
@@ -71,9 +63,9 @@ class VendorThemeService:
Raises:
VendorNotFoundException: If vendor not found
"""
vendor = db.query(Vendor).filter(
Vendor.vendor_code == vendor_code.upper()
).first()
vendor = (
db.query(Vendor).filter(Vendor.vendor_code == vendor_code.upper()).first()
)
if not vendor:
self.logger.warning(f"Vendor not found: {vendor_code}")
@@ -105,12 +97,12 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
if not theme:
self.logger.info(f"No custom theme for vendor {vendor_code}, returning default")
self.logger.info(
f"No custom theme for vendor {vendor_code}, returning default"
)
return self._get_default_theme()
return theme.to_dict()
@@ -130,23 +122,16 @@ class VendorThemeService:
"accent": "#ec4899",
"background": "#ffffff",
"text": "#1f2937",
"border": "#e5e7eb"
},
"fonts": {
"heading": "Inter, sans-serif",
"body": "Inter, sans-serif"
"border": "#e5e7eb",
},
"fonts": {"heading": "Inter, sans-serif", "body": "Inter, sans-serif"},
"branding": {
"logo": None,
"logo_dark": None,
"favicon": None,
"banner": None
},
"layout": {
"style": "grid",
"header": "fixed",
"product_card": "modern"
"banner": None,
},
"layout": {"style": "grid", "header": "fixed", "product_card": "modern"},
"social_links": {},
"custom_css": None,
"css_variables": {
@@ -158,7 +143,7 @@ class VendorThemeService:
"--color-border": "#e5e7eb",
"--font-heading": "Inter, sans-serif",
"--font-body": "Inter, sans-serif",
}
},
}
# ============================================================================
@@ -166,10 +151,7 @@ class VendorThemeService:
# ============================================================================
def update_theme(
self,
db: Session,
vendor_code: str,
theme_data: VendorThemeUpdate
self, db: Session, vendor_code: str, theme_data: VendorThemeUpdate
) -> VendorTheme:
"""
Update or create theme for vendor.
@@ -194,9 +176,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = (
db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
)
if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -224,15 +206,11 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to update theme for vendor {vendor_code}: {e}")
raise ThemeOperationException(
operation="update",
vendor_code=vendor_code,
reason=str(e)
operation="update", vendor_code=vendor_code, reason=str(e)
)
def _apply_theme_updates(
self,
theme: VendorTheme,
theme_data: VendorThemeUpdate
self, theme: VendorTheme, theme_data: VendorThemeUpdate
) -> None:
"""
Apply theme updates to theme object.
@@ -251,30 +229,30 @@ class VendorThemeService:
# Update fonts
if theme_data.fonts:
if theme_data.fonts.get('heading'):
theme.font_family_heading = theme_data.fonts['heading']
if theme_data.fonts.get('body'):
theme.font_family_body = theme_data.fonts['body']
if theme_data.fonts.get("heading"):
theme.font_family_heading = theme_data.fonts["heading"]
if theme_data.fonts.get("body"):
theme.font_family_body = theme_data.fonts["body"]
# Update branding
if theme_data.branding:
if theme_data.branding.get('logo') is not None:
theme.logo_url = theme_data.branding['logo']
if theme_data.branding.get('logo_dark') is not None:
theme.logo_dark_url = theme_data.branding['logo_dark']
if theme_data.branding.get('favicon') is not None:
theme.favicon_url = theme_data.branding['favicon']
if theme_data.branding.get('banner') is not None:
theme.banner_url = theme_data.branding['banner']
if theme_data.branding.get("logo") is not None:
theme.logo_url = theme_data.branding["logo"]
if theme_data.branding.get("logo_dark") is not None:
theme.logo_dark_url = theme_data.branding["logo_dark"]
if theme_data.branding.get("favicon") is not None:
theme.favicon_url = theme_data.branding["favicon"]
if theme_data.branding.get("banner") is not None:
theme.banner_url = theme_data.branding["banner"]
# Update layout
if theme_data.layout:
if theme_data.layout.get('style'):
theme.layout_style = theme_data.layout['style']
if theme_data.layout.get('header'):
theme.header_style = theme_data.layout['header']
if theme_data.layout.get('product_card'):
theme.product_card_style = theme_data.layout['product_card']
if theme_data.layout.get("style"):
theme.layout_style = theme_data.layout["style"]
if theme_data.layout.get("header"):
theme.header_style = theme_data.layout["header"]
if theme_data.layout.get("product_card"):
theme.product_card_style = theme_data.layout["product_card"]
# Update custom CSS
if theme_data.custom_css is not None:
@@ -289,10 +267,7 @@ class VendorThemeService:
# ============================================================================
def apply_theme_preset(
self,
db: Session,
vendor_code: str,
preset_name: str
self, db: Session, vendor_code: str, preset_name: str
) -> VendorTheme:
"""
Apply a theme preset to vendor.
@@ -322,9 +297,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get or create theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = (
db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
)
if not theme:
self.logger.info(f"Creating new theme for vendor {vendor_code}")
@@ -338,7 +313,9 @@ class VendorThemeService:
db.commit()
db.refresh(theme)
self.logger.info(f"Preset '{preset_name}' applied successfully to vendor {vendor_code}")
self.logger.info(
f"Preset '{preset_name}' applied successfully to vendor {vendor_code}"
)
return theme
except (VendorNotFoundException, ThemePresetNotFoundException):
@@ -349,9 +326,7 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to apply preset to vendor {vendor_code}: {e}")
raise ThemeOperationException(
operation="apply_preset",
vendor_code=vendor_code,
reason=str(e)
operation="apply_preset", vendor_code=vendor_code, reason=str(e)
)
def get_available_presets(self) -> List[ThemePresetPreview]:
@@ -399,9 +374,9 @@ class VendorThemeService:
vendor = self._get_vendor_by_code(db, vendor_code)
# Get theme
theme = db.query(VendorTheme).filter(
VendorTheme.vendor_id == vendor.id
).first()
theme = (
db.query(VendorTheme).filter(VendorTheme.vendor_id == vendor.id).first()
)
if not theme:
raise VendorThemeNotFoundException(vendor_code)
@@ -423,9 +398,7 @@ class VendorThemeService:
db.rollback()
self.logger.error(f"Failed to delete theme for vendor {vendor_code}: {e}")
raise ThemeOperationException(
operation="delete",
vendor_code=vendor_code,
reason=str(e)
operation="delete", vendor_code=vendor_code, reason=str(e)
)
# ============================================================================
@@ -459,9 +432,9 @@ class VendorThemeService:
# Validate layout values
if theme_data.layout:
valid_layouts = {
'style': ['grid', 'list', 'masonry'],
'header': ['fixed', 'static', 'transparent'],
'product_card': ['modern', 'classic', 'minimal']
"style": ["grid", "list", "masonry"],
"header": ["fixed", "static", "transparent"],
"product_card": ["modern", "classic", "minimal"],
}
for layout_key, layout_value in theme_data.layout.items():
@@ -472,7 +445,7 @@ class VendorThemeService:
field=layout_key,
validation_errors={
layout_key: f"Must be one of: {', '.join(valid_layouts[layout_key])}"
}
},
)
def _is_valid_color(self, color: str) -> bool:
@@ -489,7 +462,7 @@ class VendorThemeService:
return False
# Check for hex color format (#RGB or #RRGGBB)
hex_pattern = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$'
hex_pattern = r"^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
return bool(re.match(hex_pattern, color))
def _is_valid_font(self, font: str) -> bool: