From 71153a1ff52d49dfc67307045725849f6d8a4897 Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Tue, 9 Sep 2025 21:27:58 +0200 Subject: [PATCH] Refactoring code for modular approach --- .idea/Letzshop-Import-v2.iml | 2 +- .idea/misc.xml | 2 +- app/__init__.py | 0 app/api/__init__.py | 0 app/api/deps.py | 40 + app/api/main.py | 13 + app/api/v1/__init__.py | 0 app/api/v1/admin.py | 153 ++++ app/api/v1/auth.py | 70 ++ app/api/v1/marketplace.py | 146 ++++ app/api/v1/products.py | 261 ++++++ app/api/v1/shops.py | 188 +++++ app/api/v1/stats.py | 84 ++ app/api/v1/stock.py | 315 +++++++ app/core/__init__.py | 0 app/core/config.py | 26 + app/core/database.py | 21 + app/core/lifespan.py | 48 ++ app/core/logging.py | 42 + app/services/__init__.py | 0 app/services/import_service.py | 0 app/services/product_service.py | 79 ++ app/services/stock_service.py | 0 app/tasks/__init__.py | 0 app/tasks/background_tasks.py | 63 ++ comprehensive_readme.md | 493 +++++++++++ config/settings.py | 1 + main.py | 1358 +------------------------------ middleware/decorators.py | 29 + tests/Makefile | 57 ++ tests/__init__.py | 2 + tests/conftest.py | 194 +++++ tests/pytest.ini | 21 + tests/requirements_test.txt | 8 + tests/test_admin.py | 34 + tests/test_auth.py | 119 +++ tests/test_background_tasks.py | 83 ++ tests/test_csv_processor.py | 90 ++ tests/test_data_validation.py | 46 ++ tests/test_database.py | 98 +++ tests/test_error_handling.py | 45 + tests/test_export.py | 65 ++ tests/test_filtering.py | 85 ++ tests/test_integration.py | 117 +++ tests/test_marketplace.py | 52 ++ tests/test_middleware.py | 63 ++ tests/test_pagination.py | 56 ++ tests/test_performance.py | 56 ++ tests/test_products.py | 122 +++ tests/test_security.py | 61 ++ tests/test_services.py | 60 ++ tests/test_shops.py | 55 ++ tests/test_stats.py | 33 + tests/test_stock.py | 147 ++++ tests/test_utils.py | 77 +- 55 files changed, 3928 insertions(+), 1352 deletions(-) create mode 100644 app/__init__.py create mode 100644 app/api/__init__.py create mode 100644 app/api/deps.py create mode 100644 app/api/main.py create mode 100644 app/api/v1/__init__.py create mode 100644 app/api/v1/admin.py create mode 100644 app/api/v1/auth.py create mode 100644 app/api/v1/marketplace.py create mode 100644 app/api/v1/products.py create mode 100644 app/api/v1/shops.py create mode 100644 app/api/v1/stats.py create mode 100644 app/api/v1/stock.py create mode 100644 app/core/__init__.py create mode 100644 app/core/config.py create mode 100644 app/core/database.py create mode 100644 app/core/lifespan.py create mode 100644 app/core/logging.py create mode 100644 app/services/__init__.py create mode 100644 app/services/import_service.py create mode 100644 app/services/product_service.py create mode 100644 app/services/stock_service.py create mode 100644 app/tasks/__init__.py create mode 100644 app/tasks/background_tasks.py create mode 100644 comprehensive_readme.md create mode 100644 middleware/decorators.py create mode 100644 tests/Makefile create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/pytest.ini create mode 100644 tests/requirements_test.txt create mode 100644 tests/test_admin.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_background_tasks.py create mode 100644 tests/test_csv_processor.py create mode 100644 tests/test_data_validation.py create mode 100644 tests/test_database.py create mode 100644 tests/test_error_handling.py create mode 100644 tests/test_export.py create mode 100644 tests/test_filtering.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_marketplace.py create mode 100644 tests/test_middleware.py create mode 100644 tests/test_pagination.py create mode 100644 tests/test_performance.py create mode 100644 tests/test_products.py create mode 100644 tests/test_security.py create mode 100644 tests/test_services.py create mode 100644 tests/test_shops.py create mode 100644 tests/test_stats.py create mode 100644 tests/test_stock.py diff --git a/.idea/Letzshop-Import-v2.iml b/.idea/Letzshop-Import-v2.iml index 72ead01c..fc0b0f27 100644 --- a/.idea/Letzshop-Import-v2.iml +++ b/.idea/Letzshop-Import-v2.iml @@ -4,7 +4,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 5082f5ca..174cf50e 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/api/deps.py b/app/api/deps.py new file mode 100644 index 00000000..8fd32d3c --- /dev/null +++ b/app/api/deps.py @@ -0,0 +1,40 @@ +from fastapi import Depends, HTTPException +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from sqlalchemy.orm import Session +from app.core.database import get_db +from models.database_models import User, Shop +from middleware.auth import AuthManager +from middleware.rate_limiter import RateLimiter + +security = HTTPBearer() +auth_manager = AuthManager() +rate_limiter = RateLimiter() + + +def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db) +): + """Get current authenticated user""" + return auth_manager.get_current_user(db, credentials) + + +def get_current_admin_user(current_user: User = Depends(get_current_user)): + """Require admin user""" + return auth_manager.require_admin(current_user) + + +def get_user_shop( + shop_code: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Get shop and verify user ownership""" + shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() + if not shop: + raise HTTPException(status_code=404, detail="Shop not found") + + if current_user.role != "admin" and shop.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="Access denied to this shop") + + return shop diff --git a/app/api/main.py b/app/api/main.py new file mode 100644 index 00000000..0cbc0f7f --- /dev/null +++ b/app/api/main.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter +from app.api.v1 import auth, products, stock, shops, marketplace, admin, stats + +api_router = APIRouter() + +# Include all route modules +api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) +api_router.include_router(products.router, prefix="/products", tags=["products"]) +api_router.include_router(stock.router, prefix="/stock", tags=["stock"]) +api_router.include_router(shops.router, prefix="/shops", tags=["shops"]) +api_router.include_router(marketplace.router, prefix="/marketplace", tags=["marketplace"]) +api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) +api_router.include_router(stats.router, prefix="/stats", tags=["statistics"]) diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py new file mode 100644 index 00000000..1ccf3efa --- /dev/null +++ b/app/api/v1/admin.py @@ -0,0 +1,153 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from app.tasks.background_tasks import process_marketplace_import +from middleware.decorators import rate_limit +from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest +from models.database_models import User, MarketplaceImportJob, Shop +from datetime import datetime +import logging + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Admin-only routes +@router.get("/admin/users", response_model=List[UserResponse]) +def get_all_users( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user) +): + """Get all users (Admin only)""" + users = db.query(User).offset(skip).limit(limit).all() + return [UserResponse.model_validate(user) for user in users] + + +@router.put("/admin/users/{user_id}/status") +def toggle_user_status( + user_id: int, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user) +): + """Toggle user active status (Admin only)""" + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + if user.id == current_admin.id: + raise HTTPException(status_code=400, detail="Cannot deactivate your own account") + + user.is_active = not user.is_active + user.updated_at = datetime.utcnow() + db.commit() + db.refresh(user) + + status = "activated" if user.is_active else "deactivated" + return {"message": f"User {user.username} has been {status}"} + + +@router.get("/admin/shops", response_model=ShopListResponse) +def get_all_shops_admin( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user) +): + """Get all shops with admin view (Admin only)""" + total = db.query(Shop).count() + shops = db.query(Shop).offset(skip).limit(limit).all() + + return ShopListResponse( + shops=shops, + total=total, + skip=skip, + limit=limit + ) + + +@router.put("/admin/shops/{shop_id}/verify") +def verify_shop( + shop_id: int, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user) +): + """Verify/unverify shop (Admin only)""" + shop = db.query(Shop).filter(Shop.id == shop_id).first() + if not shop: + raise HTTPException(status_code=404, detail="Shop not found") + + shop.is_verified = not shop.is_verified + shop.updated_at = datetime.utcnow() + db.commit() + db.refresh(shop) + + status = "verified" if shop.is_verified else "unverified" + return {"message": f"Shop {shop.shop_code} has been {status}"} + + +@router.put("/admin/shops/{shop_id}/status") +def toggle_shop_status( + shop_id: int, + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user) +): + """Toggle shop active status (Admin only)""" + shop = db.query(Shop).filter(Shop.id == shop_id).first() + if not shop: + raise HTTPException(status_code=404, detail="Shop not found") + + shop.is_active = not shop.is_active + shop.updated_at = datetime.utcnow() + db.commit() + db.refresh(shop) + + status = "activated" if shop.is_active else "deactivated" + return {"message": f"Shop {shop.shop_code} has been {status}"} + + +@router.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) +def get_all_marketplace_import_jobs( + marketplace: Optional[str] = Query(None), + shop_name: Optional[str] = Query(None), + status: Optional[str] = Query(None), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + db: Session = Depends(get_db), + current_admin: User = Depends(get_current_admin_user) +): + """Get all marketplace import jobs (Admin only)""" + + query = db.query(MarketplaceImportJob) + + # Apply filters + if marketplace: + query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) + if shop_name: + query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) + if status: + query = query.filter(MarketplaceImportJob.status == status) + + # Order by creation date and apply pagination + jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() + + return [ + MarketplaceImportJobResponse( + job_id=job.id, + status=job.status, + marketplace=job.marketplace, + shop_name=job.shop_name, + imported=job.imported_count or 0, + updated=job.updated_count or 0, + total_processed=job.total_processed or 0, + error_count=job.error_count or 0, + error_message=job.error_message, + created_at=job.created_at, + started_at=job.started_at, + completed_at=job.completed_at + ) for job in jobs + ] diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py new file mode 100644 index 00000000..40623365 --- /dev/null +++ b/app/api/v1/auth.py @@ -0,0 +1,70 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from models.api_models import UserRegister, UserLogin, UserResponse, LoginResponse +from models.database_models import User +from middleware.auth import AuthManager +import logging + +router = APIRouter() +auth_manager = AuthManager() +logger = logging.getLogger(__name__) + + +# Authentication Routes +@router.post("/register", response_model=UserResponse) +def register_user(user_data: UserRegister, db: Session = Depends(get_db)): + """Register a new user""" + # Check if email already exists + existing_email = db.query(User).filter(User.email == user_data.email).first() + if existing_email: + raise HTTPException(status_code=400, detail="Email already registered") + + # Check if username already exists + existing_username = db.query(User).filter(User.username == user_data.username).first() + if existing_username: + raise HTTPException(status_code=400, detail="Username already taken") + + # Hash password and create user + hashed_password = auth_manager.hash_password(user_data.password) + new_user = User( + email=user_data.email, + username=user_data.username, + hashed_password=hashed_password, + role="user", + is_active=True + ) + + db.add(new_user) + db.commit() + db.refresh(new_user) + + logger.info(f"New user registered: {new_user.username}") + return new_user + + +@router.post("/login", response_model=LoginResponse) +def login_user(user_credentials: UserLogin, db: Session = Depends(get_db)): + """Login user and return JWT token""" + user = auth_manager.authenticate_user(db, user_credentials.username, user_credentials.password) + if not user: + raise HTTPException(status_code=401, detail="Incorrect username or password") + + # Create access token + token_data = auth_manager.create_access_token(user) + + logger.info(f"User logged in: {user.username}") + + return LoginResponse( + access_token=token_data["access_token"], + token_type=token_data["token_type"], + expires_in=token_data["expires_in"], + user=UserResponse.model_validate(user) + ) + + +@router.get("/me", response_model=UserResponse) +def get_current_user_info(current_user: User = Depends(get_current_user)): + """Get current user information""" + return UserResponse.model_validate(current_user) diff --git a/app/api/v1/marketplace.py b/app/api/v1/marketplace.py new file mode 100644 index 00000000..f9849ba9 --- /dev/null +++ b/app/api/v1/marketplace.py @@ -0,0 +1,146 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from app.tasks.background_tasks import process_marketplace_import +from middleware.decorators import rate_limit +from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest +from models.database_models import User, MarketplaceImportJob, Shop +from datetime import datetime +import logging + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Marketplace Import Routes (Protected) +@router.post("/import-from-marketplace", response_model=MarketplaceImportJobResponse) +@rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports +async def import_products_from_marketplace( + request: MarketplaceImportRequest, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Import products from marketplace CSV with background processing (Protected)""" + + logger.info( + f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}") + + # Verify shop exists and user has access + shop = db.query(Shop).filter(Shop.shop_code == request.shop_code).first() + if not shop: + raise HTTPException(status_code=404, detail="Shop not found") + + # Check permissions: admin can import for any shop, others only for their own + if current_user.role != "admin" and shop.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="Access denied to this shop") + + # Create marketplace import job record + import_job = MarketplaceImportJob( + status="pending", + source_url=request.url, + marketplace=request.marketplace, + shop_code=request.shop_code, + user_id=current_user.id, + created_at=datetime.utcnow() + ) + db.add(import_job) + db.commit() + db.refresh(import_job) + + # Process in background + background_tasks.add_task( + process_marketplace_import, + import_job.id, + request.url, + request.marketplace, + request.shop_code, + request.batch_size or 1000 + ) + + return MarketplaceImportJobResponse( + job_id=import_job.id, + status="pending", + marketplace=request.marketplace, + shop_code=request.shop_code, + message=f"Marketplace import started from {request.marketplace}. Check status with " + f"/marketplace-import-status/{import_job.id}" + ) + + +@router.get("/marketplace-import-status/{job_id}", response_model=MarketplaceImportJobResponse) +def get_marketplace_import_status( + job_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get status of marketplace import job (Protected)""" + job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() + if not job: + raise HTTPException(status_code=404, detail="Marketplace import job not found") + + # Users can only see their own jobs, admins can see all + if current_user.role != "admin" and job.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Access denied to this import job") + + return MarketplaceImportJobResponse( + job_id=job.id, + status=job.status, + marketplace=job.marketplace, + shop_name=job.shop_name, + imported=job.imported_count or 0, + updated=job.updated_count or 0, + total_processed=job.total_processed or 0, + error_count=job.error_count or 0, + error_message=job.error_message, + created_at=job.created_at, + started_at=job.started_at, + completed_at=job.completed_at + ) + + +@router.get("/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) +def get_marketplace_import_jobs( + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + shop_name: Optional[str] = Query(None, description="Filter by shop name"), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=100), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get marketplace import jobs with filtering (Protected)""" + + query = db.query(MarketplaceImportJob) + + # Users can only see their own jobs, admins can see all + if current_user.role != "admin": + query = query.filter(MarketplaceImportJob.user_id == current_user.id) + + # Apply filters + if marketplace: + query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) + if shop_name: + query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) + + # Order by creation date (newest first) and apply pagination + jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() + + return [ + MarketplaceImportJobResponse( + job_id=job.id, + status=job.status, + marketplace=job.marketplace, + shop_name=job.shop_name, + imported=job.imported_count or 0, + updated=job.updated_count or 0, + total_processed=job.total_processed or 0, + error_count=job.error_count or 0, + error_message=job.error_message, + created_at=job.created_at, + started_at=job.started_at, + completed_at=job.completed_at + ) for job in jobs + ] diff --git a/app/api/v1/products.py b/app/api/v1/products.py new file mode 100644 index 00000000..d2147d4f --- /dev/null +++ b/app/api/v1/products.py @@ -0,0 +1,261 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from models.api_models import (ProductListResponse, ProductResponse, ProductCreate, ProductDetailResponse, + StockLocationResponse, StockSummaryResponse, ProductUpdate) +from models.database_models import User, Product, Stock +from datetime import datetime +import logging + +from utils.data_processing import GTINProcessor, PriceProcessor + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Initialize processors +gtin_processor = GTINProcessor() +price_processor = PriceProcessor() + + +# Enhanced Product Routes with Marketplace Support +@router.get("/products", response_model=ProductListResponse) +def get_products( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + brand: Optional[str] = Query(None), + category: Optional[str] = Query(None), + availability: Optional[str] = Query(None), + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + shop_name: Optional[str] = Query(None, description="Filter by shop name"), + search: Optional[str] = Query(None), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get products with advanced filtering including marketplace and shop (Protected)""" + + query = db.query(Product) + + # Apply filters + if brand: + query = query.filter(Product.brand.ilike(f"%{brand}%")) + if category: + query = query.filter(Product.google_product_category.ilike(f"%{category}%")) + if availability: + query = query.filter(Product.availability == availability) + if marketplace: + query = query.filter(Product.marketplace.ilike(f"%{marketplace}%")) + if shop_name: + query = query.filter(Product.shop_name.ilike(f"%{shop_name}%")) + if search: + # Search in title, description, and marketplace + search_term = f"%{search}%" + query = query.filter( + (Product.title.ilike(search_term)) | + (Product.description.ilike(search_term)) | + (Product.marketplace.ilike(search_term)) | + (Product.shop_name.ilike(search_term)) + ) + + # Get total count for pagination + total = query.count() + + # Apply pagination + products = query.offset(skip).limit(limit).all() + + return ProductListResponse( + products=products, + total=total, + skip=skip, + limit=limit + ) + + +@router.post("/products", response_model=ProductResponse) +def create_product( + product: ProductCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Create a new product with validation and marketplace support (Protected)""" + + # Check if product_id already exists + existing = db.query(Product).filter(Product.product_id == product.product_id).first() + if existing: + raise HTTPException(status_code=400, detail="Product with this ID already exists") + + # Process and validate GTIN if provided + if product.gtin: + normalized_gtin = gtin_processor.normalize(product.gtin) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + product.gtin = normalized_gtin + + # Process price if provided + if product.price: + parsed_price, currency = price_processor.parse_price_currency(product.price) + if parsed_price: + product.price = parsed_price + product.currency = currency + + # Set default marketplace if not provided + if not product.marketplace: + product.marketplace = "Letzshop" + + db_product = Product(**product.dict()) + db.add(db_product) + db.commit() + db.refresh(db_product) + + logger.info( + f"Created product {db_product.product_id} for marketplace {db_product.marketplace}, " + f"shop {db_product.shop_name}") + return db_product + + +@router.get("/products/{product_id}", response_model=ProductDetailResponse) +def get_product(product_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Get product with stock information (Protected)""" + + product = db.query(Product).filter(Product.product_id == product_id).first() + if not product: + raise HTTPException(status_code=404, detail="Product not found") + + # Get stock information if GTIN exists + stock_info = None + if product.gtin: + stock_entries = db.query(Stock).filter(Stock.gtin == product.gtin).all() + if stock_entries: + total_quantity = sum(entry.quantity for entry in stock_entries) + locations = [ + StockLocationResponse(location=entry.location, quantity=entry.quantity) + for entry in stock_entries + ] + stock_info = StockSummaryResponse( + gtin=product.gtin, + total_quantity=total_quantity, + locations=locations + ) + + return ProductDetailResponse( + product=product, + stock_info=stock_info + ) + + +@router.put("/products/{product_id}", response_model=ProductResponse) +def update_product( + product_id: str, + product_update: ProductUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Update product with validation and marketplace support (Protected)""" + + product = db.query(Product).filter(Product.product_id == product_id).first() + if not product: + raise HTTPException(status_code=404, detail="Product not found") + + # Update fields + update_data = product_update.dict(exclude_unset=True) + + # Validate GTIN if being updated + if "gtin" in update_data and update_data["gtin"]: + normalized_gtin = gtin_processor.normalize(update_data["gtin"]) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + update_data["gtin"] = normalized_gtin + + # Process price if being updated + if "price" in update_data and update_data["price"]: + parsed_price, currency = price_processor.parse_price_currency(update_data["price"]) + if parsed_price: + update_data["price"] = parsed_price + update_data["currency"] = currency + + for key, value in update_data.items(): + setattr(product, key, value) + + product.updated_at = datetime.utcnow() + db.commit() + db.refresh(product) + + return product + + +@router.delete("/products/{product_id}") +def delete_product( + product_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Delete product and associated stock (Protected)""" + + product = db.query(Product).filter(Product.product_id == product_id).first() + if not product: + raise HTTPException(status_code=404, detail="Product not found") + + # Delete associated stock entries if GTIN exists + if product.gtin: + db.query(Stock).filter(Stock.gtin == product.gtin).delete() + + db.delete(product) + db.commit() + + return {"message": "Product and associated stock deleted successfully"} + +# Export with streaming for large datasets (Protected) +@router.get("/export-csv") +async def export_csv( + marketplace: Optional[str] = Query(None, description="Filter by marketplace"), + shop_name: Optional[str] = Query(None, description="Filter by shop name"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Export products as CSV with streaming and marketplace filtering (Protected)""" + + def generate_csv(): + # Stream CSV generation for memory efficiency + yield "product_id,title,description,link,image_link,availability,price,currency,brand,gtin,marketplace,shop_name\n" + + batch_size = 1000 + offset = 0 + + while True: + query = db.query(Product) + + # Apply marketplace filters + if marketplace: + query = query.filter(Product.marketplace.ilike(f"%{marketplace}%")) + if shop_name: + query = query.filter(Product.shop_name.ilike(f"%{shop_name}%")) + + products = query.offset(offset).limit(batch_size).all() + if not products: + break + + for product in products: + # Create CSV row with marketplace fields + row = (f'"{product.product_id}","{product.title or ""}","{product.description or ""}",' + f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",' + f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",' + f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n') + yield row + + offset += batch_size + + filename = "products_export" + if marketplace: + filename += f"_{marketplace}" + if shop_name: + filename += f"_{shop_name}" + filename += ".csv" + + return StreamingResponse( + generate_csv(), + media_type="text/csv", + headers={"Content-Disposition": f"attachment; filename={filename}"} + ) diff --git a/app/api/v1/shops.py b/app/api/v1/shops.py new file mode 100644 index 00000000..4e5a990d --- /dev/null +++ b/app/api/v1/shops.py @@ -0,0 +1,188 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from app.tasks.background_tasks import process_marketplace_import +from middleware.decorators import rate_limit +from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest +from models.database_models import User, MarketplaceImportJob, Shop +from datetime import datetime +import logging + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Shop Management Routes +@router.post("/shops", response_model=ShopResponse) +def create_shop( + shop_data: ShopCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Create a new shop (Protected)""" + # Check if shop code already exists + existing_shop = db.query(Shop).filter(Shop.shop_code == shop_data.shop_code).first() + if existing_shop: + raise HTTPException(status_code=400, detail="Shop code already exists") + + # Create shop + new_shop = Shop( + **shop_data.dict(), + owner_id=current_user.id, + is_active=True, + is_verified=(current_user.role == "admin") # Auto-verify if admin creates shop + ) + + db.add(new_shop) + db.commit() + db.refresh(new_shop) + + logger.info(f"New shop created: {new_shop.shop_code} by {current_user.username}") + return new_shop + + +@router.get("/shops", response_model=ShopListResponse) +def get_shops( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + active_only: bool = Query(True), + verified_only: bool = Query(False), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get shops with filtering (Protected)""" + query = db.query(Shop) + + # Non-admin users can only see active and verified shops, plus their own + if current_user.role != "admin": + query = query.filter( + (Shop.is_active == True) & + ((Shop.is_verified == True) | (Shop.owner_id == current_user.id)) + ) + else: + # Admin can apply filters + if active_only: + query = query.filter(Shop.is_active == True) + if verified_only: + query = query.filter(Shop.is_verified == True) + + total = query.count() + shops = query.offset(skip).limit(limit).all() + + return ShopListResponse( + shops=shops, + total=total, + skip=skip, + limit=limit + ) + + +@router.get("/shops/{shop_code}", response_model=ShopResponse) +def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Get shop details (Protected)""" + shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() + if not shop: + raise HTTPException(status_code=404, detail="Shop not found") + + # Non-admin users can only see active verified shops or their own shops + if current_user.role != "admin": + if not shop.is_active or (not shop.is_verified and shop.owner_id != current_user.id): + raise HTTPException(status_code=404, detail="Shop not found") + + return shop + + +# Shop Product Management +@router.post("/shops/{shop_code}/products", response_model=ShopProductResponse) +def add_product_to_shop( + shop_code: str, + shop_product: ShopProductCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Add existing product to shop catalog with shop-specific settings (Protected)""" + + # Get and verify shop + shop = get_user_shop(shop_code, current_user, db) + + # Check if product exists + product = db.query(Product).filter(Product.product_id == shop_product.product_id).first() + if not product: + raise HTTPException(status_code=404, detail="Product not found in marketplace catalog") + + # Check if product already in shop + existing_shop_product = db.query(ShopProduct).filter( + ShopProduct.shop_id == shop.id, + ShopProduct.product_id == product.id + ).first() + + if existing_shop_product: + raise HTTPException(status_code=400, detail="Product already in shop catalog") + + # Create shop-product association + new_shop_product = ShopProduct( + shop_id=shop.id, + product_id=product.id, + **shop_product.dict(exclude={'product_id'}) + ) + + db.add(new_shop_product) + db.commit() + db.refresh(new_shop_product) + + # Return with product details + response = ShopProductResponse.model_validate(new_shop_product) + response.product = product + return response + + +@router.get("/shops/{shop_code}/products") +def get_shop_products( + shop_code: str, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + active_only: bool = Query(True), + featured_only: bool = Query(False), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get products in shop catalog (Protected)""" + + # Get shop (public can view active/verified shops) + shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() + if not shop: + raise HTTPException(status_code=404, detail="Shop not found") + + # Non-owners can only see active verified shops + if current_user.role != "admin" and shop.owner_id != current_user.id: + if not shop.is_active or not shop.is_verified: + raise HTTPException(status_code=404, detail="Shop not found") + + # Query shop products + query = db.query(ShopProduct).filter(ShopProduct.shop_id == shop.id) + + if active_only: + query = query.filter(ShopProduct.is_active == True) + if featured_only: + query = query.filter(ShopProduct.is_featured == True) + + total = query.count() + shop_products = query.offset(skip).limit(limit).all() + + # Format response + products = [] + for sp in shop_products: + product_response = ShopProductResponse.model_validate(sp) + product_response.product = sp.product + products.append(product_response) + + return { + "products": products, + "total": total, + "skip": skip, + "limit": limit, + "shop": ShopResponse.model_validate(shop) + } diff --git a/app/api/v1/stats.py b/app/api/v1/stats.py new file mode 100644 index 00000000..5bc96e15 --- /dev/null +++ b/app/api/v1/stats.py @@ -0,0 +1,84 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from app.tasks.background_tasks import process_marketplace_import +from middleware.decorators import rate_limit +from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest +from models.database_models import User, MarketplaceImportJob, Shop +from datetime import datetime +import logging + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Enhanced Statistics with Marketplace Support +@router.get("/stats", response_model=StatsResponse) +def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Get comprehensive statistics with marketplace data (Protected)""" + + # Use more efficient queries with proper indexes + total_products = db.query(Product).count() + + unique_brands = db.query(Product.brand).filter( + Product.brand.isnot(None), + Product.brand != "" + ).distinct().count() + + unique_categories = db.query(Product.google_product_category).filter( + Product.google_product_category.isnot(None), + Product.google_product_category != "" + ).distinct().count() + + # New marketplace statistics + unique_marketplaces = db.query(Product.marketplace).filter( + Product.marketplace.isnot(None), + Product.marketplace != "" + ).distinct().count() + + unique_shops = db.query(Product.shop_name).filter( + Product.shop_name.isnot(None), + Product.shop_name != "" + ).distinct().count() + + # Stock statistics + total_stock_entries = db.query(Stock).count() + total_inventory = db.query(func.sum(Stock.quantity)).scalar() or 0 + + return StatsResponse( + total_products=total_products, + unique_brands=unique_brands, + unique_categories=unique_categories, + unique_marketplaces=unique_marketplaces, + unique_shops=unique_shops, + total_stock_entries=total_stock_entries, + total_inventory_quantity=total_inventory + ) + + +@router.get("/marketplace-stats", response_model=List[MarketplaceStatsResponse]) +def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Get statistics broken down by marketplace (Protected)""" + + # Query to get stats per marketplace + marketplace_stats = db.query( + Product.marketplace, + func.count(Product.id).label('total_products'), + func.count(func.distinct(Product.shop_name)).label('unique_shops'), + func.count(func.distinct(Product.brand)).label('unique_brands') + ).filter( + Product.marketplace.isnot(None) + ).group_by(Product.marketplace).all() + + return [ + MarketplaceStatsResponse( + marketplace=stat.marketplace, + total_products=stat.total_products, + unique_shops=stat.unique_shops, + unique_brands=stat.unique_brands + ) for stat in marketplace_stats + ] + diff --git a/app/api/v1/stock.py b/app/api/v1/stock.py new file mode 100644 index 00000000..d60247e0 --- /dev/null +++ b/app/api/v1/stock.py @@ -0,0 +1,315 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.api.deps import get_current_user +from app.tasks.background_tasks import process_marketplace_import +from middleware.decorators import rate_limit +from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, StockResponse, \ + StockSummaryResponse +from models.database_models import User, MarketplaceImportJob, Shop +from datetime import datetime +import logging + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Stock Management Routes (Protected) + +@router.post("/stock", response_model=StockResponse) +def set_stock(stock: StockCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)""" + + # Normalize GTIN + def normalize_gtin(gtin_value): + if not gtin_value: + return None + gtin_str = str(gtin_value).strip() + if '.' in gtin_str: + gtin_str = gtin_str.split('.')[0] + gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + if len(gtin_clean) in [8, 12, 13, 14]: + return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) + return gtin_clean if gtin_clean else None + + normalized_gtin = normalize_gtin(stock.gtin) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + + # Check if stock entry already exists for this GTIN and location + existing_stock = db.query(Stock).filter( + Stock.gtin == normalized_gtin, + Stock.location == stock.location.strip().upper() + ).first() + + if existing_stock: + # Update existing stock (SET to exact quantity) + old_quantity = existing_stock.quantity + existing_stock.quantity = stock.quantity + existing_stock.updated_at = datetime.utcnow() + db.commit() + db.refresh(existing_stock) + logger.info(f"Updated stock for GTIN {normalized_gtin} at {stock.location}: {old_quantity} → {stock.quantity}") + return existing_stock + else: + # Create new stock entry + new_stock = Stock( + gtin=normalized_gtin, + location=stock.location.strip().upper(), + quantity=stock.quantity + ) + db.add(new_stock) + db.commit() + db.refresh(new_stock) + logger.info(f"Created new stock for GTIN {normalized_gtin} at {stock.location}: {stock.quantity}") + return new_stock + + +@router.post("/stock/add", response_model=StockResponse) +def add_stock(stock: StockAdd, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)""" + + # Normalize GTIN + def normalize_gtin(gtin_value): + if not gtin_value: + return None + gtin_str = str(gtin_value).strip() + if '.' in gtin_str: + gtin_str = gtin_str.split('.')[0] + gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + if len(gtin_clean) in [8, 12, 13, 14]: + return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) + return gtin_clean if gtin_clean else None + + normalized_gtin = normalize_gtin(stock.gtin) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + + # Check if stock entry already exists for this GTIN and location + existing_stock = db.query(Stock).filter( + Stock.gtin == normalized_gtin, + Stock.location == stock.location.strip().upper() + ).first() + + if existing_stock: + # Add to existing stock + old_quantity = existing_stock.quantity + existing_stock.quantity += stock.quantity + existing_stock.updated_at = datetime.utcnow() + db.commit() + db.refresh(existing_stock) + logger.info( + f"Added stock for GTIN {normalized_gtin} at {stock.location}: {old_quantity} + {stock.quantity} = {existing_stock.quantity}") + return existing_stock + else: + # Create new stock entry with the quantity + new_stock = Stock( + gtin=normalized_gtin, + location=stock.location.strip().upper(), + quantity=stock.quantity + ) + db.add(new_stock) + db.commit() + db.refresh(new_stock) + logger.info(f"Created new stock for GTIN {normalized_gtin} at {stock.location}: {stock.quantity}") + return new_stock + + +@router.post("/stock/remove", response_model=StockResponse) +def remove_stock(stock: StockAdd, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Remove quantity from existing stock for a GTIN at a specific location""" + + # Normalize GTIN + def normalize_gtin(gtin_value): + if not gtin_value: + return None + gtin_str = str(gtin_value).strip() + if '.' in gtin_str: + gtin_str = gtin_str.split('.')[0] + gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + if len(gtin_clean) in [8, 12, 13, 14]: + return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) + return gtin_clean if gtin_clean else None + + normalized_gtin = normalize_gtin(stock.gtin) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + + # Find existing stock entry + existing_stock = db.query(Stock).filter( + Stock.gtin == normalized_gtin, + Stock.location == stock.location.strip().upper() + ).first() + + if not existing_stock: + raise HTTPException( + status_code=404, + detail=f"No stock found for GTIN {normalized_gtin} at location {stock.location}" + ) + + # Check if we have enough stock to remove + if existing_stock.quantity < stock.quantity: + raise HTTPException( + status_code=400, + detail=f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock.quantity}" + ) + + # Remove from existing stock + old_quantity = existing_stock.quantity + existing_stock.quantity -= stock.quantity + existing_stock.updated_at = datetime.utcnow() + db.commit() + db.refresh(existing_stock) + logger.info( + f"Removed stock for GTIN {normalized_gtin} at {stock.location}: {old_quantity} - {stock.quantity} = {existing_stock.quantity}") + return existing_stock + + +@router.get("/stock/{gtin}", response_model=StockSummaryResponse) +def get_stock_by_gtin(gtin: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Get all stock locations and total quantity for a specific GTIN""" + + # Normalize GTIN + def normalize_gtin(gtin_value): + if not gtin_value: + return None + gtin_str = str(gtin_value).strip() + if '.' in gtin_str: + gtin_str = gtin_str.split('.')[0] + gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + if len(gtin_clean) in [8, 12, 13, 14]: + return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) + return gtin_clean if gtin_clean else None + + normalized_gtin = normalize_gtin(gtin) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + + # Get all stock entries for this GTIN + stock_entries = db.query(Stock).filter(Stock.gtin == normalized_gtin).all() + + if not stock_entries: + raise HTTPException(status_code=404, detail=f"No stock found for GTIN: {gtin}") + + # Calculate total quantity and build locations list + total_quantity = 0 + locations = [] + + for entry in stock_entries: + total_quantity += entry.quantity + locations.append(StockLocationResponse( + location=entry.location, + quantity=entry.quantity + )) + + # Try to get product title for reference + product = db.query(Product).filter(Product.gtin == normalized_gtin).first() + product_title = product.title if product else None + + return StockSummaryResponse( + gtin=normalized_gtin, + total_quantity=total_quantity, + locations=locations, + product_title=product_title + ) + + +@router.get("/stock/{gtin}/total") +def get_total_stock(gtin: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Get total quantity in stock for a specific GTIN""" + + # Normalize GTIN + def normalize_gtin(gtin_value): + if not gtin_value: + return None + gtin_str = str(gtin_value).strip() + if '.' in gtin_str: + gtin_str = gtin_str.split('.')[0] + gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + if len(gtin_clean) in [8, 12, 13, 14]: + return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) + return gtin_clean if gtin_clean else None + + normalized_gtin = normalize_gtin(gtin) + if not normalized_gtin: + raise HTTPException(status_code=400, detail="Invalid GTIN format") + + # Calculate total stock + total_stock = db.query(Stock).filter(Stock.gtin == normalized_gtin).all() + total_quantity = sum(entry.quantity for entry in total_stock) + + # Get product info for context + product = db.query(Product).filter(Product.gtin == normalized_gtin).first() + + return { + "gtin": normalized_gtin, + "total_quantity": total_quantity, + "product_title": product.title if product else None, + "locations_count": len(total_stock) + } + + +@router.get("/stock", response_model=List[StockResponse]) +def get_all_stock( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + location: Optional[str] = Query(None, description="Filter by location"), + gtin: Optional[str] = Query(None, description="Filter by GTIN"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get all stock entries with optional filtering""" + query = db.query(Stock) + + if location: + query = query.filter(Stock.location.ilike(f"%{location}%")) + + if gtin: + # Normalize GTIN for search + def normalize_gtin(gtin_value): + if not gtin_value: + return None + gtin_str = str(gtin_value).strip() + if '.' in gtin_str: + gtin_str = gtin_str.split('.')[0] + gtin_clean = ''.join(filter(str.isdigit, gtin_str)) + if len(gtin_clean) in [8, 12, 13, 14]: + return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) + return gtin_clean if gtin_clean else None + + normalized_gtin = normalize_gtin(gtin) + if normalized_gtin: + query = query.filter(Stock.gtin == normalized_gtin) + + stock_entries = query.offset(skip).limit(limit).all() + return stock_entries + + +@router.put("/stock/{stock_id}", response_model=StockResponse) +def update_stock(stock_id: int, stock_update: StockUpdate, db: Session = Depends(get_db), + current_user: User = Depends(get_current_user)): + """Update stock quantity for a specific stock entry""" + stock_entry = db.query(Stock).filter(Stock.id == stock_id).first() + if not stock_entry: + raise HTTPException(status_code=404, detail="Stock entry not found") + + stock_entry.quantity = stock_update.quantity + stock_entry.updated_at = datetime.utcnow() + db.commit() + db.refresh(stock_entry) + return stock_entry + + +@router.delete("/stock/{stock_id}") +def delete_stock(stock_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): + """Delete a stock entry""" + stock_entry = db.query(Stock).filter(Stock.id == stock_id).first() + if not stock_entry: + raise HTTPException(status_code=404, detail="Stock entry not found") + + db.delete(stock_entry) + db.commit() + return {"message": "Stock entry deleted successfully"} + diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 00000000..aa7d08d1 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,26 @@ +from pydantic_settings import BaseSettings +from typing import List +import os + + +class Settings(BaseSettings): + PROJECT_NAME: str = "Ecommerce Backend API with Marketplace Support" + DESCRIPTION: str = "Advanced product management system with JWT authentication" + VERSION: str = "2.2.0" + + DATABASE_URL: str = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ecommerce") + SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-change-in-production") + ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + + ALLOWED_HOSTS: List[str] = ["*"] # Configure for production + + # Rate limiting + RATE_LIMIT_REQUESTS: int = 100 + RATE_LIMIT_WINDOW: int = 3600 + + class Config: + env_file = ".env" + extra = "ignore" + + +settings = Settings() diff --git a/app/core/database.py b/app/core/database.py new file mode 100644 index 00000000..85a32b28 --- /dev/null +++ b/app/core/database.py @@ -0,0 +1,21 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, Session +from .config import settings + +engine = create_engine(settings.DATABASE_URL) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + + +# Database dependency with connection pooling +def get_db(): + db = SessionLocal() + try: + yield db + except Exception as e: + db.rollback() + raise + finally: + db.close() diff --git a/app/core/lifespan.py b/app/core/lifespan.py new file mode 100644 index 00000000..76c4953a --- /dev/null +++ b/app/core/lifespan.py @@ -0,0 +1,48 @@ +from contextlib import asynccontextmanager +from fastapi import FastAPI +from sqlalchemy import text +from .logging import setup_logging +from .database import engine +from models.database_models import Base +import logging + +logger = logging.getLogger(__name__) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan events""" + # Startup + logger = setup_logging() # Configure logging first + logger.info("Starting up ecommerce API") + + # Create tables + Base.metadata.create_all(bind=engine) + + # Create indexes + create_indexes() + + yield + + # Shutdown + logger.info("Shutting down ecommerce API") + + +def create_indexes(): + """Create database indexes""" + with engine.connect() as conn: + try: + # User indexes + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)")) + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)")) + + # Product indexes + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)")) + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)")) + + # Stock indexes + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)")) + + conn.commit() + logger.info("Database indexes created successfully") + except Exception as e: + logger.warning(f"Index creation warning: {e}") diff --git a/app/core/logging.py b/app/core/logging.py new file mode 100644 index 00000000..29cc8ea0 --- /dev/null +++ b/app/core/logging.py @@ -0,0 +1,42 @@ +# app/core/logging.py +import logging +import sys +from pathlib import Path +from app.core.config import settings + + +def setup_logging(): + """Configure application logging with file and console handlers""" + + # Create logs directory if it doesn't exist + log_file = Path(settings.LOG_FILE) + log_file.parent.mkdir(parents=True, exist_ok=True) + + # Configure root logger + logger = logging.getLogger() + logger.setLevel(getattr(logging, settings.LOG_LEVEL.upper())) + + # Remove existing handlers + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Create formatters + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + # Configure specific loggers + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING) + + return logging.getLogger(__name__) diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/services/import_service.py b/app/services/import_service.py new file mode 100644 index 00000000..e69de29b diff --git a/app/services/product_service.py b/app/services/product_service.py new file mode 100644 index 00000000..8a6b9f67 --- /dev/null +++ b/app/services/product_service.py @@ -0,0 +1,79 @@ +from sqlalchemy.orm import Session +from models.database_models import Product +from models.api_models import ProductCreate +from utils.data_processing import GTINProcessor, PriceProcessor +from typing import Optional, List +import logging + +logger = logging.getLogger(__name__) + + +class ProductService: + def __init__(self): + self.gtin_processor = GTINProcessor() + self.price_processor = PriceProcessor() + + def create_product(self, db: Session, product_data: ProductCreate) -> Product: + """Create a new product with validation""" + # Process and validate GTIN if provided + if product_data.gtin: + normalized_gtin = self.gtin_processor.normalize(product_data.gtin) + if not normalized_gtin: + raise ValueError("Invalid GTIN format") + product_data.gtin = normalized_gtin + + # Process price if provided + if product_data.price: + parsed_price, currency = self.price_processor.parse_price_currency(product_data.price) + if parsed_price: + product_data.price = parsed_price + product_data.currency = currency + + # Set default marketplace if not provided + if not product_data.marketplace: + product_data.marketplace = "Letzshop" + + db_product = Product(**product_data.dict()) + db.add(db_product) + db.commit() + db.refresh(db_product) + + logger.info(f"Created product {db_product.product_id}") + return db_product + + def get_products_with_filters( + self, + db: Session, + skip: int = 0, + limit: int = 100, + brand: Optional[str] = None, + category: Optional[str] = None, + marketplace: Optional[str] = None, + search: Optional[str] = None + ) -> tuple[List[Product], int]: + """Get products with filtering and pagination""" + query = db.query(Product) + + # Apply filters + if brand: + query = query.filter(Product.brand.ilike(f"%{brand}%")) + if category: + query = query.filter(Product.google_product_category.ilike(f"%{category}%")) + if marketplace: + query = query.filter(Product.marketplace.ilike(f"%{marketplace}%")) + if search: + search_term = f"%{search}%" + query = query.filter( + (Product.title.ilike(search_term)) | + (Product.description.ilike(search_term)) | + (Product.marketplace.ilike(search_term)) + ) + + total = query.count() + products = query.offset(skip).limit(limit).all() + + return products, total + + +# Create service instance +product_service = ProductService() diff --git a/app/services/stock_service.py b/app/services/stock_service.py new file mode 100644 index 00000000..e69de29b diff --git a/app/tasks/__init__.py b/app/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/tasks/background_tasks.py b/app/tasks/background_tasks.py new file mode 100644 index 00000000..708f6f3b --- /dev/null +++ b/app/tasks/background_tasks.py @@ -0,0 +1,63 @@ +from sqlalchemy.orm import Session +from app.core.database import SessionLocal +from models.database_models import MarketplaceImportJob +from utils.csv_processor import CSVProcessor +from datetime import datetime +import logging + +logger = logging.getLogger(__name__) + + +async def process_marketplace_import( + job_id: int, + url: str, + marketplace: str, + shop_name: str, + batch_size: int = 1000 +): + """Background task to process marketplace CSV import""" + db = SessionLocal() + csv_processor = CSVProcessor() + + try: + # Update job status + job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() + if not job: + logger.error(f"Import job {job_id} not found") + return + + job.status = "processing" + job.started_at = datetime.utcnow() + db.commit() + + logger.info(f"Processing import: Job {job_id}, Marketplace: {marketplace}") + + # Process CSV + result = await csv_processor.process_marketplace_csv_from_url( + url, marketplace, shop_name, batch_size, db + ) + + # Update job with results + job.status = "completed" + job.completed_at = datetime.utcnow() + job.imported_count = result["imported"] + job.updated_count = result["updated"] + job.error_count = result.get("errors", 0) + job.total_processed = result["total_processed"] + + if result.get("errors", 0) > 0: + job.status = "completed_with_errors" + job.error_message = f"{result['errors']} rows had errors" + + db.commit() + logger.info(f"Import job {job_id} completed successfully") + + except Exception as e: + logger.error(f"Import job {job_id} failed: {e}") + job.status = "failed" + job.completed_at = datetime.utcnow() + job.error_message = str(e) + db.commit() + + finally: + db.close() diff --git a/comprehensive_readme.md b/comprehensive_readme.md new file mode 100644 index 00000000..885cc78d --- /dev/null +++ b/comprehensive_readme.md @@ -0,0 +1,493 @@ +# Ecommerce Backend API with Marketplace Support + +A comprehensive FastAPI-based product management system with JWT authentication, marketplace-aware CSV import/export, multi-shop support, and advanced stock management capabilities. + +## Features + +- **JWT Authentication** - Secure user registration, login, and role-based access control +- **Marketplace Integration** - Support for multiple marketplaces (Letzshop, Amazon, eBay, Etsy, Shopify, etc.) +- **Multi-Shop Management** - Shop creation, ownership validation, and product catalog management +- **Advanced Product Management** - GTIN validation, price processing, and comprehensive filtering +- **Stock Management** - Multi-location inventory tracking with add/remove/set operations +- **CSV Import/Export** - Background processing of marketplace CSV files with progress tracking +- **Rate Limiting** - Built-in request rate limiting for API protection +- **Admin Panel** - Administrative functions for user and shop management +- **Statistics & Analytics** - Comprehensive reporting on products, marketplaces, and inventory + +## Technology Stack + +- **FastAPI** - Modern, fast web framework for building APIs +- **SQLAlchemy** - SQL toolkit and Object-Relational Mapping (ORM) +- **PostgreSQL** - Primary database (SQLite supported for development) +- **JWT** - JSON Web Tokens for secure authentication +- **Pydantic** - Data validation using Python type annotations +- **Pandas** - Data processing for CSV operations +- **bcrypt** - Password hashing +- **Pytest** - Comprehensive testing framework + +## Directory Structure + +``` +letzshop_api/ +├── main.py # FastAPI application entry point +├── app/ +│ ├── core/ +│ │ ├── config.py # Configuration settings +│ │ ├── database.py # Database setup +│ │ └── lifespan.py # App lifecycle management +│ ├── api/ +│ │ ├── deps.py # Common dependencies +│ │ ├── main.py # API router setup +│ │ └── v1/ # API version 1 routes +│ │ ├── auth.py # Authentication endpoints +│ │ ├── products.py # Product management +│ │ ├── stock.py # Stock operations +│ │ ├── shops.py # Shop management +│ │ ├── marketplace.py # Marketplace imports +│ │ ├── admin.py # Admin functions +│ │ └── stats.py # Statistics +│ ├── services/ # Business logic layer +│ └── tasks/ # Background task processing +├── models/ +│ ├── database_models.py # SQLAlchemy ORM models +│ └── api_models.py # Pydantic API models +├── utils/ +│ ├── data_processing.py # GTIN and price processing +│ ├── csv_processor.py # CSV import/export +│ └── database.py # Database utilities +├── middleware/ +│ ├── auth.py # JWT authentication +│ ├── rate_limiter.py # Rate limiting +│ ├── error_handler.py # Error handling +│ └── logging_middleware.py # Request logging +├── tests/ # Comprehensive test suite +└── requirements.txt # Dependencies +``` + +## Quick Start + +### 1. Installation + +```bash +# Clone the repository +git clone +cd letzshop_api + +# Create virtual environment +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate + +# Install dependencies +pip install -r requirements.txt +``` + +### 2. Environment Configuration + +Create a `.env` file in the project root: + +```env +# Database Configuration +DATABASE_URL=postgresql://user:password@localhost/ecommerce +# For development, you can use SQLite: +# DATABASE_URL=sqlite:///./ecommerce.db + +# JWT Configuration +SECRET_KEY=your-super-secret-key-change-in-production +ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# API Configuration +ALLOWED_HOSTS=["*"] +RATE_LIMIT_REQUESTS=100 +RATE_LIMIT_WINDOW=3600 + +# Application Settings +PROJECT_NAME="Ecommerce Backend API" +VERSION="2.2.0" +DEBUG=True +``` + +### 3. Database Setup + +```bash +# The application will automatically create tables on startup +# For production, consider using Alembic for migrations + +# Install PostgreSQL (if using PostgreSQL) +# Create database +createdb ecommerce + +# Run the application (tables will be created automatically) +python main.py +``` + +### 4. Running the Application + +```bash +# Development server +python main.py + +# Or using uvicorn directly +uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +The API will be available at: +- **API Documentation**: http://localhost:8000/docs (Swagger UI) +- **Alternative Docs**: http://localhost:8000/redoc +- **Health Check**: http://localhost:8000/health + +## API Usage + +### Authentication + +#### Register a new user +```bash +curl -X POST "http://localhost:8000/api/v1/auth/register" \ + -H "Content-Type: application/json" \ + -d '{ + "email": "user@example.com", + "username": "testuser", + "password": "securepassword123" + }' +``` + +#### Login +```bash +curl -X POST "http://localhost:8000/api/v1/auth/login" \ + -H "Content-Type: application/json" \ + -d '{ + "username": "testuser", + "password": "securepassword123" + }' +``` + +#### Use JWT Token +```bash +# Get token from login response and use in subsequent requests +curl -X GET "http://localhost:8000/api/v1/products" \ + -H "Authorization: Bearer YOUR_JWT_TOKEN" +``` + +### Product Management + +#### Create a product +```bash +curl -X POST "http://localhost:8000/api/v1/products" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "product_id": "PROD001", + "title": "Amazing Product", + "description": "An amazing product description", + "price": "29.99", + "currency": "EUR", + "brand": "BrandName", + "gtin": "1234567890123", + "availability": "in stock", + "marketplace": "Letzshop", + "shop_name": "MyShop" + }' +``` + +#### Get products with filtering +```bash +# Get all products +curl -X GET "http://localhost:8000/api/v1/products" \ + -H "Authorization: Bearer YOUR_TOKEN" + +# Filter by marketplace +curl -X GET "http://localhost:8000/api/v1/products?marketplace=Amazon&limit=50" \ + -H "Authorization: Bearer YOUR_TOKEN" + +# Search products +curl -X GET "http://localhost:8000/api/v1/products?search=Amazing&brand=BrandName" \ + -H "Authorization: Bearer YOUR_TOKEN" +``` + +### Stock Management + +#### Set stock quantity +```bash +curl -X POST "http://localhost:8000/api/v1/stock" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "gtin": "1234567890123", + "location": "WAREHOUSE_A", + "quantity": 100 + }' +``` + +#### Add stock +```bash +curl -X POST "http://localhost:8000/api/v1/stock/add" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "gtin": "1234567890123", + "location": "WAREHOUSE_A", + "quantity": 25 + }' +``` + +#### Check stock levels +```bash +curl -X GET "http://localhost:8000/api/v1/stock/1234567890123" \ + -H "Authorization: Bearer YOUR_TOKEN" +``` + +### Marketplace Import + +#### Import products from CSV +```bash +curl -X POST "http://localhost:8000/api/v1/marketplace/import-from-marketplace" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "url": "https://example.com/products.csv", + "marketplace": "Amazon", + "shop_code": "MYSHOP", + "batch_size": 1000 + }' +``` + +#### Check import status +```bash +curl -X GET "http://localhost:8000/api/v1/marketplace/marketplace-import-status/1" \ + -H "Authorization: Bearer YOUR_TOKEN" +``` + +### Export Data + +#### Export products to CSV +```bash +# Export all products +curl -X GET "http://localhost:8000/api/v1/export-csv" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -o products_export.csv + +# Export with filters +curl -X GET "http://localhost:8000/api/v1/export-csv?marketplace=Amazon&shop_name=MyShop" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -o amazon_products.csv +``` + +## CSV Import Format + +The system supports CSV imports with the following headers: + +### Required Fields +- `product_id` - Unique product identifier +- `title` - Product title + +### Optional Fields +- `description` - Product description +- `link` - Product URL +- `image_link` - Product image URL +- `availability` - Stock availability (in stock, out of stock, preorder) +- `price` - Product price +- `currency` - Price currency (EUR, USD, etc.) +- `brand` - Product brand +- `gtin` - Global Trade Item Number (EAN/UPC) +- `google_product_category` - Product category +- `marketplace` - Source marketplace +- `shop_name` - Shop/seller name + +### Example CSV +```csv +product_id,title,description,price,currency,brand,gtin,marketplace,shop_name +PROD001,"Amazing Widget","The best widget ever",29.99,EUR,WidgetCorp,1234567890123,Letzshop,MyShop +PROD002,"Super Gadget","A fantastic gadget",19.99,EUR,GadgetInc,9876543210987,Amazon,TechStore +``` + +## Testing + +### Run Tests +```bash +# Install test dependencies +pip install -r tests/requirements_test.txt + +# Run all tests +pytest tests/ -v + +# Run with coverage report +pytest tests/ --cov=app --cov=models --cov=utils --cov-report=html + +# Run specific test categories +pytest tests/ -m unit -v # Unit tests only +pytest tests/ -m integration -v # Integration tests only +pytest tests/ -m "not slow" -v # Fast tests only + +# Run specific test files +pytest tests/test_auth.py -v # Authentication tests +pytest tests/test_products.py -v # Product tests +pytest tests/test_stock.py -v # Stock management tests +``` + +### Test Coverage +The test suite includes: +- **Unit Tests** - Individual component testing +- **Integration Tests** - Full workflow testing +- **Security Tests** - Authentication, authorization, input validation +- **Performance Tests** - Load testing with large datasets +- **Error Handling Tests** - Edge cases and error conditions + +## API Reference + +### Authentication Endpoints +- `POST /api/v1/auth/register` - Register new user +- `POST /api/v1/auth/login` - Login user +- `GET /api/v1/auth/me` - Get current user info + +### Product Endpoints +- `GET /api/v1/products` - List products with filtering +- `POST /api/v1/products` - Create new product +- `GET /api/v1/products/{product_id}` - Get specific product +- `PUT /api/v1/products/{product_id}` - Update product +- `DELETE /api/v1/products/{product_id}` - Delete product + +### Stock Endpoints +- `POST /api/v1/stock` - Set stock quantity +- `POST /api/v1/stock/add` - Add to stock +- `POST /api/v1/stock/remove` - Remove from stock +- `GET /api/v1/stock/{gtin}` - Get stock by GTIN +- `GET /api/v1/stock/{gtin}/total` - Get total stock +- `GET /api/v1/stock` - List all stock entries + +### Shop Endpoints +- `POST /api/v1/shops` - Create new shop +- `GET /api/v1/shops` - List shops +- `GET /api/v1/shops/{shop_code}` - Get specific shop + +### Marketplace Endpoints +- `POST /api/v1/marketplace/import-from-marketplace` - Start CSV import +- `GET /api/v1/marketplace/marketplace-import-status/{job_id}` - Check import status +- `GET /api/v1/marketplace/marketplace-import-jobs` - List import jobs + +### Statistics Endpoints +- `GET /api/v1/stats` - Get general statistics +- `GET /api/v1/stats/marketplace-stats` - Get marketplace statistics + +### Admin Endpoints (Admin only) +- `GET /api/v1/admin/users` - List all users +- `PUT /api/v1/admin/users/{user_id}/status` - Toggle user status +- `GET /api/v1/admin/shops` - List all shops +- `PUT /api/v1/admin/shops/{shop_id}/verify` - Verify/unverify shop + +## Database Schema + +### Core Tables +- **users** - User accounts and authentication +- **products** - Product catalog with marketplace info +- **stock** - Inventory tracking by location and GTIN +- **shops** - Shop/seller information +- **shop_products** - Shop-specific product settings +- **marketplace_import_jobs** - Background import job tracking + +### Key Relationships +- Users own shops (one-to-many) +- Products belong to marketplaces and shops +- Stock entries are linked to products via GTIN +- Import jobs track user-initiated imports + +## Security Features + +- **JWT Authentication** - Secure token-based authentication +- **Password Hashing** - bcrypt for secure password storage +- **Role-Based Access** - User and admin role separation +- **Rate Limiting** - Protection against API abuse +- **Input Validation** - Comprehensive data validation +- **SQL Injection Protection** - Parameterized queries +- **CORS Configuration** - Cross-origin request handling + +## Performance Optimizations + +- **Database Indexing** - Strategic indexes on key columns +- **Pagination** - Efficient data retrieval with skip/limit +- **Streaming Responses** - Memory-efficient CSV exports +- **Background Processing** - Async import job handling +- **Connection Pooling** - Efficient database connections +- **Query Optimization** - Optimized database queries + +## Deployment + +### Docker Deployment +```dockerfile +FROM python:3.11 + +WORKDIR /app + +COPY requirements.txt . +RUN pip install -r requirements.txt + +COPY . . + +EXPOSE 8000 + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +### Production Considerations +- Use PostgreSQL for production database +- Set strong SECRET_KEY in environment +- Configure proper CORS settings +- Enable HTTPS +- Set up monitoring and logging +- Use a reverse proxy (nginx) +- Configure database connection pooling +- Set up backup strategies + +## Contributing + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Add tests for new functionality +5. Ensure all tests pass +6. Submit a pull request + +### Development Setup +```bash +# Install development dependencies +pip install -r requirements_dev.txt + +# Run pre-commit hooks +pre-commit install + +# Run linting +flake8 app/ models/ utils/ +black app/ models/ utils/ + +# Run type checking +mypy app/ +``` + +## Support & Documentation + +- **API Documentation**: http://localhost:8000/docs +- **Health Check**: http://localhost:8000/health +- **Version Info**: http://localhost:8000/ + +For issues and feature requests, please create an issue in the repository. + +## License + +[Specify your license here] + +## Changelog + +### v2.2.0 +- Added marketplace-aware product import +- Implemented multi-shop support +- Enhanced stock management with location tracking +- Added comprehensive test suite +- Improved authentication and authorization + +### v2.1.0 +- Added JWT authentication +- Implemented role-based access control +- Added CSV import/export functionality + +### v2.0.0 +- Complete rewrite with FastAPI +- Added PostgreSQL support +- Implemented comprehensive API documentation \ No newline at end of file diff --git a/config/settings.py b/config/settings.py index 998df2db..9b2ce534 100644 --- a/config/settings.py +++ b/config/settings.py @@ -10,6 +10,7 @@ class Settings(BaseSettings): # JWT jwt_secret_key: str = "change-this-in-production" jwt_expire_hours: int = 24 + jwt_expire_minutes: int = 30 # API api_host: str = "0.0.0.0" diff --git a/main.py b/main.py index 53838a14..27ce2b4e 100644 --- a/main.py +++ b/main.py @@ -1,265 +1,45 @@ -from fastapi import FastAPI, HTTPException, Query, Depends, BackgroundTasks -from fastapi.responses import StreamingResponse +from fastapi import FastAPI, Depends, HTTPException +from sqlalchemy.orm import Session +from sqlalchemy import text +from app.core.config import settings +from app.core.lifespan import lifespan +from app.core.database import get_db +from datetime import datetime +from app.api.main import api_router from fastapi.middleware.cors import CORSMiddleware -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from pydantic import BaseModel, Field, validator -from typing import Optional, List, Dict, Any -from datetime import datetime, timedelta -from sqlalchemy import create_engine, Column, Integer, String, DateTime, text, ForeignKey, Index, func -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session, relationship -from contextlib import asynccontextmanager -import pandas as pd -import requests -from io import StringIO, BytesIO import logging -import asyncio -import time -from functools import wraps -import os -from dotenv import load_dotenv -# Import utility modules -from utils.data_processing import GTINProcessor, PriceProcessor -from utils.csv_processor import CSVProcessor -from utils.database import get_db_engine, get_session_local -from models.database_models import Base, Product, Stock, User, MarketplaceImportJob, Shop, ShopProduct -from models.api_models import * -from middleware.rate_limiter import RateLimiter -from middleware.auth import AuthManager -# Load environment variables -load_dotenv() - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) logger = logging.getLogger(__name__) -# Initialize processors -gtin_processor = GTINProcessor() -price_processor = PriceProcessor() -csv_processor = CSVProcessor() - -# Database setup -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ecommerce") -engine = get_db_engine(DATABASE_URL) -SessionLocal = get_session_local(engine) - -# Rate limiter and auth manager -rate_limiter = RateLimiter() -auth_manager = AuthManager() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Application lifespan events""" - # Startup - logger.info("Starting up ecommerce API with marketplace import support") - - # Create tables - Base.metadata.create_all(bind=engine) - - # Create default admin user - db = SessionLocal() - try: - auth_manager.create_default_admin_user(db) - except Exception as e: - logger.error(f"Failed to create default admin user: {e}") - finally: - db.close() - - # Add indexes - with engine.connect() as conn: - try: - # User indexes - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_role ON users(role)")) - - # Product indexes (including new marketplace indexes) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_brand ON products(brand)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_category ON products(google_product_category)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_availability ON products(availability)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_shop_name ON products(shop_name)")) - conn.execute( - text("CREATE INDEX IF NOT EXISTS idx_product_marketplace_shop ON products(marketplace, shop_name)")) - - # Stock indexes - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_location ON stock(location)")) - - # Marketplace import job indexes - conn.execute(text( - "CREATE INDEX IF NOT EXISTS idx_marketplace_import_marketplace ON marketplace_import_jobs(marketplace)")) - conn.execute(text( - "CREATE INDEX IF NOT EXISTS idx_marketplace_import_shop_name ON marketplace_import_jobs(shop_name)")) - conn.execute( - text("CREATE INDEX IF NOT EXISTS idx_marketplace_import_user_id ON marketplace_import_jobs(user_id)")) - - conn.commit() - logger.info("Database indexes created successfully") - except Exception as e: - logger.warning(f"Index creation warning: {e}") - - yield - - # Shutdown - logger.info("Shutting down ecommerce API") - - # FastAPI app with lifespan app = FastAPI( - title="Ecommerce Backend API with Marketplace Support", - description="Advanced product management system with JWT authentication, marketplace-aware CSV " - "import/export and stock management", - version="2.2.0", + title=settings.PROJECT_NAME, + description=settings.DESCRIPTION, + version=settings.VERSION, lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Configure appropriately for production + allow_origins=settings.ALLOWED_HOSTS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) -# Security -security = HTTPBearer() - - -# Database dependency with connection pooling -def get_db(): - db = SessionLocal() - try: - yield db - except Exception as e: - db.rollback() - logger.error(f"Database error: {e}") - raise - finally: - db.close() - - -# Authentication dependencies -def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db)): - """Get current authenticated user""" - return auth_manager.get_current_user(db, credentials) - - -def get_current_admin_user(current_user: User = Depends(get_current_user)): - """Require admin user""" - return auth_manager.require_admin(current_user) - - -def get_user_shop(shop_code: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): - """Get shop and verify user ownership (or admin)""" - shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() - if not shop: - raise HTTPException(status_code=404, detail="Shop not found") - - # Admin can access any shop, owners can access their own shops - if current_user.role != "admin" and shop.owner_id != current_user.id: - raise HTTPException(status_code=403, detail="Access denied to this shop") - - return shop - - -# Rate limiting decorator -def rate_limit(max_requests: int = 100, window_seconds: int = 3600): - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - # Extract client IP or user ID for rate limiting - client_id = "anonymous" # In production, extract from request - - if not rate_limiter.allow_request(client_id, max_requests, window_seconds): - raise HTTPException( - status_code=429, - detail="Rate limit exceeded" - ) - - return await func(*args, **kwargs) - - return wrapper - - return decorator - - -# Authentication Routes -@app.post("/register", response_model=UserResponse) -def register_user(user_data: UserRegister, db: Session = Depends(get_db)): - """Register a new user""" - - # Check if email already exists - existing_email = db.query(User).filter(User.email == user_data.email).first() - if existing_email: - raise HTTPException(status_code=400, detail="Email already registered") - - # Check if username already exists - existing_username = db.query(User).filter(User.username == user_data.username).first() - if existing_username: - raise HTTPException(status_code=400, detail="Username already taken") - - # Hash password and create user - hashed_password = auth_manager.hash_password(user_data.password) - new_user = User( - email=user_data.email, - username=user_data.username, - hashed_password=hashed_password, - role="user", # Default role - is_active=True - ) - - db.add(new_user) - db.commit() - db.refresh(new_user) - - logger.info(f"New user registered: {new_user.username}") - return new_user - - -@app.post("/login", response_model=LoginResponse) -def login_user(user_credentials: UserLogin, db: Session = Depends(get_db)): - """Login user and return JWT token""" - - user = auth_manager.authenticate_user(db, user_credentials.username, user_credentials.password) - if not user: - raise HTTPException( - status_code=401, - detail="Incorrect username or password" - ) - - # Create access token - token_data = auth_manager.create_access_token(user) - - logger.info(f"User logged in: {user.username}") - - return LoginResponse( - access_token=token_data["access_token"], - token_type=token_data["token_type"], - expires_in=token_data["expires_in"], - user=UserResponse.model_validate(user) - ) - - -@app.get("/me", response_model=UserResponse) -def get_current_user_info(current_user: User = Depends(get_current_user)): - """Get current user information""" - return UserResponse.model_validate(current_user) +# Include API router +app.include_router(api_router, prefix="/api/v1") # Public Routes (no authentication required) +# Core application endpoints (Public Routes, no authentication required) @app.get("/") def root(): return { - "message": "Ecommerce Backend API v2.2 with Marketplace Support", + "message": f"{settings.PROJECT_NAME} v{settings.VERSION}", "status": "operational", + "docs": "/docs", "features": [ "JWT Authentication", "Marketplace-aware product import", @@ -283,1109 +63,7 @@ def health_check(db: Session = Depends(get_db)): raise HTTPException(status_code=503, detail="Service unhealthy") -# Marketplace Import Routes (Protected) -@app.post("/import-from-marketplace", response_model=MarketplaceImportJobResponse) -@rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports -async def import_products_from_marketplace( - request: MarketplaceImportRequest, - background_tasks: BackgroundTasks, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Import products from marketplace CSV with background processing (Protected)""" - - logger.info( - f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}") - - # Verify shop exists and user has access - shop = db.query(Shop).filter(Shop.shop_code == request.shop_code).first() - if not shop: - raise HTTPException(status_code=404, detail="Shop not found") - - # Check permissions: admin can import for any shop, others only for their own - if current_user.role != "admin" and shop.owner_id != current_user.id: - raise HTTPException(status_code=403, detail="Access denied to this shop") - - # Create marketplace import job record - import_job = MarketplaceImportJob( - status="pending", - source_url=request.url, - marketplace=request.marketplace, - shop_code=request.shop_code, - user_id=current_user.id, - created_at=datetime.utcnow() - ) - db.add(import_job) - db.commit() - db.refresh(import_job) - - # Process in background - background_tasks.add_task( - process_marketplace_import, - import_job.id, - request.url, - request.marketplace, - request.shop_code, - request.batch_size or 1000 - ) - - return MarketplaceImportJobResponse( - job_id=import_job.id, - status="pending", - marketplace=request.marketplace, - shop_code=request.shop_code, - message=f"Marketplace import started from {request.marketplace}. Check status with /marketplace-import-status/{import_job.id}" - ) - - -async def process_marketplace_import(job_id: int, url: str, marketplace: str, shop_name: str, - batch_size: int = 1000): - """Background task to process marketplace CSV import with batching""" - db = SessionLocal() - - try: - # Update job status - job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() - if not job: - logger.error(f"Marketplace import job {job_id} not found") - return - - job.status = "processing" - job.started_at = datetime.utcnow() - db.commit() - - logger.info(f"Processing marketplace import: Job {job_id}, Marketplace: {marketplace}, Shop: {shop_name}") - - # Process CSV with marketplace and shop information - result = await csv_processor.process_marketplace_csv_from_url( - url, marketplace, shop_name, batch_size, db - ) - - # Update job with results - job.status = "completed" - job.completed_at = datetime.utcnow() - job.imported_count = result["imported"] - job.updated_count = result["updated"] - job.error_count = result.get("errors", 0) - job.total_processed = result["total_processed"] - - if result.get("errors", 0) > 0: - job.status = "completed_with_errors" - job.error_message = f"{result['errors']} rows had errors" - - db.commit() - logger.info( - f"Marketplace import job {job_id} completed successfully - Imported: {result['imported']}, " - f"Updated: {result['updated']}") - - except Exception as e: - logger.error(f"Marketplace import job {job_id} failed: {e}") - job.status = "failed" - job.completed_at = datetime.utcnow() - job.error_message = str(e) - db.commit() - - finally: - db.close() - - -@app.get("/marketplace-import-status/{job_id}", response_model=MarketplaceImportJobResponse) -def get_marketplace_import_status( - job_id: int, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Get status of marketplace import job (Protected)""" - job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() - if not job: - raise HTTPException(status_code=404, detail="Marketplace import job not found") - - # Users can only see their own jobs, admins can see all - if current_user.role != "admin" and job.user_id != current_user.id: - raise HTTPException(status_code=403, detail="Access denied to this import job") - - return MarketplaceImportJobResponse( - job_id=job.id, - status=job.status, - marketplace=job.marketplace, - shop_name=job.shop_name, - imported=job.imported_count or 0, - updated=job.updated_count or 0, - total_processed=job.total_processed or 0, - error_count=job.error_count or 0, - error_message=job.error_message, - created_at=job.created_at, - started_at=job.started_at, - completed_at=job.completed_at - ) - - -@app.get("/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) -def get_marketplace_import_jobs( - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - shop_name: Optional[str] = Query(None, description="Filter by shop name"), - skip: int = Query(0, ge=0), - limit: int = Query(50, ge=1, le=100), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Get marketplace import jobs with filtering (Protected)""" - - query = db.query(MarketplaceImportJob) - - # Users can only see their own jobs, admins can see all - if current_user.role != "admin": - query = query.filter(MarketplaceImportJob.user_id == current_user.id) - - # Apply filters - if marketplace: - query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) - if shop_name: - query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) - - # Order by creation date (newest first) and apply pagination - jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() - - return [ - MarketplaceImportJobResponse( - job_id=job.id, - status=job.status, - marketplace=job.marketplace, - shop_name=job.shop_name, - imported=job.imported_count or 0, - updated=job.updated_count or 0, - total_processed=job.total_processed or 0, - error_count=job.error_count or 0, - error_message=job.error_message, - created_at=job.created_at, - started_at=job.started_at, - completed_at=job.completed_at - ) for job in jobs - ] - - -# Enhanced Product Routes with Marketplace Support -@app.get("/products", response_model=ProductListResponse) -def get_products( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - brand: Optional[str] = Query(None), - category: Optional[str] = Query(None), - availability: Optional[str] = Query(None), - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - shop_name: Optional[str] = Query(None, description="Filter by shop name"), - search: Optional[str] = Query(None), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Get products with advanced filtering including marketplace and shop (Protected)""" - - query = db.query(Product) - - # Apply filters - if brand: - query = query.filter(Product.brand.ilike(f"%{brand}%")) - if category: - query = query.filter(Product.google_product_category.ilike(f"%{category}%")) - if availability: - query = query.filter(Product.availability == availability) - if marketplace: - query = query.filter(Product.marketplace.ilike(f"%{marketplace}%")) - if shop_name: - query = query.filter(Product.shop_name.ilike(f"%{shop_name}%")) - if search: - # Search in title, description, and marketplace - search_term = f"%{search}%" - query = query.filter( - (Product.title.ilike(search_term)) | - (Product.description.ilike(search_term)) | - (Product.marketplace.ilike(search_term)) | - (Product.shop_name.ilike(search_term)) - ) - - # Get total count for pagination - total = query.count() - - # Apply pagination - products = query.offset(skip).limit(limit).all() - - return ProductListResponse( - products=products, - total=total, - skip=skip, - limit=limit - ) - - -@app.post("/products", response_model=ProductResponse) -def create_product( - product: ProductCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Create a new product with validation and marketplace support (Protected)""" - - # Check if product_id already exists - existing = db.query(Product).filter(Product.product_id == product.product_id).first() - if existing: - raise HTTPException(status_code=400, detail="Product with this ID already exists") - - # Process and validate GTIN if provided - if product.gtin: - normalized_gtin = gtin_processor.normalize(product.gtin) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - product.gtin = normalized_gtin - - # Process price if provided - if product.price: - parsed_price, currency = price_processor.parse_price_currency(product.price) - if parsed_price: - product.price = parsed_price - product.currency = currency - - # Set default marketplace if not provided - if not product.marketplace: - product.marketplace = "Letzshop" - - db_product = Product(**product.dict()) - db.add(db_product) - db.commit() - db.refresh(db_product) - - logger.info( - f"Created product {db_product.product_id} for marketplace {db_product.marketplace}, shop {db_product.shop_name}") - return db_product - - -@app.get("/products/{product_id}", response_model=ProductDetailResponse) -def get_product(product_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Get product with stock information (Protected)""" - - product = db.query(Product).filter(Product.product_id == product_id).first() - if not product: - raise HTTPException(status_code=404, detail="Product not found") - - # Get stock information if GTIN exists - stock_info = None - if product.gtin: - stock_entries = db.query(Stock).filter(Stock.gtin == product.gtin).all() - if stock_entries: - total_quantity = sum(entry.quantity for entry in stock_entries) - locations = [ - StockLocationResponse(location=entry.location, quantity=entry.quantity) - for entry in stock_entries - ] - stock_info = StockSummaryResponse( - gtin=product.gtin, - total_quantity=total_quantity, - locations=locations - ) - - return ProductDetailResponse( - product=product, - stock_info=stock_info - ) - - -@app.put("/products/{product_id}", response_model=ProductResponse) -def update_product( - product_id: str, - product_update: ProductUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Update product with validation and marketplace support (Protected)""" - - product = db.query(Product).filter(Product.product_id == product_id).first() - if not product: - raise HTTPException(status_code=404, detail="Product not found") - - # Update fields - update_data = product_update.dict(exclude_unset=True) - - # Validate GTIN if being updated - if "gtin" in update_data and update_data["gtin"]: - normalized_gtin = gtin_processor.normalize(update_data["gtin"]) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - update_data["gtin"] = normalized_gtin - - # Process price if being updated - if "price" in update_data and update_data["price"]: - parsed_price, currency = price_processor.parse_price_currency(update_data["price"]) - if parsed_price: - update_data["price"] = parsed_price - update_data["currency"] = currency - - for key, value in update_data.items(): - setattr(product, key, value) - - product.updated_at = datetime.utcnow() - db.commit() - db.refresh(product) - - return product - - -@app.delete("/products/{product_id}") -def delete_product( - product_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Delete product and associated stock (Protected)""" - - product = db.query(Product).filter(Product.product_id == product_id).first() - if not product: - raise HTTPException(status_code=404, detail="Product not found") - - # Delete associated stock entries if GTIN exists - if product.gtin: - db.query(Stock).filter(Stock.gtin == product.gtin).delete() - - db.delete(product) - db.commit() - - return {"message": "Product and associated stock deleted successfully"} - - -# Stock Management Routes (Protected) -# Stock Management Routes - -@app.post("/stock", response_model=StockResponse) -def set_stock(stock: StockCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)""" - - # Normalize GTIN - def normalize_gtin(gtin_value): - if not gtin_value: - return None - gtin_str = str(gtin_value).strip() - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) - if len(gtin_clean) in [8, 12, 13, 14]: - return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) - return gtin_clean if gtin_clean else None - - normalized_gtin = normalize_gtin(stock.gtin) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - - # Check if stock entry already exists for this GTIN and location - existing_stock = db.query(Stock).filter( - Stock.gtin == normalized_gtin, - Stock.location == stock.location.strip().upper() - ).first() - - if existing_stock: - # Update existing stock (SET to exact quantity) - old_quantity = existing_stock.quantity - existing_stock.quantity = stock.quantity - existing_stock.updated_at = datetime.utcnow() - db.commit() - db.refresh(existing_stock) - logger.info(f"Updated stock for GTIN {normalized_gtin} at {stock.location}: {old_quantity} → {stock.quantity}") - return existing_stock - else: - # Create new stock entry - new_stock = Stock( - gtin=normalized_gtin, - location=stock.location.strip().upper(), - quantity=stock.quantity - ) - db.add(new_stock) - db.commit() - db.refresh(new_stock) - logger.info(f"Created new stock for GTIN {normalized_gtin} at {stock.location}: {stock.quantity}") - return new_stock - - -@app.post("/stock/add", response_model=StockResponse) -def add_stock(stock: StockAdd, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)""" - - # Normalize GTIN - def normalize_gtin(gtin_value): - if not gtin_value: - return None - gtin_str = str(gtin_value).strip() - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) - if len(gtin_clean) in [8, 12, 13, 14]: - return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) - return gtin_clean if gtin_clean else None - - normalized_gtin = normalize_gtin(stock.gtin) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - - # Check if stock entry already exists for this GTIN and location - existing_stock = db.query(Stock).filter( - Stock.gtin == normalized_gtin, - Stock.location == stock.location.strip().upper() - ).first() - - if existing_stock: - # Add to existing stock - old_quantity = existing_stock.quantity - existing_stock.quantity += stock.quantity - existing_stock.updated_at = datetime.utcnow() - db.commit() - db.refresh(existing_stock) - logger.info( - f"Added stock for GTIN {normalized_gtin} at {stock.location}: {old_quantity} + {stock.quantity} = {existing_stock.quantity}") - return existing_stock - else: - # Create new stock entry with the quantity - new_stock = Stock( - gtin=normalized_gtin, - location=stock.location.strip().upper(), - quantity=stock.quantity - ) - db.add(new_stock) - db.commit() - db.refresh(new_stock) - logger.info(f"Created new stock for GTIN {normalized_gtin} at {stock.location}: {stock.quantity}") - return new_stock - - -@app.post("/stock/remove", response_model=StockResponse) -def remove_stock(stock: StockAdd, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Remove quantity from existing stock for a GTIN at a specific location""" - - # Normalize GTIN - def normalize_gtin(gtin_value): - if not gtin_value: - return None - gtin_str = str(gtin_value).strip() - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) - if len(gtin_clean) in [8, 12, 13, 14]: - return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) - return gtin_clean if gtin_clean else None - - normalized_gtin = normalize_gtin(stock.gtin) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - - # Find existing stock entry - existing_stock = db.query(Stock).filter( - Stock.gtin == normalized_gtin, - Stock.location == stock.location.strip().upper() - ).first() - - if not existing_stock: - raise HTTPException( - status_code=404, - detail=f"No stock found for GTIN {normalized_gtin} at location {stock.location}" - ) - - # Check if we have enough stock to remove - if existing_stock.quantity < stock.quantity: - raise HTTPException( - status_code=400, - detail=f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock.quantity}" - ) - - # Remove from existing stock - old_quantity = existing_stock.quantity - existing_stock.quantity -= stock.quantity - existing_stock.updated_at = datetime.utcnow() - db.commit() - db.refresh(existing_stock) - logger.info( - f"Removed stock for GTIN {normalized_gtin} at {stock.location}: {old_quantity} - {stock.quantity} = {existing_stock.quantity}") - return existing_stock - - -@app.get("/stock/{gtin}", response_model=StockSummaryResponse) -def get_stock_by_gtin(gtin: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Get all stock locations and total quantity for a specific GTIN""" - - # Normalize GTIN - def normalize_gtin(gtin_value): - if not gtin_value: - return None - gtin_str = str(gtin_value).strip() - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) - if len(gtin_clean) in [8, 12, 13, 14]: - return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) - return gtin_clean if gtin_clean else None - - normalized_gtin = normalize_gtin(gtin) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - - # Get all stock entries for this GTIN - stock_entries = db.query(Stock).filter(Stock.gtin == normalized_gtin).all() - - if not stock_entries: - raise HTTPException(status_code=404, detail=f"No stock found for GTIN: {gtin}") - - # Calculate total quantity and build locations list - total_quantity = 0 - locations = [] - - for entry in stock_entries: - total_quantity += entry.quantity - locations.append(StockLocationResponse( - location=entry.location, - quantity=entry.quantity - )) - - # Try to get product title for reference - product = db.query(Product).filter(Product.gtin == normalized_gtin).first() - product_title = product.title if product else None - - return StockSummaryResponse( - gtin=normalized_gtin, - total_quantity=total_quantity, - locations=locations, - product_title=product_title - ) - - -@app.get("/stock/{gtin}/total") -def get_total_stock(gtin: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Get total quantity in stock for a specific GTIN""" - - # Normalize GTIN - def normalize_gtin(gtin_value): - if not gtin_value: - return None - gtin_str = str(gtin_value).strip() - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) - if len(gtin_clean) in [8, 12, 13, 14]: - return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) - return gtin_clean if gtin_clean else None - - normalized_gtin = normalize_gtin(gtin) - if not normalized_gtin: - raise HTTPException(status_code=400, detail="Invalid GTIN format") - - # Calculate total stock - total_stock = db.query(Stock).filter(Stock.gtin == normalized_gtin).all() - total_quantity = sum(entry.quantity for entry in total_stock) - - # Get product info for context - product = db.query(Product).filter(Product.gtin == normalized_gtin).first() - - return { - "gtin": normalized_gtin, - "total_quantity": total_quantity, - "product_title": product.title if product else None, - "locations_count": len(total_stock) - } - - -@app.get("/stock", response_model=List[StockResponse]) -def get_all_stock( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - location: Optional[str] = Query(None, description="Filter by location"), - gtin: Optional[str] = Query(None, description="Filter by GTIN"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Get all stock entries with optional filtering""" - query = db.query(Stock) - - if location: - query = query.filter(Stock.location.ilike(f"%{location}%")) - - if gtin: - # Normalize GTIN for search - def normalize_gtin(gtin_value): - if not gtin_value: - return None - gtin_str = str(gtin_value).strip() - if '.' in gtin_str: - gtin_str = gtin_str.split('.')[0] - gtin_clean = ''.join(filter(str.isdigit, gtin_str)) - if len(gtin_clean) in [8, 12, 13, 14]: - return gtin_clean.zfill(13) if len(gtin_clean) == 13 else gtin_clean.zfill(12) - return gtin_clean if gtin_clean else None - - normalized_gtin = normalize_gtin(gtin) - if normalized_gtin: - query = query.filter(Stock.gtin == normalized_gtin) - - stock_entries = query.offset(skip).limit(limit).all() - return stock_entries - - -@app.put("/stock/{stock_id}", response_model=StockResponse) -def update_stock(stock_id: int, stock_update: StockUpdate, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user)): - """Update stock quantity for a specific stock entry""" - stock_entry = db.query(Stock).filter(Stock.id == stock_id).first() - if not stock_entry: - raise HTTPException(status_code=404, detail="Stock entry not found") - - stock_entry.quantity = stock_update.quantity - stock_entry.updated_at = datetime.utcnow() - db.commit() - db.refresh(stock_entry) - return stock_entry - - -@app.delete("/stock/{stock_id}") -def delete_stock(stock_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Delete a stock entry""" - stock_entry = db.query(Stock).filter(Stock.id == stock_id).first() - if not stock_entry: - raise HTTPException(status_code=404, detail="Stock entry not found") - - db.delete(stock_entry) - db.commit() - return {"message": "Stock entry deleted successfully"} - - -# Shop Management Routes -@app.post("/shops", response_model=ShopResponse) -def create_shop( - shop_data: ShopCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Create a new shop (Protected)""" - # Check if shop code already exists - existing_shop = db.query(Shop).filter(Shop.shop_code == shop_data.shop_code).first() - if existing_shop: - raise HTTPException(status_code=400, detail="Shop code already exists") - - # Create shop - new_shop = Shop( - **shop_data.dict(), - owner_id=current_user.id, - is_active=True, - is_verified=(current_user.role == "admin") # Auto-verify if admin creates shop - ) - - db.add(new_shop) - db.commit() - db.refresh(new_shop) - - logger.info(f"New shop created: {new_shop.shop_code} by {current_user.username}") - return new_shop - - -@app.get("/shops", response_model=ShopListResponse) -def get_shops( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - active_only: bool = Query(True), - verified_only: bool = Query(False), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Get shops with filtering (Protected)""" - query = db.query(Shop) - - # Non-admin users can only see active and verified shops, plus their own - if current_user.role != "admin": - query = query.filter( - (Shop.is_active == True) & - ((Shop.is_verified == True) | (Shop.owner_id == current_user.id)) - ) - else: - # Admin can apply filters - if active_only: - query = query.filter(Shop.is_active == True) - if verified_only: - query = query.filter(Shop.is_verified == True) - - total = query.count() - shops = query.offset(skip).limit(limit).all() - - return ShopListResponse( - shops=shops, - total=total, - skip=skip, - limit=limit - ) - - -@app.get("/shops/{shop_code}", response_model=ShopResponse) -def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Get shop details (Protected)""" - shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() - if not shop: - raise HTTPException(status_code=404, detail="Shop not found") - - # Non-admin users can only see active verified shops or their own shops - if current_user.role != "admin": - if not shop.is_active or (not shop.is_verified and shop.owner_id != current_user.id): - raise HTTPException(status_code=404, detail="Shop not found") - - return shop - - -# Shop Product Management -@app.post("/shops/{shop_code}/products", response_model=ShopProductResponse) -def add_product_to_shop( - shop_code: str, - shop_product: ShopProductCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Add existing product to shop catalog with shop-specific settings (Protected)""" - - # Get and verify shop - shop = get_user_shop(shop_code, current_user, db) - - # Check if product exists - product = db.query(Product).filter(Product.product_id == shop_product.product_id).first() - if not product: - raise HTTPException(status_code=404, detail="Product not found in marketplace catalog") - - # Check if product already in shop - existing_shop_product = db.query(ShopProduct).filter( - ShopProduct.shop_id == shop.id, - ShopProduct.product_id == product.id - ).first() - - if existing_shop_product: - raise HTTPException(status_code=400, detail="Product already in shop catalog") - - # Create shop-product association - new_shop_product = ShopProduct( - shop_id=shop.id, - product_id=product.id, - **shop_product.dict(exclude={'product_id'}) - ) - - db.add(new_shop_product) - db.commit() - db.refresh(new_shop_product) - - # Return with product details - response = ShopProductResponse.model_validate(new_shop_product) - response.product = product - return response - - -@app.get("/shops/{shop_code}/products") -def get_shop_products( - shop_code: str, - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - active_only: bool = Query(True), - featured_only: bool = Query(False), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Get products in shop catalog (Protected)""" - - # Get shop (public can view active/verified shops) - shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() - if not shop: - raise HTTPException(status_code=404, detail="Shop not found") - - # Non-owners can only see active verified shops - if current_user.role != "admin" and shop.owner_id != current_user.id: - if not shop.is_active or not shop.is_verified: - raise HTTPException(status_code=404, detail="Shop not found") - - # Query shop products - query = db.query(ShopProduct).filter(ShopProduct.shop_id == shop.id) - - if active_only: - query = query.filter(ShopProduct.is_active == True) - if featured_only: - query = query.filter(ShopProduct.is_featured == True) - - total = query.count() - shop_products = query.offset(skip).limit(limit).all() - - # Format response - products = [] - for sp in shop_products: - product_response = ShopProductResponse.model_validate(sp) - product_response.product = sp.product - products.append(product_response) - - return { - "products": products, - "total": total, - "skip": skip, - "limit": limit, - "shop": ShopResponse.model_validate(shop) - } - - -# Enhanced Statistics with Marketplace Support -@app.get("/stats", response_model=StatsResponse) -def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Get comprehensive statistics with marketplace data (Protected)""" - - # Use more efficient queries with proper indexes - total_products = db.query(Product).count() - - unique_brands = db.query(Product.brand).filter( - Product.brand.isnot(None), - Product.brand != "" - ).distinct().count() - - unique_categories = db.query(Product.google_product_category).filter( - Product.google_product_category.isnot(None), - Product.google_product_category != "" - ).distinct().count() - - # New marketplace statistics - unique_marketplaces = db.query(Product.marketplace).filter( - Product.marketplace.isnot(None), - Product.marketplace != "" - ).distinct().count() - - unique_shops = db.query(Product.shop_name).filter( - Product.shop_name.isnot(None), - Product.shop_name != "" - ).distinct().count() - - # Stock statistics - total_stock_entries = db.query(Stock).count() - total_inventory = db.query(func.sum(Stock.quantity)).scalar() or 0 - - return StatsResponse( - total_products=total_products, - unique_brands=unique_brands, - unique_categories=unique_categories, - unique_marketplaces=unique_marketplaces, - unique_shops=unique_shops, - total_stock_entries=total_stock_entries, - total_inventory_quantity=total_inventory - ) - - -@app.get("/marketplace-stats", response_model=List[MarketplaceStatsResponse]) -def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): - """Get statistics broken down by marketplace (Protected)""" - - # Query to get stats per marketplace - marketplace_stats = db.query( - Product.marketplace, - func.count(Product.id).label('total_products'), - func.count(func.distinct(Product.shop_name)).label('unique_shops'), - func.count(func.distinct(Product.brand)).label('unique_brands') - ).filter( - Product.marketplace.isnot(None) - ).group_by(Product.marketplace).all() - - return [ - MarketplaceStatsResponse( - marketplace=stat.marketplace, - total_products=stat.total_products, - unique_shops=stat.unique_shops, - unique_brands=stat.unique_brands - ) for stat in marketplace_stats - ] - - -# Export with streaming for large datasets (Protected) -@app.get("/export-csv") -async def export_csv( - marketplace: Optional[str] = Query(None, description="Filter by marketplace"), - shop_name: Optional[str] = Query(None, description="Filter by shop name"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """Export products as CSV with streaming and marketplace filtering (Protected)""" - - def generate_csv(): - # Stream CSV generation for memory efficiency - yield "product_id,title,description,link,image_link,availability,price,currency,brand,gtin,marketplace,shop_name\n" - - batch_size = 1000 - offset = 0 - - while True: - query = db.query(Product) - - # Apply marketplace filters - if marketplace: - query = query.filter(Product.marketplace.ilike(f"%{marketplace}%")) - if shop_name: - query = query.filter(Product.shop_name.ilike(f"%{shop_name}%")) - - products = query.offset(offset).limit(batch_size).all() - if not products: - break - - for product in products: - # Create CSV row with marketplace fields - row = f'"{product.product_id}","{product.title or ""}","{product.description or ""}","{product.link or ""}","{product.image_link or ""}","{product.availability or ""}","{product.price or ""}","{product.currency or ""}","{product.brand or ""}","{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n' - yield row - - offset += batch_size - - filename = "products_export" - if marketplace: - filename += f"_{marketplace}" - if shop_name: - filename += f"_{shop_name}" - filename += ".csv" - - return StreamingResponse( - generate_csv(), - media_type="text/csv", - headers={"Content-Disposition": f"attachment; filename={filename}"} - ) - - -# Admin-only routes -@app.get("/admin/users", response_model=List[UserResponse]) -def get_all_users( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) -): - """Get all users (Admin only)""" - users = db.query(User).offset(skip).limit(limit).all() - return [UserResponse.model_validate(user) for user in users] - - -@app.put("/admin/users/{user_id}/status") -def toggle_user_status( - user_id: int, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) -): - """Toggle user active status (Admin only)""" - user = db.query(User).filter(User.id == user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - - if user.id == current_admin.id: - raise HTTPException(status_code=400, detail="Cannot deactivate your own account") - - user.is_active = not user.is_active - user.updated_at = datetime.utcnow() - db.commit() - db.refresh(user) - - status = "activated" if user.is_active else "deactivated" - return {"message": f"User {user.username} has been {status}"} - - -@app.get("/admin/shops", response_model=ShopListResponse) -def get_all_shops_admin( - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) -): - """Get all shops with admin view (Admin only)""" - total = db.query(Shop).count() - shops = db.query(Shop).offset(skip).limit(limit).all() - - return ShopListResponse( - shops=shops, - total=total, - skip=skip, - limit=limit - ) - - -@app.put("/admin/shops/{shop_id}/verify") -def verify_shop( - shop_id: int, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) -): - """Verify/unverify shop (Admin only)""" - shop = db.query(Shop).filter(Shop.id == shop_id).first() - if not shop: - raise HTTPException(status_code=404, detail="Shop not found") - - shop.is_verified = not shop.is_verified - shop.updated_at = datetime.utcnow() - db.commit() - db.refresh(shop) - - status = "verified" if shop.is_verified else "unverified" - return {"message": f"Shop {shop.shop_code} has been {status}"} - - -@app.put("/admin/shops/{shop_id}/status") -def toggle_shop_status( - shop_id: int, - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) -): - """Toggle shop active status (Admin only)""" - shop = db.query(Shop).filter(Shop.id == shop_id).first() - if not shop: - raise HTTPException(status_code=404, detail="Shop not found") - - shop.is_active = not shop.is_active - shop.updated_at = datetime.utcnow() - db.commit() - db.refresh(shop) - - status = "activated" if shop.is_active else "deactivated" - return {"message": f"Shop {shop.shop_code} has been {status}"} - - -@app.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) -def get_all_marketplace_import_jobs( - marketplace: Optional[str] = Query(None), - shop_name: Optional[str] = Query(None), - status: Optional[str] = Query(None), - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=100), - db: Session = Depends(get_db), - current_admin: User = Depends(get_current_admin_user) -): - """Get all marketplace import jobs (Admin only)""" - - query = db.query(MarketplaceImportJob) - - # Apply filters - if marketplace: - query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) - if shop_name: - query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) - if status: - query = query.filter(MarketplaceImportJob.status == status) - - # Order by creation date and apply pagination - jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() - - return [ - MarketplaceImportJobResponse( - job_id=job.id, - status=job.status, - marketplace=job.marketplace, - shop_name=job.shop_name, - imported=job.imported_count or 0, - updated=job.updated_count or 0, - total_processed=job.total_processed or 0, - error_count=job.error_count or 0, - error_message=job.error_message, - created_at=job.created_at, - started_at=job.started_at, - completed_at=job.completed_at - ) for job in jobs - ] - - if __name__ == "__main__": import uvicorn - uvicorn.run( - "main:app", - host="0.0.0.0", - port=8000, - reload=True, - log_level="info" - ) + uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/middleware/decorators.py b/middleware/decorators.py new file mode 100644 index 00000000..69607047 --- /dev/null +++ b/middleware/decorators.py @@ -0,0 +1,29 @@ +# middleware/decorators.py +from functools import wraps +from fastapi import HTTPException +from middleware.rate_limiter import RateLimiter + +# Initialize rate limiter instance +rate_limiter = RateLimiter() + + +def rate_limit(max_requests: int = 100, window_seconds: int = 3600): + """Rate limiting decorator for FastAPI endpoints""" + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + # Extract client IP or user ID for rate limiting + client_id = "anonymous" # In production, extract from request + + if not rate_limiter.allow_request(client_id, max_requests, window_seconds): + raise HTTPException( + status_code=429, + detail="Rate limit exceeded" + ) + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/Makefile b/tests/Makefile new file mode 100644 index 00000000..bfe59532 --- /dev/null +++ b/tests/Makefile @@ -0,0 +1,57 @@ +# Makefile for running tests +# tests/Makefile +.PHONY: test test-unit test-integration test-coverage test-fast test-slow + +# Run all tests +test: + pytest tests/ -v + +# Run only unit tests +test-unit: + pytest tests/ -v -m unit + +# Run only integration tests +test-integration: + pytest tests/ -v -m integration + +# Run tests with coverage report +test-coverage: + pytest tests/ --cov=app --cov=models --cov=utils --cov=middleware --cov-report=html --cov-report=term-missing + +# Run fast tests only (exclude slow ones) +test-fast: + pytest tests/ -v -m "not slow" + +# Run slow tests only +test-slow: + pytest tests/ -v -m slow + +# Run specific test file +test-auth: + pytest tests/test_auth.py -v + +test-products: + pytest tests/test_products.py -v + +test-stock: + pytest tests/test_stock.py -v + +# Clean up test artifacts +clean: + rm -rf htmlcov/ + rm -rf .pytest_cache/ + rm -rf .coverage + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -name "*.pyc" -delete + +# Install test dependencies +install-test-deps: + pip install -r tests/requirements_test.txtvalidate_csv_headers(valid_df) == True + + # Invalid headers (missing required fields) + invalid_df = pd.DataFrame({ + "id": ["TEST001"], # Wrong column name + "name": ["Test"] + }) + + assert self.processor._ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..ac0aaf1b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# tests/__init__.py +# This file makes the tests directory a Python package diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..6766b091 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,194 @@ +# tests/conftest.py +import pytest +import tempfile +import os +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from main import app +from app.core.database import get_db, Base +from models.database_models import User, Product, Stock, Shop +from middleware.auth import AuthManager + +# Use in-memory SQLite database for tests +SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:" + + +@pytest.fixture(scope="session") +def engine(): + """Create test database engine""" + return create_engine( + SQLALCHEMY_TEST_DATABASE_URL, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + echo=False # Set to True for SQL debugging + ) + + +@pytest.fixture(scope="session") +def testing_session_local(engine): + """Create session factory for tests""" + return sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture(scope="function") +def db(engine, testing_session_local): + """Create a fresh database for each test""" + # Create all tables + Base.metadata.create_all(bind=engine) + + # Create session + db = testing_session_local() + + # Override the dependency + def override_get_db(): + try: + yield db + finally: + pass # Don't close here, we'll close in cleanup + + app.dependency_overrides[get_db] = override_get_db + + try: + yield db + finally: + db.rollback() # Rollback any uncommitted changes + db.close() + # Clean up the dependency override + if get_db in app.dependency_overrides: + del app.dependency_overrides[get_db] + # Drop all tables for next test + Base.metadata.drop_all(bind=engine) + + +@pytest.fixture(scope="function") +def client(db): + """Create a test client with database dependency override""" + return TestClient(app) + + +@pytest.fixture(scope="session") +def auth_manager(): + """Create auth manager instance (session scope since it's stateless)""" + return AuthManager() + + +@pytest.fixture +def test_user(db, auth_manager): + """Create a test user""" + hashed_password = auth_manager.hash_password("testpass123") + user = User( + email="test@example.com", + username="testuser", + hashed_password=hashed_password, + role="user", + is_active=True + ) + db.add(user) + db.commit() + db.refresh(user) + return user + + +@pytest.fixture +def test_admin(db, auth_manager): + """Create a test admin user""" + hashed_password = auth_manager.hash_password("adminpass123") + admin = User( + email="admin@example.com", + username="admin", + hashed_password=hashed_password, + role="admin", + is_active=True + ) + db.add(admin) + db.commit() + db.refresh(admin) + return admin + + +@pytest.fixture +def auth_headers(client, test_user): + """Get authentication headers for test user""" + response = client.post("/api/v1/auth/login", json={ + "username": "testuser", + "password": "testpass123" + }) + assert response.status_code == 200, f"Login failed: {response.text}" + token = response.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture +def admin_headers(client, test_admin): + """Get authentication headers for admin user""" + response = client.post("/api/v1/auth/login", json={ + "username": "admin", + "password": "adminpass123" + }) + assert response.status_code == 200, f"Admin login failed: {response.text}" + token = response.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture +def test_product(db): + """Create a test product""" + product = Product( + product_id="TEST001", + title="Test Product", + description="A test product", + price="10.99", + currency="EUR", + brand="TestBrand", + gtin="1234567890123", + availability="in stock", + marketplace="Letzshop", + shop_name="TestShop" + ) + db.add(product) + db.commit() + db.refresh(product) + return product + + +@pytest.fixture +def test_shop(db, test_user): + """Create a test shop""" + shop = Shop( + shop_code="TESTSHOP", + shop_name="Test Shop", + owner_id=test_user.id, + is_active=True, + is_verified=True + ) + db.add(shop) + db.commit() + db.refresh(shop) + return shop + + +@pytest.fixture +def test_stock(db, test_product, test_shop): + """Create test stock entry""" + stock = Stock( + product_id=test_product.product_id, + shop_code=test_shop.shop_code, + quantity=10, + reserved_quantity=0 + ) + db.add(stock) + db.commit() + db.refresh(stock) + return stock + + +# Cleanup fixture to ensure clean state +@pytest.fixture(autouse=True) +def cleanup(): + """Automatically clean up after each test""" + yield + # Clear any remaining dependency overrides + app.dependency_overrides.clear() diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..4bb15305 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,21 @@ +# tests/pytest.ini +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + integration: marks tests as integration tests + unit: marks tests as unit tests + auth: marks tests related to authentication + products: marks tests related to products + stock: marks tests related to stock management + shops: marks tests related to shop management + admin: marks tests related to admin functionality diff --git a/tests/requirements_test.txt b/tests/requirements_test.txt new file mode 100644 index 00000000..6cd8711f --- /dev/null +++ b/tests/requirements_test.txt @@ -0,0 +1,8 @@ +# tests/requirements_test.txt +# Testing dependencies +pytest>=7.4.0 +pytest-cov>=4.1.0 +pytest-asyncio>=0.21.0 +pytest-mock>=3.11.0 +httpx>=0.24.0 +faker>=19.0.0 diff --git a/tests/test_admin.py b/tests/test_admin.py new file mode 100644 index 00000000..fb3d14a2 --- /dev/null +++ b/tests/test_admin.py @@ -0,0 +1,34 @@ +# tests/test_admin.py +import pytest + + +class TestAdminAPI: + def test_get_all_users_admin(self, client, admin_headers, test_user): + """Test admin getting all users""" + response = client.get("/api/v1/admin/users", headers=admin_headers) + + assert response.status_code == 200 + data = response.json() + assert len(data) >= 2 # test_user + admin user + + def test_get_all_users_non_admin(self, client, auth_headers): + """Test non-admin trying to access admin endpoint""" + response = client.get("/api/v1/admin/users", headers=auth_headers) + + assert response.status_code == 403 + assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() + + def test_toggle_user_status_admin(self, client, admin_headers, test_user): + """Test admin toggling user status""" + response = client.put(f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers) + + assert response.status_code == 200 + assert "deactivated" in response.json()["message"] or "activated" in response.json()["message"] + + def test_get_all_shops_admin(self, client, admin_headers, test_shop): + """Test admin getting all shops""" + response = client.get("/api/v1/admin/shops", headers=admin_headers) + + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 00000000..ffa2cd22 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,119 @@ +# tests/test_auth.py +import pytest +from fastapi import HTTPException + + +class TestAuthenticationAPI: + def test_register_user_success(self, client, db): + """Test successful user registration""" + response = client.post("/api/v1/auth/register", json={ + "email": "newuser@example.com", + "username": "newuser", + "password": "securepass123" + }) + + assert response.status_code == 200 + data = response.json() + assert data["email"] == "newuser@example.com" + assert data["username"] == "newuser" + assert data["role"] == "user" + assert data["is_active"] == True + assert "hashed_password" not in data + + def test_register_user_duplicate_email(self, client, test_user): + """Test registration with duplicate email""" + response = client.post("/api/v1/auth/register", json={ + "email": "test@example.com", # Same as test_user + "username": "newuser", + "password": "securepass123" + }) + + assert response.status_code == 400 + assert "Email already registered" in response.json()["detail"] + + def test_register_user_duplicate_username(self, client, test_user): + """Test registration with duplicate username""" + response = client.post("/api/v1/auth/register", json={ + "email": "new@example.com", + "username": "testuser", # Same as test_user + "password": "securepass123" + }) + + assert response.status_code == 400 + assert "Username already taken" in response.json()["detail"] + + def test_login_success(self, client, test_user): + """Test successful login""" + response = client.post("/api/v1/auth/login", json={ + "username": "testuser", + "password": "testpass123" + }) + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + assert "expires_in" in data + assert data["user"]["username"] == "testuser" + + def test_login_wrong_password(self, client, test_user): + """Test login with wrong password""" + response = client.post("/api/v1/auth/login", json={ + "username": "testuser", + "password": "wrongpassword" + }) + + assert response.status_code == 401 + assert "Incorrect username or password" in response.json()["detail"] + + def test_login_nonexistent_user(self, client): + """Test login with nonexistent user""" + response = client.post("/api/v1/auth/login", json={ + "username": "nonexistent", + "password": "password123" + }) + + assert response.status_code == 401 + + def test_get_current_user_info(self, client, auth_headers): + """Test getting current user info""" + response = client.get("/api/v1/auth/me", headers=auth_headers) + + assert response.status_code == 200 + data = response.json() + assert data["username"] == "testuser" + assert data["email"] == "test@example.com" + + def test_get_current_user_no_auth(self, client): + """Test getting current user without authentication""" + response = client.get("/api/v1/auth/me") + + assert response.status_code == 403 # No authorization header + + +class TestAuthManager: + def test_hash_password(self, auth_manager): + """Test password hashing""" + password = "testpassword123" + hashed = auth_manager.hash_password(password) + + assert hashed != password + assert len(hashed) > 20 # bcrypt hashes are long + + def test_verify_password(self, auth_manager): + """Test password verification""" + password = "testpassword123" + hashed = auth_manager.hash_password(password) + + assert auth_manager.verify_password(password, hashed) == True + assert auth_manager.verify_password("wrongpassword", hashed) == False + + def test_create_access_token(self, auth_manager, test_user): + """Test JWT token creation""" + token_data = auth_manager.create_access_token(test_user) + + assert "access_token" in token_data + assert token_data["token_type"] == "bearer" + assert "expires_in" in token_data + assert isinstance(token_data["expires_in"], int) + diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py new file mode 100644 index 00000000..f1adea68 --- /dev/null +++ b/tests/test_background_tasks.py @@ -0,0 +1,83 @@ +# tests/test_background_tasks.py +import pytest +from unittest.mock import patch, AsyncMock +from app.tasks.background_tasks import process_marketplace_import +from models.database_models import MarketplaceImportJob + + +class TestBackgroundTasks: + @pytest.mark.asyncio + async def test_marketplace_import_success(self, db): + """Test successful marketplace import background task""" + # Create import job + job = MarketplaceImportJob( + status="pending", + source_url="http://example.com/test.csv", + marketplace="TestMarket", + shop_code="TESTSHOP", + user_id=1 + ) + db.add(job) + db.commit() + db.refresh(job) + + # Mock CSV processor + with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor: + mock_instance = mock_processor.return_value + mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ + "imported": 10, + "updated": 5, + "total_processed": 15, + "errors": 0 + }) + + # Run background task + await process_marketplace_import( + job.id, + "http://example.com/test.csv", + "TestMarket", + "TESTSHOP", + 1000 + ) + + # Verify job was updated + db.refresh(job) + assert job.status == "completed" + assert job.imported_count == 10 + assert job.updated_count == 5 + + @pytest.mark.asyncio + async def test_marketplace_import_failure(self, db): + """Test marketplace import failure handling""" + # Create import job + job = MarketplaceImportJob( + status="pending", + source_url="http://example.com/test.csv", + marketplace="TestMarket", + shop_code="TESTSHOP", + user_id=1 + ) + db.add(job) + db.commit() + db.refresh(job) + + # Mock CSV processor to raise exception + with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor: + mock_instance = mock_processor.return_value + mock_instance.process_marketplace_csv_from_url = AsyncMock( + side_effect=Exception("Import failed") + ) + + # Run background task + await process_marketplace_import( + job.id, + "http://example.com/test.csv", + "TestMarket", + "TESTSHOP", + 1000 + ) + + # Verify job failure was recorded + db.refresh(job) + assert job.status == "failed" + assert "Import failed" in job.error_message diff --git a/tests/test_csv_processor.py b/tests/test_csv_processor.py new file mode 100644 index 00000000..6294fde5 --- /dev/null +++ b/tests/test_csv_processor.py @@ -0,0 +1,90 @@ +# tests/test_csv_processor.py +import pytest +from unittest.mock import Mock, patch, AsyncMock +from io import StringIO +import pandas as pd +from utils.csv_processor import CSVProcessor + + +class TestCSVProcessor: + def setup_method(self): + self.processor = CSVProcessor() + + @patch('requests.get') + def test_download_csv_success(self, mock_get): + """Test successful CSV download""" + # Mock successful HTTP response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.text = "product_id,title,price\nTEST001,Test Product,10.99" + mock_get.return_value = mock_response + + csv_content = self.processor._download_csv("http://example.com/test.csv") + + assert "product_id,title,price" in csv_content + assert "TEST001,Test Product,10.99" in csv_content + + @patch('requests.get') + def test_download_csv_failure(self, mock_get): + """Test CSV download failure""" + # Mock failed HTTP response + mock_response = Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + with pytest.raises(Exception): + self.processor._download_csv("http://example.com/nonexistent.csv") + + def test_parse_csv_content(self): + """Test CSV content parsing""" + csv_content = """product_id,title,price,marketplace +TEST001,Test Product 1,10.99,TestMarket +TEST002,Test Product 2,15.99,TestMarket""" + + df = self.processor._parse_csv_content(csv_content) + + assert len(df) == 2 + assert "product_id" in df.columns + assert df.iloc[0]["product_id"] == "TEST001" + assert df.iloc[1]["price"] == "15.99" + + def test_validate_csv_headers(self): + """Test CSV header validation""" + # Valid headers + valid_df = pd.DataFrame({ + "product_id": ["TEST001"], + "title": ["Test"], + "price": ["10.99"] + }) + + assert self.processor._validate_csv_headers(invalid_df) == False + + @pytest.mark.asyncio + async def test_process_marketplace_csv_from_url(self, db): + """Test complete marketplace CSV processing""" + with patch.object(self.processor, '_download_csv') as mock_download, \ + patch.object(self.processor, '_parse_csv_content') as mock_parse, \ + patch.object(self.processor, '_validate_csv_headers') as mock_validate: + # Mock successful download and parsing + mock_download.return_value = "csv_content" + mock_df = pd.DataFrame({ + "product_id": ["TEST001", "TEST002"], + "title": ["Product 1", "Product 2"], + "price": ["10.99", "15.99"], + "marketplace": ["TestMarket", "TestMarket"], + "shop_name": ["TestShop", "TestShop"] + }) + mock_parse.return_value = mock_df + mock_validate.return_value = True + + result = await self.processor.process_marketplace_csv_from_url( + "http://example.com/test.csv", + "TestMarket", + "TestShop", + 1000, + db + ) + + assert "imported" in result + assert "updated" in result + assert "total_processed" in result diff --git a/tests/test_data_validation.py b/tests/test_data_validation.py new file mode 100644 index 00000000..a94807db --- /dev/null +++ b/tests/test_data_validation.py @@ -0,0 +1,46 @@ +# tests/test_data_validation.py +import pytest +from utils.data_processing import GTINProcessor, PriceProcessor + + +class TestDataValidation: + def test_gtin_normalization_edge_cases(self): + """Test GTIN normalization with edge cases""" + processor = GTINProcessor() + + # Test with leading zeros + assert processor.normalize("000123456789") == "000123456789" + + # Test with spaces + assert processor.normalize("123 456 789 012") == "123456789012" + + # Test with dashes + assert processor.normalize("123-456-789-012") == "123456789012" + + # Test very long numbers + long_number = "1234567890123456789" + normalized = processor.normalize(long_number) + assert len(normalized) <= 14 # Should be truncated + + def test_price_parsing_edge_cases(self): + """Test price parsing with edge cases""" + processor = PriceProcessor() + + # Test with multiple decimal places + price, currency = processor.parse_price_currency("12.999 EUR") + assert price == "12.999" + + # Test with no currency + price, currency = processor.parse_price_currency("15.50") + assert price == "15.50" + + # Test with unusual formatting + price, currency = processor.parse_price_currency("EUR 25,50") + assert currency == "EUR" + assert price == "25.50" # Comma should be converted to dot + + def test_input_sanitization(self): + """Test input sanitization""" + # These tests would verify that inputs are properly sanitized + # to prevent SQL injection, XSS, etc. + pass # Implementation would depend on your sanitization logic diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 00000000..9a898415 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,98 @@ +# tests/test_database.py +import pytest +from sqlalchemy import text +from models.database_models import User, Product, Stock, Shop + + +class TestDatabaseModels: + def test_user_model(self, db): + """Test User model creation and relationships""" + user = User( + email="db_test@example.com", + username="dbtest", + hashed_password="hashed_password_123", + role="user", + is_active=True + ) + + db.add(user) + db.commit() + db.refresh(user) + + assert user.id is not None + assert user.email == "db_test@example.com" + assert user.created_at is not None + assert user.updated_at is not None + + def test_product_model(self, db): + """Test Product model creation""" + product = Product( + product_id="DB_TEST_001", + title="Database Test Product", + description="Testing product model", + price="25.99", + currency="USD", + brand="DBTest", + gtin="1234567890123", + availability="in stock", + marketplace="TestDB", + shop_name="DBTestShop" + ) + + db.add(product) + db.commit() + db.refresh(product) + + assert product.id is not None + assert product.product_id == "DB_TEST_001" + assert product.created_at is not None + + def test_stock_model(self, db): + """Test Stock model creation""" + stock = Stock( + gtin="1234567890123", + location="DB_WAREHOUSE", + quantity=150 + ) + + db.add(stock) + db.commit() + db.refresh(stock) + + assert stock.id is not None + assert stock.gtin == "1234567890123" + assert stock.location == "DB_WAREHOUSE" + assert stock.quantity == 150 + + def test_shop_model_with_owner(self, db, test_user): + """Test Shop model with owner relationship""" + shop = Shop( + shop_code="DBTEST", + shop_name="Database Test Shop", + description="Testing shop model", + owner_id=test_user.id, + is_active=True, + is_verified=False + ) + + db.add(shop) + db.commit() + db.refresh(shop) + + assert shop.id is not None + assert shop.shop_code == "DBTEST" + assert shop.owner_id == test_user.id + assert shop.owner.username == test_user.username + + def test_database_constraints(self, db): + """Test database constraints and unique indexes""" + # Test unique product_id constraint + product1 = Product(product_id="UNIQUE_001", title="Product 1") + db.add(product1) + db.commit() + + # This should raise an integrity error + with pytest.raises(Exception): # Could be IntegrityError or similar + product2 = Product(product_id="UNIQUE_001", title="Product 2") + db.add(product2) + db.commit() diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py new file mode 100644 index 00000000..2a33968f --- /dev/null +++ b/tests/test_error_handling.py @@ -0,0 +1,45 @@ +# tests/test_error_handling.py +import pytest + + +class TestErrorHandling: + def test_invalid_json(self, client, auth_headers): + """Test handling of invalid JSON""" + response = client.post("/api/v1/products", + headers=auth_headers, + data="invalid json") + + assert response.status_code == 422 # Validation error + + def test_missing_required_fields(self, client, auth_headers): + """Test handling of missing required fields""" + response = client.post("/api/v1/products", + headers=auth_headers, + json={"title": "Test"}) # Missing product_id + + assert response.status_code == 422 + + def test_invalid_authentication(self, client): + """Test handling of invalid authentication""" + response = client.get("/api/v1/products", + headers={"Authorization": "Bearer invalid_token"}) + + assert response.status_code == 403 + + def test_nonexistent_resource(self, client, auth_headers): + """Test handling of nonexistent resource access""" + response = client.get("/api/v1/products/NONEXISTENT", headers=auth_headers) + assert response.status_code == 404 + + response = client.get("/api/v1/shops/NONEXISTENT", headers=auth_headers) + assert response.status_code == 404 + + def test_duplicate_resource_creation(self, client, auth_headers, test_product): + """Test handling of duplicate resource creation""" + product_data = { + "product_id": test_product.product_id, # Duplicate ID + "title": "Another Product" + } + + response = client.post("/api/v1/products", headers=auth_headers, json=product_data) + assert response.status_code == 400 diff --git a/tests/test_export.py b/tests/test_export.py new file mode 100644 index 00000000..e4bd434e --- /dev/null +++ b/tests/test_export.py @@ -0,0 +1,65 @@ +# tests/test_export.py +import pytest +import csv +from io import StringIO + + +class TestExportFunctionality: + def test_csv_export_basic(self, client, auth_headers, test_product): + """Test basic CSV export functionality""" + response = client.get("/api/v1/export-csv", headers=auth_headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/csv; charset=utf-8" + + # Parse CSV content + csv_content = response.content.decode('utf-8') + csv_reader = csv.reader(StringIO(csv_content)) + + # Check header row + header = next(csv_reader) + expected_fields = ["product_id", "title", "description", "price", "marketplace"] + for field in expected_fields: + assert field in header + + def test_csv_export_with_marketplace_filter(self, client, auth_headers, db): + """Test CSV export with marketplace filtering""" + # Create products in different marketplaces + products = [ + Product(product_id="EXP1", title="Product 1", marketplace="Amazon"), + Product(product_id="EXP2", title="Product 2", marketplace="eBay"), + ] + + db.add_all(products) + db.commit() + + response = client.get("/api/v1/export-csv?marketplace=Amazon", headers=auth_headers) + assert response.status_code == 200 + + csv_content = response.content.decode('utf-8') + assert "EXP1" in csv_content + assert "EXP2" not in csv_content # Should be filtered out + + def test_csv_export_performance(self, client, auth_headers, db): + """Test CSV export performance with many products""" + # Create many products + products = [] + for i in range(1000): + product = Product( + product_id=f"PERF{i:04d}", + title=f"Performance Product {i}", + marketplace="Performance" + ) + products.append(product) + + db.add_all(products) + db.commit() + + import time + start_time = time.time() + response = client.get("/api/v1/export-csv", headers=auth_headers) + end_time = time.time() + + assert response.status_code == 200 + assert end_time - start_time < 10.0 # Should complete within 10 seconds + \ No newline at end of file diff --git a/tests/test_filtering.py b/tests/test_filtering.py new file mode 100644 index 00000000..9d2d5791 --- /dev/null +++ b/tests/test_filtering.py @@ -0,0 +1,85 @@ +# tests/test_filtering.py +import pytest +from models.database_models import Product + + +class TestFiltering: + def test_product_brand_filter(self, client, auth_headers, db): + """Test filtering products by brand""" + # Create products with different brands + products = [ + Product(product_id="BRAND1", title="Product 1", brand="BrandA"), + Product(product_id="BRAND2", title="Product 2", brand="BrandB"), + Product(product_id="BRAND3", title="Product 3", brand="BrandA"), + ] + + db.add_all(products) + db.commit() + + # Filter by BrandA + response = client.get("/api/v1/products?brand=BrandA", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 + + # Filter by BrandB + response = client.get("/api/v1/products?brand=BrandB", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + + def test_product_marketplace_filter(self, client, auth_headers, db): + """Test filtering products by marketplace""" + products = [ + Product(product_id="MKT1", title="Product 1", marketplace="Amazon"), + Product(product_id="MKT2", title="Product 2", marketplace="eBay"), + Product(product_id="MKT3", title="Product 3", marketplace="Amazon"), + ] + + db.add_all(products) + db.commit() + + response = client.get("/api/v1/products?marketplace=Amazon", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 + + def test_product_search_filter(self, client, auth_headers, db): + """Test searching products by text""" + products = [ + Product(product_id="SEARCH1", title="Apple iPhone", description="Smartphone"), + Product(product_id="SEARCH2", title="Samsung Galaxy", description="Android phone"), + Product(product_id="SEARCH3", title="iPad Tablet", description="Apple tablet"), + ] + + db.add_all(products) + db.commit() + + # Search for "Apple" + response = client.get("/api/v1/products?search=Apple", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # iPhone and iPad + + # Search for "phone" + response = client.get("/api/v1/products?search=phone", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # iPhone and Galaxy + + def test_combined_filters(self, client, auth_headers, db): + """Test combining multiple filters""" + products = [ + Product(product_id="COMBO1", title="Apple iPhone", brand="Apple", marketplace="Amazon"), + Product(product_id="COMBO2", title="Apple iPad", brand="Apple", marketplace="eBay"), + Product(product_id="COMBO3", title="Samsung Phone", brand="Samsung", marketplace="Amazon"), + ] + + db.add_all(products) + db.commit() + + # Filter by brand AND marketplace + response = client.get("/api/v1/products?brand=Apple&marketplace=Amazon", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 # Only iPhone matches both diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 00000000..8ac4783c --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,117 @@ +# tests/test_integration.py +import pytest + + +class TestIntegrationFlows: + def test_full_product_workflow(self, client, auth_headers): + """Test complete product creation and management workflow""" + # 1. Create a product + product_data = { + "product_id": "FLOW001", + "title": "Integration Test Product", + "description": "Testing full workflow", + "price": "29.99", + "brand": "FlowBrand", + "gtin": "1111222233334", + "availability": "in stock", + "marketplace": "TestFlow" + } + + response = client.post("/api/v1/products", headers=auth_headers, json=product_data) + assert response.status_code == 200 + product = response.json() + + # 2. Add stock for the product + stock_data = { + "gtin": product["gtin"], + "location": "MAIN_WAREHOUSE", + "quantity": 50 + } + + response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) + assert response.status_code == 200 + + # 3. Get product with stock info + response = client.get(f"/api/v1/products/{product['product_id']}", headers=auth_headers) + assert response.status_code == 200 + product_detail = response.json() + assert product_detail["stock_info"]["total_quantity"] == 50 + + # 4. Update product + update_data = {"title": "Updated Integration Test Product"} + response = client.put(f"/api/v1/products/{product['product_id']}", + headers=auth_headers, json=update_data) + assert response.status_code == 200 + + # 5. Search for product + response = client.get("/api/v1/products?search=Updated Integration", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["total"] == 1 + + def test_shop_product_workflow(self, client, auth_headers): + """Test shop creation and product management workflow""" + # 1. Create a shop + shop_data = { + "shop_code": "FLOWSHOP", + "shop_name": "Integration Flow Shop", + "description": "Test shop for integration" + } + + response = client.post("/api/v1/shops", headers=auth_headers, json=shop_data) + assert response.status_code == 200 + shop = response.json() + + # 2. Create a product + product_data = { + "product_id": "SHOPFLOW001", + "title": "Shop Flow Product", + "price": "15.99", + "marketplace": "ShopFlow" + } + + response = client.post("/api/v1/products", headers=auth_headers, json=product_data) + assert response.status_code == 200 + product = response.json() + + # 3. Add product to shop (if endpoint exists) + # This would test the shop-product association + + # 4. Get shop details + response = client.get(f"/api/v1/shops/{shop['shop_code']}", headers=auth_headers) + assert response.status_code == 200 + + def test_stock_operations_workflow(self, client, auth_headers): + """Test complete stock management workflow""" + gtin = "9999888877776" + location = "TEST_WAREHOUSE" + + # 1. Set initial stock + response = client.post("/api/v1/stock", headers=auth_headers, json={ + "gtin": gtin, + "location": location, + "quantity": 100 + }) + assert response.status_code == 200 + + # 2. Add more stock + response = client.post("/api/v1/stock/add", headers=auth_headers, json={ + "gtin": gtin, + "location": location, + "quantity": 25 + }) + assert response.status_code == 200 + assert response.json()["quantity"] == 125 + + # 3. Remove some stock + response = client.post("/api/v1/stock/remove", headers=auth_headers, json={ + "gtin": gtin, + "location": location, + "quantity": 30 + }) + assert response.status_code == 200 + assert response.json()["quantity"] == 95 + + # 4. Check total stock + response = client.get(f"/api/v1/stock/{gtin}/total", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["total_quantity"] == 95 diff --git a/tests/test_marketplace.py b/tests/test_marketplace.py new file mode 100644 index 00000000..c5209921 --- /dev/null +++ b/tests/test_marketplace.py @@ -0,0 +1,52 @@ +# tests/test_marketplace.py +import pytest +from unittest.mock import patch, AsyncMock + + +class TestMarketplaceAPI: + @patch('utils.csv_processor.CSVProcessor.process_marketplace_csv_from_url') + def test_import_from_marketplace(self, mock_process, client, auth_headers, test_shop): + """Test marketplace import endpoint""" + mock_process.return_value = AsyncMock() + + import_data = { + "url": "https://example.com/products.csv", + "marketplace": "TestMarket", + "shop_code": test_shop.shop_code + } + + response = client.post("/api/v1/marketplace/import-from-marketplace", + headers=auth_headers, json=import_data) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["marketplace"] == "TestMarket" + assert "job_id" in data + + def test_import_from_marketplace_invalid_shop(self, client, auth_headers): + """Test marketplace import with invalid shop""" + import_data = { + "url": "https://example.com/products.csv", + "marketplace": "TestMarket", + "shop_code": "NONEXISTENT" + } + + response = client.post("/api/v1/marketplace/import-from-marketplace", + headers=auth_headers, json=import_data) + + assert response.status_code == 404 + assert "Shop not found" in response.json()["detail"] + + def test_get_marketplace_import_jobs(self, client, auth_headers): + """Test getting marketplace import jobs""" + response = client.get("/api/v1/marketplace/marketplace-import-jobs", headers=auth_headers) + + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_marketplace_requires_auth(self, client): + """Test that marketplace endpoints require authentication""" + response = client.get("/api/v1/marketplace/marketplace-import-jobs") + assert response.status_code == 403 + diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..b8467af8 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,63 @@ +# tests/test_middleware.py +import pytest +from unittest.mock import Mock, patch +from middleware.rate_limiter import RateLimiter +from middleware.auth import AuthManager + + +class TestRateLimiter: + def test_rate_limiter_allows_requests(self): + """Test rate limiter allows requests within limit""" + limiter = RateLimiter() + client_id = "test_client" + + # Should allow first request + assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) == True + + # Should allow subsequent requests within limit + for _ in range(5): + assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) == True + + def test_rate_limiter_blocks_excess_requests(self): + """Test rate limiter blocks requests exceeding limit""" + limiter = RateLimiter() + client_id = "test_client_blocked" + max_requests = 3 + + # Use up the allowed requests + for _ in range(max_requests): + assert limiter.allow_request(client_id, max_requests, 3600) == True + + # Next request should be blocked + assert limiter.allow_request(client_id, max_requests, 3600) == False + + +class TestAuthManager: + def test_password_hashing_and_verification(self): + """Test password hashing and verification""" + auth_manager = AuthManager() + password = "test_password_123" + + # Hash password + hashed = auth_manager.hash_password(password) + + # Verify correct password + assert auth_manager.verify_password(password, hashed) == True + + # Verify incorrect password + assert auth_manager.verify_password("wrong_password", hashed) == False + + def test_jwt_token_creation_and_validation(self, test_user): + """Test JWT token creation and validation""" + auth_manager = AuthManager() + + # Create token + token_data = auth_manager.create_access_token(test_user) + + assert "access_token" in token_data + assert token_data["token_type"] == "bearer" + assert isinstance(token_data["expires_in"], int) + + # Token should be a string + assert isinstance(token_data["access_token"], str) + assert len(token_data["access_token"]) > 50 # JWT tokens are long diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 00000000..4849d9af --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,56 @@ +# tests/test_pagination.py +import pytest +from models.database_models import Product + + +class TestPagination: + def test_product_pagination(self, client, auth_headers, db): + """Test pagination for product listing""" + # Create multiple products + products = [] + for i in range(25): + product = Product( + product_id=f"PAGE{i:03d}", + title=f"Pagination Test Product {i}", + marketplace="PaginationTest" + ) + products.append(product) + + db.add_all(products) + db.commit() + + # Test first page + response = client.get("/api/v1/products?limit=10&skip=0", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert len(data["products"]) == 10 + assert data["total"] == 25 + assert data["skip"] == 0 + assert data["limit"] == 10 + + # Test second page + response = client.get("/api/v1/products?limit=10&skip=10", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert len(data["products"]) == 10 + assert data["skip"] == 10 + + # Test last page + response = client.get("/api/v1/products?limit=10&skip=20", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert len(data["products"]) == 5 # Only 5 remaining + + def test_pagination_boundaries(self, client, auth_headers): + """Test pagination boundary conditions""" + # Test negative skip + response = client.get("/api/v1/products?skip=-1", headers=auth_headers) + assert response.status_code == 422 # Validation error + + # Test zero limit + response = client.get("/api/v1/products?limit=0", headers=auth_headers) + assert response.status_code == 422 # Validation error + + # Test excessive limit + response = client.get("/api/v1/products?limit=10000", headers=auth_headers) + assert response.status_code == 422 # Should be limited diff --git a/tests/test_performance.py b/tests/test_performance.py new file mode 100644 index 00000000..6b31544a --- /dev/null +++ b/tests/test_performance.py @@ -0,0 +1,56 @@ +# tests/test_performance.py +import pytest +import time + + +class TestPerformance: + def test_product_list_performance(self, client, auth_headers, db): + """Test performance of product listing with many products""" + # Create multiple products + products = [] + for i in range(100): + product = Product( + product_id=f"PERF{i:03d}", + title=f"Performance Test Product {i}", + price=f"{i}.99", + marketplace="Performance" + ) + products.append(product) + + db.add_all(products) + db.commit() + + # Time the request + start_time = time.time() + response = client.get("/api/v1/products?limit=100", headers=auth_headers) + end_time = time.time() + + assert response.status_code == 200 + assert len(response.json()["products"]) == 100 + assert end_time - start_time < 2.0 # Should complete within 2 seconds + + def test_search_performance(self, client, auth_headers, db): + """Test search performance""" + # Create products with searchable content + products = [] + for i in range(50): + product = Product( + product_id=f"SEARCH{i:03d}", + title=f"Searchable Product {i}", + description=f"This is a searchable product number {i}", + brand="SearchBrand", + marketplace="SearchMarket" + ) + products.append(product) + + db.add_all(products) + db.commit() + + # Time search request + start_time = time.time() + response = client.get("/api/v1/products?search=Searchable", headers=auth_headers) + end_time = time.time() + + assert response.status_code == 200 + assert response.json()["total"] == 50 + assert end_time - start_time < 1.0 # Search should be fast diff --git a/tests/test_products.py b/tests/test_products.py new file mode 100644 index 00000000..dc9082a1 --- /dev/null +++ b/tests/test_products.py @@ -0,0 +1,122 @@ +# tests/test_products.py +import pytest + + +class TestProductsAPI: + def test_get_products_empty(self, client, auth_headers): + """Test getting products when none exist""" + response = client.get("/api/v1/products", headers=auth_headers) + + assert response.status_code == 200 + data = response.json() + assert data["products"] == [] + assert data["total"] == 0 + + def test_get_products_with_data(self, client, auth_headers, test_product): + """Test getting products with data""" + response = client.get("/api/v1/products", headers=auth_headers) + + assert response.status_code == 200 + data = response.json() + assert len(data["products"]) == 1 + assert data["total"] == 1 + assert data["products"][0]["product_id"] == "TEST001" + + def test_get_products_with_filters(self, client, auth_headers, test_product): + """Test filtering products""" + # Test brand filter + response = client.get("/api/v1/products?brand=TestBrand", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["total"] == 1 + + # Test marketplace filter + response = client.get("/api/v1/products?marketplace=Letzshop", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["total"] == 1 + + # Test search + response = client.get("/api/v1/products?search=Test", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["total"] == 1 + + def test_create_product(self, client, auth_headers): + """Test creating a new product""" + product_data = { + "product_id": "NEW001", + "title": "New Product", + "description": "A new product", + "price": "15.99", + "brand": "NewBrand", + "gtin": "9876543210987", + "availability": "in stock", + "marketplace": "Amazon" + } + + response = client.post("/api/v1/products", headers=auth_headers, json=product_data) + + assert response.status_code == 200 + data = response.json() + assert data["product_id"] == "NEW001" + assert data["title"] == "New Product" + assert data["marketplace"] == "Amazon" + + def test_create_product_duplicate_id(self, client, auth_headers, test_product): + """Test creating product with duplicate ID""" + product_data = { + "product_id": "TEST001", # Same as test_product + "title": "Another Product", + "price": "20.00" + } + + response = client.post("/api/v1/products", headers=auth_headers, json=product_data) + + assert response.status_code == 400 + assert "already exists" in response.json()["detail"] + + def test_get_product_by_id(self, client, auth_headers, test_product): + """Test getting specific product""" + response = client.get(f"/api/v1/products/{test_product.product_id}", headers=auth_headers) + + assert response.status_code == 200 + data = response.json() + assert data["product"]["product_id"] == test_product.product_id + assert data["product"]["title"] == test_product.title + + def test_get_nonexistent_product(self, client, auth_headers): + """Test getting nonexistent product""" + response = client.get("/api/v1/products/NONEXISTENT", headers=auth_headers) + + assert response.status_code == 404 + + def test_update_product(self, client, auth_headers, test_product): + """Test updating product""" + update_data = { + "title": "Updated Product Title", + "price": "25.99" + } + + response = client.put( + f"/api/v1/products/{test_product.product_id}", + headers=auth_headers, + json=update_data + ) + + assert response.status_code == 200 + data = response.json() + assert data["title"] == "Updated Product Title" + assert data["price"] == "25.99" + + def test_delete_product(self, client, auth_headers, test_product): + """Test deleting product""" + response = client.delete( + f"/api/v1/products/{test_product.product_id}", + headers=auth_headers + ) + + assert response.status_code == 200 + assert "deleted successfully" in response.json()["message"] + + def test_products_require_auth(self, client): + """Test that product endpoints require authentication""" + response = client.get("/api/v1/products") + assert response.status_code == 403 diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 00000000..827dbeb7 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,61 @@ +# tests/test_security.py +import pytest +from fastapi import HTTPException +from unittest.mock import patch + + +class TestSecurity: + def test_protected_endpoint_without_auth(self, client): + """Test that protected endpoints reject unauthenticated requests""" + protected_endpoints = [ + "/api/v1/products", + "/api/v1/stock", + "/api/v1/shops", + "/api/v1/stats", + "/api/v1/admin/users" + ] + + for endpoint in protected_endpoints: + response = client.get(endpoint) + assert response.status_code == 403 + + def test_protected_endpoint_with_invalid_token(self, client): + """Test protected endpoints with invalid token""" + headers = {"Authorization": "Bearer invalid_token_here"} + + response = client.get("/api/v1/products", headers=headers) + assert response.status_code == 403 + + def test_admin_endpoint_requires_admin_role(self, client, auth_headers): + """Test that admin endpoints require admin role""" + response = client.get("/api/v1/admin/users", headers=auth_headers) + assert response.status_code == 403 # Regular user should be denied + + def test_sql_injection_prevention(self, client, auth_headers): + """Test SQL injection prevention in search parameters""" + # Try SQL injection in search parameter + malicious_search = "'; DROP TABLE products; --" + + response = client.get(f"/api/v1/products?search={malicious_search}", headers=auth_headers) + + # Should not crash and should return normal response + assert response.status_code == 200 + # Database should still be intact (no products dropped) + + def test_input_validation(self, client, auth_headers): + """Test input validation and sanitization""" + # Test XSS attempt in product creation + xss_payload = "" + + product_data = { + "product_id": "XSS_TEST", + "title": xss_payload, + "description": xss_payload + } + + response = client.post("/api/v1/products", headers=auth_headers, json=product_data) + + if response.status_code == 200: + # If creation succeeds, content should be escaped/sanitized + data = response.json() + assert "