diff --git a/alembic-env.py b/alembic-env.py index 408667b0..7d751343 100644 --- a/alembic-env.py +++ b/alembic-env.py @@ -1,14 +1,16 @@ -from logging.config import fileConfig -from sqlalchemy import engine_from_config, pool -from alembic import context import os import sys +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context # Add your project directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from models.database_models import Base from app.core.config import settings +from models.database_models import Base # Alembic Config object config = context.config @@ -45,9 +47,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/alembic/env.py b/alembic/env.py index 408667b0..7d751343 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,14 +1,16 @@ -from logging.config import fileConfig -from sqlalchemy import engine_from_config, pool -from alembic import context import os import sys +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context # Add your project directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from models.database_models import Base from app.core.config import settings +from models.database_models import Base # Alembic Config object config = context.config @@ -45,9 +47,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/app/api/deps.py b/app/api/deps.py index e8d2a714..3a6d0d8d 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,10 +1,11 @@ from fastapi import Depends, HTTPException -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session + from app.core.database import get_db -from models.database_models import User, Shop from middleware.auth import AuthManager from middleware.rate_limiter import RateLimiter +from models.database_models import Shop, User # Set auto_error=False to prevent automatic 403 responses security = HTTPBearer(auto_error=False) @@ -13,8 +14,8 @@ rate_limiter = RateLimiter() def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db) + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), ): """Get current authenticated user""" # Check if credentials are provided @@ -30,9 +31,9 @@ def get_current_admin_user(current_user: User = Depends(get_current_user)): def get_user_shop( - shop_code: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + shop_code: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """Get shop and verify user ownership""" shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() diff --git a/app/api/main.py b/app/api/main.py index 19b2855c..ae76b30e 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -1,5 +1,6 @@ from fastapi import APIRouter -from app.api.v1 import auth, product, stock, shop, marketplace, admin, stats + +from app.api.v1 import admin, auth, marketplace, product, shop, stats, stock api_router = APIRouter() @@ -9,6 +10,5 @@ api_router.include_router(auth.router, tags=["authentication"]) api_router.include_router(marketplace.router, tags=["marketplace"]) api_router.include_router(product.router, tags=["product"]) api_router.include_router(shop.router, tags=["shop"]) -api_router.include_router(stats.router, tags=["statistics"]) +api_router.include_router(stats.router, tags=["statistics"]) api_router.include_router(stock.router, tags=["stock"]) - diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index ec8b6b8b..4d7b51ea 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -1,13 +1,15 @@ +import logging from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_admin_user +from app.core.database import get_db from app.services.admin_service import admin_service -from models.api_models import MarketplaceImportJobResponse, UserResponse, ShopListResponse +from models.api_models import (MarketplaceImportJobResponse, ShopListResponse, + UserResponse) from models.database_models import User -import logging router = APIRouter() logger = logging.getLogger(__name__) @@ -16,10 +18,10 @@ logger = logging.getLogger(__name__) # Admin-only routes @router.get("/admin/users", response_model=List[UserResponse]) def get_all_users( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user), ): """Get all users (Admin only)""" try: @@ -32,9 +34,9 @@ def get_all_users( @router.put("/admin/users/{user_id}/status") def toggle_user_status( - user_id: int, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) + user_id: int, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user), ): """Toggle user active status (Admin only)""" try: @@ -49,21 +51,16 @@ def toggle_user_status( @router.get("/admin/shops", response_model=ShopListResponse) def get_all_shops_admin( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user), ): """Get all shops with admin view (Admin only)""" try: shops, total = admin_service.get_all_shops(db=db, skip=skip, limit=limit) - return ShopListResponse( - shops=shops, - total=total, - skip=skip, - limit=limit - ) + return ShopListResponse(shops=shops, total=total, skip=skip, limit=limit) except Exception as e: logger.error(f"Error getting shops: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @@ -71,9 +68,9 @@ def get_all_shops_admin( @router.put("/admin/shops/{shop_id}/verify") def verify_shop( - shop_id: int, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) + shop_id: int, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user), ): """Verify/unverify shop (Admin only)""" try: @@ -88,9 +85,9 @@ def verify_shop( @router.put("/admin/shops/{shop_id}/status") def toggle_shop_status( - shop_id: int, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) + shop_id: int, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user), ): """Toggle shop active status (Admin only)""" try: @@ -103,15 +100,17 @@ def toggle_shop_status( raise HTTPException(status_code=500, detail="Internal server error") -@router.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) +@router.get( + "/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse] +) def get_all_marketplace_import_jobs( - marketplace: Optional[str] = Query(None), - shop_name: Optional[str] = Query(None), - status: Optional[str] = Query(None), - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=100), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) + marketplace: Optional[str] = Query(None), + shop_name: Optional[str] = Query(None), + status: Optional[str] = Query(None), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user), ): """Get all marketplace import jobs (Admin only)""" try: @@ -121,7 +120,7 @@ def get_all_marketplace_import_jobs( shop_name=shop_name, status=status, skip=skip, - limit=limit + limit=limit, ) except Exception as e: logger.error(f"Error getting marketplace import jobs: {str(e)}") diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index 978871cd..a6d6a2fa 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -1,11 +1,14 @@ +import logging + from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_user +from app.core.database import get_db from app.services.auth_service import auth_service -from models.api_models import UserRegister, UserLogin, UserResponse, LoginResponse +from models.api_models import (LoginResponse, UserLogin, UserRegister, + UserResponse) from models.database_models import User -import logging router = APIRouter() logger = logging.getLogger(__name__) @@ -35,7 +38,7 @@ def login_user(user_credentials: UserLogin, db: Session = Depends(get_db)): access_token=login_result["token_data"]["access_token"], token_type=login_result["token_data"]["token_type"], expires_in=login_result["token_data"]["expires_in"], - user=UserResponse.model_validate(login_result["user"]) + user=UserResponse.model_validate(login_result["user"]), ) except HTTPException: raise diff --git a/app/api/v1/marketplace.py b/app/api/v1/marketplace.py index 793ea22b..511387b7 100644 --- a/app/api/v1/marketplace.py +++ b/app/api/v1/marketplace.py @@ -1,15 +1,17 @@ +import logging from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_user +from app.core.database import get_db +from app.services.marketplace_service import marketplace_service from app.tasks.background_tasks import process_marketplace_import from middleware.decorators import rate_limit -from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest +from models.api_models import (MarketplaceImportJobResponse, + MarketplaceImportRequest) from models.database_models import User -from app.services.marketplace_service import marketplace_service -import logging router = APIRouter() logger = logging.getLogger(__name__) @@ -19,15 +21,16 @@ logger = logging.getLogger(__name__) @router.post("/marketplace/import-product", response_model=MarketplaceImportJobResponse) @rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports async def import_products_from_marketplace( - request: MarketplaceImportRequest, - background_tasks: BackgroundTasks, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + request: MarketplaceImportRequest, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Import products from marketplace CSV with background processing (Protected)""" try: logger.info( - f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}") + f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}" + ) # Create import job through service import_job = marketplace_service.create_import_job(db, request, current_user) @@ -39,7 +42,7 @@ async def import_products_from_marketplace( request.url, request.marketplace, request.shop_code, - request.batch_size or 1000 + request.batch_size or 1000, ) return MarketplaceImportJobResponse( @@ -50,7 +53,7 @@ async def import_products_from_marketplace( shop_id=import_job.shop_id, shop_name=import_job.shop_name, message=f"Marketplace import started from {request.marketplace}. Check status with " - f"/import-status/{import_job.id}" + f"/import-status/{import_job.id}", ) except ValueError as e: @@ -62,11 +65,13 @@ async def import_products_from_marketplace( raise HTTPException(status_code=500, detail="Internal server error") -@router.get("/marketplace/import-status/{job_id}", response_model=MarketplaceImportJobResponse) +@router.get( + "/marketplace/import-status/{job_id}", response_model=MarketplaceImportJobResponse +) def get_marketplace_import_status( - job_id: int, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + job_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get status of marketplace import job (Protected)""" try: @@ -82,14 +87,16 @@ def get_marketplace_import_status( raise HTTPException(status_code=500, detail="Internal server error") -@router.get("/marketplace/import-jobs", response_model=List[MarketplaceImportJobResponse]) +@router.get( + "/marketplace/import-jobs", response_model=List[MarketplaceImportJobResponse] +) def get_marketplace_import_jobs( - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - shop_name: Optional[str] = Query(None, description="Filter by shop name"), - skip: int = Query(0, ge=0), - limit: int = Query(50, ge=1, le=100), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + shop_name: Optional[str] = Query(None, description="Filter by shop name"), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=100), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get marketplace import jobs with filtering (Protected)""" try: @@ -99,7 +106,7 @@ def get_marketplace_import_jobs( marketplace=marketplace, shop_name=shop_name, skip=skip, - limit=limit + limit=limit, ) return [marketplace_service.convert_to_response_model(job) for job in jobs] @@ -111,8 +118,7 @@ def get_marketplace_import_jobs( @router.get("/marketplace/marketplace-import-stats") def get_marketplace_import_stats( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """Get statistics about marketplace import jobs (Protected)""" try: @@ -124,11 +130,14 @@ def get_marketplace_import_stats( raise HTTPException(status_code=500, detail="Internal server error") -@router.put("/marketplace/import-jobs/{job_id}/cancel", response_model=MarketplaceImportJobResponse) +@router.put( + "/marketplace/import-jobs/{job_id}/cancel", + response_model=MarketplaceImportJobResponse, +) def cancel_marketplace_import_job( - job_id: int, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + job_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Cancel a pending or running marketplace import job (Protected)""" try: @@ -146,9 +155,9 @@ def cancel_marketplace_import_job( @router.delete("/marketplace/import-jobs/{job_id}") def delete_marketplace_import_job( - job_id: int, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + job_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Delete a completed marketplace import job (Protected)""" try: diff --git a/app/api/v1/product.py b/app/api/v1/product.py index c437493d..571ec3d5 100644 --- a/app/api/v1/product.py +++ b/app/api/v1/product.py @@ -1,17 +1,17 @@ +import logging from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_user -from models.api_models import (ProductListResponse, ProductResponse, ProductCreate, ProductDetailResponse, +from app.core.database import get_db +from app.services.product_service import product_service +from models.api_models import (ProductCreate, ProductDetailResponse, + ProductListResponse, ProductResponse, ProductUpdate) from models.database_models import User -import logging - -from app.services.product_service import product_service - router = APIRouter() logger = logging.getLogger(__name__) @@ -20,16 +20,16 @@ logger = logging.getLogger(__name__) # Enhanced Product Routes with Marketplace Support @router.get("/product", response_model=ProductListResponse) def get_products( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - brand: Optional[str] = Query(None), - category: Optional[str] = Query(None), - availability: Optional[str] = Query(None), - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - shop_name: Optional[str] = Query(None, description="Filter by shop name"), - search: Optional[str] = Query(None), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + brand: Optional[str] = Query(None), + category: Optional[str] = Query(None), + availability: Optional[str] = Query(None), + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + shop_name: Optional[str] = Query(None, description="Filter by shop name"), + search: Optional[str] = Query(None), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get products with advanced filtering including marketplace and shop (Protected)""" @@ -43,14 +43,11 @@ def get_products( availability=availability, marketplace=marketplace, shop_name=shop_name, - search=search + search=search, ) return ProductListResponse( - products=products, - total=total, - skip=skip, - limit=limit + products=products, total=total, skip=skip, limit=limit ) except Exception as e: logger.error(f"Error getting products: {str(e)}") @@ -59,9 +56,9 @@ def get_products( @router.post("/product", response_model=ProductResponse) def create_product( - product: ProductCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + product: ProductCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Create a new product with validation and marketplace support (Protected)""" @@ -75,7 +72,9 @@ def create_product( if existing: logger.info("Product already exists, raising 400 error") - raise HTTPException(status_code=400, detail="Product with this ID already exists") + raise HTTPException( + status_code=400, detail="Product with this ID already exists" + ) logger.info("No existing product found, proceeding to create...") db_product = product_service.create_product(db, product) @@ -93,11 +92,12 @@ def create_product( logger.error(f"Unexpected error: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail="Internal server error") + @router.get("/product/{product_id}", response_model=ProductDetailResponse) def get_product( - product_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + product_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get product with stock information (Protected)""" @@ -111,10 +111,7 @@ def get_product( if product.gtin: stock_info = product_service.get_stock_info(db, product.gtin) - return ProductDetailResponse( - product=product, - stock_info=stock_info - ) + return ProductDetailResponse(product=product, stock_info=stock_info) except HTTPException: raise @@ -125,10 +122,10 @@ def get_product( @router.put("/product/{product_id}", response_model=ProductResponse) def update_product( - product_id: str, - product_update: ProductUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + product_id: str, + product_update: ProductUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Update product with validation and marketplace support (Protected)""" @@ -151,9 +148,9 @@ def update_product( @router.delete("/product/{product_id}") def delete_product( - product_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + product_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Delete product and associated stock (Protected)""" @@ -176,19 +173,18 @@ def delete_product( # Export with streaming for large datasets (Protected) @router.get("/export-csv") async def export_csv( - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - shop_name: Optional[str] = Query(None, description="Filter by shop name"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + shop_name: Optional[str] = Query(None, description="Filter by shop name"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Export products as CSV with streaming and marketplace filtering (Protected)""" try: + def generate_csv(): return product_service.generate_csv_export( - db=db, - marketplace=marketplace, - shop_name=shop_name + db=db, marketplace=marketplace, shop_name=shop_name ) filename = "products_export" @@ -201,7 +197,7 @@ async def export_csv( return StreamingResponse( generate_csv(), media_type="text/csv", - headers={"Content-Disposition": f"attachment; filename={filename}"} + headers={"Content-Disposition": f"attachment; filename={filename}"}, ) except Exception as e: diff --git a/app/api/v1/shop.py b/app/api/v1/shop.py index 1cdcabbb..c9113cd1 100644 --- a/app/api/v1/shop.py +++ b/app/api/v1/shop.py @@ -1,17 +1,21 @@ +import logging +from datetime import datetime from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_user, get_user_shop +from app.core.database import get_db from app.services.shop_service import shop_service from app.tasks.background_tasks import process_marketplace_import from middleware.decorators import rate_limit -from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, ShopResponse, ShopCreate, \ - ShopListResponse, ShopProductResponse, ShopProductCreate -from models.database_models import User, MarketplaceImportJob, Shop, Product, ShopProduct -from datetime import datetime -import logging +from models.api_models import (MarketplaceImportJobResponse, + MarketplaceImportRequest, ShopCreate, + ShopListResponse, ShopProductCreate, + ShopProductResponse, ShopResponse) +from models.database_models import (MarketplaceImportJob, Product, Shop, + ShopProduct, User) router = APIRouter() logger = logging.getLogger(__name__) @@ -20,13 +24,15 @@ logger = logging.getLogger(__name__) # Shop Management Routes @router.post("/shop", response_model=ShopResponse) def create_shop( - shop_data: ShopCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + shop_data: ShopCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Create a new shop (Protected)""" try: - shop = shop_service.create_shop(db=db, shop_data=shop_data, current_user=current_user) + shop = shop_service.create_shop( + db=db, shop_data=shop_data, current_user=current_user + ) return ShopResponse.model_validate(shop) except HTTPException: raise @@ -37,12 +43,12 @@ def create_shop( @router.get("/shop", response_model=ShopListResponse) def get_shops( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - active_only: bool = Query(True), - verified_only: bool = Query(False), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + active_only: bool = Query(True), + verified_only: bool = Query(False), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get shops with filtering (Protected)""" try: @@ -52,15 +58,10 @@ def get_shops( skip=skip, limit=limit, active_only=active_only, - verified_only=verified_only + verified_only=verified_only, ) - return ShopListResponse( - shops=shops, - total=total, - skip=skip, - limit=limit - ) + return ShopListResponse(shops=shops, total=total, skip=skip, limit=limit) except HTTPException: raise except Exception as e: @@ -69,10 +70,16 @@ def get_shops( @router.get("/shop/{shop_code}", response_model=ShopResponse) -def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): +def get_shop( + shop_code: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): """Get shop details (Protected)""" try: - shop = shop_service.get_shop_by_code(db=db, shop_code=shop_code, current_user=current_user) + shop = shop_service.get_shop_by_code( + db=db, shop_code=shop_code, current_user=current_user + ) return ShopResponse.model_validate(shop) except HTTPException: raise @@ -84,10 +91,10 @@ def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User = # Shop Product Management @router.post("/shop/{shop_code}/products", response_model=ShopProductResponse) def add_product_to_shop( - shop_code: str, - shop_product: ShopProductCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + shop_code: str, + shop_product: ShopProductCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Add existing product to shop catalog with shop-specific settings (Protected)""" try: @@ -96,9 +103,7 @@ def add_product_to_shop( # Add product to shop new_shop_product = shop_service.add_product_to_shop( - db=db, - shop=shop, - shop_product=shop_product + db=db, shop=shop, shop_product=shop_product ) # Return with product details @@ -114,18 +119,20 @@ def add_product_to_shop( @router.get("/shop/{shop_code}/products") def get_shop_products( - shop_code: str, - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - active_only: bool = Query(True), - featured_only: bool = Query(False), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + shop_code: str, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + active_only: bool = Query(True), + featured_only: bool = Query(False), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get products in shop catalog (Protected)""" try: # Get shop - shop = shop_service.get_shop_by_code(db=db, shop_code=shop_code, current_user=current_user) + shop = shop_service.get_shop_by_code( + db=db, shop_code=shop_code, current_user=current_user + ) # Get shop products shop_products, total = shop_service.get_shop_products( @@ -135,7 +142,7 @@ def get_shop_products( skip=skip, limit=limit, active_only=active_only, - featured_only=featured_only + featured_only=featured_only, ) # Format response @@ -150,7 +157,7 @@ def get_shop_products( "total": total, "skip": skip, "limit": limit, - "shop": ShopResponse.model_validate(shop) + "shop": ShopResponse.model_validate(shop), } except HTTPException: raise diff --git a/app/api/v1/stats.py b/app/api/v1/stats.py index a2b5ed68..7d1b74f4 100644 --- a/app/api/v1/stats.py +++ b/app/api/v1/stats.py @@ -1,18 +1,21 @@ +import logging +from datetime import datetime from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from sqlalchemy import func from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_user +from app.core.database import get_db from app.services.stats_service import stats_service from app.tasks.background_tasks import process_marketplace_import from middleware.decorators import rate_limit -from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, StatsResponse, \ - MarketplaceStatsResponse -from models.database_models import User, MarketplaceImportJob, Shop, Product, Stock -from datetime import datetime -import logging +from models.api_models import (MarketplaceImportJobResponse, + MarketplaceImportRequest, + MarketplaceStatsResponse, StatsResponse) +from models.database_models import (MarketplaceImportJob, Product, Shop, Stock, + User) router = APIRouter() logger = logging.getLogger(__name__) @@ -20,7 +23,9 @@ logger = logging.getLogger(__name__) # Enhanced Statistics with Marketplace Support @router.get("/stats", response_model=StatsResponse) -def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): +def get_stats( + db: Session = Depends(get_db), current_user: User = Depends(get_current_user) +): """Get comprehensive statistics with marketplace data (Protected)""" try: stats_data = stats_service.get_comprehensive_stats(db=db) @@ -32,7 +37,7 @@ def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_cu unique_marketplaces=stats_data["unique_marketplaces"], unique_shops=stats_data["unique_shops"], total_stock_entries=stats_data["total_stock_entries"], - total_inventory_quantity=stats_data["total_inventory_quantity"] + total_inventory_quantity=stats_data["total_inventory_quantity"], ) except Exception as e: logger.error(f"Error getting comprehensive stats: {str(e)}") @@ -40,7 +45,9 @@ def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_cu @router.get("/stats/marketplace", response_model=List[MarketplaceStatsResponse]) -def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): +def get_marketplace_stats( + db: Session = Depends(get_db), current_user: User = Depends(get_current_user) +): """Get statistics broken down by marketplace (Protected)""" try: marketplace_stats = stats_service.get_marketplace_breakdown_stats(db=db) @@ -50,8 +57,9 @@ def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = De marketplace=stat["marketplace"], total_products=stat["total_products"], unique_shops=stat["unique_shops"], - unique_brands=stat["unique_brands"] - ) for stat in marketplace_stats + unique_brands=stat["unique_brands"], + ) + for stat in marketplace_stats ] except Exception as e: logger.error(f"Error getting marketplace stats: {str(e)}") diff --git a/app/api/v1/stock.py b/app/api/v1/stock.py index 9b6e3b28..f43aa046 100644 --- a/app/api/v1/stock.py +++ b/app/api/v1/stock.py @@ -1,16 +1,19 @@ +import logging from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from sqlalchemy.orm import Session -from app.core.database import get_db + from app.api.deps import get_current_user +from app.core.database import get_db +from app.services.stock_service import stock_service from app.tasks.background_tasks import process_marketplace_import from middleware.decorators import rate_limit -from models.api_models import (MarketplaceImportJobResponse, MarketplaceImportRequest, StockResponse, - StockSummaryResponse, StockCreate, StockAdd, StockUpdate) -from models.database_models import User, MarketplaceImportJob, Shop -from app.services.stock_service import stock_service -import logging +from models.api_models import (MarketplaceImportJobResponse, + MarketplaceImportRequest, StockAdd, StockCreate, + StockResponse, StockSummaryResponse, + StockUpdate) +from models.database_models import MarketplaceImportJob, Shop, User router = APIRouter() logger = logging.getLogger(__name__) @@ -18,11 +21,12 @@ logger = logging.getLogger(__name__) # Stock Management Routes (Protected) + @router.post("/stock", response_model=StockResponse) def set_stock( - stock: StockCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + stock: StockCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)""" try: @@ -37,9 +41,9 @@ def set_stock( @router.post("/stock/add", response_model=StockResponse) def add_stock( - stock: StockAdd, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + stock: StockAdd, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)""" try: @@ -54,9 +58,9 @@ def add_stock( @router.post("/stock/remove", response_model=StockResponse) def remove_stock( - stock: StockAdd, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + stock: StockAdd, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Remove quantity from existing stock for a GTIN at a specific location""" try: @@ -71,9 +75,9 @@ def remove_stock( @router.get("/stock/{gtin}", response_model=StockSummaryResponse) def get_stock_by_gtin( - gtin: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + gtin: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get all stock locations and total quantity for a specific GTIN""" try: @@ -88,9 +92,9 @@ def get_stock_by_gtin( @router.get("/stock/{gtin}/total") def get_total_stock( - gtin: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + gtin: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get total quantity in stock for a specific GTIN""" try: @@ -105,21 +109,17 @@ def get_total_stock( @router.get("/stock", response_model=List[StockResponse]) def get_all_stock( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - location: Optional[str] = Query(None, description="Filter by location"), - gtin: Optional[str] = Query(None, description="Filter by GTIN"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + location: Optional[str] = Query(None, description="Filter by location"), + gtin: Optional[str] = Query(None, description="Filter by GTIN"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Get all stock entries with optional filtering""" try: result = stock_service.get_all_stock( - db=db, - skip=skip, - limit=limit, - location=location, - gtin=gtin + db=db, skip=skip, limit=limit, location=location, gtin=gtin ) return result except Exception as e: @@ -129,10 +129,10 @@ def get_all_stock( @router.put("/stock/{stock_id}", response_model=StockResponse) def update_stock( - stock_id: int, - stock_update: StockUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + stock_id: int, + stock_update: StockUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Update stock quantity for a specific stock entry""" try: @@ -147,9 +147,9 @@ def update_stock( @router.delete("/stock/{stock_id}") def delete_stock( - stock_id: int, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + stock_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), ): """Delete a stock entry""" try: diff --git a/app/core/config.py b/app/core/config.py index 8c6496d5..50d3b7eb 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,6 +1,8 @@ # app/core/config.py -from pydantic_settings import BaseSettings # This is the correct import for Pydantic v2 -from typing import Optional, List +from typing import List, Optional + +from pydantic_settings import \ + BaseSettings # This is the correct import for Pydantic v2 class Settings(BaseSettings): diff --git a/app/core/database.py b/app/core/database.py index 6231a341..335f7aad 100644 --- a/app/core/database.py +++ b/app/core/database.py @@ -1,6 +1,6 @@ from sqlalchemy import create_engine -from sqlalchemy.orm import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import declarative_base, sessionmaker + from .config import settings engine = create_engine(settings.database_url) diff --git a/app/core/lifespan.py b/app/core/lifespan.py index cb7e5096..a3f4e8e1 100644 --- a/app/core/lifespan.py +++ b/app/core/lifespan.py @@ -1,11 +1,14 @@ +import logging from contextlib import asynccontextmanager + from fastapi import FastAPI from sqlalchemy import text -from .logging import setup_logging -from .database import engine, SessionLocal -from models.database_models import Base -import logging + from middleware.auth import AuthManager +from models.database_models import Base + +from .database import SessionLocal, engine +from .logging import setup_logging logger = logging.getLogger(__name__) auth_manager = AuthManager() @@ -44,15 +47,29 @@ def create_indexes(): with engine.connect() as conn: try: # User indexes - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)")) + conn.execute( + text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)") + ) + conn.execute( + text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)") + ) # Product indexes - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)")) + conn.execute( + text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)") + ) + conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)" + ) + ) # Stock indexes - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)")) + conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)" + ) + ) conn.commit() logger.info("Database indexes created successfully") diff --git a/app/core/logging.py b/app/core/logging.py index cef00cad..94ef4c69 100644 --- a/app/core/logging.py +++ b/app/core/logging.py @@ -2,6 +2,7 @@ import logging import sys from pathlib import Path + from app.core.config import settings @@ -22,7 +23,7 @@ def setup_logging(): # Create formatters formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) # Console handler diff --git a/app/services/admin_service.py b/app/services/admin_service.py index 3aa0f636..c96eddfa 100644 --- a/app/services/admin_service.py +++ b/app/services/admin_service.py @@ -1,11 +1,12 @@ -from sqlalchemy.orm import Session -from fastapi import HTTPException -from datetime import datetime import logging +from datetime import datetime from typing import List, Optional, Tuple -from models.database_models import User, MarketplaceImportJob, Shop +from fastapi import HTTPException +from sqlalchemy.orm import Session + from models.api_models import MarketplaceImportJobResponse +from models.database_models import MarketplaceImportJob, Shop, User logger = logging.getLogger(__name__) @@ -17,7 +18,9 @@ class AdminService: """Get paginated list of all users""" return db.query(User).offset(skip).limit(limit).all() - def toggle_user_status(self, db: Session, user_id: int, current_admin_id: int) -> Tuple[User, str]: + def toggle_user_status( + self, db: Session, user_id: int, current_admin_id: int + ) -> Tuple[User, str]: """ Toggle user active status @@ -37,7 +40,9 @@ class AdminService: raise HTTPException(status_code=404, detail="User not found") if user.id == current_admin_id: - raise HTTPException(status_code=400, detail="Cannot deactivate your own account") + raise HTTPException( + status_code=400, detail="Cannot deactivate your own account" + ) user.is_active = not user.is_active user.updated_at = datetime.utcnow() @@ -45,10 +50,14 @@ class AdminService: db.refresh(user) status = "activated" if user.is_active else "deactivated" - logger.info(f"User {user.username} has been {status} by admin {current_admin_id}") + logger.info( + f"User {user.username} has been {status} by admin {current_admin_id}" + ) return user, f"User {user.username} has been {status}" - def get_all_shops(self, db: Session, skip: int = 0, limit: int = 100) -> Tuple[List[Shop], int]: + def get_all_shops( + self, db: Session, skip: int = 0, limit: int = 100 + ) -> Tuple[List[Shop], int]: """ Get paginated list of all shops with total count @@ -119,13 +128,13 @@ class AdminService: return shop, f"Shop {shop.shop_code} has been {status}" def get_marketplace_import_jobs( - self, - db: Session, - marketplace: Optional[str] = None, - shop_name: Optional[str] = None, - status: Optional[str] = None, - skip: int = 0, - limit: int = 100 + self, + db: Session, + marketplace: Optional[str] = None, + shop_name: Optional[str] = None, + status: Optional[str] = None, + skip: int = 0, + limit: int = 100, ) -> List[MarketplaceImportJobResponse]: """ Get filtered and paginated marketplace import jobs @@ -145,14 +154,21 @@ class AdminService: # Apply filters if marketplace: - query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) + query = query.filter( + MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") + ) if shop_name: query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) if status: query = query.filter(MarketplaceImportJob.status == status) # Order by creation date and apply pagination - jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() + jobs = ( + query.order_by(MarketplaceImportJob.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) return [ MarketplaceImportJobResponse( @@ -168,8 +184,9 @@ class AdminService: error_message=job.error_message, created_at=job.created_at, started_at=job.started_at, - completed_at=job.completed_at - ) for job in jobs + completed_at=job.completed_at, + ) + for job in jobs ] def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]: diff --git a/app/services/auth_service.py b/app/services/auth_service.py index 589105ee..e28a209a 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -1,11 +1,12 @@ -from sqlalchemy.orm import Session -from fastapi import HTTPException import logging -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional + +from fastapi import HTTPException +from sqlalchemy.orm import Session -from models.database_models import User -from models.api_models import UserRegister, UserLogin from middleware.auth import AuthManager +from models.api_models import UserLogin, UserRegister +from models.database_models import User logger = logging.getLogger(__name__) @@ -36,7 +37,9 @@ class AuthService: raise HTTPException(status_code=400, detail="Email already registered") # Check if username already exists - existing_username = db.query(User).filter(User.username == user_data.username).first() + existing_username = ( + db.query(User).filter(User.username == user_data.username).first() + ) if existing_username: raise HTTPException(status_code=400, detail="Username already taken") @@ -47,7 +50,7 @@ class AuthService: username=user_data.username, hashed_password=hashed_password, role="user", - is_active=True + is_active=True, ) db.add(new_user) @@ -71,19 +74,20 @@ class AuthService: Raises: HTTPException: If authentication fails """ - user = self.auth_manager.authenticate_user(db, user_credentials.username, user_credentials.password) + user = self.auth_manager.authenticate_user( + db, user_credentials.username, user_credentials.password + ) if not user: - raise HTTPException(status_code=401, detail="Incorrect username or password") + raise HTTPException( + status_code=401, detail="Incorrect username or password" + ) # Create access token token_data = self.auth_manager.create_access_token(user) logger.info(f"User logged in: {user.username}") - return { - "token_data": token_data, - "user": user - } + return {"token_data": token_data, "user": user} def get_user_by_email(self, db: Session, email: str) -> Optional[User]: """Get user by email""" @@ -101,7 +105,9 @@ class AuthService: """Check if username already exists""" return db.query(User).filter(User.username == username).first() is not None - def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: + def authenticate_user( + self, db: Session, username: str, password: str + ) -> Optional[User]: """Authenticate user with username/password""" return self.auth_manager.authenticate_user(db, username, password) diff --git a/app/services/marketplace_service.py b/app/services/marketplace_service.py index 15f847ef..6f511df4 100644 --- a/app/services/marketplace_service.py +++ b/app/services/marketplace_service.py @@ -1,10 +1,13 @@ +import logging +from datetime import datetime +from typing import List, Optional + from sqlalchemy import func from sqlalchemy.orm import Session + +from models.api_models import (MarketplaceImportJobResponse, + MarketplaceImportRequest) from models.database_models import MarketplaceImportJob, Shop, User -from models.api_models import MarketplaceImportRequest, MarketplaceImportJobResponse -from typing import Optional, List -from datetime import datetime -import logging logger = logging.getLogger(__name__) @@ -17,9 +20,11 @@ class MarketplaceService: """Validate that the shop exists and user has access to it""" # Explicit type hint to help type checker shop: Optional[Shop] # Use case-insensitive query to handle both uppercase and lowercase codes - shop: Optional[Shop] = db.query(Shop).filter( - func.upper(Shop.shop_code) == shop_code.upper() - ).first() + shop: Optional[Shop] = ( + db.query(Shop) + .filter(func.upper(Shop.shop_code) == shop_code.upper()) + .first() + ) if not shop: raise ValueError("Shop not found") @@ -30,10 +35,7 @@ class MarketplaceService: return shop def create_import_job( - self, - db: Session, - request: MarketplaceImportRequest, - user: User + self, db: Session, request: MarketplaceImportRequest, user: User ) -> MarketplaceImportJob: """Create a new marketplace import job""" # Validate shop access first @@ -47,7 +49,7 @@ class MarketplaceService: shop_id=shop.id, # Foreign key to shops table shop_name=shop.shop_name, # Use shop.shop_name (the display name) user_id=user.id, - created_at=datetime.utcnow() + created_at=datetime.utcnow(), ) db.add(import_job) @@ -55,13 +57,20 @@ class MarketplaceService: db.refresh(import_job) logger.info( - f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}") + f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}" + ) return import_job - def get_import_job_by_id(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob: + def get_import_job_by_id( + self, db: Session, job_id: int, user: User + ) -> MarketplaceImportJob: """Get a marketplace import job by ID with access control""" - job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() + job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) if not job: raise ValueError("Marketplace import job not found") @@ -72,13 +81,13 @@ class MarketplaceService: return job def get_import_jobs( - self, - db: Session, - user: User, - marketplace: Optional[str] = None, - shop_name: Optional[str] = None, - skip: int = 0, - limit: int = 50 + self, + db: Session, + user: User, + marketplace: Optional[str] = None, + shop_name: Optional[str] = None, + skip: int = 0, + limit: int = 50, ) -> List[MarketplaceImportJob]: """Get marketplace import jobs with filtering and access control""" query = db.query(MarketplaceImportJob) @@ -89,44 +98,51 @@ class MarketplaceService: # Apply filters if marketplace: - query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) + query = query.filter( + MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") + ) if shop_name: query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) # Order by creation date (newest first) and apply pagination - jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() + jobs = ( + query.order_by(MarketplaceImportJob.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) return jobs def update_job_status( - self, - db: Session, - job_id: int, - status: str, - **kwargs + self, db: Session, job_id: int, status: str, **kwargs ) -> MarketplaceImportJob: """Update marketplace import job status and other fields""" - job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() + job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) if not job: raise ValueError("Marketplace import job not found") job.status = status # Update optional fields if provided - if 'imported_count' in kwargs: - job.imported_count = kwargs['imported_count'] - if 'updated_count' in kwargs: - job.updated_count = kwargs['updated_count'] - if 'total_processed' in kwargs: - job.total_processed = kwargs['total_processed'] - if 'error_count' in kwargs: - job.error_count = kwargs['error_count'] - if 'error_message' in kwargs: - job.error_message = kwargs['error_message'] - if 'started_at' in kwargs: - job.started_at = kwargs['started_at'] - if 'completed_at' in kwargs: - job.completed_at = kwargs['completed_at'] + if "imported_count" in kwargs: + job.imported_count = kwargs["imported_count"] + if "updated_count" in kwargs: + job.updated_count = kwargs["updated_count"] + if "total_processed" in kwargs: + job.total_processed = kwargs["total_processed"] + if "error_count" in kwargs: + job.error_count = kwargs["error_count"] + if "error_message" in kwargs: + job.error_message = kwargs["error_message"] + if "started_at" in kwargs: + job.started_at = kwargs["started_at"] + if "completed_at" in kwargs: + job.completed_at = kwargs["completed_at"] db.commit() db.refresh(job) @@ -145,7 +161,9 @@ class MarketplaceService: total_jobs = query.count() pending_jobs = query.filter(MarketplaceImportJob.status == "pending").count() running_jobs = query.filter(MarketplaceImportJob.status == "running").count() - completed_jobs = query.filter(MarketplaceImportJob.status == "completed").count() + completed_jobs = query.filter( + MarketplaceImportJob.status == "completed" + ).count() failed_jobs = query.filter(MarketplaceImportJob.status == "failed").count() return { @@ -153,17 +171,21 @@ class MarketplaceService: "pending_jobs": pending_jobs, "running_jobs": running_jobs, "completed_jobs": completed_jobs, - "failed_jobs": failed_jobs + "failed_jobs": failed_jobs, } - def convert_to_response_model(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse: + def convert_to_response_model( + self, job: MarketplaceImportJob + ) -> MarketplaceImportJobResponse: """Convert database model to API response model""" return MarketplaceImportJobResponse( job_id=job.id, status=job.status, marketplace=job.marketplace, shop_id=job.shop_id, - shop_code=job.shop.shop_code if job.shop else None, # Add this optional field via relationship + shop_code=( + job.shop.shop_code if job.shop else None + ), # Add this optional field via relationship shop_name=job.shop_name, imported=job.imported_count or 0, updated=job.updated_count or 0, @@ -172,10 +194,12 @@ class MarketplaceService: error_message=job.error_message, created_at=job.created_at, started_at=job.started_at, - completed_at=job.completed_at + completed_at=job.completed_at, ) - def cancel_import_job(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob: + def cancel_import_job( + self, db: Session, job_id: int, user: User + ) -> MarketplaceImportJob: """Cancel a pending or running import job""" job = self.get_import_job_by_id(db, job_id, user) @@ -197,7 +221,9 @@ class MarketplaceService: # Only allow deletion of completed, failed, or cancelled jobs if job.status in ["pending", "running"]: - raise ValueError(f"Cannot delete job with status: {job.status}. Cancel it first.") + raise ValueError( + f"Cannot delete job with status: {job.status}. Cancel it first." + ) db.delete(job) db.commit() diff --git a/app/services/product_service.py b/app/services/product_service.py index bffd501b..827f6bbd 100644 --- a/app/services/product_service.py +++ b/app/services/product_service.py @@ -1,11 +1,14 @@ -from sqlalchemy.orm import Session -from sqlalchemy.exc import IntegrityError -from models.database_models import Product, Stock -from models.api_models import ProductCreate, ProductUpdate, StockLocationResponse, StockSummaryResponse -from utils.data_processing import GTINProcessor, PriceProcessor -from typing import Optional, List, Generator -from datetime import datetime import logging +from datetime import datetime +from typing import Generator, List, Optional + +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from models.api_models import (ProductCreate, ProductUpdate, + StockLocationResponse, StockSummaryResponse) +from models.database_models import Product, Stock +from utils.data_processing import GTINProcessor, PriceProcessor logger = logging.getLogger(__name__) @@ -27,7 +30,9 @@ class ProductService: # Process price if provided if product_data.price: - parsed_price, currency = self.price_processor.parse_price_currency(product_data.price) + parsed_price, currency = self.price_processor.parse_price_currency( + product_data.price + ) if parsed_price: product_data.price = parsed_price product_data.currency = currency @@ -58,16 +63,16 @@ class ProductService: return db.query(Product).filter(Product.product_id == product_id).first() def get_products_with_filters( - self, - db: Session, - skip: int = 0, - limit: int = 100, - brand: Optional[str] = None, - category: Optional[str] = None, - availability: Optional[str] = None, - marketplace: Optional[str] = None, - shop_name: Optional[str] = None, - search: Optional[str] = None + self, + db: Session, + skip: int = 0, + limit: int = 100, + brand: Optional[str] = None, + category: Optional[str] = None, + availability: Optional[str] = None, + marketplace: Optional[str] = None, + shop_name: Optional[str] = None, + search: Optional[str] = None, ) -> tuple[List[Product], int]: """Get products with filtering and pagination""" query = db.query(Product) @@ -87,10 +92,10 @@ class ProductService: # Search in title, description, marketplace, and shop_name search_term = f"%{search}%" query = query.filter( - (Product.title.ilike(search_term)) | - (Product.description.ilike(search_term)) | - (Product.marketplace.ilike(search_term)) | - (Product.shop_name.ilike(search_term)) + (Product.title.ilike(search_term)) + | (Product.description.ilike(search_term)) + | (Product.marketplace.ilike(search_term)) + | (Product.shop_name.ilike(search_term)) ) total = query.count() @@ -98,7 +103,9 @@ class ProductService: return products, total - def update_product(self, db: Session, product_id: str, product_update: ProductUpdate) -> Product: + def update_product( + self, db: Session, product_id: str, product_update: ProductUpdate + ) -> Product: """Update product with validation""" product = db.query(Product).filter(Product.product_id == product_id).first() if not product: @@ -116,7 +123,9 @@ class ProductService: # Process price if being updated if "price" in update_data and update_data["price"]: - parsed_price, currency = self.price_processor.parse_price_currency(update_data["price"]) + parsed_price, currency = self.price_processor.parse_price_currency( + update_data["price"] + ) if parsed_price: update_data["price"] = parsed_price update_data["currency"] = currency @@ -160,21 +169,21 @@ class ProductService: ] return StockSummaryResponse( - gtin=gtin, - total_quantity=total_quantity, - locations=locations + gtin=gtin, total_quantity=total_quantity, locations=locations ) def generate_csv_export( - self, - db: Session, - marketplace: Optional[str] = None, - shop_name: Optional[str] = None + self, + db: Session, + marketplace: Optional[str] = None, + shop_name: Optional[str] = None, ) -> Generator[str, None, None]: """Generate CSV export with streaming for memory efficiency""" # CSV header - yield ("product_id,title,description,link,image_link,availability,price,currency,brand," - "gtin,marketplace,shop_name\n") + yield ( + "product_id,title,description,link,image_link,availability,price,currency,brand," + "gtin,marketplace,shop_name\n" + ) batch_size = 1000 offset = 0 @@ -194,17 +203,22 @@ class ProductService: for product in products: # Create CSV row with marketplace fields - row = (f'"{product.product_id}","{product.title or ""}","{product.description or ""}",' - f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",' - f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",' - f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n') + row = ( + f'"{product.product_id}","{product.title or ""}","{product.description or ""}",' + f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",' + f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",' + f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n' + ) yield row offset += batch_size def product_exists(self, db: Session, product_id: str) -> bool: """Check if product exists by ID""" - return db.query(Product).filter(Product.product_id == product_id).first() is not None + return ( + db.query(Product).filter(Product.product_id == product_id).first() + is not None + ) # Create service instance diff --git a/app/services/shop_service.py b/app/services/shop_service.py index 416b0440..5fbbf26d 100644 --- a/app/services/shop_service.py +++ b/app/services/shop_service.py @@ -1,12 +1,13 @@ +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from fastapi import HTTPException from sqlalchemy import func from sqlalchemy.orm import Session -from fastapi import HTTPException -from datetime import datetime -import logging -from typing import List, Optional, Tuple, Dict, Any -from models.database_models import User, Shop, Product, ShopProduct from models.api_models import ShopCreate, ShopProductCreate +from models.database_models import Product, Shop, ShopProduct, User logger = logging.getLogger(__name__) @@ -14,7 +15,9 @@ logger = logging.getLogger(__name__) class ShopService: """Service class for shop operations following the application's service pattern""" - def create_shop(self, db: Session, shop_data: ShopCreate, current_user: User) -> Shop: + def create_shop( + self, db: Session, shop_data: ShopCreate, current_user: User + ) -> Shop: """ Create a new shop @@ -33,39 +36,43 @@ class ShopService: normalized_shop_code = shop_data.shop_code.upper() # Check if shop code already exists (case-insensitive check against existing data) - existing_shop = db.query(Shop).filter( - func.upper(Shop.shop_code) == normalized_shop_code - ).first() + existing_shop = ( + db.query(Shop) + .filter(func.upper(Shop.shop_code) == normalized_shop_code) + .first() + ) if existing_shop: raise HTTPException(status_code=400, detail="Shop code already exists") # Create shop with uppercase code shop_dict = shop_data.model_dump() # Fixed deprecated .dict() method - shop_dict['shop_code'] = normalized_shop_code # Store as uppercase + shop_dict["shop_code"] = normalized_shop_code # Store as uppercase new_shop = Shop( **shop_dict, owner_id=current_user.id, is_active=True, - is_verified=(current_user.role == "admin") + is_verified=(current_user.role == "admin"), ) db.add(new_shop) db.commit() db.refresh(new_shop) - logger.info(f"New shop created: {new_shop.shop_code} by {current_user.username}") + logger.info( + f"New shop created: {new_shop.shop_code} by {current_user.username}" + ) return new_shop def get_shops( - self, - db: Session, - current_user: User, - skip: int = 0, - limit: int = 100, - active_only: bool = True, - verified_only: bool = False + self, + db: Session, + current_user: User, + skip: int = 0, + limit: int = 100, + active_only: bool = True, + verified_only: bool = False, ) -> Tuple[List[Shop], int]: """ Get shops with filtering @@ -86,8 +93,8 @@ class ShopService: # Non-admin users can only see active and verified shops, plus their own if current_user.role != "admin": query = query.filter( - (Shop.is_active == True) & - ((Shop.is_verified == True) | (Shop.owner_id == current_user.id)) + (Shop.is_active == True) + & ((Shop.is_verified == True) | (Shop.owner_id == current_user.id)) ) else: # Admin can apply filters @@ -117,22 +124,25 @@ class ShopService: HTTPException: If shop not found or access denied """ # Explicit type hint to help type checker shop: Optional[Shop] - shop: Optional[Shop] = db.query(Shop).filter(func.upper(Shop.shop_code) == shop_code.upper()).first() + shop: Optional[Shop] = ( + db.query(Shop) + .filter(func.upper(Shop.shop_code) == shop_code.upper()) + .first() + ) if not shop: raise HTTPException(status_code=404, detail="Shop not found") # Non-admin users can only see active verified shops or their own shops if current_user.role != "admin": - if not shop.is_active or (not shop.is_verified and shop.owner_id != current_user.id): + if not shop.is_active or ( + not shop.is_verified and shop.owner_id != current_user.id + ): raise HTTPException(status_code=404, detail="Shop not found") return shop def add_product_to_shop( - self, - db: Session, - shop: Shop, - shop_product: ShopProductCreate + self, db: Session, shop: Shop, shop_product: ShopProductCreate ) -> ShopProduct: """ Add existing product to shop catalog with shop-specific settings @@ -149,24 +159,35 @@ class ShopService: HTTPException: If product not found or already in shop """ # Check if product exists - product = db.query(Product).filter(Product.product_id == shop_product.product_id).first() + product = ( + db.query(Product) + .filter(Product.product_id == shop_product.product_id) + .first() + ) if not product: - raise HTTPException(status_code=404, detail="Product not found in marketplace catalog") + raise HTTPException( + status_code=404, detail="Product not found in marketplace catalog" + ) # Check if product already in shop - existing_shop_product = db.query(ShopProduct).filter( - ShopProduct.shop_id == shop.id, - ShopProduct.product_id == product.id - ).first() + existing_shop_product = ( + db.query(ShopProduct) + .filter( + ShopProduct.shop_id == shop.id, ShopProduct.product_id == product.id + ) + .first() + ) if existing_shop_product: - raise HTTPException(status_code=400, detail="Product already in shop catalog") + raise HTTPException( + status_code=400, detail="Product already in shop catalog" + ) # Create shop-product association new_shop_product = ShopProduct( shop_id=shop.id, product_id=product.id, - **shop_product.model_dump(exclude={'product_id'}) + **shop_product.model_dump(exclude={"product_id"}), ) db.add(new_shop_product) @@ -180,14 +201,14 @@ class ShopService: return new_shop_product def get_shop_products( - self, - db: Session, - shop: Shop, - current_user: User, - skip: int = 0, - limit: int = 100, - active_only: bool = True, - featured_only: bool = False + self, + db: Session, + shop: Shop, + current_user: User, + skip: int = 0, + limit: int = 100, + active_only: bool = True, + featured_only: bool = False, ) -> Tuple[List[ShopProduct], int]: """ Get products in shop catalog with filtering @@ -239,10 +260,14 @@ class ShopService: def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool: """Check if product is already in shop""" - return db.query(ShopProduct).filter( - ShopProduct.shop_id == shop_id, - ShopProduct.product_id == product_id - ).first() is not None + return ( + db.query(ShopProduct) + .filter( + ShopProduct.shop_id == shop_id, ShopProduct.product_id == product_id + ) + .first() + is not None + ) def is_shop_owner(self, shop: Shop, user: User) -> bool: """Check if user is shop owner""" diff --git a/app/services/stats_service.py b/app/services/stats_service.py index c1d6b490..2ad3cb1c 100644 --- a/app/services/stats_service.py +++ b/app/services/stats_service.py @@ -1,10 +1,11 @@ +import logging +from typing import Any, Dict, List + from sqlalchemy import func from sqlalchemy.orm import Session -import logging -from typing import List, Dict, Any -from models.database_models import User, Product, Stock -from models.api_models import StatsResponse, MarketplaceStatsResponse +from models.api_models import MarketplaceStatsResponse, StatsResponse +from models.database_models import Product, Stock, User logger = logging.getLogger(__name__) @@ -25,26 +26,37 @@ class StatsService: # Use more efficient queries with proper indexes total_products = db.query(Product).count() - unique_brands = db.query(Product.brand).filter( - Product.brand.isnot(None), - Product.brand != "" - ).distinct().count() + unique_brands = ( + db.query(Product.brand) + .filter(Product.brand.isnot(None), Product.brand != "") + .distinct() + .count() + ) - unique_categories = db.query(Product.google_product_category).filter( - Product.google_product_category.isnot(None), - Product.google_product_category != "" - ).distinct().count() + unique_categories = ( + db.query(Product.google_product_category) + .filter( + Product.google_product_category.isnot(None), + Product.google_product_category != "", + ) + .distinct() + .count() + ) # New marketplace statistics - unique_marketplaces = db.query(Product.marketplace).filter( - Product.marketplace.isnot(None), - Product.marketplace != "" - ).distinct().count() + unique_marketplaces = ( + db.query(Product.marketplace) + .filter(Product.marketplace.isnot(None), Product.marketplace != "") + .distinct() + .count() + ) - unique_shops = db.query(Product.shop_name).filter( - Product.shop_name.isnot(None), - Product.shop_name != "" - ).distinct().count() + unique_shops = ( + db.query(Product.shop_name) + .filter(Product.shop_name.isnot(None), Product.shop_name != "") + .distinct() + .count() + ) # Stock statistics total_stock_entries = db.query(Stock).count() @@ -57,10 +69,12 @@ class StatsService: "unique_marketplaces": unique_marketplaces, "unique_shops": unique_shops, "total_stock_entries": total_stock_entries, - "total_inventory_quantity": total_inventory + "total_inventory_quantity": total_inventory, } - logger.info(f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces") + logger.info( + f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces" + ) return stats_data def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]: @@ -74,25 +88,31 @@ class StatsService: List of dictionaries containing marketplace statistics """ # Query to get stats per marketplace - marketplace_stats = db.query( - Product.marketplace, - func.count(Product.id).label('total_products'), - func.count(func.distinct(Product.shop_name)).label('unique_shops'), - func.count(func.distinct(Product.brand)).label('unique_brands') - ).filter( - Product.marketplace.isnot(None) - ).group_by(Product.marketplace).all() + marketplace_stats = ( + db.query( + Product.marketplace, + func.count(Product.id).label("total_products"), + func.count(func.distinct(Product.shop_name)).label("unique_shops"), + func.count(func.distinct(Product.brand)).label("unique_brands"), + ) + .filter(Product.marketplace.isnot(None)) + .group_by(Product.marketplace) + .all() + ) stats_list = [ { "marketplace": stat.marketplace, "total_products": stat.total_products, "unique_shops": stat.unique_shops, - "unique_brands": stat.unique_brands - } for stat in marketplace_stats + "unique_brands": stat.unique_brands, + } + for stat in marketplace_stats ] - logger.info(f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces") + logger.info( + f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces" + ) return stats_list def get_product_count(self, db: Session) -> int: @@ -101,31 +121,42 @@ class StatsService: def get_unique_brands_count(self, db: Session) -> int: """Get count of unique brands""" - return db.query(Product.brand).filter( - Product.brand.isnot(None), - Product.brand != "" - ).distinct().count() + return ( + db.query(Product.brand) + .filter(Product.brand.isnot(None), Product.brand != "") + .distinct() + .count() + ) def get_unique_categories_count(self, db: Session) -> int: """Get count of unique categories""" - return db.query(Product.google_product_category).filter( - Product.google_product_category.isnot(None), - Product.google_product_category != "" - ).distinct().count() + return ( + db.query(Product.google_product_category) + .filter( + Product.google_product_category.isnot(None), + Product.google_product_category != "", + ) + .distinct() + .count() + ) def get_unique_marketplaces_count(self, db: Session) -> int: """Get count of unique marketplaces""" - return db.query(Product.marketplace).filter( - Product.marketplace.isnot(None), - Product.marketplace != "" - ).distinct().count() + return ( + db.query(Product.marketplace) + .filter(Product.marketplace.isnot(None), Product.marketplace != "") + .distinct() + .count() + ) def get_unique_shops_count(self, db: Session) -> int: """Get count of unique shops""" - return db.query(Product.shop_name).filter( - Product.shop_name.isnot(None), - Product.shop_name != "" - ).distinct().count() + return ( + db.query(Product.shop_name) + .filter(Product.shop_name.isnot(None), Product.shop_name != "") + .distinct() + .count() + ) def get_stock_statistics(self, db: Session) -> Dict[str, int]: """ @@ -142,25 +173,35 @@ class StatsService: return { "total_stock_entries": total_stock_entries, - "total_inventory_quantity": total_inventory + "total_inventory_quantity": total_inventory, } def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]: """Get unique brands for a specific marketplace""" - brands = db.query(Product.brand).filter( - Product.marketplace == marketplace, - Product.brand.isnot(None), - Product.brand != "" - ).distinct().all() + brands = ( + db.query(Product.brand) + .filter( + Product.marketplace == marketplace, + Product.brand.isnot(None), + Product.brand != "", + ) + .distinct() + .all() + ) return [brand[0] for brand in brands] def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]: """Get unique shops for a specific marketplace""" - shops = db.query(Product.shop_name).filter( - Product.marketplace == marketplace, - Product.shop_name.isnot(None), - Product.shop_name != "" - ).distinct().all() + shops = ( + db.query(Product.shop_name) + .filter( + Product.marketplace == marketplace, + Product.shop_name.isnot(None), + Product.shop_name != "", + ) + .distinct() + .all() + ) return [shop[0] for shop in shops] def get_products_by_marketplace(self, db: Session, marketplace: str) -> int: diff --git a/app/services/stock_service.py b/app/services/stock_service.py index 8dad7930..78edf73b 100644 --- a/app/services/stock_service.py +++ b/app/services/stock_service.py @@ -1,10 +1,13 @@ -from sqlalchemy.orm import Session -from models.database_models import Stock, Product -from models.api_models import StockCreate, StockAdd, StockUpdate, StockLocationResponse, StockSummaryResponse -from utils.data_processing import GTINProcessor -from typing import Optional, List, Tuple -from datetime import datetime import logging +from datetime import datetime +from typing import List, Optional, Tuple + +from sqlalchemy.orm import Session + +from models.api_models import (StockAdd, StockCreate, StockLocationResponse, + StockSummaryResponse, StockUpdate) +from models.database_models import Product, Stock +from utils.data_processing import GTINProcessor logger = logging.getLogger(__name__) @@ -26,10 +29,11 @@ class StockService: location = stock_data.location.strip().upper() # Check if stock entry already exists for this GTIN and location - existing_stock = db.query(Stock).filter( - Stock.gtin == normalized_gtin, - Stock.location == location - ).first() + existing_stock = ( + db.query(Stock) + .filter(Stock.gtin == normalized_gtin, Stock.location == location) + .first() + ) if existing_stock: # Update existing stock (SET to exact quantity) @@ -39,19 +43,20 @@ class StockService: db.commit() db.refresh(existing_stock) logger.info( - f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity} → {stock_data.quantity}") + f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity} → {stock_data.quantity}" + ) return existing_stock else: # Create new stock entry new_stock = Stock( - gtin=normalized_gtin, - location=location, - quantity=stock_data.quantity + gtin=normalized_gtin, location=location, quantity=stock_data.quantity ) db.add(new_stock) db.commit() db.refresh(new_stock) - logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}") + logger.info( + f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}" + ) return new_stock def add_stock(self, db: Session, stock_data: StockAdd) -> Stock: @@ -63,10 +68,11 @@ class StockService: location = stock_data.location.strip().upper() # Check if stock entry already exists for this GTIN and location - existing_stock = db.query(Stock).filter( - Stock.gtin == normalized_gtin, - Stock.location == location - ).first() + existing_stock = ( + db.query(Stock) + .filter(Stock.gtin == normalized_gtin, Stock.location == location) + .first() + ) if existing_stock: # Add to existing stock @@ -76,19 +82,20 @@ class StockService: db.commit() db.refresh(existing_stock) logger.info( - f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}") + f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}" + ) return existing_stock else: # Create new stock entry with the quantity new_stock = Stock( - gtin=normalized_gtin, - location=location, - quantity=stock_data.quantity + gtin=normalized_gtin, location=location, quantity=stock_data.quantity ) db.add(new_stock) db.commit() db.refresh(new_stock) - logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}") + logger.info( + f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}" + ) return new_stock def remove_stock(self, db: Session, stock_data: StockAdd) -> Stock: @@ -100,18 +107,22 @@ class StockService: location = stock_data.location.strip().upper() # Find existing stock entry - existing_stock = db.query(Stock).filter( - Stock.gtin == normalized_gtin, - Stock.location == location - ).first() + existing_stock = ( + db.query(Stock) + .filter(Stock.gtin == normalized_gtin, Stock.location == location) + .first() + ) if not existing_stock: - raise ValueError(f"No stock found for GTIN {normalized_gtin} at location {location}") + raise ValueError( + f"No stock found for GTIN {normalized_gtin} at location {location}" + ) # Check if we have enough stock to remove if existing_stock.quantity < stock_data.quantity: raise ValueError( - f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}") + f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}" + ) # Remove from existing stock old_quantity = existing_stock.quantity @@ -120,7 +131,8 @@ class StockService: db.commit() db.refresh(existing_stock) logger.info( - f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}") + f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}" + ) return existing_stock def get_stock_by_gtin(self, db: Session, gtin: str) -> StockSummaryResponse: @@ -141,10 +153,9 @@ class StockService: for entry in stock_entries: total_quantity += entry.quantity - locations.append(StockLocationResponse( - location=entry.location, - quantity=entry.quantity - )) + locations.append( + StockLocationResponse(location=entry.location, quantity=entry.quantity) + ) # Try to get product title for reference product = db.query(Product).filter(Product.gtin == normalized_gtin).first() @@ -154,7 +165,7 @@ class StockService: gtin=normalized_gtin, total_quantity=total_quantity, locations=locations, - product_title=product_title + product_title=product_title, ) def get_total_stock(self, db: Session, gtin: str) -> dict: @@ -174,16 +185,16 @@ class StockService: "gtin": normalized_gtin, "total_quantity": total_quantity, "product_title": product.title if product else None, - "locations_count": len(total_stock) + "locations_count": len(total_stock), } def get_all_stock( - self, - db: Session, - skip: int = 0, - limit: int = 100, - location: Optional[str] = None, - gtin: Optional[str] = None + self, + db: Session, + skip: int = 0, + limit: int = 100, + location: Optional[str] = None, + gtin: Optional[str] = None, ) -> List[Stock]: """Get all stock entries with optional filtering""" query = db.query(Stock) @@ -198,7 +209,9 @@ class StockService: return query.offset(skip).limit(limit).all() - def update_stock(self, db: Session, stock_id: int, stock_update: StockUpdate) -> Stock: + def update_stock( + self, db: Session, stock_id: int, stock_update: StockUpdate + ) -> Stock: """Update stock quantity for a specific stock entry""" stock_entry = db.query(Stock).filter(Stock.id == stock_id).first() if not stock_entry: @@ -209,7 +222,9 @@ class StockService: db.commit() db.refresh(stock_entry) - logger.info(f"Updated stock entry {stock_id} to quantity {stock_update.quantity}") + logger.info( + f"Updated stock entry {stock_id} to quantity {stock_update.quantity}" + ) return stock_entry def delete_stock(self, db: Session, stock_id: int) -> bool: diff --git a/app/tasks/background_tasks.py b/app/tasks/background_tasks.py index 93ae5b4a..186a0d7d 100644 --- a/app/tasks/background_tasks.py +++ b/app/tasks/background_tasks.py @@ -1,6 +1,7 @@ # app/tasks/background_tasks.py import logging from datetime import datetime + from app.core.database import SessionLocal from models.database_models import MarketplaceImportJob from utils.csv_processor import CSVProcessor @@ -9,11 +10,7 @@ logger = logging.getLogger(__name__) async def process_marketplace_import( - job_id: int, - url: str, - marketplace: str, - shop_name: str, - batch_size: int = 1000 + job_id: int, url: str, marketplace: str, shop_name: str, batch_size: int = 1000 ): """Background task to process marketplace CSV import""" db = SessionLocal() @@ -22,7 +19,11 @@ async def process_marketplace_import( try: # Update job status - job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() + job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) if not job: logger.error(f"Import job {job_id} not found") return @@ -70,7 +71,7 @@ async def process_marketplace_import( finally: # Close the database session only if it's not a mock # In tests, we use the same session so we shouldn't close it - if hasattr(db, 'close') and callable(getattr(db, 'close')): + if hasattr(db, "close") and callable(getattr(db, "close")): try: db.close() except Exception as close_error: diff --git a/auth_example.py b/auth_example.py index 06fbe5ab..ac90ce11 100644 --- a/auth_example.py +++ b/auth_example.py @@ -6,21 +6,21 @@ import requests # API Base URL BASE_URL = "http://localhost:8000" + def register_user(email, username, password): """Register a new user""" - response = requests.post(f"{BASE_URL}/register", json={ - "email": email, - "username": username, - "password": password - }) + response = requests.post( + f"{BASE_URL}/register", + json={"email": email, "username": username, "password": password}, + ) return response.json() + def login_user(username, password): """Login and get JWT token""" - response = requests.post(f"{BASE_URL}/login", json={ - "username": username, - "password": password - }) + response = requests.post( + f"{BASE_URL}/login", json={"username": username, "password": password} + ) if response.status_code == 200: data = response.json() return data["access_token"] @@ -28,24 +28,30 @@ def login_user(username, password): print(f"Login failed: {response.json()}") return None + def get_user_info(token): """Get current user info""" headers = {"Authorization": f"Bearer {token}"} response = requests.get(f"{BASE_URL}/me", headers=headers) return response.json() + def get_products(token, skip=0, limit=10): """Get products (requires authentication)""" headers = {"Authorization": f"Bearer {token}"} - response = requests.get(f"{BASE_URL}/products?skip={skip}&limit={limit}", headers=headers) + response = requests.get( + f"{BASE_URL}/products?skip={skip}&limit={limit}", headers=headers + ) return response.json() + def create_product(token, product_data): """Create a new product (requires authentication)""" headers = {"Authorization": f"Bearer {token}"} response = requests.post(f"{BASE_URL}/products", json=product_data, headers=headers) return response.json() + # Example usage if __name__ == "__main__": # 1. Register a new user @@ -55,18 +61,18 @@ if __name__ == "__main__": print(f"User registered: {user_result}") except Exception as e: print(f"Registration failed: {e}") - + # 2. Login with default admin user print("\n2. Logging in as admin...") admin_token = login_user("admin", "admin123") if admin_token: print(f"Admin login successful! Token: {admin_token[:50]}...") - + # 3. Get user info print("\n3. Getting admin user info...") user_info = get_user_info(admin_token) print(f"User info: {user_info}") - + # 4. Create a sample product print("\n4. Creating a sample product...") sample_product = { @@ -75,30 +81,32 @@ if __name__ == "__main__": "description": "A test product for demonstration", "price": "19.99", "brand": "Test Brand", - "availability": "in stock" + "availability": "in stock", } - + product_result = create_product(admin_token, sample_product) print(f"Product created: {product_result}") - + # 5. Get products list print("\n5. Getting products list...") products = get_products(admin_token) print(f"Products: {products}") - + # 6. Login with regular user print("\n6. Logging in as regular user...") user_token = login_user("testuser", "password123") if user_token: print(f"User login successful! Token: {user_token[:50]}...") - + # Regular users can also access protected endpoints user_info = get_user_info(user_token) print(f"Regular user info: {user_info}") - + products = get_products(user_token, limit=5) - print(f"Products accessible to regular user: {len(products.get('products', []))} products") - + print( + f"Products accessible to regular user: {len(products.get('products', []))} products" + ) + print("\nAuthentication example completed!") # Example cURL commands: @@ -126,4 +134,4 @@ curl -X POST "http://localhost:8000/products" \ -H "Authorization: Bearer YOUR_JWT_TOKEN_HERE" \ -H "Content-Type: application/json" \ -d '{"product_id": "TEST001", "title": "Test Product", "price": "19.99"}' -""" \ No newline at end of file +""" diff --git a/main.py b/main.py index 6ba9d4a6..e0b1a2ce 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,15 @@ -from fastapi import FastAPI, Depends, HTTPException -from sqlalchemy.orm import Session -from sqlalchemy import text -from app.core.config import settings -from app.core.lifespan import lifespan -from app.core.database import get_db -from datetime import datetime -from app.api.main import api_router -from fastapi.middleware.cors import CORSMiddleware import logging +from datetime import datetime + +from fastapi import Depends, FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import text +from sqlalchemy.orm import Session + +from app.api.main import api_router +from app.core.config import settings +from app.core.database import get_db +from app.core.lifespan import lifespan logger = logging.getLogger(__name__) @@ -16,7 +18,7 @@ app = FastAPI( title=settings.project_name, description=settings.description, version=settings.version, - lifespan=lifespan + lifespan=lifespan, ) # Add CORS middleware @@ -44,10 +46,17 @@ def root(): "JWT Authentication", "Marketplace-aware product import", "Multi-shop product management", - "Stock management with location tracking" + "Stock management with location tracking", ], - "supported_marketplaces": ["Letzshop", "Amazon", "eBay", "Etsy", "Shopify", "Other"], - "auth_required": "Most endpoints require Bearer token authentication" + "supported_marketplaces": [ + "Letzshop", + "Amazon", + "eBay", + "Etsy", + "Shopify", + "Other", + ], + "auth_required": "Most endpoints require Bearer token authentication", } diff --git a/middleware/auth.py b/middleware/auth.py index 2baa117e..352ccba6 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -1,14 +1,16 @@ # middleware/auth.py +import logging +import os +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials -from passlib.context import CryptContext from jose import jwt -from datetime import datetime, timedelta -from typing import Dict, Any, Optional +from passlib.context import CryptContext from sqlalchemy.orm import Session + from models.database_models import User -import os -import logging logger = logging.getLogger(__name__) @@ -20,7 +22,9 @@ class AuthManager: """JWT-based authentication manager with bcrypt password hashing""" def __init__(self): - self.secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production-please") + self.secret_key = os.getenv( + "JWT_SECRET_KEY", "your-secret-key-change-in-production-please" + ) self.algorithm = "HS256" self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30")) @@ -32,11 +36,15 @@ class AuthManager: """Verify password against hash""" return pwd_context.verify(plain_password, hashed_password) - def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: + def authenticate_user( + self, db: Session, username: str, password: str + ) -> Optional[User]: """Authenticate user and return user object if valid""" - user = db.query(User).filter( - (User.username == username) | (User.email == username) - ).first() + user = ( + db.query(User) + .filter((User.username == username) | (User.email == username)) + .first() + ) if not user: return None @@ -65,7 +73,7 @@ class AuthManager: "email": user.email, "role": user.role, "exp": expire, - "iat": datetime.utcnow() + "iat": datetime.utcnow(), } token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) @@ -73,7 +81,7 @@ class AuthManager: return { "access_token": token, "token_type": "bearer", - "expires_in": self.token_expire_minutes * 60 # Return in seconds + "expires_in": self.token_expire_minutes * 60, # Return in seconds } def verify_token(self, token: str) -> Dict[str, Any]: @@ -92,25 +100,31 @@ class AuthManager: # Extract user data user_id = payload.get("sub") if user_id is None: - raise HTTPException(status_code=401, detail="Token missing user identifier") + raise HTTPException( + status_code=401, detail="Token missing user identifier" + ) return { "user_id": int(user_id), "username": payload.get("username"), "email": payload.get("email"), - "role": payload.get("role", "user") + "role": payload.get("role", "user"), } except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token has expired") except jwt.JWTError as e: logger.error(f"JWT decode error: {e}") - raise HTTPException(status_code=401, detail="Could not validate credentials") + raise HTTPException( + status_code=401, detail="Could not validate credentials" + ) except Exception as e: logger.error(f"Token verification error: {e}") raise HTTPException(status_code=401, detail="Authentication failed") - def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User: + def get_current_user( + self, db: Session, credentials: HTTPAuthorizationCredentials + ) -> User: """Get current authenticated user from database""" user_data = self.verify_token(credentials.credentials) @@ -131,7 +145,7 @@ class AuthManager: if current_user.role != required_role: raise HTTPException( status_code=403, - detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'" + detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'", ) return func(current_user, *args, **kwargs) @@ -142,10 +156,7 @@ class AuthManager: def require_admin(self, current_user: User): """Require admin role""" if current_user.role != "admin": - raise HTTPException( - status_code=403, - detail="Admin privileges required" - ) + raise HTTPException(status_code=403, detail="Admin privileges required") return current_user def create_default_admin_user(self, db: Session): @@ -159,11 +170,13 @@ class AuthManager: username="admin", hashed_password=hashed_password, role="admin", - is_active=True + is_active=True, ) db.add(admin_user) db.commit() db.refresh(admin_user) - logger.info("Default admin user created: username='admin', password='admin123'") + logger.info( + "Default admin user created: username='admin', password='admin123'" + ) return admin_user diff --git a/middleware/decorators.py b/middleware/decorators.py index 69607047..020f3160 100644 --- a/middleware/decorators.py +++ b/middleware/decorators.py @@ -1,6 +1,8 @@ # middleware/decorators.py from functools import wraps + from fastapi import HTTPException + from middleware.rate_limiter import RateLimiter # Initialize rate limiter instance @@ -17,10 +19,7 @@ def rate_limit(max_requests: int = 100, window_seconds: int = 3600): client_id = "anonymous" # In production, extract from request if not rate_limiter.allow_request(client_id, max_requests, window_seconds): - raise HTTPException( - status_code=429, - detail="Rate limit exceeded" - ) + raise HTTPException(status_code=429, detail="Rate limit exceeded") return await func(*args, **kwargs) diff --git a/middleware/error_handler.py b/middleware/error_handler.py index dec50cbc..e01cc9fc 100644 --- a/middleware/error_handler.py +++ b/middleware/error_handler.py @@ -1,31 +1,35 @@ # middleware/error_handler.py -from fastapi import Request, HTTPException -from fastapi.responses import JSONResponse -from fastapi.exceptions import RequestValidationError - import logging +from fastapi import HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + logger = logging.getLogger(__name__) + async def custom_http_exception_handler(request: Request, exc: HTTPException): """Custom HTTP exception handler""" - logger.error(f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}") - + logger.error( + f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}" + ) + return JSONResponse( status_code=exc.status_code, content={ "error": { "code": exc.status_code, "message": exc.detail, - "type": "http_exception" + "type": "http_exception", } - } + }, ) + async def validation_exception_handler(request: Request, exc: RequestValidationError): """Handle Pydantic validation errors""" logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}") - + return JSONResponse( status_code=422, content={ @@ -33,23 +37,25 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE "code": 422, "message": "Validation error", "type": "validation_error", - "details": exc.errors() + "details": exc.errors(), } - } + }, ) + async def general_exception_handler(request: Request, exc: Exception): """Handle unexpected exceptions""" - logger.error(f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True) - + logger.error( + f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True + ) + return JSONResponse( status_code=500, content={ "error": { "code": 500, "message": "Internal server error", - "type": "server_error" + "type": "server_error", } - } + }, ) - diff --git a/middleware/logging_middleware.py b/middleware/logging_middleware.py index f0f71829..2c6eac5a 100644 --- a/middleware/logging_middleware.py +++ b/middleware/logging_middleware.py @@ -1,9 +1,10 @@ # middleware/logging_middleware.py import logging import time +from typing import Callable + from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware -from typing import Callable logger = logging.getLogger(__name__) @@ -43,4 +44,4 @@ class LoggingMiddleware(BaseHTTPMiddleware): f"Error: {str(e)} for {request.method} {request.url.path} " f"({duration:.3f}s)" ) - raise \ No newline at end of file + raise diff --git a/middleware/rate_limiter.py b/middleware/rate_limiter.py index c1389376..1ba70b3a 100644 --- a/middleware/rate_limiter.py +++ b/middleware/rate_limiter.py @@ -1,8 +1,8 @@ # middleware/rate_limiter.py -from typing import Dict -from datetime import datetime, timedelta import logging from collections import defaultdict, deque +from datetime import datetime, timedelta +from typing import Dict logger = logging.getLogger(__name__) @@ -16,7 +16,9 @@ class RateLimiter: self.cleanup_interval = 3600 # Clean up old entries every hour self.last_cleanup = datetime.utcnow() - def allow_request(self, client_id: str, max_requests: int, window_seconds: int) -> bool: + def allow_request( + self, client_id: str, max_requests: int, window_seconds: int + ) -> bool: """ Check if client is allowed to make a request Uses sliding window algorithm @@ -41,7 +43,9 @@ class RateLimiter: client_requests.append(now) return True - logger.warning(f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}") + logger.warning( + f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}" + ) return False def _cleanup_old_entries(self): @@ -62,7 +66,9 @@ class RateLimiter: for client_id in clients_to_remove: del self.clients[client_id] - logger.info(f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients") + logger.info( + f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients" + ) def get_client_stats(self, client_id: str) -> Dict[str, int]: """Get statistics for a specific client""" @@ -72,11 +78,13 @@ class RateLimiter: hour_ago = now - timedelta(hours=1) day_ago = now - timedelta(days=1) - requests_last_hour = sum(1 for req_time in client_requests if req_time > hour_ago) + requests_last_hour = sum( + 1 for req_time in client_requests if req_time > hour_ago + ) requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago) return { "requests_last_hour": requests_last_hour, "requests_last_day": requests_last_day, - "total_tracked_requests": len(client_requests) - } \ No newline at end of file + "total_tracked_requests": len(client_requests), + } diff --git a/models/api_models.py b/models/api_models.py index f3e08587..499abda0 100644 --- a/models/api_models.py +++ b/models/api_models.py @@ -1,28 +1,35 @@ # models/api_models.py - Updated with Marketplace Support and Pydantic v2 -from pydantic import BaseModel, Field, field_validator, EmailStr, ConfigDict -from typing import Optional, List -from datetime import datetime import re +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator # User Authentication Models class UserRegister(BaseModel): email: EmailStr = Field(..., description="Valid email address") - username: str = Field(..., min_length=3, max_length=50, description="Username (3-50 characters)") - password: str = Field(..., min_length=6, description="Password (minimum 6 characters)") + username: str = Field( + ..., min_length=3, max_length=50, description="Username (3-50 characters)" + ) + password: str = Field( + ..., min_length=6, description="Password (minimum 6 characters)" + ) - @field_validator('username') + @field_validator("username") @classmethod def validate_username(cls, v): - if not re.match(r'^[a-zA-Z0-9_]+$', v): - raise ValueError('Username must contain only letters, numbers, or underscores') + if not re.match(r"^[a-zA-Z0-9_]+$", v): + raise ValueError( + "Username must contain only letters, numbers, or underscores" + ) return v.lower().strip() - @field_validator('password') + @field_validator("password") @classmethod def validate_password(cls, v): if len(v) < 6: - raise ValueError('Password must be at least 6 characters long') + raise ValueError("Password must be at least 6 characters long") return v @@ -30,7 +37,7 @@ class UserLogin(BaseModel): username: str = Field(..., description="Username") password: str = Field(..., description="Password") - @field_validator('username') + @field_validator("username") @classmethod def validate_username(cls, v): return v.strip() @@ -58,27 +65,38 @@ class LoginResponse(BaseModel): # NEW: Shop models class ShopCreate(BaseModel): - shop_code: str = Field(..., min_length=3, max_length=50, description="Unique shop code (e.g., TECHSTORE)") - shop_name: str = Field(..., min_length=1, max_length=200, description="Display name of the shop") - description: Optional[str] = Field(None, max_length=2000, description="Shop description") + shop_code: str = Field( + ..., + min_length=3, + max_length=50, + description="Unique shop code (e.g., TECHSTORE)", + ) + shop_name: str = Field( + ..., min_length=1, max_length=200, description="Display name of the shop" + ) + description: Optional[str] = Field( + None, max_length=2000, description="Shop description" + ) contact_email: Optional[str] = None contact_phone: Optional[str] = None website: Optional[str] = None business_address: Optional[str] = None tax_number: Optional[str] = None - @field_validator('shop_code') + @field_validator("shop_code") def validate_shop_code(cls, v): # Convert to uppercase and check format v = v.upper().strip() - if not v.replace('_', '').replace('-', '').isalnum(): - raise ValueError('Shop code must be alphanumeric (underscores and hyphens allowed)') + if not v.replace("_", "").replace("-", "").isalnum(): + raise ValueError( + "Shop code must be alphanumeric (underscores and hyphens allowed)" + ) return v - @field_validator('contact_email') + @field_validator("contact_email") def validate_contact_email(cls, v): - if v and ('@' not in v or '.' not in v): - raise ValueError('Invalid email format') + if v and ("@" not in v or "." not in v): + raise ValueError("Invalid email format") return v.lower() if v else v @@ -91,10 +109,10 @@ class ShopUpdate(BaseModel): business_address: Optional[str] = None tax_number: Optional[str] = None - @field_validator('contact_email') + @field_validator("contact_email") def validate_contact_email(cls, v): - if v and ('@' not in v or '.' not in v): - raise ValueError('Invalid email format') + if v and ("@" not in v or "." not in v): + raise ValueError("Invalid email format") return v.lower() if v else v @@ -172,11 +190,11 @@ class ProductCreate(ProductBase): product_id: str = Field(..., min_length=1, description="Product ID is required") title: str = Field(..., min_length=1, description="Title is required") - @field_validator('product_id', 'title') + @field_validator("product_id", "title") @classmethod def validate_required_fields(cls, v): if not v or not v.strip(): - raise ValueError('Field cannot be empty') + raise ValueError("Field cannot be empty") return v.strip() @@ -195,15 +213,25 @@ class ProductResponse(ProductBase): # NEW: Shop Product models class ShopProductCreate(BaseModel): product_id: str = Field(..., description="Product ID to add to shop") - shop_product_id: Optional[str] = Field(None, description="Shop's internal product ID") - shop_price: Optional[float] = Field(None, ge=0, description="Shop-specific price override") - shop_sale_price: Optional[float] = Field(None, ge=0, description="Shop-specific sale price") + shop_product_id: Optional[str] = Field( + None, description="Shop's internal product ID" + ) + shop_price: Optional[float] = Field( + None, ge=0, description="Shop-specific price override" + ) + shop_sale_price: Optional[float] = Field( + None, ge=0, description="Shop-specific sale price" + ) shop_currency: Optional[str] = Field(None, description="Shop-specific currency") - shop_availability: Optional[str] = Field(None, description="Shop-specific availability") + shop_availability: Optional[str] = Field( + None, description="Shop-specific availability" + ) shop_condition: Optional[str] = Field(None, description="Shop-specific condition") is_featured: bool = Field(False, description="Featured product flag") min_quantity: int = Field(1, ge=1, description="Minimum order quantity") - max_quantity: Optional[int] = Field(None, ge=1, description="Maximum order quantity") + max_quantity: Optional[int] = Field( + None, ge=1, description="Maximum order quantity" + ) class ShopProductResponse(BaseModel): @@ -270,28 +298,40 @@ class StockSummaryResponse(BaseModel): # Marketplace Import Models class MarketplaceImportRequest(BaseModel): url: str = Field(..., description="URL to CSV file from marketplace") - marketplace: str = Field(default="Letzshop", description="Name of the marketplace (e.g., Letzshop, Amazon, eBay)") + marketplace: str = Field( + default="Letzshop", + description="Name of the marketplace (e.g., Letzshop, Amazon, eBay)", + ) shop_code: str = Field(..., description="Shop code to associate products with") - batch_size: Optional[int] = Field(1000, gt=0, le=10000, description="Batch size for processing") + batch_size: Optional[int] = Field( + 1000, gt=0, le=10000, description="Batch size for processing" + ) - @field_validator('url') + @field_validator("url") @classmethod def validate_url(cls, v): - if not v.startswith(('http://', 'https://')): - raise ValueError('URL must start with http:// or https://') + if not v.startswith(("http://", "https://")): + raise ValueError("URL must start with http:// or https://") return v - @field_validator('marketplace') + @field_validator("marketplace") @classmethod def validate_marketplace(cls, v): # You can add validation for supported marketplaces here - supported_marketplaces = ['Letzshop', 'Amazon', 'eBay', 'Etsy', 'Shopify', 'Other'] + supported_marketplaces = [ + "Letzshop", + "Amazon", + "eBay", + "Etsy", + "Shopify", + "Other", + ] if v not in supported_marketplaces: # For now, allow any marketplace but log it pass return v.strip() - @field_validator('shop_code') + @field_validator("shop_code") @classmethod def validate_shop_code(cls, v): return v.upper().strip() diff --git a/models/database_models.py b/models/database_models.py index b93affb8..3fa75f19 100644 --- a/models/database_models.py +++ b/models/database_models.py @@ -1,8 +1,10 @@ # models/database_models.py - Updated with Marketplace Support -from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text, ForeignKey, UniqueConstraint, Index -from sqlalchemy.orm import relationship from datetime import datetime +from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Index, + Integer, String, Text, UniqueConstraint) +from sqlalchemy.orm import relationship + # Import Base from the central database module instead of creating a new one from app.core.database import Base @@ -18,10 +20,14 @@ class User(Base): is_active = Column(Boolean, default=True, nullable=False) last_login = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + updated_at = Column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False + ) # Relationships - marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="user") + marketplace_import_jobs = relationship( + "MarketplaceImportJob", back_populates="user" + ) owned_shops = relationship("Shop", back_populates="owner") def __repr__(self): @@ -32,7 +38,9 @@ class Shop(Base): __tablename__ = "shops" id = Column(Integer, primary_key=True, index=True) - shop_code = Column(String, unique=True, index=True, nullable=False) # e.g., "TECHSTORE", "FASHIONHUB" + shop_code = Column( + String, unique=True, index=True, nullable=False + ) # e.g., "TECHSTORE", "FASHIONHUB" shop_name = Column(String, nullable=False) # Display name description = Column(Text) owner_id = Column(Integer, ForeignKey("users.id"), nullable=False) @@ -57,7 +65,9 @@ class Shop(Base): # Relationships owner = relationship("User", back_populates="owned_shops") shop_products = relationship("ShopProduct", back_populates="shop") - marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="shop") + marketplace_import_jobs = relationship( + "MarketplaceImportJob", back_populates="shop" + ) class Product(Base): @@ -103,26 +113,40 @@ class Product(Base): currency = Column(String) # New marketplace fields - marketplace = Column(String, index=True, nullable=True, default="Letzshop") # Index for marketplace filtering + marketplace = Column( + String, index=True, nullable=True, default="Letzshop" + ) # Index for marketplace filtering shop_name = Column(String, index=True, nullable=True) # Index for shop filtering created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + updated_at = Column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False + ) # Relationship to stock (one-to-many via GTIN) - stock_entries = relationship("Stock", foreign_keys="Stock.gtin", primaryjoin="Product.gtin == Stock.gtin", - viewonly=True) + stock_entries = relationship( + "Stock", + foreign_keys="Stock.gtin", + primaryjoin="Product.gtin == Stock.gtin", + viewonly=True, + ) shop_products = relationship("ShopProduct", back_populates="product") # Additional indexes for marketplace queries __table_args__ = ( - Index('idx_marketplace_shop', 'marketplace', 'shop_name'), # Composite index for marketplace+shop queries - Index('idx_marketplace_brand', 'marketplace', 'brand'), # Composite index for marketplace+brand queries + Index( + "idx_marketplace_shop", "marketplace", "shop_name" + ), # Composite index for marketplace+shop queries + Index( + "idx_marketplace_brand", "marketplace", "brand" + ), # Composite index for marketplace+brand queries ) def __repr__(self): - return (f"") + return ( + f"" + ) class ShopProduct(Base): @@ -159,9 +183,9 @@ class ShopProduct(Base): # Constraints __table_args__ = ( - UniqueConstraint('shop_id', 'product_id', name='uq_shop_product'), - Index('idx_shop_product_active', 'shop_id', 'is_active'), - Index('idx_shop_product_featured', 'shop_id', 'is_featured'), + UniqueConstraint("shop_id", "product_id", name="uq_shop_product"), + Index("idx_shop_product_active", "shop_id", "is_active"), + Index("idx_shop_product_featured", "shop_id", "is_featured"), ) @@ -169,22 +193,28 @@ class Stock(Base): __tablename__ = "stock" id = Column(Integer, primary_key=True, index=True) - gtin = Column(String, index=True, nullable=False) # Foreign key relationship would be ideal + gtin = Column( + String, index=True, nullable=False + ) # Foreign key relationship would be ideal location = Column(String, nullable=False, index=True) quantity = Column(Integer, nullable=False, default=0) reserved_quantity = Column(Integer, default=0) # For orders being processed shop_id = Column(Integer, ForeignKey("shops.id")) # Optional: shop-specific stock created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + updated_at = Column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False + ) # Relationships shop = relationship("Shop") # Composite unique constraint to prevent duplicate GTIN-location combinations __table_args__ = ( - UniqueConstraint('gtin', 'location', name='uq_stock_gtin_location'), - Index('idx_stock_gtin_location', 'gtin', 'location'), # Composite index for efficient queries + UniqueConstraint("gtin", "location", name="uq_stock_gtin_location"), + Index( + "idx_stock_gtin_location", "gtin", "location" + ), # Composite index for efficient queries ) def __repr__(self): @@ -195,13 +225,20 @@ class MarketplaceImportJob(Base): __tablename__ = "marketplace_import_jobs" id = Column(Integer, primary_key=True, index=True) - status = Column(String, nullable=False, - default="pending") # pending, processing, completed, failed, completed_with_errors + status = Column( + String, nullable=False, default="pending" + ) # pending, processing, completed, failed, completed_with_errors source_url = Column(String, nullable=False) - marketplace = Column(String, nullable=False, index=True, default="Letzshop") # Index for marketplace filtering + marketplace = Column( + String, nullable=False, index=True, default="Letzshop" + ) # Index for marketplace filtering shop_name = Column(String, nullable=False, index=True) # Index for shop filtering - shop_id = Column(Integer, ForeignKey("shops.id"), nullable=False) # Add proper foreign key - user_id = Column(Integer, ForeignKey('users.id'), nullable=False) # Foreign key to users table + shop_id = Column( + Integer, ForeignKey("shops.id"), nullable=False + ) # Add proper foreign key + user_id = Column( + Integer, ForeignKey("users.id"), nullable=False + ) # Foreign key to users table # Results imported_count = Column(Integer, default=0) @@ -223,11 +260,15 @@ class MarketplaceImportJob(Base): # Additional indexes for marketplace import job queries __table_args__ = ( - Index('idx_marketplace_import_user_marketplace', 'user_id', 'marketplace'), # User's marketplace imports - Index('idx_marketplace_import_shop_status', 'status'), # Shop import status - Index('idx_marketplace_import_shop_id', 'shop_id'), + Index( + "idx_marketplace_import_user_marketplace", "user_id", "marketplace" + ), # User's marketplace imports + Index("idx_marketplace_import_shop_status", "status"), # Shop import status + Index("idx_marketplace_import_shop_id", "shop_id"), ) def __repr__(self): - return (f"") + return ( + f"" + ) diff --git a/scripts/setup_dev.py b/scripts/setup_dev.py index 6b416df0..22d5b10c 100644 --- a/scripts/setup_dev.py +++ b/scripts/setup_dev.py @@ -3,8 +3,8 @@ """Development environment setup script""" import os -import sys import subprocess +import sys from pathlib import Path @@ -30,6 +30,7 @@ def setup_alembic(): if alembic_dir.exists(): # Remove incomplete alembic directory import shutil + shutil.rmtree(alembic_dir) if not run_command("alembic init alembic", "Initializing Alembic"): @@ -138,16 +139,23 @@ LOG_LEVEL=INFO # Set up Alembic if not setup_alembic(): - print("⚠️ Alembic setup failed. You'll need to set up database migrations manually.") + print( + "⚠️ Alembic setup failed. You'll need to set up database migrations manually." + ) return False # Create initial migration - if not run_command("alembic revision --autogenerate -m \"Initial migration\"", "Creating initial migration"): + if not run_command( + 'alembic revision --autogenerate -m "Initial migration"', + "Creating initial migration", + ): print("⚠️ Initial migration creation failed. Check your database models.") # Apply migrations if not run_command("alembic upgrade head", "Setting up database"): - print("⚠️ Database setup failed. Make sure your DATABASE_URL is correct in .env") + print( + "⚠️ Database setup failed. Make sure your DATABASE_URL is correct in .env" + ) # Run tests if not run_command("pytest", "Running tests"): @@ -157,7 +165,7 @@ LOG_LEVEL=INFO print("To start the development server, run:") print(" uvicorn main:app --reload") print("\nDatabase commands:") - print(" alembic revision --autogenerate -m \"Description\" # Create migration") + print(' alembic revision --autogenerate -m "Description" # Create migration') print(" alembic upgrade head # Apply migrations") print(" alembic current # Check status") diff --git a/tests/conftest.py b/tests/conftest.py index 0e6703ae..df3036c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,18 @@ # tests/conftest.py +import uuid + import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool +from app.core.database import Base, get_db from main import app -from app.core.database import get_db, Base -# Import all models to ensure they're registered with Base metadata -from models.database_models import User, Product, Stock, Shop, MarketplaceImportJob, ShopProduct from middleware.auth import AuthManager -import uuid +# Import all models to ensure they're registered with Base metadata +from models.database_models import (MarketplaceImportJob, Product, Shop, + ShopProduct, Stock, User) # Use in-memory SQLite database for tests SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:" @@ -23,7 +25,7 @@ def engine(): SQLALCHEMY_TEST_DATABASE_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, - echo=False # Set to True for SQL debugging + echo=False, # Set to True for SQL debugging ) @@ -89,7 +91,7 @@ def test_user(db, auth_manager): username=f"testuser_{unique_id}", hashed_password=hashed_password, role="user", - is_active=True + is_active=True, ) db.add(user) db.commit() @@ -107,7 +109,7 @@ def test_admin(db, auth_manager): username=f"admin_{unique_id}", hashed_password=hashed_password, role="admin", - is_active=True + is_active=True, ) db.add(admin) db.commit() @@ -118,10 +120,10 @@ def test_admin(db, auth_manager): @pytest.fixture def auth_headers(client, test_user): """Get authentication headers for test user""" - response = client.post("/api/v1/auth/login", json={ - "username": test_user.username, - "password": "testpass123" - }) + response = client.post( + "/api/v1/auth/login", + json={"username": test_user.username, "password": "testpass123"}, + ) assert response.status_code == 200, f"Login failed: {response.text}" token = response.json()["access_token"] return {"Authorization": f"Bearer {token}"} @@ -130,10 +132,10 @@ def auth_headers(client, test_user): @pytest.fixture def admin_headers(client, test_admin): """Get authentication headers for admin user""" - response = client.post("/api/v1/auth/login", json={ - "username": test_admin.username, - "password": "adminpass123" - }) + response = client.post( + "/api/v1/auth/login", + json={"username": test_admin.username, "password": "adminpass123"}, + ) assert response.status_code == 200, f"Admin login failed: {response.text}" token = response.json()["access_token"] return {"Authorization": f"Bearer {token}"} @@ -152,7 +154,7 @@ def test_product(db): gtin="1234567890123", availability="in stock", marketplace="Letzshop", - shop_name="TestShop" + shop_name="TestShop", ) db.add(product) db.commit() @@ -169,7 +171,7 @@ def test_shop(db, test_user): shop_name=f"Test Shop {unique_id}", owner_id=test_user.id, is_active=True, - is_verified=True + is_verified=True, ) db.add(shop) db.commit() @@ -186,7 +188,7 @@ def test_stock(db, test_product, test_shop): location=f"WAREHOUSE_A_{unique_id}", quantity=10, reserved_quantity=0, - shop_id=test_shop.id # Add shop_id reference + shop_id=test_shop.id, # Add shop_id reference ) db.add(stock) db.commit() @@ -208,7 +210,7 @@ def test_marketplace_job(db, test_shop, test_user): # Add test_shop dependency updated_count=3, total_processed=8, error_count=0, - error_message=None + error_message=None, ) db.add(job) db.commit() @@ -219,16 +221,16 @@ def test_marketplace_job(db, test_shop, test_user): # Add test_shop dependency def create_test_import_job(db, shop_id, **kwargs): # Add shop_id parameter """Helper function to create MarketplaceImportJob with defaults""" defaults = { - 'marketplace': 'test', - 'shop_name': 'Test Shop', - 'status': 'pending', - 'source_url': 'https://test.example.com/import', - 'shop_id': shop_id, # Add required shop_id - 'imported_count': 0, - 'updated_count': 0, - 'total_processed': 0, - 'error_count': 0, - 'error_message': None + "marketplace": "test", + "shop_name": "Test Shop", + "status": "pending", + "source_url": "https://test.example.com/import", + "shop_id": shop_id, # Add required shop_id + "imported_count": 0, + "updated_count": 0, + "total_processed": 0, + "error_count": 0, + "error_message": None, } defaults.update(kwargs) @@ -250,6 +252,7 @@ def cleanup(): # Add these fixtures to your existing conftest.py + @pytest.fixture def unique_product(db): """Create a unique product for tests that need isolated product data""" @@ -265,7 +268,7 @@ def unique_product(db): availability="in stock", marketplace="Letzshop", shop_name=f"UniqueShop_{unique_id}", - google_product_category=f"UniqueCategory_{unique_id}" + google_product_category=f"UniqueCategory_{unique_id}", ) db.add(product) db.commit() @@ -283,7 +286,7 @@ def unique_shop(db, test_user): description=f"A unique test shop {unique_id}", owner_id=test_user.id, is_active=True, - is_verified=True + is_verified=True, ) db.add(shop) db.commit() @@ -301,7 +304,7 @@ def other_user(db, auth_manager): username=f"otheruser_{unique_id}", hashed_password=hashed_password, role="user", - is_active=True + is_active=True, ) db.add(user) db.commit() @@ -318,7 +321,7 @@ def inactive_shop(db, other_user): shop_name=f"Inactive Shop {unique_id}", owner_id=other_user.id, is_active=False, - is_verified=False + is_verified=False, ) db.add(shop) db.commit() @@ -335,7 +338,7 @@ def verified_shop(db, other_user): shop_name=f"Verified Shop {unique_id}", owner_id=other_user.id, is_active=True, - is_verified=True + is_verified=True, ) db.add(shop) db.commit() @@ -347,16 +350,14 @@ def verified_shop(db, other_user): def shop_product(db, test_shop, unique_product): """Create a shop product relationship""" shop_product = ShopProduct( - shop_id=test_shop.id, - product_id=unique_product.id, - is_active=True + shop_id=test_shop.id, product_id=unique_product.id, is_active=True ) # Add optional fields if they exist in your model - if hasattr(ShopProduct, 'price'): + if hasattr(ShopProduct, "price"): shop_product.price = "24.99" - if hasattr(ShopProduct, 'is_featured'): + if hasattr(ShopProduct, "is_featured"): shop_product.is_featured = False - if hasattr(ShopProduct, 'stock_quantity'): + if hasattr(ShopProduct, "stock_quantity"): shop_product.stock_quantity = 10 db.add(shop_product) @@ -382,7 +383,7 @@ def multiple_products(db): marketplace=f"MultiMarket_{i % 2}", # Create 2 different marketplaces shop_name=f"MultiShop_{i}", google_product_category=f"MultiCategory_{i % 2}", # Create 2 different categories - gtin=f"1234567890{i}{unique_id[:2]}" + gtin=f"1234567890{i}{unique_id[:2]}", ) products.append(product) @@ -404,7 +405,7 @@ def multiple_stocks(db, multiple_products, test_shop): location=f"LOC_{i}", quantity=10 + (i * 5), # Different quantities reserved_quantity=i, - shop_id=test_shop.id + shop_id=test_shop.id, ) stocks.append(stock) @@ -422,12 +423,12 @@ def create_unique_product_factory(): def _create_product(db, **kwargs): unique_id = str(uuid.uuid4())[:8] defaults = { - 'product_id': f"FACTORY_{unique_id}", - 'title': f"Factory Product {unique_id}", - 'price': "15.99", - 'currency': "EUR", - 'marketplace': "TestMarket", - 'shop_name': "TestShop" + "product_id": f"FACTORY_{unique_id}", + "title": f"Factory Product {unique_id}", + "price": "15.99", + "currency": "EUR", + "marketplace": "TestMarket", + "shop_name": "TestShop", } defaults.update(kwargs) @@ -452,11 +453,11 @@ def create_unique_shop_factory(): def _create_shop(db, owner_id, **kwargs): unique_id = str(uuid.uuid4())[:8] defaults = { - 'shop_code': f"FACTORY_{unique_id}", - 'shop_name': f"Factory Shop {unique_id}", - 'owner_id': owner_id, - 'is_active': True, - 'is_verified': False + "shop_code": f"FACTORY_{unique_id}", + "shop_name": f"Factory Shop {unique_id}", + "owner_id": owner_id, + "is_active": True, + "is_verified": False, } defaults.update(kwargs) @@ -473,4 +474,3 @@ def create_unique_shop_factory(): def shop_factory(): """Fixture that provides a shop factory function""" return create_unique_shop_factory() - diff --git a/tests/test_admin.py b/tests/test_admin.py index 1288cd8e..0bf46621 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -20,11 +20,16 @@ class TestAdminAPI: response = client.get("/api/v1/admin/users", headers=auth_headers) assert response.status_code == 403 - assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() + assert ( + "Access denied" in response.json()["detail"] + or "admin" in response.json()["detail"].lower() + ) def test_toggle_user_status_admin(self, client, admin_headers, test_user): """Test admin toggling user status""" - response = client.put(f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers) + response = client.put( + f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers + ) assert response.status_code == 200 message = response.json()["message"] @@ -39,9 +44,13 @@ class TestAdminAPI: assert response.status_code == 404 assert "User not found" in response.json()["detail"] - def test_toggle_user_status_cannot_deactivate_self(self, client, admin_headers, test_admin): + def test_toggle_user_status_cannot_deactivate_self( + self, client, admin_headers, test_admin + ): """Test that admin cannot deactivate their own account""" - response = client.put(f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers) + response = client.put( + f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers + ) assert response.status_code == 400 assert "Cannot deactivate your own account" in response.json()["detail"] @@ -56,7 +65,9 @@ class TestAdminAPI: assert len(data["shops"]) >= 1 # Check that test_shop is in the response - shop_codes = [shop["shop_code"] for shop in data["shops"] if "shop_code" in shop] + shop_codes = [ + shop["shop_code"] for shop in data["shops"] if "shop_code" in shop + ] assert test_shop.shop_code in shop_codes def test_get_all_shops_non_admin(self, client, auth_headers): @@ -64,11 +75,16 @@ class TestAdminAPI: response = client.get("/api/v1/admin/shops", headers=auth_headers) assert response.status_code == 403 - assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() + assert ( + "Access denied" in response.json()["detail"] + or "admin" in response.json()["detail"].lower() + ) def test_verify_shop_admin(self, client, admin_headers, test_shop): """Test admin verifying/unverifying shop""" - response = client.put(f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers) + response = client.put( + f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers + ) assert response.status_code == 200 message = response.json()["message"] @@ -84,7 +100,9 @@ class TestAdminAPI: def test_toggle_shop_status_admin(self, client, admin_headers, test_shop): """Test admin toggling shop status""" - response = client.put(f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers) + response = client.put( + f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers + ) assert response.status_code == 200 message = response.json()["message"] @@ -98,9 +116,13 @@ class TestAdminAPI: assert response.status_code == 404 assert "Shop not found" in response.json()["detail"] - def test_get_marketplace_import_jobs_admin(self, client, admin_headers, test_marketplace_job): + def test_get_marketplace_import_jobs_admin( + self, client, admin_headers, test_marketplace_job + ): """Test admin getting marketplace import jobs""" - response = client.get("/api/v1/admin/marketplace-import-jobs", headers=admin_headers) + response = client.get( + "/api/v1/admin/marketplace-import-jobs", headers=admin_headers + ) assert response.status_code == 200 data = response.json() @@ -110,43 +132,58 @@ class TestAdminAPI: job_ids = [job["job_id"] for job in data if "job_id" in job] assert test_marketplace_job.id in job_ids - def test_get_marketplace_import_jobs_with_filters(self, client, admin_headers, test_marketplace_job): + def test_get_marketplace_import_jobs_with_filters( + self, client, admin_headers, test_marketplace_job + ): """Test admin getting marketplace import jobs with filters""" response = client.get( "/api/v1/admin/marketplace-import-jobs", params={"marketplace": test_marketplace_job.marketplace}, - headers=admin_headers + headers=admin_headers, ) assert response.status_code == 200 data = response.json() assert len(data) >= 1 - assert all(job["marketplace"] == test_marketplace_job.marketplace for job in data) + assert all( + job["marketplace"] == test_marketplace_job.marketplace for job in data + ) def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers): """Test non-admin trying to access marketplace import jobs""" - response = client.get("/api/v1/admin/marketplace-import-jobs", headers=auth_headers) + response = client.get( + "/api/v1/admin/marketplace-import-jobs", headers=auth_headers + ) assert response.status_code == 403 - assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() + assert ( + "Access denied" in response.json()["detail"] + or "admin" in response.json()["detail"].lower() + ) def test_admin_pagination_users(self, client, admin_headers, test_user, test_admin): """Test user pagination works correctly""" # Test first page - response = client.get("/api/v1/admin/users?skip=0&limit=1", headers=admin_headers) + response = client.get( + "/api/v1/admin/users?skip=0&limit=1", headers=admin_headers + ) assert response.status_code == 200 data = response.json() assert len(data) == 1 # Test second page - response = client.get("/api/v1/admin/users?skip=1&limit=1", headers=admin_headers) + response = client.get( + "/api/v1/admin/users?skip=1&limit=1", headers=admin_headers + ) assert response.status_code == 200 data = response.json() assert len(data) >= 0 # Could be 1 or 0 depending on total users def test_admin_pagination_shops(self, client, admin_headers, test_shop): """Test shop pagination works correctly""" - response = client.get("/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers) + response = client.get( + "/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 diff --git a/tests/test_admin_service.py b/tests/test_admin_service.py index 3a4228ba..f5574f21 100644 --- a/tests/test_admin_service.py +++ b/tests/test_admin_service.py @@ -1,10 +1,11 @@ # tests/test_admin_service.py -import pytest from datetime import datetime + +import pytest from fastapi import HTTPException from app.services.admin_service import AdminService -from models.database_models import User, Shop, MarketplaceImportJob +from models.database_models import MarketplaceImportJob, Shop, User class TestAdminService: @@ -91,7 +92,7 @@ class TestAdminService: shop_name="Test Shop 2", owner_id=test_shop.owner_id, is_active=True, - is_verified=False + is_verified=False, ) db.add(additional_shop) db.commit() @@ -173,13 +174,17 @@ class TestAdminService: assert len(result) >= 1 # Find our test job in the results - test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None) + test_job = next( + (job for job in result if job.job_id == test_marketplace_job.id), None + ) assert test_job is not None assert test_job.marketplace == test_marketplace_job.marketplace assert test_job.shop_name == test_marketplace_job.shop_name assert test_job.status == test_marketplace_job.status - def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_job, test_user, test_shop): + def test_get_marketplace_import_jobs_with_marketplace_filter( + self, db, test_marketplace_job, test_user, test_shop + ): """Test getting marketplace import jobs filtered by marketplace""" # Create additional job with different marketplace other_job = MarketplaceImportJob( @@ -188,20 +193,24 @@ class TestAdminService: status="completed", source_url="https://ebay.example.com/import", shop_id=test_shop.id, - user_id=test_user.id # Fixed: Added missing user_id + user_id=test_user.id, # Fixed: Added missing user_id ) db.add(other_job) db.commit() # Filter by the test marketplace job's marketplace - result = self.service.get_marketplace_import_jobs(db, marketplace=test_marketplace_job.marketplace) + result = self.service.get_marketplace_import_jobs( + db, marketplace=test_marketplace_job.marketplace + ) assert len(result) >= 1 # All results should match the marketplace filter for job in result: assert test_marketplace_job.marketplace.lower() in job.marketplace.lower() - def test_get_marketplace_import_jobs_with_shop_filter(self, db, test_marketplace_job, test_user, test_shop): + def test_get_marketplace_import_jobs_with_shop_filter( + self, db, test_marketplace_job, test_user, test_shop + ): """Test getting marketplace import jobs filtered by shop name""" # Create additional job with different shop name other_job = MarketplaceImportJob( @@ -210,20 +219,24 @@ class TestAdminService: status="completed", source_url="https://different.example.com/import", shop_id=test_shop.id, - user_id=test_user.id # Fixed: Added missing user_id + user_id=test_user.id, # Fixed: Added missing user_id ) db.add(other_job) db.commit() # Filter by the test marketplace job's shop name - result = self.service.get_marketplace_import_jobs(db, shop_name=test_marketplace_job.shop_name) + result = self.service.get_marketplace_import_jobs( + db, shop_name=test_marketplace_job.shop_name + ) assert len(result) >= 1 # All results should match the shop name filter for job in result: assert test_marketplace_job.shop_name.lower() in job.shop_name.lower() - def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_job, test_user, test_shop): + def test_get_marketplace_import_jobs_with_status_filter( + self, db, test_marketplace_job, test_user, test_shop + ): """Test getting marketplace import jobs filtered by status""" # Create additional job with different status other_job = MarketplaceImportJob( @@ -232,20 +245,24 @@ class TestAdminService: status="pending", source_url="https://pending.example.com/import", shop_id=test_shop.id, - user_id=test_user.id # Fixed: Added missing user_id + user_id=test_user.id, # Fixed: Added missing user_id ) db.add(other_job) db.commit() # Filter by the test marketplace job's status - result = self.service.get_marketplace_import_jobs(db, status=test_marketplace_job.status) + result = self.service.get_marketplace_import_jobs( + db, status=test_marketplace_job.status + ) assert len(result) >= 1 # All results should match the status filter for job in result: assert job.status == test_marketplace_job.status - def test_get_marketplace_import_jobs_with_multiple_filters(self, db, test_marketplace_job, test_shop, test_user): + def test_get_marketplace_import_jobs_with_multiple_filters( + self, db, test_marketplace_job, test_shop, test_user + ): """Test getting marketplace import jobs with multiple filters""" # Create jobs that don't match all filters non_matching_job1 = MarketplaceImportJob( @@ -254,7 +271,7 @@ class TestAdminService: status=test_marketplace_job.status, source_url="https://non-matching1.example.com/import", shop_id=test_shop.id, - user_id=test_user.id # Fixed: Added missing user_id + user_id=test_user.id, # Fixed: Added missing user_id ) non_matching_job2 = MarketplaceImportJob( marketplace=test_marketplace_job.marketplace, @@ -262,7 +279,7 @@ class TestAdminService: status=test_marketplace_job.status, source_url="https://non-matching2.example.com/import", shop_id=test_shop.id, - user_id=test_user.id # Fixed: Added missing user_id + user_id=test_user.id, # Fixed: Added missing user_id ) db.add_all([non_matching_job1, non_matching_job2]) db.commit() @@ -272,12 +289,14 @@ class TestAdminService: db, marketplace=test_marketplace_job.marketplace, shop_name=test_marketplace_job.shop_name, - status=test_marketplace_job.status + status=test_marketplace_job.status, ) assert len(result) >= 1 # Find our test job in the results - test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None) + test_job = next( + (job for job in result if job.job_id == test_marketplace_job.id), None + ) assert test_job is not None assert test_job.marketplace == test_marketplace_job.marketplace assert test_job.shop_name == test_marketplace_job.shop_name @@ -297,7 +316,7 @@ class TestAdminService: updated_count=None, total_processed=None, error_count=None, - error_message=None + error_message=None, ) db.add(job) db.commit() diff --git a/tests/test_auth.py b/tests/test_auth.py index dc94f344..96aef38e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -6,11 +6,14 @@ from fastapi import HTTPException class TestAuthenticationAPI: def test_register_user_success(self, client, db): """Test successful user registration""" - response = client.post("/api/v1/auth/register", json={ - "email": "newuser@example.com", - "username": "newuser", - "password": "securepass123" - }) + response = client.post( + "/api/v1/auth/register", + json={ + "email": "newuser@example.com", + "username": "newuser", + "password": "securepass123", + }, + ) assert response.status_code == 200 data = response.json() @@ -22,32 +25,38 @@ class TestAuthenticationAPI: def test_register_user_duplicate_email(self, client, test_user): """Test registration with duplicate email""" - response = client.post("/api/v1/auth/register", json={ - "email": test_user.email, # Same as test_user - "username": "newuser", - "password": "securepass123" - }) + response = client.post( + "/api/v1/auth/register", + json={ + "email": test_user.email, # Same as test_user + "username": "newuser", + "password": "securepass123", + }, + ) assert response.status_code == 400 assert "Email already registered" in response.json()["detail"] def test_register_user_duplicate_username(self, client, test_user): """Test registration with duplicate username""" - response = client.post("/api/v1/auth/register", json={ - "email": "new@example.com", - "username": test_user.username, # Same as test_user - "password": "securepass123" - }) + response = client.post( + "/api/v1/auth/register", + json={ + "email": "new@example.com", + "username": test_user.username, # Same as test_user + "password": "securepass123", + }, + ) assert response.status_code == 400 assert "Username already taken" in response.json()["detail"] def test_login_success(self, client, test_user): """Test successful login""" - response = client.post("/api/v1/auth/login", json={ - "username": test_user.username, - "password": "testpass123" - }) + response = client.post( + "/api/v1/auth/login", + json={"username": test_user.username, "password": "testpass123"}, + ) assert response.status_code == 200 data = response.json() @@ -58,20 +67,20 @@ class TestAuthenticationAPI: def test_login_wrong_password(self, client, test_user): """Test login with wrong password""" - response = client.post("/api/v1/auth/login", json={ - "username": "testuser", - "password": "wrongpassword" - }) + response = client.post( + "/api/v1/auth/login", + json={"username": "testuser", "password": "wrongpassword"}, + ) assert response.status_code == 401 assert "Incorrect username or password" in response.json()["detail"] def test_login_nonexistent_user(self, client, db): # Added db fixture """Test login with nonexistent user""" - response = client.post("/api/v1/auth/login", json={ - "username": "nonexistent", - "password": "password123" - }) + response = client.post( + "/api/v1/auth/login", + json={"username": "nonexistent", "password": "password123"}, + ) assert response.status_code == 401 diff --git a/tests/test_auth_service.py b/tests/test_auth_service.py index 72852548..291a388f 100644 --- a/tests/test_auth_service.py +++ b/tests/test_auth_service.py @@ -3,8 +3,8 @@ import pytest from fastapi import HTTPException from app.services.auth_service import AuthService +from models.api_models import UserLogin, UserRegister from models.database_models import User -from models.api_models import UserRegister, UserLogin class TestAuthService: @@ -17,9 +17,7 @@ class TestAuthService: def test_register_user_success(self, db): """Test successful user registration""" user_data = UserRegister( - email="newuser@example.com", - username="newuser123", - password="securepass123" + email="newuser@example.com", username="newuser123", password="securepass123" ) user = self.service.register_user(db, user_data) @@ -36,7 +34,7 @@ class TestAuthService: user_data = UserRegister( email=test_user.email, # Use existing email username="differentuser", - password="securepass123" + password="securepass123", ) with pytest.raises(HTTPException) as exc_info: @@ -50,7 +48,7 @@ class TestAuthService: user_data = UserRegister( email="different@example.com", username=test_user.username, # Use existing username - password="securepass123" + password="securepass123", ) with pytest.raises(HTTPException) as exc_info: @@ -62,8 +60,7 @@ class TestAuthService: def test_login_user_success(self, db, test_user): """Test successful user login""" user_credentials = UserLogin( - username=test_user.username, - password="testpass123" + username=test_user.username, password="testpass123" ) result = self.service.login_user(db, user_credentials) @@ -78,10 +75,7 @@ class TestAuthService: def test_login_user_wrong_username(self, db): """Test login fails with wrong username""" - user_credentials = UserLogin( - username="nonexistentuser", - password="testpass123" - ) + user_credentials = UserLogin(username="nonexistentuser", password="testpass123") with pytest.raises(HTTPException) as exc_info: self.service.login_user(db, user_credentials) @@ -92,8 +86,7 @@ class TestAuthService: def test_login_user_wrong_password(self, db, test_user): """Test login fails with wrong password""" user_credentials = UserLogin( - username=test_user.username, - password="wrongpassword" + username=test_user.username, password="wrongpassword" ) with pytest.raises(HTTPException) as exc_info: @@ -109,8 +102,7 @@ class TestAuthService: db.commit() user_credentials = UserLogin( - username=test_user.username, - password="testpass123" + username=test_user.username, password="testpass123" ) with pytest.raises(HTTPException) as exc_info: diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index e9640172..5edf0153 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -1,9 +1,11 @@ # tests/test_background_tasks.py +from datetime import datetime +from unittest.mock import AsyncMock, patch + import pytest -from unittest.mock import patch, AsyncMock + from app.tasks.background_tasks import process_marketplace_import from models.database_models import MarketplaceImportJob -from datetime import datetime class TestBackgroundTasks: @@ -17,7 +19,7 @@ class TestBackgroundTasks: shop_name="TESTSHOP", marketplace="TestMarket", shop_id=test_shop.id, - user_id=test_user.id + user_id=test_user.id, ) db.add(job) db.commit() @@ -27,29 +29,30 @@ class TestBackgroundTasks: job_id = job.id # Mock CSV processor and prevent session from closing - with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ - patch('app.tasks.background_tasks.SessionLocal', return_value=db): + with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch( + "app.tasks.background_tasks.SessionLocal", return_value=db + ): mock_instance = mock_processor.return_value - mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ - "imported": 10, - "updated": 5, - "total_processed": 15, - "errors": 0 - }) + mock_instance.process_marketplace_csv_from_url = AsyncMock( + return_value={ + "imported": 10, + "updated": 5, + "total_processed": 15, + "errors": 0, + } + ) # Run background task await process_marketplace_import( - job_id, - "http://example.com/test.csv", - "TestMarket", - "TESTSHOP", - 1000 + job_id, "http://example.com/test.csv", "TestMarket", "TESTSHOP", 1000 ) # Re-query the job using the stored ID - updated_job = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.id == job_id - ).first() + updated_job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) assert updated_job is not None assert updated_job.status == "completed" @@ -70,7 +73,7 @@ class TestBackgroundTasks: shop_name="TESTSHOP", marketplace="TestMarket", shop_id=test_shop.id, - user_id=test_user.id + user_id=test_user.id, ) db.add(job) db.commit() @@ -80,8 +83,9 @@ class TestBackgroundTasks: job_id = job.id # Mock CSV processor to raise exception - with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ - patch('app.tasks.background_tasks.SessionLocal', return_value=db): + with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch( + "app.tasks.background_tasks.SessionLocal", return_value=db + ): mock_instance = mock_processor.return_value mock_instance.process_marketplace_csv_from_url = AsyncMock( @@ -96,7 +100,7 @@ class TestBackgroundTasks: "http://example.com/test.csv", "TestMarket", "TESTSHOP", - 1000 + 1000, ) except Exception: # The background task should handle exceptions internally @@ -104,9 +108,11 @@ class TestBackgroundTasks: pass # Re-query the job using the stored ID - updated_job = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.id == job_id - ).first() + updated_job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) assert updated_job is not None assert updated_job.status == "failed" @@ -115,15 +121,18 @@ class TestBackgroundTasks: @pytest.mark.asyncio async def test_marketplace_import_job_not_found(self, db): """Test handling when import job doesn't exist""" - with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ - patch('app.tasks.background_tasks.SessionLocal', return_value=db): + with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch( + "app.tasks.background_tasks.SessionLocal", return_value=db + ): mock_instance = mock_processor.return_value - mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ - "imported": 10, - "updated": 5, - "total_processed": 15, - "errors": 0 - }) + mock_instance.process_marketplace_csv_from_url = AsyncMock( + return_value={ + "imported": 10, + "updated": 5, + "total_processed": 15, + "errors": 0, + } + ) # Run background task with non-existent job ID await process_marketplace_import( @@ -131,7 +140,7 @@ class TestBackgroundTasks: "http://example.com/test.csv", "TestMarket", "TESTSHOP", - 1000 + 1000, ) # Should not raise an exception, just log and return @@ -148,7 +157,7 @@ class TestBackgroundTasks: shop_name="TESTSHOP", marketplace="TestMarket", shop_id=test_shop.id, - user_id=test_user.id + user_id=test_user.id, ) db.add(job) db.commit() @@ -158,29 +167,30 @@ class TestBackgroundTasks: job_id = job.id # Mock CSV processor with some errors - with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ - patch('app.tasks.background_tasks.SessionLocal', return_value=db): + with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch( + "app.tasks.background_tasks.SessionLocal", return_value=db + ): mock_instance = mock_processor.return_value - mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ - "imported": 8, - "updated": 5, - "total_processed": 15, - "errors": 2 - }) + mock_instance.process_marketplace_csv_from_url = AsyncMock( + return_value={ + "imported": 8, + "updated": 5, + "total_processed": 15, + "errors": 2, + } + ) # Run background task await process_marketplace_import( - job_id, - "http://example.com/test.csv", - "TestMarket", - "TESTSHOP", - 1000 + job_id, "http://example.com/test.csv", "TestMarket", "TESTSHOP", 1000 ) # Re-query the job using the stored ID - updated_job = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.id == job_id - ).first() + updated_job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) assert updated_job is not None assert updated_job.status == "completed_with_errors" diff --git a/tests/test_csv_processor.py b/tests/test_csv_processor.py index e17f90ca..c1401cc8 100644 --- a/tests/test_csv_processor.py +++ b/tests/test_csv_processor.py @@ -1,9 +1,11 @@ # tests/test_csv_processor.py +from unittest.mock import Mock, patch + +import pandas as pd import pytest import requests import requests.exceptions -from unittest.mock import Mock, patch -import pandas as pd + from utils.csv_processor import CSVProcessor @@ -11,7 +13,7 @@ class TestCSVProcessor: def setup_method(self): self.processor = CSVProcessor() - @patch('requests.get') + @patch("requests.get") def test_download_csv_encoding_fallback(self, mock_get): """Test CSV download with encoding fallback""" # Create content with special characters that would fail UTF-8 if not properly encoded @@ -20,7 +22,7 @@ class TestCSVProcessor: mock_response = Mock() mock_response.status_code = 200 # Use latin-1 encoding which your method should try - mock_response.content = special_content.encode('latin-1') + mock_response.content = special_content.encode("latin-1") mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response @@ -30,14 +32,16 @@ class TestCSVProcessor: assert isinstance(csv_content, str) assert "Café Product" in csv_content - @patch('requests.get') + @patch("requests.get") def test_download_csv_encoding_ignore_fallback(self, mock_get): """Test CSV download falls back to UTF-8 with error ignoring""" # Create problematic bytes that would fail most encoding attempts mock_response = Mock() mock_response.status_code = 200 # Create bytes that will fail most encodings - mock_response.content = b"product_id,title,price\nTEST001,\xff\xfe Product,10.99" + mock_response.content = ( + b"product_id,title,price\nTEST001,\xff\xfe Product,10.99" + ) mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response @@ -49,7 +53,7 @@ class TestCSVProcessor: assert "product_id,title,price" in csv_content assert "TEST001" in csv_content - @patch('requests.get') + @patch("requests.get") def test_download_csv_request_exception(self, mock_get): """Test CSV download with request exception""" mock_get.side_effect = requests.exceptions.RequestException("Connection error") @@ -57,18 +61,20 @@ class TestCSVProcessor: with pytest.raises(requests.exceptions.RequestException): self.processor.download_csv("http://example.com/test.csv") - @patch('requests.get') + @patch("requests.get") def test_download_csv_http_error(self, mock_get): """Test CSV download with HTTP error""" mock_response = Mock() mock_response.status_code = 404 - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found") + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "404 Not Found" + ) mock_get.return_value = mock_response with pytest.raises(requests.exceptions.HTTPError): self.processor.download_csv("http://example.com/nonexistent.csv") - @patch('requests.get') + @patch("requests.get") def test_download_csv_failure(self, mock_get): """Test CSV download failure""" # Mock failed HTTP response @@ -95,26 +101,24 @@ TEST002,Test Product 2,15.99,TestMarket""" @pytest.mark.asyncio async def test_process_marketplace_csv_from_url(self, db): """Test complete marketplace CSV processing""" - with patch.object(self.processor, 'download_csv') as mock_download, \ - patch.object(self.processor, 'parse_csv') as mock_parse: + with patch.object( + self.processor, "download_csv" + ) as mock_download, patch.object(self.processor, "parse_csv") as mock_parse: # Mock successful download and parsing mock_download.return_value = "csv_content" - mock_df = pd.DataFrame({ - "product_id": ["TEST001", "TEST002"], - "title": ["Product 1", "Product 2"], - "price": ["10.99", "15.99"], - "marketplace": ["TestMarket", "TestMarket"], - "shop_name": ["TestShop", "TestShop"] - }) + mock_df = pd.DataFrame( + { + "product_id": ["TEST001", "TEST002"], + "title": ["Product 1", "Product 2"], + "price": ["10.99", "15.99"], + "marketplace": ["TestMarket", "TestMarket"], + "shop_name": ["TestShop", "TestShop"], + } + ) mock_parse.return_value = mock_df - result = await self.processor.process_marketplace_csv_from_url( - "http://example.com/test.csv", - "TestMarket", - "TestShop", - 1000, - db + "http://example.com/test.csv", "TestMarket", "TestShop", 1000, db ) assert "imported" in result diff --git a/tests/test_data_validation.py b/tests/test_data_validation.py index a94807db..2b39133e 100644 --- a/tests/test_data_validation.py +++ b/tests/test_data_validation.py @@ -1,5 +1,6 @@ # tests/test_data_validation.py import pytest + from utils.data_processing import GTINProcessor, PriceProcessor diff --git a/tests/test_database.py b/tests/test_database.py index 9a898415..f2894a77 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,7 +1,8 @@ # tests/test_database.py import pytest from sqlalchemy import text -from models.database_models import User, Product, Stock, Shop + +from models.database_models import Product, Shop, Stock, User class TestDatabaseModels: @@ -12,7 +13,7 @@ class TestDatabaseModels: username="dbtest", hashed_password="hashed_password_123", role="user", - is_active=True + is_active=True, ) db.add(user) @@ -36,7 +37,7 @@ class TestDatabaseModels: gtin="1234567890123", availability="in stock", marketplace="TestDB", - shop_name="DBTestShop" + shop_name="DBTestShop", ) db.add(product) @@ -49,11 +50,7 @@ class TestDatabaseModels: def test_stock_model(self, db): """Test Stock model creation""" - stock = Stock( - gtin="1234567890123", - location="DB_WAREHOUSE", - quantity=150 - ) + stock = Stock(gtin="1234567890123", location="DB_WAREHOUSE", quantity=150) db.add(stock) db.commit() @@ -72,7 +69,7 @@ class TestDatabaseModels: description="Testing shop model", owner_id=test_user.id, is_active=True, - is_verified=False + is_verified=False, ) db.add(shop) diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 88aff8c8..66951aa0 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -5,24 +5,25 @@ import pytest class TestErrorHandling: def test_invalid_json(self, client, auth_headers): """Test handling of invalid JSON""" - response = client.post("/api/v1/product", - headers=auth_headers, - content="invalid json") + response = client.post( + "/api/v1/product", headers=auth_headers, content="invalid json" + ) assert response.status_code == 422 # Validation error def test_missing_required_fields(self, client, auth_headers): """Test handling of missing required fields""" - response = client.post("/api/v1/product", - headers=auth_headers, - json={"title": "Test"}) # Missing product_id + response = client.post( + "/api/v1/product", headers=auth_headers, json={"title": "Test"} + ) # Missing product_id assert response.status_code == 422 def test_invalid_authentication(self, client): """Test handling of invalid authentication""" - response = client.get("/api/v1/product", - headers={"Authorization": "Bearer invalid_token"}) + response = client.get( + "/api/v1/product", headers={"Authorization": "Bearer invalid_token"} + ) assert response.status_code == 401 # Token is not valid @@ -38,8 +39,10 @@ class TestErrorHandling: """Test handling of duplicate resource creation""" product_data = { "product_id": test_product.product_id, # Duplicate ID - "title": "Another Product" + "title": "Another Product", } - response = client.post("/api/v1/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/product", headers=auth_headers, json=product_data + ) assert response.status_code == 400 diff --git a/tests/test_export.py b/tests/test_export.py index 14255c7e..a1dd0d89 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -1,8 +1,9 @@ # tests/test_export.py -import pytest import csv from io import StringIO +import pytest + from models.database_models import Product @@ -15,7 +16,7 @@ class TestExportFunctionality: assert response.headers["content-type"] == "text/csv; charset=utf-8" # Parse CSV content - csv_content = response.content.decode('utf-8') + csv_content = response.content.decode("utf-8") csv_reader = csv.reader(StringIO(csv_content)) # Check header row @@ -35,10 +36,12 @@ class TestExportFunctionality: db.add_all(products) db.commit() - response = client.get("/api/v1/export-csv?marketplace=Amazon", headers=auth_headers) + response = client.get( + "/api/v1/export-csv?marketplace=Amazon", headers=auth_headers + ) assert response.status_code == 200 - csv_content = response.content.decode('utf-8') + csv_content = response.content.decode("utf-8") assert "EXP1" in csv_content assert "EXP2" not in csv_content # Should be filtered out @@ -50,7 +53,7 @@ class TestExportFunctionality: product = Product( product_id=f"PERF{i:04d}", title=f"Performance Product {i}", - marketplace="Performance" + marketplace="Performance", ) products.append(product) @@ -58,10 +61,10 @@ class TestExportFunctionality: db.commit() import time + start_time = time.time() response = client.get("/api/v1/export-csv", headers=auth_headers) end_time = time.time() assert response.status_code == 200 assert end_time - start_time < 10.0 # Should complete within 10 seconds - \ No newline at end of file diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 0afe686b..fbfd8623 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -1,5 +1,6 @@ # tests/test_filtering.py import pytest + from models.database_models import Product @@ -39,7 +40,9 @@ class TestFiltering: db.add_all(products) db.commit() - response = client.get("/api/v1/product?marketplace=Amazon", headers=auth_headers) + response = client.get( + "/api/v1/product?marketplace=Amazon", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] == 2 @@ -47,9 +50,17 @@ class TestFiltering: def test_product_search_filter(self, client, auth_headers, db): """Test searching products by text""" products = [ - Product(product_id="SEARCH1", title="Apple iPhone", description="Smartphone"), - Product(product_id="SEARCH2", title="Samsung Galaxy", description="Android phone"), - Product(product_id="SEARCH3", title="iPad Tablet", description="Apple tablet"), + Product( + product_id="SEARCH1", title="Apple iPhone", description="Smartphone" + ), + Product( + product_id="SEARCH2", + title="Samsung Galaxy", + description="Android phone", + ), + Product( + product_id="SEARCH3", title="iPad Tablet", description="Apple tablet" + ), ] db.add_all(products) @@ -70,16 +81,33 @@ class TestFiltering: def test_combined_filters(self, client, auth_headers, db): """Test combining multiple filters""" products = [ - Product(product_id="COMBO1", title="Apple iPhone", brand="Apple", marketplace="Amazon"), - Product(product_id="COMBO2", title="Apple iPad", brand="Apple", marketplace="eBay"), - Product(product_id="COMBO3", title="Samsung Phone", brand="Samsung", marketplace="Amazon"), + Product( + product_id="COMBO1", + title="Apple iPhone", + brand="Apple", + marketplace="Amazon", + ), + Product( + product_id="COMBO2", + title="Apple iPad", + brand="Apple", + marketplace="eBay", + ), + Product( + product_id="COMBO3", + title="Samsung Phone", + brand="Samsung", + marketplace="Amazon", + ), ] db.add_all(products) db.commit() # Filter by brand AND marketplace - response = client.get("/api/v1/product?brand=Apple&marketplace=Amazon", headers=auth_headers) + response = client.get( + "/api/v1/product?brand=Apple&marketplace=Amazon", headers=auth_headers + ) assert response.status_code == 200 data = response.json() assert data["total"] == 1 # Only iPhone matches both diff --git a/tests/test_integration.py b/tests/test_integration.py index dd7e65d7..fc82ced3 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -14,10 +14,12 @@ class TestIntegrationFlows: "brand": "FlowBrand", "gtin": "1111222233334", "availability": "in stock", - "marketplace": "TestFlow" + "marketplace": "TestFlow", } - response = client.post("/api/v1/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/product", headers=auth_headers, json=product_data + ) assert response.status_code == 200 product = response.json() @@ -25,26 +27,33 @@ class TestIntegrationFlows: stock_data = { "gtin": product["gtin"], "location": "MAIN_WAREHOUSE", - "quantity": 50 + "quantity": 50, } response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) assert response.status_code == 200 # 3. Get product with stock info - response = client.get(f"/api/v1/product/{product['product_id']}", headers=auth_headers) + response = client.get( + f"/api/v1/product/{product['product_id']}", headers=auth_headers + ) assert response.status_code == 200 product_detail = response.json() assert product_detail["stock_info"]["total_quantity"] == 50 # 4. Update product update_data = {"title": "Updated Integration Test Product"} - response = client.put(f"/api/v1/product/{product['product_id']}", - headers=auth_headers, json=update_data) + response = client.put( + f"/api/v1/product/{product['product_id']}", + headers=auth_headers, + json=update_data, + ) assert response.status_code == 200 # 5. Search for product - response = client.get("/api/v1/product?search=Updated Integration", headers=auth_headers) + response = client.get( + "/api/v1/product?search=Updated Integration", headers=auth_headers + ) assert response.status_code == 200 assert response.json()["total"] == 1 @@ -54,7 +63,7 @@ class TestIntegrationFlows: shop_data = { "shop_code": "FLOWSHOP", "shop_name": "Integration Flow Shop", - "description": "Test shop for integration" + "description": "Test shop for integration", } response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data) @@ -66,10 +75,12 @@ class TestIntegrationFlows: "product_id": "SHOPFLOW001", "title": "Shop Flow Product", "price": "15.99", - "marketplace": "ShopFlow" + "marketplace": "ShopFlow", } - response = client.post("/api/v1/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/product", headers=auth_headers, json=product_data + ) assert response.status_code == 200 product = response.json() @@ -86,28 +97,28 @@ class TestIntegrationFlows: location = "TEST_WAREHOUSE" # 1. Set initial stock - response = client.post("/api/v1/stock", headers=auth_headers, json={ - "gtin": gtin, - "location": location, - "quantity": 100 - }) + response = client.post( + "/api/v1/stock", + headers=auth_headers, + json={"gtin": gtin, "location": location, "quantity": 100}, + ) assert response.status_code == 200 # 2. Add more stock - response = client.post("/api/v1/stock/add", headers=auth_headers, json={ - "gtin": gtin, - "location": location, - "quantity": 25 - }) + response = client.post( + "/api/v1/stock/add", + headers=auth_headers, + json={"gtin": gtin, "location": location, "quantity": 25}, + ) assert response.status_code == 200 assert response.json()["quantity"] == 125 # 3. Remove some stock - response = client.post("/api/v1/stock/remove", headers=auth_headers, json={ - "gtin": gtin, - "location": location, - "quantity": 30 - }) + response = client.post( + "/api/v1/stock/remove", + headers=auth_headers, + json={"gtin": gtin, "location": location, "quantity": 30}, + ) assert response.status_code == 200 assert response.json()["quantity"] == 95 diff --git a/tests/test_marketplace.py b/tests/test_marketplace.py index 1f473739..446ef3d3 100644 --- a/tests/test_marketplace.py +++ b/tests/test_marketplace.py @@ -1,6 +1,7 @@ # tests/test_marketplace.py +from unittest.mock import AsyncMock, patch + import pytest -from unittest.mock import patch, AsyncMock class TestMarketplaceAPI: @@ -10,11 +11,12 @@ class TestMarketplaceAPI: import_data = { "url": "https://example.com/products.csv", "marketplace": "TestMarket", - "shop_code": test_shop.shop_code + "shop_code": test_shop.shop_code, } - response = client.post("/api/v1/marketplace/import-product", - headers=auth_headers, json=import_data) + response = client.post( + "/api/v1/marketplace/import-product", headers=auth_headers, json=import_data + ) assert response.status_code == 200 data = response.json() @@ -29,11 +31,12 @@ class TestMarketplaceAPI: import_data = { "url": "https://example.com/products.csv", "marketplace": "TestMarket", - "shop_code": "NONEXISTENT" + "shop_code": "NONEXISTENT", } - response = client.post("/api/v1/marketplace/import-product", - headers=auth_headers, json=import_data) + response = client.post( + "/api/v1/marketplace/import-product", headers=auth_headers, json=import_data + ) assert response.status_code == 404 assert "Shop not found" in response.json()["detail"] @@ -49,4 +52,3 @@ class TestMarketplaceAPI: """Test that marketplace endpoints require authentication""" response = client.get("/api/v1/marketplace/import-jobs") assert response.status_code == 401 # No authorization header - diff --git a/tests/test_marketplace_service.py b/tests/test_marketplace_service.py index 236887ac..e3652dae 100644 --- a/tests/test_marketplace_service.py +++ b/tests/test_marketplace_service.py @@ -1,10 +1,12 @@ # tests/test_marketplace_service.py -import pytest import uuid +from datetime import datetime + +import pytest + from app.services.marketplace_service import MarketplaceService from models.api_models import MarketplaceImportRequest from models.database_models import MarketplaceImportJob, Shop, User -from datetime import datetime class TestMarketplaceService: @@ -22,7 +24,9 @@ class TestMarketplaceService: assert result.shop_code == test_shop.shop_code assert result.owner_id == test_user.id - def test_validate_shop_access_admin_can_access_any_shop(self, db, test_shop, test_admin): + def test_validate_shop_access_admin_can_access_any_shop( + self, db, test_shop, test_admin + ): """Test that admin users can access any shop""" result = self.service.validate_shop_access(db, test_shop.shop_code, test_admin) @@ -33,7 +37,9 @@ class TestMarketplaceService: with pytest.raises(ValueError, match="Shop not found"): self.service.validate_shop_access(db, "NONEXISTENT", test_user) - def test_validate_shop_access_permission_denied(self, db, test_shop, test_user, other_user): + def test_validate_shop_access_permission_denied( + self, db, test_shop, test_user, other_user + ): """Test shop access validation when user doesn't own the shop""" # Set the shop owner to a different user test_shop.owner_id = other_user.id @@ -52,7 +58,7 @@ class TestMarketplaceService: url="https://example.com/products.csv", marketplace="Amazon", shop_code=test_shop.shop_code, - batch_size=1000 + batch_size=1000, ) result = self.service.create_import_job(db, request, test_user) @@ -60,7 +66,7 @@ class TestMarketplaceService: assert result.marketplace == "Amazon" # Check the correct field based on your model assert result.shop_id == test_shop.id # Changed from shop_code to shop_id - assert result.user_id == test_user.id if hasattr(result, 'user_id') else True + assert result.user_id == test_user.id if hasattr(result, "user_id") else True assert result.status == "pending" assert result.source_url == "https://example.com/products.csv" @@ -70,7 +76,7 @@ class TestMarketplaceService: url="https://example.com/products.csv", marketplace="Amazon", shop_code="INVALID_SHOP", - batch_size=1000 + batch_size=1000, ) with pytest.raises(ValueError, match="Shop not found"): @@ -78,16 +84,22 @@ class TestMarketplaceService: def test_get_import_job_by_id_success(self, db, test_marketplace_job, test_user): """Test getting import job by ID for job owner""" - result = self.service.get_import_job_by_id(db, test_marketplace_job.id, test_user) + result = self.service.get_import_job_by_id( + db, test_marketplace_job.id, test_user + ) assert result.id == test_marketplace_job.id # Check user_id if the field exists - if hasattr(result, 'user_id'): + if hasattr(result, "user_id"): assert result.user_id == test_user.id - def test_get_import_job_by_id_admin_access(self, db, test_marketplace_job, test_admin): + def test_get_import_job_by_id_admin_access( + self, db, test_marketplace_job, test_admin + ): """Test that admin can access any import job""" - result = self.service.get_import_job_by_id(db, test_marketplace_job.id, test_admin) + result = self.service.get_import_job_by_id( + db, test_marketplace_job.id, test_admin + ) assert result.id == test_marketplace_job.id @@ -96,7 +108,9 @@ class TestMarketplaceService: with pytest.raises(ValueError, match="Marketplace import job not found"): self.service.get_import_job_by_id(db, 99999, test_user) - def test_get_import_job_by_id_access_denied(self, db, test_marketplace_job, other_user): + def test_get_import_job_by_id_access_denied( + self, db, test_marketplace_job, other_user + ): """Test access denied when user doesn't own the job""" with pytest.raises(PermissionError, match="Access denied to this import job"): self.service.get_import_job_by_id(db, test_marketplace_job.id, other_user) @@ -108,7 +122,7 @@ class TestMarketplaceService: assert len(jobs) >= 1 assert any(job.id == test_marketplace_job.id for job in jobs) # Check user_id if the field exists - if hasattr(test_marketplace_job, 'user_id'): + if hasattr(test_marketplace_job, "user_id"): assert test_marketplace_job.user_id == test_user.id def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_job, test_admin): @@ -118,7 +132,9 @@ class TestMarketplaceService: assert len(jobs) >= 1 assert any(job.id == test_marketplace_job.id for job in jobs) - def test_get_import_jobs_with_marketplace_filter(self, db, test_marketplace_job, test_user): + def test_get_import_jobs_with_marketplace_filter( + self, db, test_marketplace_job, test_user + ): """Test getting import jobs with marketplace filter""" jobs = self.service.get_import_jobs( db, test_user, marketplace=test_marketplace_job.marketplace @@ -143,7 +159,7 @@ class TestMarketplaceService: imported_count=0, updated_count=0, total_processed=0, - error_count=0 + error_count=0, ) db.add(job) db.commit() @@ -159,7 +175,7 @@ class TestMarketplaceService: test_marketplace_job.id, "completed", imported_count=100, - total_processed=100 + total_processed=100, ) assert result.status == "completed" @@ -211,7 +227,7 @@ class TestMarketplaceService: imported_count=0, updated_count=0, total_processed=0, - error_count=0 + error_count=0, ) db.add(job) db.commit() @@ -222,13 +238,17 @@ class TestMarketplaceService: assert result.status == "cancelled" assert result.completed_at is not None - def test_cancel_import_job_invalid_status(self, db, test_marketplace_job, test_user): + def test_cancel_import_job_invalid_status( + self, db, test_marketplace_job, test_user + ): """Test cancelling a job that can't be cancelled""" # Set job status to completed test_marketplace_job.status = "completed" db.commit() - with pytest.raises(ValueError, match="Cannot cancel job with status: completed"): + with pytest.raises( + ValueError, match="Cannot cancel job with status: completed" + ): self.service.cancel_import_job(db, test_marketplace_job.id, test_user) def test_delete_import_job_success(self, db, test_user, test_shop): @@ -246,7 +266,7 @@ class TestMarketplaceService: imported_count=0, updated_count=0, total_processed=0, - error_count=0 + error_count=0, ) db.add(job) db.commit() @@ -258,7 +278,11 @@ class TestMarketplaceService: assert result is True # Verify the job is actually deleted - deleted_job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() + deleted_job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) assert deleted_job is None def test_delete_import_job_invalid_status(self, db, test_user, test_shop): @@ -276,7 +300,7 @@ class TestMarketplaceService: imported_count=0, updated_count=0, total_processed=0, - error_count=0 + error_count=0, ) db.add(job) db.commit() diff --git a/tests/test_middleware.py b/tests/test_middleware.py index ee58a72d..f224eb13 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,8 +1,10 @@ # tests/test_middleware.py -import pytest from unittest.mock import Mock, patch -from middleware.rate_limiter import RateLimiter + +import pytest + from middleware.auth import AuthManager +from middleware.rate_limiter import RateLimiter class TestRateLimiter: @@ -12,11 +14,17 @@ class TestRateLimiter: client_id = "test_client" # Should allow first request - assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True + assert ( + limiter.allow_request(client_id, max_requests=10, window_seconds=3600) + is True + ) # Should allow subsequent requests within limit for _ in range(5): - assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True + assert ( + limiter.allow_request(client_id, max_requests=10, window_seconds=3600) + is True + ) def test_rate_limiter_blocks_excess_requests(self): """Test rate limiter blocks requests exceeding limit""" diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 69979d5f..d2d28eda 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,5 +1,6 @@ # tests/test_pagination.py import pytest + from models.database_models import Product @@ -12,7 +13,7 @@ class TestPagination: product = Product( product_id=f"PAGE{i:03d}", title=f"Pagination Test Product {i}", - marketplace="PaginationTest" + marketplace="PaginationTest", ) products.append(product) diff --git a/tests/test_performance.py b/tests/test_performance.py index 23eba39a..f71eade4 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -1,7 +1,8 @@ # tests/test_performance.py -import pytest import time +import pytest + from models.database_models import Product @@ -15,7 +16,7 @@ class TestPerformance: product_id=f"PERF{i:03d}", title=f"Performance Test Product {i}", price=f"{i}.99", - marketplace="Performance" + marketplace="Performance", ) products.append(product) @@ -41,7 +42,7 @@ class TestPerformance: title=f"Searchable Product {i}", description=f"This is a searchable product number {i}", brand="SearchBrand", - marketplace="SearchMarket" + marketplace="SearchMarket", ) products.append(product) diff --git a/tests/test_product.py b/tests/test_product.py index d192faf3..ddf70aab 100644 --- a/tests/test_product.py +++ b/tests/test_product.py @@ -31,7 +31,9 @@ class TestProductsAPI: assert response.json()["total"] == 1 # Test marketplace filter - response = client.get("/api/v1/product?marketplace=Letzshop", headers=auth_headers) + response = client.get( + "/api/v1/product?marketplace=Letzshop", headers=auth_headers + ) assert response.status_code == 200 assert response.json()["total"] == 1 @@ -50,10 +52,12 @@ class TestProductsAPI: "brand": "NewBrand", "gtin": "9876543210987", "availability": "in stock", - "marketplace": "Amazon" + "marketplace": "Amazon", } - response = client.post("/api/v1/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/product", headers=auth_headers, json=product_data + ) assert response.status_code == 200 data = response.json() @@ -71,10 +75,12 @@ class TestProductsAPI: "brand": "NewBrand", "gtin": "9876543210987", "availability": "in stock", - "marketplace": "Amazon" + "marketplace": "Amazon", } - response = client.post("/api/v1/product", headers=auth_headers, json=product_data) + response = client.post( + "/api/v1/product", headers=auth_headers, json=product_data + ) # Debug output print(f"Status Code: {response.status_code}") @@ -89,7 +95,9 @@ class TestProductsAPI: def test_get_product_by_id(self, client, auth_headers, test_product): """Test getting specific product""" - response = client.get(f"/api/v1/product/{test_product.product_id}", headers=auth_headers) + response = client.get( + f"/api/v1/product/{test_product.product_id}", headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -104,15 +112,12 @@ class TestProductsAPI: def test_update_product(self, client, auth_headers, test_product): """Test updating product""" - update_data = { - "title": "Updated Product Title", - "price": "25.99" - } + update_data = {"title": "Updated Product Title", "price": "25.99"} response = client.put( f"/api/v1/product/{test_product.product_id}", headers=auth_headers, - json=update_data + json=update_data, ) assert response.status_code == 200 @@ -123,8 +128,7 @@ class TestProductsAPI: def test_delete_product(self, client, auth_headers, test_product): """Test deleting product""" response = client.delete( - f"/api/v1/product/{test_product.product_id}", - headers=auth_headers + f"/api/v1/product/{test_product.product_id}", headers=auth_headers ) assert response.status_code == 200 diff --git a/tests/test_product_service.py b/tests/test_product_service.py index a25eb5ae..e6a68d7c 100644 --- a/tests/test_product_service.py +++ b/tests/test_product_service.py @@ -1,5 +1,6 @@ # tests/test_product_service.py import pytest + from app.services.product_service import ProductService from models.api_models import ProductCreate from models.database_models import Product @@ -16,7 +17,7 @@ class TestProductService: title="Service Test Product", gtin="1234567890123", price="19.99", - marketplace="TestMarket" + marketplace="TestMarket", ) product = self.service.create_product(db, product_data) @@ -31,7 +32,7 @@ class TestProductService: product_id="SVC002", title="Service Test Product", gtin="invalid_gtin", - price="19.99" + price="19.99", ) with pytest.raises(ValueError, match="Invalid GTIN format"): @@ -39,10 +40,7 @@ class TestProductService: def test_get_products_with_filters(self, db, test_product): """Test getting products with various filters""" - products, total = self.service.get_products_with_filters( - db, - brand="TestBrand" - ) + products, total = self.service.get_products_with_filters(db, brand="TestBrand") assert total == 1 assert len(products) == 1 @@ -51,10 +49,8 @@ class TestProductService: def test_get_products_with_search(self, db, test_product): """Test getting products with search""" products, total = self.service.get_products_with_filters( - db, - search="Test Product" + db, search="Test Product" ) assert total == 1 assert len(products) == 1 - diff --git a/tests/test_security.py b/tests/test_security.py index 5f9a218f..dc1ac162 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,7 +1,8 @@ # tests/test_security.py +from unittest.mock import patch + import pytest from fastapi import HTTPException -from unittest.mock import patch class TestSecurity: @@ -10,7 +11,9 @@ class TestSecurity: response = client.get("/api/v1/debug-bearer") print(f"Direct Bearer - Status: {response.status_code}") - print(f"Direct Bearer - Response: {response.json() if response.content else 'No content'}") + print( + f"Direct Bearer - Response: {response.json() if response.content else 'No content'}" + ) def test_debug_dependencies(self, client): """Debug the dependency chain step by step""" @@ -24,7 +27,9 @@ class TestSecurity: print(f"Admin endpoint - Raw: {response.content}") # Test 2: Try a regular endpoint that uses get_current_user - response2 = client.get("/api/v1/product") # or any endpoint with get_current_user + response2 = client.get( + "/api/v1/product" + ) # or any endpoint with get_current_user print(f"Regular endpoint - Status: {response2.status_code}") try: print(f"Regular endpoint - Response: {response2.json()}") @@ -35,7 +40,7 @@ class TestSecurity: """Debug test to see all available routes""" print("\n=== All Available Routes ===") for route in client.app.routes: - if hasattr(route, 'path') and hasattr(route, 'methods'): + if hasattr(route, "path") and hasattr(route, "methods"): print(f"{list(route.methods)} {route.path}") print("\n=== Testing Product Endpoint Variations ===") @@ -59,7 +64,7 @@ class TestSecurity: "/api/v1/product", "/api/v1/shop", "/api/v1/stats", - "/api/v1/stock" + "/api/v1/stock", ] for endpoint in protected_endpoints: @@ -76,7 +81,9 @@ class TestSecurity: def test_admin_endpoint_requires_admin_role(self, client, auth_headers): """Test that admin endpoints require admin role""" response = client.get("/api/v1/admin/users", headers=auth_headers) - assert response.status_code == 403 # Token is valid but user does not have access. + assert ( + response.status_code == 403 + ) # Token is valid but user does not have access. # Regular user should be denied def test_sql_injection_prevention(self, client, auth_headers): @@ -84,7 +91,9 @@ class TestSecurity: # Try SQL injection in search parameter malicious_search = "'; DROP TABLE products; --" - response = client.get(f"/api/v1/product?search={malicious_search}", headers=auth_headers) + response = client.get( + f"/api/v1/product?search={malicious_search}", headers=auth_headers + ) # Should not crash and should return normal response assert response.status_code == 200 diff --git a/tests/test_shop.py b/tests/test_shop.py index 89f83964..cead5ad2 100644 --- a/tests/test_shop.py +++ b/tests/test_shop.py @@ -8,7 +8,7 @@ class TestShopsAPI: shop_data = { "shop_code": "NEWSHOP", "shop_name": "New Shop", - "description": "A new test shop" + "description": "A new test shop", } response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data) @@ -23,7 +23,7 @@ class TestShopsAPI: """Test creating shop with duplicate code""" shop_data = { "shop_code": test_shop.shop_code, # Same as test_shop - "shop_name": test_shop.shop_name + "shop_name": test_shop.shop_name, } response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data) @@ -42,7 +42,9 @@ class TestShopsAPI: def test_get_shop_by_code(self, client, auth_headers, test_shop): """Test getting specific shop""" - response = client.get(f"/api/v1/shop/{test_shop.shop_code}", headers=auth_headers) + response = client.get( + f"/api/v1/shop/{test_shop.shop_code}", headers=auth_headers + ) assert response.status_code == 200 data = response.json() diff --git a/tests/test_shop_service.py b/tests/test_shop_service.py index e9059507..5b0e7ee9 100644 --- a/tests/test_shop_service.py +++ b/tests/test_shop_service.py @@ -18,7 +18,7 @@ class TestShopService: shop_data = ShopCreate( shop_code="NEWSHOP", shop_name="New Test Shop", - description="A new test shop" + description="A new test shop", ) shop = self.service.create_shop(db, shop_data, test_user) @@ -30,10 +30,7 @@ class TestShopService: def test_create_shop_admin_auto_verify(self, db, test_admin, shop_factory): """Test admin creates verified shop automatically""" - shop_data = ShopCreate( - shop_code="ADMINSHOP", - shop_name="Admin Test Shop" - ) + shop_data = ShopCreate(shop_code="ADMINSHOP", shop_name="Admin Test Shop") shop = self.service.create_shop(db, shop_data, test_admin) @@ -42,8 +39,7 @@ class TestShopService: def test_create_shop_duplicate_code(self, db, test_user, test_shop): """Test shop creation fails with duplicate shop code""" shop_data = ShopCreate( - shop_code=test_shop.shop_code, - shop_name=test_shop.shop_name + shop_code=test_shop.shop_code, shop_name=test_shop.shop_name ) with pytest.raises(HTTPException) as exc_info: @@ -60,9 +56,13 @@ class TestShopService: assert test_shop.shop_code in shop_codes assert inactive_shop.shop_code not in shop_codes - def test_get_shops_admin_user(self, db, test_admin, test_shop, inactive_shop, verified_shop): + def test_get_shops_admin_user( + self, db, test_admin, test_shop, inactive_shop, verified_shop + ): """Test admin user can see all shops with filters""" - shops, total = self.service.get_shops(db, test_admin, active_only=False, verified_only=False) + shops, total = self.service.get_shops( + db, test_admin, active_only=False, verified_only=False + ) shop_codes = [shop.shop_code for shop in shops] assert test_shop.shop_code in shop_codes @@ -78,7 +78,9 @@ class TestShopService: def test_get_shop_by_code_admin_access(self, db, test_admin, test_shop): """Test admin can access any shop""" - shop = self.service.get_shop_by_code(db, test_shop.shop_code.lower(), test_admin) + shop = self.service.get_shop_by_code( + db, test_shop.shop_code.lower(), test_admin + ) assert shop is not None assert shop.id == test_shop.id @@ -103,10 +105,12 @@ class TestShopService: product_id=unique_product.product_id, price="15.99", is_featured=True, - stock_quantity=5 + stock_quantity=5, ) - shop_product = self.service.add_product_to_shop(db, test_shop, shop_product_data) + shop_product = self.service.add_product_to_shop( + db, test_shop, shop_product_data + ) assert shop_product is not None assert shop_product.shop_id == test_shop.id @@ -114,10 +118,7 @@ class TestShopService: def test_add_product_to_shop_product_not_found(self, db, test_shop): """Test adding non-existent product to shop fails""" - shop_product_data = ShopProductCreate( - product_id="NONEXISTENT", - price="15.99" - ) + shop_product_data = ShopProductCreate(product_id="NONEXISTENT", price="15.99") with pytest.raises(HTTPException) as exc_info: self.service.add_product_to_shop(db, test_shop, shop_product_data) @@ -127,8 +128,7 @@ class TestShopService: def test_add_product_to_shop_already_exists(self, db, test_shop, shop_product): """Test adding product that's already in shop fails""" shop_product_data = ShopProductCreate( - product_id=shop_product.product.product_id, - price="15.99" + product_id=shop_product.product.product_id, price="15.99" ) with pytest.raises(HTTPException) as exc_info: @@ -136,7 +136,9 @@ class TestShopService: assert exc_info.value.status_code == 400 - def test_get_shop_products_owner_access(self, db, test_user, test_shop, shop_product): + def test_get_shop_products_owner_access( + self, db, test_user, test_shop, shop_product + ): """Test shop owner can get shop products""" products, total = self.service.get_shop_products(db, test_shop, test_user) diff --git a/tests/test_stats_service.py b/tests/test_stats_service.py index 2c6e6930..95749ba0 100644 --- a/tests/test_stats_service.py +++ b/tests/test_stats_service.py @@ -40,7 +40,7 @@ class TestStatsService: marketplace="Amazon", shop_name="AmazonShop", price="15.99", - currency="EUR" + currency="EUR", ), Product( product_id="PROD003", @@ -50,7 +50,7 @@ class TestStatsService: marketplace="eBay", shop_name="eBayShop", price="25.99", - currency="USD" + currency="USD", ), Product( product_id="PROD004", @@ -60,8 +60,8 @@ class TestStatsService: marketplace="Letzshop", # Same as test_product shop_name="DifferentShop", price="35.99", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(additional_products) db.commit() @@ -86,7 +86,7 @@ class TestStatsService: marketplace=None, # Null marketplace shop_name=None, # Null shop price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="EMPTY001", @@ -96,8 +96,8 @@ class TestStatsService: marketplace="", # Empty marketplace shop_name="", # Empty shop price="15.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(products_with_nulls) db.commit() @@ -122,14 +122,16 @@ class TestStatsService: # Find our test marketplace in the results test_marketplace_stat = next( (stat for stat in stats if stat["marketplace"] == test_product.marketplace), - None + None, ) assert test_marketplace_stat is not None assert test_marketplace_stat["total_products"] >= 1 assert test_marketplace_stat["unique_shops"] >= 1 assert test_marketplace_stat["unique_brands"] >= 1 - def test_get_marketplace_breakdown_stats_multiple_marketplaces(self, db, test_product): + def test_get_marketplace_breakdown_stats_multiple_marketplaces( + self, db, test_product + ): """Test marketplace breakdown with multiple marketplaces""" # Create products for different marketplaces marketplace_products = [ @@ -140,7 +142,7 @@ class TestStatsService: marketplace="Amazon", shop_name="AmazonShop1", price="20.00", - currency="EUR" + currency="EUR", ), Product( product_id="AMAZON002", @@ -149,7 +151,7 @@ class TestStatsService: marketplace="Amazon", shop_name="AmazonShop2", price="25.00", - currency="EUR" + currency="EUR", ), Product( product_id="EBAY001", @@ -158,8 +160,8 @@ class TestStatsService: marketplace="eBay", shop_name="eBayShop", price="30.00", - currency="USD" - ) + currency="USD", + ), ] db.add_all(marketplace_products) db.commit() @@ -194,7 +196,7 @@ class TestStatsService: shop_name="SomeShop", brand="SomeBrand", price="10.00", - currency="EUR" + currency="EUR", ) db.add(null_marketplace_product) db.commit() @@ -202,7 +204,9 @@ class TestStatsService: stats = self.service.get_marketplace_breakdown_stats(db) # Should not include any stats for null marketplace - marketplace_names = [stat["marketplace"] for stat in stats if stat["marketplace"] is not None] + marketplace_names = [ + stat["marketplace"] for stat in stats if stat["marketplace"] is not None + ] assert None not in marketplace_names def test_get_product_count(self, db, test_product): @@ -223,7 +227,7 @@ class TestStatsService: marketplace="Test", shop_name="TestShop", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="BRAND002", @@ -232,15 +236,17 @@ class TestStatsService: marketplace="Test", shop_name="TestShop", price="15.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(brand_products) db.commit() count = self.service.get_unique_brands_count(db) - assert count >= 2 # At least BrandA and BrandB, plus possibly test_product brand + assert ( + count >= 2 + ) # At least BrandA and BrandB, plus possibly test_product brand assert isinstance(count, int) def test_get_unique_categories_count(self, db, test_product): @@ -254,7 +260,7 @@ class TestStatsService: marketplace="Test", shop_name="TestShop", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="CAT002", @@ -263,8 +269,8 @@ class TestStatsService: marketplace="Test", shop_name="TestShop", price="15.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(category_products) db.commit() @@ -284,7 +290,7 @@ class TestStatsService: marketplace="Amazon", shop_name="AmazonShop", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="MARKET002", @@ -292,8 +298,8 @@ class TestStatsService: marketplace="eBay", shop_name="eBayShop", price="15.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(marketplace_products) db.commit() @@ -313,7 +319,7 @@ class TestStatsService: marketplace="Test", shop_name="ShopA", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="SHOP002", @@ -321,8 +327,8 @@ class TestStatsService: marketplace="Test", shop_name="ShopB", price="15.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(shop_products) db.commit() @@ -341,15 +347,15 @@ class TestStatsService: location="LOCATION2", quantity=25, reserved_quantity=5, - shop_id=test_stock.shop_id + shop_id=test_stock.shop_id, ), Stock( gtin="1234567890125", location="LOCATION3", quantity=0, # Out of stock reserved_quantity=0, - shop_id=test_stock.shop_id - ) + shop_id=test_stock.shop_id, + ), ] db.add_all(additional_stocks) db.commit() @@ -372,7 +378,7 @@ class TestStatsService: marketplace="SpecificMarket", shop_name="SpecificShop1", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="SPECIFIC002", @@ -381,7 +387,7 @@ class TestStatsService: marketplace="SpecificMarket", shop_name="SpecificShop2", price="15.00", - currency="EUR" + currency="EUR", ), Product( product_id="OTHER001", @@ -390,8 +396,8 @@ class TestStatsService: marketplace="OtherMarket", shop_name="OtherShop", price="20.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(marketplace_products) db.commit() @@ -414,7 +420,7 @@ class TestStatsService: marketplace="TestMarketplace", shop_name="TestShop1", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="SHOPTEST002", @@ -423,8 +429,8 @@ class TestStatsService: marketplace="TestMarketplace", shop_name="TestShop2", price="15.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(marketplace_products) db.commit() @@ -445,7 +451,7 @@ class TestStatsService: marketplace="CountMarketplace", shop_name="CountShop", price="10.00", - currency="EUR" + currency="EUR", ), Product( product_id="COUNT002", @@ -453,7 +459,7 @@ class TestStatsService: marketplace="CountMarketplace", shop_name="CountShop", price="15.00", - currency="EUR" + currency="EUR", ), Product( product_id="COUNT003", @@ -461,8 +467,8 @@ class TestStatsService: marketplace="CountMarketplace", shop_name="CountShop", price="20.00", - currency="EUR" - ) + currency="EUR", + ), ] db.add_all(marketplace_products) db.commit() diff --git a/tests/test_stock.py b/tests/test_stock.py index f9fd6a09..3c84dc24 100644 --- a/tests/test_stock.py +++ b/tests/test_stock.py @@ -1,5 +1,6 @@ # tests/test_stock.py import pytest + from models.database_models import Stock @@ -9,7 +10,7 @@ class TestStockAPI: stock_data = { "gtin": "1234567890123", "location": "WAREHOUSE_A", - "quantity": 100 + "quantity": 100, } response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) @@ -30,7 +31,7 @@ class TestStockAPI: stock_data = { "gtin": "1234567890123", "location": "WAREHOUSE_A", - "quantity": 75 + "quantity": 75, } response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) @@ -49,10 +50,12 @@ class TestStockAPI: stock_data = { "gtin": "1234567890123", "location": "WAREHOUSE_A", - "quantity": 25 + "quantity": 25, } - response = client.post("/api/v1/stock/add", headers=auth_headers, json=stock_data) + response = client.post( + "/api/v1/stock/add", headers=auth_headers, json=stock_data + ) assert response.status_code == 200 data = response.json() @@ -68,10 +71,12 @@ class TestStockAPI: stock_data = { "gtin": "1234567890123", "location": "WAREHOUSE_A", - "quantity": 15 + "quantity": 15, } - response = client.post("/api/v1/stock/remove", headers=auth_headers, json=stock_data) + response = client.post( + "/api/v1/stock/remove", headers=auth_headers, json=stock_data + ) assert response.status_code == 200 data = response.json() @@ -87,10 +92,12 @@ class TestStockAPI: stock_data = { "gtin": "1234567890123", "location": "WAREHOUSE_A", - "quantity": 20 + "quantity": 20, } - response = client.post("/api/v1/stock/remove", headers=auth_headers, json=stock_data) + response = client.post( + "/api/v1/stock/remove", headers=auth_headers, json=stock_data + ) assert response.status_code == 400 assert "Insufficient stock" in response.json()["detail"] diff --git a/tests/test_stock_service.py b/tests/test_stock_service.py index 239a014d..f24c1925 100644 --- a/tests/test_stock_service.py +++ b/tests/test_stock_service.py @@ -1,9 +1,11 @@ # tests/test_stock_service.py -import pytest import uuid + +import pytest + from app.services.stock_service import StockService -from models.api_models import StockCreate, StockAdd, StockUpdate -from models.database_models import Stock, Product +from models.api_models import StockAdd, StockCreate, StockUpdate +from models.database_models import Product, Stock class TestStockService: @@ -39,7 +41,9 @@ class TestStockService: assert self.service.normalize_gtin("1234567890123") == "1234567890123" # EAN-13 assert self.service.normalize_gtin("123456789012") == "123456789012" # UPC-A assert self.service.normalize_gtin("12345678") == "12345678" # EAN-8 - assert self.service.normalize_gtin("12345678901234") == "12345678901234" # GTIN-14 + assert ( + self.service.normalize_gtin("12345678901234") == "12345678901234" + ) # GTIN-14 # Test with decimal points (should be removed) assert self.service.normalize_gtin("1234567890123.0") == "1234567890123" @@ -49,10 +53,14 @@ class TestStockService: # Test short GTINs being padded assert self.service.normalize_gtin("123") == "0000000000123" # Padded to EAN-13 - assert self.service.normalize_gtin("12345") == "0000000012345" # Padded to EAN-13 + assert ( + self.service.normalize_gtin("12345") == "0000000012345" + ) # Padded to EAN-13 # Test long GTINs being truncated - assert self.service.normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13 + assert ( + self.service.normalize_gtin("123456789012345") == "3456789012345" + ) # Truncated to 13 def test_normalize_gtin_edge_cases(self): """Test GTIN normalization edge cases""" @@ -61,17 +69,21 @@ class TestStockService: assert self.service.normalize_gtin(123) == "0000000000123" # Test mixed valid/invalid characters - assert self.service.normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed - assert self.service.normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed - assert self.service.normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed + assert ( + self.service.normalize_gtin("123-456-789-012") == "123456789012" + ) # Dashes removed + assert ( + self.service.normalize_gtin("123 456 789 012") == "123456789012" + ) # Spaces removed + assert ( + self.service.normalize_gtin("ABC123456789012DEF") == "123456789012" + ) # Letters removed def test_set_stock_new_entry(self, db): """Test setting stock for a new GTIN/location combination""" unique_id = str(uuid.uuid4())[:8] stock_data = StockCreate( - gtin="1234567890123", - location=f"WAREHOUSE_A_{unique_id}", - quantity=100 + gtin="1234567890123", location=f"WAREHOUSE_A_{unique_id}", quantity=100 ) result = self.service.set_stock(db, stock_data) @@ -85,7 +97,7 @@ class TestStockService: stock_data = StockCreate( gtin=test_stock.gtin, location=test_stock.location, # Use exact same location as test_stock - quantity=200 + quantity=200, ) result = self.service.set_stock(db, stock_data) @@ -98,9 +110,7 @@ class TestStockService: def test_set_stock_invalid_gtin(self, db): """Test setting stock with invalid GTIN""" stock_data = StockCreate( - gtin="invalid_gtin", - location="WAREHOUSE_A", - quantity=100 + gtin="invalid_gtin", location="WAREHOUSE_A", quantity=100 ) with pytest.raises(ValueError, match="Invalid GTIN format"): @@ -110,9 +120,7 @@ class TestStockService: """Test adding stock for a new GTIN/location combination""" unique_id = str(uuid.uuid4())[:8] stock_data = StockAdd( - gtin="1234567890123", - location=f"WAREHOUSE_B_{unique_id}", - quantity=50 + gtin="1234567890123", location=f"WAREHOUSE_B_{unique_id}", quantity=50 ) result = self.service.add_stock(db, stock_data) @@ -127,7 +135,7 @@ class TestStockService: stock_data = StockAdd( gtin=test_stock.gtin, location=test_stock.location, # Use exact same location as test_stock - quantity=25 + quantity=25, ) result = self.service.add_stock(db, stock_data) @@ -138,11 +146,7 @@ class TestStockService: def test_add_stock_invalid_gtin(self, db): """Test adding stock with invalid GTIN""" - stock_data = StockAdd( - gtin="invalid_gtin", - location="WAREHOUSE_A", - quantity=50 - ) + stock_data = StockAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50) with pytest.raises(ValueError, match="Invalid GTIN format"): self.service.add_stock(db, stock_data) @@ -150,12 +154,14 @@ class TestStockService: def test_remove_stock_success(self, db, test_stock): """Test removing stock successfully""" original_quantity = test_stock.quantity - remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available + remove_quantity = min( + 10, original_quantity + ) # Ensure we don't remove more than available stock_data = StockAdd( gtin=test_stock.gtin, location=test_stock.location, # Use exact same location as test_stock - quantity=remove_quantity + quantity=remove_quantity, ) result = self.service.remove_stock(db, stock_data) @@ -169,20 +175,20 @@ class TestStockService: stock_data = StockAdd( gtin=test_stock.gtin, location=test_stock.location, # Use exact same location as test_stock - quantity=test_stock.quantity + 10 # More than available + quantity=test_stock.quantity + 10, # More than available ) # Fix: Use more flexible regex pattern - with pytest.raises(ValueError, match="Insufficient stock|Not enough stock|Cannot remove"): + with pytest.raises( + ValueError, match="Insufficient stock|Not enough stock|Cannot remove" + ): self.service.remove_stock(db, stock_data) def test_remove_stock_nonexistent_entry(self, db): """Test removing stock from non-existent GTIN/location""" unique_id = str(uuid.uuid4())[:8] stock_data = StockAdd( - gtin="9999999999999", - location=f"NONEXISTENT_{unique_id}", - quantity=10 + gtin="9999999999999", location=f"NONEXISTENT_{unique_id}", quantity=10 ) with pytest.raises(ValueError, match="No stock found|Stock not found"): @@ -190,11 +196,7 @@ class TestStockService: def test_remove_stock_invalid_gtin(self, db): """Test removing stock with invalid GTIN""" - stock_data = StockAdd( - gtin="invalid_gtin", - location="WAREHOUSE_A", - quantity=10 - ) + stock_data = StockAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10) with pytest.raises(ValueError, match="Invalid GTIN format"): self.service.remove_stock(db, stock_data) @@ -218,14 +220,10 @@ class TestStockService: # Create multiple stock entries for the same GTIN with unique locations stock1 = Stock( - gtin=unique_gtin, - location=f"WAREHOUSE_A_{unique_id}", - quantity=50 + gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50 ) stock2 = Stock( - gtin=unique_gtin, - location=f"WAREHOUSE_B_{unique_id}", - quantity=30 + gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30 ) db.add(stock1) @@ -275,7 +273,9 @@ class TestStockService: assert len(result) >= 1 # Fix: Handle case sensitivity in comparison - assert all(stock.location.upper() == test_stock.location.upper() for stock in result) + assert all( + stock.location.upper() == test_stock.location.upper() for stock in result + ) def test_get_all_stock_with_gtin_filter(self, db, test_stock): """Test getting all stock with GTIN filter""" @@ -293,14 +293,16 @@ class TestStockService: stock = Stock( gtin=f"1234567890{i:03d}", # Creates valid 13-digit GTINs: 1234567890000, 1234567890001, etc. location=f"WAREHOUSE_{unique_prefix}_{i}", - quantity=10 + quantity=10, ) db.add(stock) db.commit() result = self.service.get_all_stock(db, skip=2, limit=2) - assert len(result) <= 2 # Should be at most 2, might be less if other records exist + assert ( + len(result) <= 2 + ) # Should be at most 2, might be less if other records exist def test_update_stock_success(self, db, test_stock): """Test updating stock quantity""" @@ -359,7 +361,7 @@ def test_product_with_stock(db, test_stock): gtin=test_stock.gtin, price="29.99", brand="TestBrand", - marketplace="Letzshop" + marketplace="Letzshop", ) db.add(product) db.commit() diff --git a/tests/test_utils.py b/tests/test_utils.py index dc9f92d4..fe77523f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ # tests/test_utils.py (Enhanced version of your existing file) import pytest + from utils.data_processing import GTINProcessor, PriceProcessor diff --git a/utils/csv_processor.py b/utils/csv_processor.py index 559fff83..b5efa6cc 100644 --- a/utils/csv_processor.py +++ b/utils/csv_processor.py @@ -1,14 +1,15 @@ # utils/csv_processor.py +import logging +from datetime import datetime +from io import StringIO +from typing import Any, Dict + import pandas as pd import requests -from io import StringIO -from typing import Dict, Any - from sqlalchemy import literal from sqlalchemy.orm import Session + from models.database_models import Product -from datetime import datetime -import logging logger = logging.getLogger(__name__) @@ -16,67 +17,66 @@ logger = logging.getLogger(__name__) class CSVProcessor: """Handles CSV import with robust parsing and batching""" - ENCODINGS = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'utf-8-sig'] + ENCODINGS = ["utf-8", "latin-1", "iso-8859-1", "cp1252", "utf-8-sig"] PARSING_CONFIGS = [ # Try auto-detection first - {'sep': None, 'engine': 'python'}, + {"sep": None, "engine": "python"}, # Try semicolon (common in European CSVs) - {'sep': ';', 'engine': 'python'}, + {"sep": ";", "engine": "python"}, # Try comma - {'sep': ',', 'engine': 'python'}, + {"sep": ",", "engine": "python"}, # Try tab - {'sep': '\t', 'engine': 'python'}, + {"sep": "\t", "engine": "python"}, ] COLUMN_MAPPING = { # Standard variations - 'id': 'product_id', - 'ID': 'product_id', - 'Product ID': 'product_id', - 'name': 'title', - 'Name': 'title', - 'product_name': 'title', - 'Product Name': 'title', - + "id": "product_id", + "ID": "product_id", + "Product ID": "product_id", + "name": "title", + "Name": "title", + "product_name": "title", + "Product Name": "title", # Google Shopping feed standard - 'g:id': 'product_id', - 'g:title': 'title', - 'g:description': 'description', - 'g:link': 'link', - 'g:image_link': 'image_link', - 'g:availability': 'availability', - 'g:price': 'price', - 'g:brand': 'brand', - 'g:gtin': 'gtin', - 'g:mpn': 'mpn', - 'g:condition': 'condition', - 'g:adult': 'adult', - 'g:multipack': 'multipack', - 'g:is_bundle': 'is_bundle', - 'g:age_group': 'age_group', - 'g:color': 'color', - 'g:gender': 'gender', - 'g:material': 'material', - 'g:pattern': 'pattern', - 'g:size': 'size', - 'g:size_type': 'size_type', - 'g:size_system': 'size_system', - 'g:item_group_id': 'item_group_id', - 'g:google_product_category': 'google_product_category', - 'g:product_type': 'product_type', - 'g:custom_label_0': 'custom_label_0', - 'g:custom_label_1': 'custom_label_1', - 'g:custom_label_2': 'custom_label_2', - 'g:custom_label_3': 'custom_label_3', - 'g:custom_label_4': 'custom_label_4', - + "g:id": "product_id", + "g:title": "title", + "g:description": "description", + "g:link": "link", + "g:image_link": "image_link", + "g:availability": "availability", + "g:price": "price", + "g:brand": "brand", + "g:gtin": "gtin", + "g:mpn": "mpn", + "g:condition": "condition", + "g:adult": "adult", + "g:multipack": "multipack", + "g:is_bundle": "is_bundle", + "g:age_group": "age_group", + "g:color": "color", + "g:gender": "gender", + "g:material": "material", + "g:pattern": "pattern", + "g:size": "size", + "g:size_type": "size_type", + "g:size_system": "size_system", + "g:item_group_id": "item_group_id", + "g:google_product_category": "google_product_category", + "g:product_type": "product_type", + "g:custom_label_0": "custom_label_0", + "g:custom_label_1": "custom_label_1", + "g:custom_label_2": "custom_label_2", + "g:custom_label_3": "custom_label_3", + "g:custom_label_4": "custom_label_4", # Handle complex shipping column - 'shipping(country:price:max_handling_time:min_transit_time:max_transit_time)': 'shipping' + "shipping(country:price:max_handling_time:min_transit_time:max_transit_time)": "shipping", } def __init__(self): from utils.data_processing import GTINProcessor, PriceProcessor + self.gtin_processor = GTINProcessor() self.price_processor = PriceProcessor() @@ -98,7 +98,7 @@ class CSVProcessor: continue # Fallback with error ignoring - decoded_content = content.decode('utf-8', errors='ignore') + decoded_content = content.decode("utf-8", errors="ignore") logger.warning("Used UTF-8 with error ignoring for CSV decoding") return decoded_content @@ -113,11 +113,11 @@ class CSVProcessor: try: df = pd.read_csv( StringIO(csv_content), - on_bad_lines='skip', + on_bad_lines="skip", quotechar='"', skip_blank_lines=True, skipinitialspace=True, - **config + **config, ) logger.info(f"Successfully parsed CSV with config: {config}") return df @@ -143,42 +143,43 @@ class CSVProcessor: processed_data = {k: (v if pd.notna(v) else None) for k, v in row_data.items()} # Process GTIN - if processed_data.get('gtin'): - processed_data['gtin'] = self.gtin_processor.normalize(processed_data['gtin']) + if processed_data.get("gtin"): + processed_data["gtin"] = self.gtin_processor.normalize( + processed_data["gtin"] + ) # Process price and currency - if processed_data.get('price'): - parsed_price, currency = self.price_processor.parse_price_currency(processed_data['price']) - processed_data['price'] = parsed_price - processed_data['currency'] = currency + if processed_data.get("price"): + parsed_price, currency = self.price_processor.parse_price_currency( + processed_data["price"] + ) + processed_data["price"] = parsed_price + processed_data["currency"] = currency # Process sale_price - if processed_data.get('sale_price'): - parsed_sale_price, _ = self.price_processor.parse_price_currency(processed_data['sale_price']) - processed_data['sale_price'] = parsed_sale_price + if processed_data.get("sale_price"): + parsed_sale_price, _ = self.price_processor.parse_price_currency( + processed_data["sale_price"] + ) + processed_data["sale_price"] = parsed_sale_price # Clean MPN (remove .0 endings) - if processed_data.get('mpn'): - mpn_str = str(processed_data['mpn']).strip() - if mpn_str.endswith('.0'): - processed_data['mpn'] = mpn_str[:-2] + if processed_data.get("mpn"): + mpn_str = str(processed_data["mpn"]).strip() + if mpn_str.endswith(".0"): + processed_data["mpn"] = mpn_str[:-2] # Handle multipack type conversion - if processed_data.get('multipack') is not None: + if processed_data.get("multipack") is not None: try: - processed_data['multipack'] = int(float(processed_data['multipack'])) + processed_data["multipack"] = int(float(processed_data["multipack"])) except (ValueError, TypeError): - processed_data['multipack'] = None + processed_data["multipack"] = None return processed_data async def process_marketplace_csv_from_url( - self, - url: str, - marketplace: str, - shop_name: str, - batch_size: int, - db: Session + self, url: str, marketplace: str, shop_name: str, batch_size: int, db: Session ) -> Dict[str, Any]: """ Process CSV from URL with marketplace and shop information @@ -194,7 +195,9 @@ class CSVProcessor: Dictionary with processing results """ - logger.info(f"Starting marketplace CSV import from {url} for {marketplace} -> {shop_name}") + logger.info( + f"Starting marketplace CSV import from {url} for {marketplace} -> {shop_name}" + ) # Download and parse CSV csv_content = self.download_csv(url) df = self.parse_csv(csv_content) @@ -208,40 +211,42 @@ class CSVProcessor: # Process in batches for i in range(0, len(df), batch_size): - batch_df = df.iloc[i:i + batch_size] + batch_df = df.iloc[i : i + batch_size] batch_result = await self._process_marketplace_batch( batch_df, marketplace, shop_name, db, i // batch_size + 1 ) - imported += batch_result['imported'] - updated += batch_result['updated'] - errors += batch_result['errors'] + imported += batch_result["imported"] + updated += batch_result["updated"] + errors += batch_result["errors"] logger.info(f"Processed batch {i // batch_size + 1}: {batch_result}") return { - 'total_processed': imported + updated + errors, - 'imported': imported, - 'updated': updated, - 'errors': errors, - 'marketplace': marketplace, - 'shop_name': shop_name + "total_processed": imported + updated + errors, + "imported": imported, + "updated": updated, + "errors": errors, + "marketplace": marketplace, + "shop_name": shop_name, } async def _process_marketplace_batch( - self, - batch_df: pd.DataFrame, - marketplace: str, - shop_name: str, - db: Session, - batch_num: int + self, + batch_df: pd.DataFrame, + marketplace: str, + shop_name: str, + db: Session, + batch_num: int, ) -> Dict[str, int]: """Process a batch of CSV rows with marketplace information""" imported = 0 updated = 0 errors = 0 - logger.info(f"Processing batch {batch_num} with {len(batch_df)} rows for {marketplace} -> {shop_name}") + logger.info( + f"Processing batch {batch_num} with {len(batch_df)} rows for {marketplace} -> {shop_name}" + ) for index, row in batch_df.iterrows(): try: @@ -249,42 +254,54 @@ class CSVProcessor: product_data = self._clean_row_data(row.to_dict()) # Add marketplace and shop information - product_data['marketplace'] = marketplace - product_data['shop_name'] = shop_name + product_data["marketplace"] = marketplace + product_data["shop_name"] = shop_name # Validate required fields - if not product_data.get('product_id'): + if not product_data.get("product_id"): logger.warning(f"Row {index}: Missing product_id, skipping") errors += 1 continue - if not product_data.get('title'): + if not product_data.get("title"): logger.warning(f"Row {index}: Missing title, skipping") errors += 1 continue # Check if product exists - existing_product = db.query(Product).filter( - Product.product_id == literal(product_data['product_id']) - ).first() + existing_product = ( + db.query(Product) + .filter(Product.product_id == literal(product_data["product_id"])) + .first() + ) if existing_product: # Update existing product for key, value in product_data.items(): - if key not in ['id', 'created_at'] and hasattr(existing_product, key): + if key not in ["id", "created_at"] and hasattr( + existing_product, key + ): setattr(existing_product, key, value) existing_product.updated_at = datetime.utcnow() updated += 1 - logger.debug(f"Updated product {product_data['product_id']} for {marketplace} and shop {shop_name}") + logger.debug( + f"Updated product {product_data['product_id']} for {marketplace} and shop {shop_name}" + ) else: # Create new product - filtered_data = {k: v for k, v in product_data.items() - if k not in ['id', 'created_at', 'updated_at'] and hasattr(Product, k)} + filtered_data = { + k: v + for k, v in product_data.items() + if k not in ["id", "created_at", "updated_at"] + and hasattr(Product, k) + } new_product = Product(**filtered_data) db.add(new_product) imported += 1 - logger.debug(f"Imported new product {product_data['product_id']} for {marketplace} and shop " - f"{shop_name}") + logger.debug( + f"Imported new product {product_data['product_id']} for {marketplace} and shop " + f"{shop_name}" + ) except Exception as e: logger.error(f"Error processing row: {e}") @@ -303,8 +320,4 @@ class CSVProcessor: imported = 0 updated = 0 - return { - 'imported': imported, - 'updated': updated, - 'errors': errors - } + return {"imported": imported, "updated": updated, "errors": errors} diff --git a/utils/data_processing.py b/utils/data_processing.py index 5befc82b..23585cd0 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -1,8 +1,9 @@ # utils/data_processing.py -import re -import pandas as pd -from typing import Tuple, Optional import logging +import re +from typing import Optional, Tuple + +import pandas as pd logger = logging.getLogger(__name__) @@ -25,11 +26,11 @@ class GTINProcessor: return None # Remove decimal point (e.g., "889698116923.0" -> "889698116923") - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] + if "." in gtin_str: + gtin_str = gtin_str.split(".")[0] # Keep only digits - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + gtin_clean = "".join(filter(str.isdigit, gtin_str)) if not gtin_clean: return None @@ -73,23 +74,23 @@ class PriceProcessor: CURRENCY_PATTERNS = { # Amount followed by currency - r'([0-9.,]+)\s*(EUR|€)': lambda m: (m.group(1), 'EUR'), - r'([0-9.,]+)\s*(USD|\$)': lambda m: (m.group(1), 'USD'), - r'([0-9.,]+)\s*(GBP|£)': lambda m: (m.group(1), 'GBP'), - r'([0-9.,]+)\s*(CHF)': lambda m: (m.group(1), 'CHF'), - r'([0-9.,]+)\s*(CAD|AUD|JPY|¥)': lambda m: (m.group(1), m.group(2).upper()), - + r"([0-9.,]+)\s*(EUR|€)": lambda m: (m.group(1), "EUR"), + r"([0-9.,]+)\s*(USD|\$)": lambda m: (m.group(1), "USD"), + r"([0-9.,]+)\s*(GBP|£)": lambda m: (m.group(1), "GBP"), + r"([0-9.,]+)\s*(CHF)": lambda m: (m.group(1), "CHF"), + r"([0-9.,]+)\s*(CAD|AUD|JPY|¥)": lambda m: (m.group(1), m.group(2).upper()), # Currency followed by amount - r'(EUR|€)\s*([0-9.,]+)': lambda m: (m.group(2), 'EUR'), - r'(USD|\$)\s*([0-9.,]+)': lambda m: (m.group(2), 'USD'), - r'(GBP|£)\s*([0-9.,]+)': lambda m: (m.group(2), 'GBP'), - + r"(EUR|€)\s*([0-9.,]+)": lambda m: (m.group(2), "EUR"), + r"(USD|\$)\s*([0-9.,]+)": lambda m: (m.group(2), "USD"), + r"(GBP|£)\s*([0-9.,]+)": lambda m: (m.group(2), "GBP"), # Generic 3-letter currency codes - r'([0-9.,]+)\s*([A-Z]{3})': lambda m: (m.group(1), m.group(2)), - r'([A-Z]{3})\s*([0-9.,]+)': lambda m: (m.group(2), m.group(1)), + r"([0-9.,]+)\s*([A-Z]{3})": lambda m: (m.group(1), m.group(2)), + r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)), } - def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]: + def parse_price_currency( + self, price_str: any + ) -> Tuple[Optional[str], Optional[str]]: """ Parse price string into (price, currency) tuple Returns (None, None) if parsing fails @@ -108,7 +109,7 @@ class PriceProcessor: try: price_val, currency_val = extract_func(match) # Normalize price (remove spaces, handle comma as decimal) - price_val = price_val.replace(' ', '').replace(',', '.') + price_val = price_val.replace(" ", "").replace(",", ".") # Validate numeric float(price_val) return price_val, currency_val.upper() @@ -116,10 +117,10 @@ class PriceProcessor: continue # Fallback: extract just numbers - number_match = re.search(r'([0-9.,]+)', price_str) + number_match = re.search(r"([0-9.,]+)", price_str) if number_match: try: - price_val = number_match.group(1).replace(',', '.') + price_val = number_match.group(1).replace(",", ".") float(price_val) # Validate return price_val, None except ValueError: diff --git a/utils/database.py b/utils/database.py index b1523090..db16451f 100644 --- a/utils/database.py +++ b/utils/database.py @@ -1,20 +1,19 @@ # utils/database.py +import logging + from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool -import logging logger = logging.getLogger(__name__) def get_db_engine(database_url: str): """Create database engine with connection pooling""" - if database_url.startswith('sqlite'): + if database_url.startswith("sqlite"): # SQLite configuration engine = create_engine( - database_url, - connect_args={"check_same_thread": False}, - echo=False + database_url, connect_args={"check_same_thread": False}, echo=False ) else: # PostgreSQL configuration with connection pooling @@ -24,7 +23,7 @@ def get_db_engine(database_url: str): pool_size=10, max_overflow=20, pool_pre_ping=True, - echo=False + echo=False, ) logger.info(f"Database engine created for: {database_url.split('@')[0]}@...")