code quality run

This commit is contained in:
2025-09-13 21:58:54 +02:00
parent 0dfd885847
commit 3eb18ef91e
63 changed files with 1802 additions and 1289 deletions

View File

@@ -1,14 +1,16 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import os import os
import sys import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
# Add your project directory to the Python path # Add your project directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from models.database_models import Base
from app.core.config import settings from app.core.config import settings
from models.database_models import Base
# Alembic Config object # Alembic Config object
config = context.config config = context.config
@@ -45,9 +47,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@@ -1,14 +1,16 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import os import os
import sys import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
# Add your project directory to the Python path # Add your project directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from models.database_models import Base
from app.core.config import settings from app.core.config import settings
from models.database_models import Base
# Alembic Config object # Alembic Config object
config = context.config config = context.config
@@ -45,9 +47,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@@ -1,10 +1,11 @@
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from models.database_models import User, Shop
from middleware.auth import AuthManager from middleware.auth import AuthManager
from middleware.rate_limiter import RateLimiter from middleware.rate_limiter import RateLimiter
from models.database_models import Shop, User
# Set auto_error=False to prevent automatic 403 responses # Set auto_error=False to prevent automatic 403 responses
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
@@ -13,8 +14,8 @@ rate_limiter = RateLimiter()
def get_current_user( def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
"""Get current authenticated user""" """Get current authenticated user"""
# Check if credentials are provided # Check if credentials are provided
@@ -30,9 +31,9 @@ def get_current_admin_user(current_user: User = Depends(get_current_user)):
def get_user_shop( def get_user_shop(
shop_code: str, shop_code: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
"""Get shop and verify user ownership""" """Get shop and verify user ownership"""
shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first() shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first()

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import auth, product, stock, shop, marketplace, admin, stats
from app.api.v1 import admin, auth, marketplace, product, shop, stats, stock
api_router = APIRouter() api_router = APIRouter()
@@ -9,6 +10,5 @@ api_router.include_router(auth.router, tags=["authentication"])
api_router.include_router(marketplace.router, tags=["marketplace"]) api_router.include_router(marketplace.router, tags=["marketplace"])
api_router.include_router(product.router, tags=["product"]) api_router.include_router(product.router, tags=["product"])
api_router.include_router(shop.router, tags=["shop"]) api_router.include_router(shop.router, tags=["shop"])
api_router.include_router(stats.router, tags=["statistics"]) api_router.include_router(stats.router, tags=["statistics"])
api_router.include_router(stock.router, tags=["stock"]) api_router.include_router(stock.router, tags=["stock"])

View File

@@ -1,13 +1,15 @@
import logging
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_admin_user from app.api.deps import get_current_admin_user
from app.core.database import get_db
from app.services.admin_service import admin_service from app.services.admin_service import admin_service
from models.api_models import MarketplaceImportJobResponse, UserResponse, ShopListResponse from models.api_models import (MarketplaceImportJobResponse, ShopListResponse,
UserResponse)
from models.database_models import User from models.database_models import User
import logging
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,10 +18,10 @@ logger = logging.getLogger(__name__)
# Admin-only routes # Admin-only routes
@router.get("/admin/users", response_model=List[UserResponse]) @router.get("/admin/users", response_model=List[UserResponse])
def get_all_users( def get_all_users(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user),
): ):
"""Get all users (Admin only)""" """Get all users (Admin only)"""
try: try:
@@ -32,9 +34,9 @@ def get_all_users(
@router.put("/admin/users/{user_id}/status") @router.put("/admin/users/{user_id}/status")
def toggle_user_status( def toggle_user_status(
user_id: int, user_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user),
): ):
"""Toggle user active status (Admin only)""" """Toggle user active status (Admin only)"""
try: try:
@@ -49,21 +51,16 @@ def toggle_user_status(
@router.get("/admin/shops", response_model=ShopListResponse) @router.get("/admin/shops", response_model=ShopListResponse)
def get_all_shops_admin( def get_all_shops_admin(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user),
): ):
"""Get all shops with admin view (Admin only)""" """Get all shops with admin view (Admin only)"""
try: try:
shops, total = admin_service.get_all_shops(db=db, skip=skip, limit=limit) shops, total = admin_service.get_all_shops(db=db, skip=skip, limit=limit)
return ShopListResponse( return ShopListResponse(shops=shops, total=total, skip=skip, limit=limit)
shops=shops,
total=total,
skip=skip,
limit=limit
)
except Exception as e: except Exception as e:
logger.error(f"Error getting shops: {str(e)}") logger.error(f"Error getting shops: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@@ -71,9 +68,9 @@ def get_all_shops_admin(
@router.put("/admin/shops/{shop_id}/verify") @router.put("/admin/shops/{shop_id}/verify")
def verify_shop( def verify_shop(
shop_id: int, shop_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user),
): ):
"""Verify/unverify shop (Admin only)""" """Verify/unverify shop (Admin only)"""
try: try:
@@ -88,9 +85,9 @@ def verify_shop(
@router.put("/admin/shops/{shop_id}/status") @router.put("/admin/shops/{shop_id}/status")
def toggle_shop_status( def toggle_shop_status(
shop_id: int, shop_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user),
): ):
"""Toggle shop active status (Admin only)""" """Toggle shop active status (Admin only)"""
try: try:
@@ -103,15 +100,17 @@ def toggle_shop_status(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) @router.get(
"/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]
)
def get_all_marketplace_import_jobs( def get_all_marketplace_import_jobs(
marketplace: Optional[str] = Query(None), marketplace: Optional[str] = Query(None),
shop_name: Optional[str] = Query(None), shop_name: Optional[str] = Query(None),
status: Optional[str] = Query(None), status: Optional[str] = Query(None),
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100), limit: int = Query(100, ge=1, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user),
): ):
"""Get all marketplace import jobs (Admin only)""" """Get all marketplace import jobs (Admin only)"""
try: try:
@@ -121,7 +120,7 @@ def get_all_marketplace_import_jobs(
shop_name=shop_name, shop_name=shop_name,
status=status, status=status,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
except Exception as e: except Exception as e:
logger.error(f"Error getting marketplace import jobs: {str(e)}") logger.error(f"Error getting marketplace import jobs: {str(e)}")

View File

@@ -1,11 +1,14 @@
import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.database import get_db
from app.services.auth_service import auth_service from app.services.auth_service import auth_service
from models.api_models import UserRegister, UserLogin, UserResponse, LoginResponse from models.api_models import (LoginResponse, UserLogin, UserRegister,
UserResponse)
from models.database_models import User from models.database_models import User
import logging
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +38,7 @@ def login_user(user_credentials: UserLogin, db: Session = Depends(get_db)):
access_token=login_result["token_data"]["access_token"], access_token=login_result["token_data"]["access_token"],
token_type=login_result["token_data"]["token_type"], token_type=login_result["token_data"]["token_type"],
expires_in=login_result["token_data"]["expires_in"], expires_in=login_result["token_data"]["expires_in"],
user=UserResponse.model_validate(login_result["user"]) user=UserResponse.model_validate(login_result["user"]),
) )
except HTTPException: except HTTPException:
raise raise

View File

@@ -1,15 +1,17 @@
import logging
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.database import get_db
from app.services.marketplace_service import marketplace_service
from app.tasks.background_tasks import process_marketplace_import from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit from middleware.decorators import rate_limit
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest from models.api_models import (MarketplaceImportJobResponse,
MarketplaceImportRequest)
from models.database_models import User from models.database_models import User
from app.services.marketplace_service import marketplace_service
import logging
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,15 +21,16 @@ logger = logging.getLogger(__name__)
@router.post("/marketplace/import-product", response_model=MarketplaceImportJobResponse) @router.post("/marketplace/import-product", response_model=MarketplaceImportJobResponse)
@rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports @rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports
async def import_products_from_marketplace( async def import_products_from_marketplace(
request: MarketplaceImportRequest, request: MarketplaceImportRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Import products from marketplace CSV with background processing (Protected)""" """Import products from marketplace CSV with background processing (Protected)"""
try: try:
logger.info( logger.info(
f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}") f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}"
)
# Create import job through service # Create import job through service
import_job = marketplace_service.create_import_job(db, request, current_user) import_job = marketplace_service.create_import_job(db, request, current_user)
@@ -39,7 +42,7 @@ async def import_products_from_marketplace(
request.url, request.url,
request.marketplace, request.marketplace,
request.shop_code, request.shop_code,
request.batch_size or 1000 request.batch_size or 1000,
) )
return MarketplaceImportJobResponse( return MarketplaceImportJobResponse(
@@ -50,7 +53,7 @@ async def import_products_from_marketplace(
shop_id=import_job.shop_id, shop_id=import_job.shop_id,
shop_name=import_job.shop_name, shop_name=import_job.shop_name,
message=f"Marketplace import started from {request.marketplace}. Check status with " message=f"Marketplace import started from {request.marketplace}. Check status with "
f"/import-status/{import_job.id}" f"/import-status/{import_job.id}",
) )
except ValueError as e: except ValueError as e:
@@ -62,11 +65,13 @@ async def import_products_from_marketplace(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/marketplace/import-status/{job_id}", response_model=MarketplaceImportJobResponse) @router.get(
"/marketplace/import-status/{job_id}", response_model=MarketplaceImportJobResponse
)
def get_marketplace_import_status( def get_marketplace_import_status(
job_id: int, job_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get status of marketplace import job (Protected)""" """Get status of marketplace import job (Protected)"""
try: try:
@@ -82,14 +87,16 @@ def get_marketplace_import_status(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/marketplace/import-jobs", response_model=List[MarketplaceImportJobResponse]) @router.get(
"/marketplace/import-jobs", response_model=List[MarketplaceImportJobResponse]
)
def get_marketplace_import_jobs( def get_marketplace_import_jobs(
marketplace: Optional[str] = Query(None, description="Filter by marketplace"), marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
shop_name: Optional[str] = Query(None, description="Filter by shop name"), shop_name: Optional[str] = Query(None, description="Filter by shop name"),
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100), limit: int = Query(50, ge=1, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get marketplace import jobs with filtering (Protected)""" """Get marketplace import jobs with filtering (Protected)"""
try: try:
@@ -99,7 +106,7 @@ def get_marketplace_import_jobs(
marketplace=marketplace, marketplace=marketplace,
shop_name=shop_name, shop_name=shop_name,
skip=skip, skip=skip,
limit=limit limit=limit,
) )
return [marketplace_service.convert_to_response_model(job) for job in jobs] return [marketplace_service.convert_to_response_model(job) for job in jobs]
@@ -111,8 +118,7 @@ def get_marketplace_import_jobs(
@router.get("/marketplace/marketplace-import-stats") @router.get("/marketplace/marketplace-import-stats")
def get_marketplace_import_stats( def get_marketplace_import_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
): ):
"""Get statistics about marketplace import jobs (Protected)""" """Get statistics about marketplace import jobs (Protected)"""
try: try:
@@ -124,11 +130,14 @@ def get_marketplace_import_stats(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/marketplace/import-jobs/{job_id}/cancel", response_model=MarketplaceImportJobResponse) @router.put(
"/marketplace/import-jobs/{job_id}/cancel",
response_model=MarketplaceImportJobResponse,
)
def cancel_marketplace_import_job( def cancel_marketplace_import_job(
job_id: int, job_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Cancel a pending or running marketplace import job (Protected)""" """Cancel a pending or running marketplace import job (Protected)"""
try: try:
@@ -146,9 +155,9 @@ def cancel_marketplace_import_job(
@router.delete("/marketplace/import-jobs/{job_id}") @router.delete("/marketplace/import-jobs/{job_id}")
def delete_marketplace_import_job( def delete_marketplace_import_job(
job_id: int, job_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Delete a completed marketplace import job (Protected)""" """Delete a completed marketplace import job (Protected)"""
try: try:

View File

@@ -1,17 +1,17 @@
import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_user from app.api.deps import get_current_user
from models.api_models import (ProductListResponse, ProductResponse, ProductCreate, ProductDetailResponse, from app.core.database import get_db
from app.services.product_service import product_service
from models.api_models import (ProductCreate, ProductDetailResponse,
ProductListResponse, ProductResponse,
ProductUpdate) ProductUpdate)
from models.database_models import User from models.database_models import User
import logging
from app.services.product_service import product_service
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,16 +20,16 @@ logger = logging.getLogger(__name__)
# Enhanced Product Routes with Marketplace Support # Enhanced Product Routes with Marketplace Support
@router.get("/product", response_model=ProductListResponse) @router.get("/product", response_model=ProductListResponse)
def get_products( def get_products(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
brand: Optional[str] = Query(None), brand: Optional[str] = Query(None),
category: Optional[str] = Query(None), category: Optional[str] = Query(None),
availability: Optional[str] = Query(None), availability: Optional[str] = Query(None),
marketplace: Optional[str] = Query(None, description="Filter by marketplace"), marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
shop_name: Optional[str] = Query(None, description="Filter by shop name"), shop_name: Optional[str] = Query(None, description="Filter by shop name"),
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get products with advanced filtering including marketplace and shop (Protected)""" """Get products with advanced filtering including marketplace and shop (Protected)"""
@@ -43,14 +43,11 @@ def get_products(
availability=availability, availability=availability,
marketplace=marketplace, marketplace=marketplace,
shop_name=shop_name, shop_name=shop_name,
search=search search=search,
) )
return ProductListResponse( return ProductListResponse(
products=products, products=products, total=total, skip=skip, limit=limit
total=total,
skip=skip,
limit=limit
) )
except Exception as e: except Exception as e:
logger.error(f"Error getting products: {str(e)}") logger.error(f"Error getting products: {str(e)}")
@@ -59,9 +56,9 @@ def get_products(
@router.post("/product", response_model=ProductResponse) @router.post("/product", response_model=ProductResponse)
def create_product( def create_product(
product: ProductCreate, product: ProductCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Create a new product with validation and marketplace support (Protected)""" """Create a new product with validation and marketplace support (Protected)"""
@@ -75,7 +72,9 @@ def create_product(
if existing: if existing:
logger.info("Product already exists, raising 400 error") logger.info("Product already exists, raising 400 error")
raise HTTPException(status_code=400, detail="Product with this ID already exists") raise HTTPException(
status_code=400, detail="Product with this ID already exists"
)
logger.info("No existing product found, proceeding to create...") logger.info("No existing product found, proceeding to create...")
db_product = product_service.create_product(db, product) db_product = product_service.create_product(db, product)
@@ -93,11 +92,12 @@ def create_product(
logger.error(f"Unexpected error: {str(e)}", exc_info=True) logger.error(f"Unexpected error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/product/{product_id}", response_model=ProductDetailResponse) @router.get("/product/{product_id}", response_model=ProductDetailResponse)
def get_product( def get_product(
product_id: str, product_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get product with stock information (Protected)""" """Get product with stock information (Protected)"""
@@ -111,10 +111,7 @@ def get_product(
if product.gtin: if product.gtin:
stock_info = product_service.get_stock_info(db, product.gtin) stock_info = product_service.get_stock_info(db, product.gtin)
return ProductDetailResponse( return ProductDetailResponse(product=product, stock_info=stock_info)
product=product,
stock_info=stock_info
)
except HTTPException: except HTTPException:
raise raise
@@ -125,10 +122,10 @@ def get_product(
@router.put("/product/{product_id}", response_model=ProductResponse) @router.put("/product/{product_id}", response_model=ProductResponse)
def update_product( def update_product(
product_id: str, product_id: str,
product_update: ProductUpdate, product_update: ProductUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Update product with validation and marketplace support (Protected)""" """Update product with validation and marketplace support (Protected)"""
@@ -151,9 +148,9 @@ def update_product(
@router.delete("/product/{product_id}") @router.delete("/product/{product_id}")
def delete_product( def delete_product(
product_id: str, product_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Delete product and associated stock (Protected)""" """Delete product and associated stock (Protected)"""
@@ -176,19 +173,18 @@ def delete_product(
# Export with streaming for large datasets (Protected) # Export with streaming for large datasets (Protected)
@router.get("/export-csv") @router.get("/export-csv")
async def export_csv( async def export_csv(
marketplace: Optional[str] = Query(None, description="Filter by marketplace"), marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
shop_name: Optional[str] = Query(None, description="Filter by shop name"), shop_name: Optional[str] = Query(None, description="Filter by shop name"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Export products as CSV with streaming and marketplace filtering (Protected)""" """Export products as CSV with streaming and marketplace filtering (Protected)"""
try: try:
def generate_csv(): def generate_csv():
return product_service.generate_csv_export( return product_service.generate_csv_export(
db=db, db=db, marketplace=marketplace, shop_name=shop_name
marketplace=marketplace,
shop_name=shop_name
) )
filename = "products_export" filename = "products_export"
@@ -201,7 +197,7 @@ async def export_csv(
return StreamingResponse( return StreamingResponse(
generate_csv(), generate_csv(),
media_type="text/csv", media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"} headers={"Content-Disposition": f"attachment; filename={filename}"},
) )
except Exception as e: except Exception as e:

View File

@@ -1,17 +1,21 @@
import logging
from datetime import datetime
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_user, get_user_shop from app.api.deps import get_current_user, get_user_shop
from app.core.database import get_db
from app.services.shop_service import shop_service from app.services.shop_service import shop_service
from app.tasks.background_tasks import process_marketplace_import from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit from middleware.decorators import rate_limit
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, ShopResponse, ShopCreate, \ from models.api_models import (MarketplaceImportJobResponse,
ShopListResponse, ShopProductResponse, ShopProductCreate MarketplaceImportRequest, ShopCreate,
from models.database_models import User, MarketplaceImportJob, Shop, Product, ShopProduct ShopListResponse, ShopProductCreate,
from datetime import datetime ShopProductResponse, ShopResponse)
import logging from models.database_models import (MarketplaceImportJob, Product, Shop,
ShopProduct, User)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,13 +24,15 @@ logger = logging.getLogger(__name__)
# Shop Management Routes # Shop Management Routes
@router.post("/shop", response_model=ShopResponse) @router.post("/shop", response_model=ShopResponse)
def create_shop( def create_shop(
shop_data: ShopCreate, shop_data: ShopCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Create a new shop (Protected)""" """Create a new shop (Protected)"""
try: try:
shop = shop_service.create_shop(db=db, shop_data=shop_data, current_user=current_user) shop = shop_service.create_shop(
db=db, shop_data=shop_data, current_user=current_user
)
return ShopResponse.model_validate(shop) return ShopResponse.model_validate(shop)
except HTTPException: except HTTPException:
raise raise
@@ -37,12 +43,12 @@ def create_shop(
@router.get("/shop", response_model=ShopListResponse) @router.get("/shop", response_model=ShopListResponse)
def get_shops( def get_shops(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
active_only: bool = Query(True), active_only: bool = Query(True),
verified_only: bool = Query(False), verified_only: bool = Query(False),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get shops with filtering (Protected)""" """Get shops with filtering (Protected)"""
try: try:
@@ -52,15 +58,10 @@ def get_shops(
skip=skip, skip=skip,
limit=limit, limit=limit,
active_only=active_only, active_only=active_only,
verified_only=verified_only verified_only=verified_only,
) )
return ShopListResponse( return ShopListResponse(shops=shops, total=total, skip=skip, limit=limit)
shops=shops,
total=total,
skip=skip,
limit=limit
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -69,10 +70,16 @@ def get_shops(
@router.get("/shop/{shop_code}", response_model=ShopResponse) @router.get("/shop/{shop_code}", response_model=ShopResponse)
def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): def get_shop(
shop_code: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get shop details (Protected)""" """Get shop details (Protected)"""
try: try:
shop = shop_service.get_shop_by_code(db=db, shop_code=shop_code, current_user=current_user) shop = shop_service.get_shop_by_code(
db=db, shop_code=shop_code, current_user=current_user
)
return ShopResponse.model_validate(shop) return ShopResponse.model_validate(shop)
except HTTPException: except HTTPException:
raise raise
@@ -84,10 +91,10 @@ def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User =
# Shop Product Management # Shop Product Management
@router.post("/shop/{shop_code}/products", response_model=ShopProductResponse) @router.post("/shop/{shop_code}/products", response_model=ShopProductResponse)
def add_product_to_shop( def add_product_to_shop(
shop_code: str, shop_code: str,
shop_product: ShopProductCreate, shop_product: ShopProductCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Add existing product to shop catalog with shop-specific settings (Protected)""" """Add existing product to shop catalog with shop-specific settings (Protected)"""
try: try:
@@ -96,9 +103,7 @@ def add_product_to_shop(
# Add product to shop # Add product to shop
new_shop_product = shop_service.add_product_to_shop( new_shop_product = shop_service.add_product_to_shop(
db=db, db=db, shop=shop, shop_product=shop_product
shop=shop,
shop_product=shop_product
) )
# Return with product details # Return with product details
@@ -114,18 +119,20 @@ def add_product_to_shop(
@router.get("/shop/{shop_code}/products") @router.get("/shop/{shop_code}/products")
def get_shop_products( def get_shop_products(
shop_code: str, shop_code: str,
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
active_only: bool = Query(True), active_only: bool = Query(True),
featured_only: bool = Query(False), featured_only: bool = Query(False),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get products in shop catalog (Protected)""" """Get products in shop catalog (Protected)"""
try: try:
# Get shop # Get shop
shop = shop_service.get_shop_by_code(db=db, shop_code=shop_code, current_user=current_user) shop = shop_service.get_shop_by_code(
db=db, shop_code=shop_code, current_user=current_user
)
# Get shop products # Get shop products
shop_products, total = shop_service.get_shop_products( shop_products, total = shop_service.get_shop_products(
@@ -135,7 +142,7 @@ def get_shop_products(
skip=skip, skip=skip,
limit=limit, limit=limit,
active_only=active_only, active_only=active_only,
featured_only=featured_only featured_only=featured_only,
) )
# Format response # Format response
@@ -150,7 +157,7 @@ def get_shop_products(
"total": total, "total": total,
"skip": skip, "skip": skip,
"limit": limit, "limit": limit,
"shop": ShopResponse.model_validate(shop) "shop": ShopResponse.model_validate(shop),
} }
except HTTPException: except HTTPException:
raise raise

View File

@@ -1,18 +1,21 @@
import logging
from datetime import datetime
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.database import get_db
from app.services.stats_service import stats_service from app.services.stats_service import stats_service
from app.tasks.background_tasks import process_marketplace_import from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit from middleware.decorators import rate_limit
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, StatsResponse, \ from models.api_models import (MarketplaceImportJobResponse,
MarketplaceStatsResponse MarketplaceImportRequest,
from models.database_models import User, MarketplaceImportJob, Shop, Product, Stock MarketplaceStatsResponse, StatsResponse)
from datetime import datetime from models.database_models import (MarketplaceImportJob, Product, Shop, Stock,
import logging User)
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,7 +23,9 @@ logger = logging.getLogger(__name__)
# Enhanced Statistics with Marketplace Support # Enhanced Statistics with Marketplace Support
@router.get("/stats", response_model=StatsResponse) @router.get("/stats", response_model=StatsResponse)
def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): def get_stats(
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
"""Get comprehensive statistics with marketplace data (Protected)""" """Get comprehensive statistics with marketplace data (Protected)"""
try: try:
stats_data = stats_service.get_comprehensive_stats(db=db) stats_data = stats_service.get_comprehensive_stats(db=db)
@@ -32,7 +37,7 @@ def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_cu
unique_marketplaces=stats_data["unique_marketplaces"], unique_marketplaces=stats_data["unique_marketplaces"],
unique_shops=stats_data["unique_shops"], unique_shops=stats_data["unique_shops"],
total_stock_entries=stats_data["total_stock_entries"], total_stock_entries=stats_data["total_stock_entries"],
total_inventory_quantity=stats_data["total_inventory_quantity"] total_inventory_quantity=stats_data["total_inventory_quantity"],
) )
except Exception as e: except Exception as e:
logger.error(f"Error getting comprehensive stats: {str(e)}") logger.error(f"Error getting comprehensive stats: {str(e)}")
@@ -40,7 +45,9 @@ def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_cu
@router.get("/stats/marketplace", response_model=List[MarketplaceStatsResponse]) @router.get("/stats/marketplace", response_model=List[MarketplaceStatsResponse])
def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): def get_marketplace_stats(
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
"""Get statistics broken down by marketplace (Protected)""" """Get statistics broken down by marketplace (Protected)"""
try: try:
marketplace_stats = stats_service.get_marketplace_breakdown_stats(db=db) marketplace_stats = stats_service.get_marketplace_breakdown_stats(db=db)
@@ -50,8 +57,9 @@ def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = De
marketplace=stat["marketplace"], marketplace=stat["marketplace"],
total_products=stat["total_products"], total_products=stat["total_products"],
unique_shops=stat["unique_shops"], unique_shops=stat["unique_shops"],
unique_brands=stat["unique_brands"] unique_brands=stat["unique_brands"],
) for stat in marketplace_stats )
for stat in marketplace_stats
] ]
except Exception as e: except Exception as e:
logger.error(f"Error getting marketplace stats: {str(e)}") logger.error(f"Error getting marketplace stats: {str(e)}")

View File

@@ -1,16 +1,19 @@
import logging
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.database import get_db
from app.services.stock_service import stock_service
from app.tasks.background_tasks import process_marketplace_import from app.tasks.background_tasks import process_marketplace_import
from middleware.decorators import rate_limit from middleware.decorators import rate_limit
from models.api_models import (MarketplaceImportJobResponse, MarketplaceImportRequest, StockResponse, from models.api_models import (MarketplaceImportJobResponse,
StockSummaryResponse, StockCreate, StockAdd, StockUpdate) MarketplaceImportRequest, StockAdd, StockCreate,
from models.database_models import User, MarketplaceImportJob, Shop StockResponse, StockSummaryResponse,
from app.services.stock_service import stock_service StockUpdate)
import logging from models.database_models import MarketplaceImportJob, Shop, User
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,11 +21,12 @@ logger = logging.getLogger(__name__)
# Stock Management Routes (Protected) # Stock Management Routes (Protected)
@router.post("/stock", response_model=StockResponse) @router.post("/stock", response_model=StockResponse)
def set_stock( def set_stock(
stock: StockCreate, stock: StockCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)""" """Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)"""
try: try:
@@ -37,9 +41,9 @@ def set_stock(
@router.post("/stock/add", response_model=StockResponse) @router.post("/stock/add", response_model=StockResponse)
def add_stock( def add_stock(
stock: StockAdd, stock: StockAdd,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)""" """Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)"""
try: try:
@@ -54,9 +58,9 @@ def add_stock(
@router.post("/stock/remove", response_model=StockResponse) @router.post("/stock/remove", response_model=StockResponse)
def remove_stock( def remove_stock(
stock: StockAdd, stock: StockAdd,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Remove quantity from existing stock for a GTIN at a specific location""" """Remove quantity from existing stock for a GTIN at a specific location"""
try: try:
@@ -71,9 +75,9 @@ def remove_stock(
@router.get("/stock/{gtin}", response_model=StockSummaryResponse) @router.get("/stock/{gtin}", response_model=StockSummaryResponse)
def get_stock_by_gtin( def get_stock_by_gtin(
gtin: str, gtin: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get all stock locations and total quantity for a specific GTIN""" """Get all stock locations and total quantity for a specific GTIN"""
try: try:
@@ -88,9 +92,9 @@ def get_stock_by_gtin(
@router.get("/stock/{gtin}/total") @router.get("/stock/{gtin}/total")
def get_total_stock( def get_total_stock(
gtin: str, gtin: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get total quantity in stock for a specific GTIN""" """Get total quantity in stock for a specific GTIN"""
try: try:
@@ -105,21 +109,17 @@ def get_total_stock(
@router.get("/stock", response_model=List[StockResponse]) @router.get("/stock", response_model=List[StockResponse])
def get_all_stock( def get_all_stock(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
location: Optional[str] = Query(None, description="Filter by location"), location: Optional[str] = Query(None, description="Filter by location"),
gtin: Optional[str] = Query(None, description="Filter by GTIN"), gtin: Optional[str] = Query(None, description="Filter by GTIN"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Get all stock entries with optional filtering""" """Get all stock entries with optional filtering"""
try: try:
result = stock_service.get_all_stock( result = stock_service.get_all_stock(
db=db, db=db, skip=skip, limit=limit, location=location, gtin=gtin
skip=skip,
limit=limit,
location=location,
gtin=gtin
) )
return result return result
except Exception as e: except Exception as e:
@@ -129,10 +129,10 @@ def get_all_stock(
@router.put("/stock/{stock_id}", response_model=StockResponse) @router.put("/stock/{stock_id}", response_model=StockResponse)
def update_stock( def update_stock(
stock_id: int, stock_id: int,
stock_update: StockUpdate, stock_update: StockUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Update stock quantity for a specific stock entry""" """Update stock quantity for a specific stock entry"""
try: try:
@@ -147,9 +147,9 @@ def update_stock(
@router.delete("/stock/{stock_id}") @router.delete("/stock/{stock_id}")
def delete_stock( def delete_stock(
stock_id: int, stock_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
"""Delete a stock entry""" """Delete a stock entry"""
try: try:

View File

@@ -1,6 +1,8 @@
# app/core/config.py # app/core/config.py
from pydantic_settings import BaseSettings # This is the correct import for Pydantic v2 from typing import List, Optional
from typing import Optional, List
from pydantic_settings import \
BaseSettings # This is the correct import for Pydantic v2
class Settings(BaseSettings): class Settings(BaseSettings):

View File

@@ -1,6 +1,6 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import sessionmaker
from .config import settings from .config import settings
engine = create_engine(settings.database_url) engine = create_engine(settings.database_url)

View File

@@ -1,11 +1,14 @@
import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from sqlalchemy import text from sqlalchemy import text
from .logging import setup_logging
from .database import engine, SessionLocal
from models.database_models import Base
import logging
from middleware.auth import AuthManager from middleware.auth import AuthManager
from models.database_models import Base
from .database import SessionLocal, engine
from .logging import setup_logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
auth_manager = AuthManager() auth_manager = AuthManager()
@@ -44,15 +47,29 @@ def create_indexes():
with engine.connect() as conn: with engine.connect() as conn:
try: try:
# User indexes # User indexes
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)")) conn.execute(
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)")) text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)")
)
conn.execute(
text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)")
)
# Product indexes # Product indexes
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)")) conn.execute(
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)")) text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)")
)
conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)"
)
)
# Stock indexes # Stock indexes
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)")) conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)"
)
)
conn.commit() conn.commit()
logger.info("Database indexes created successfully") logger.info("Database indexes created successfully")

View File

@@ -2,6 +2,7 @@
import logging import logging
import sys import sys
from pathlib import Path from pathlib import Path
from app.core.config import settings from app.core.config import settings
@@ -22,7 +23,7 @@ def setup_logging():
# Create formatters # Create formatters
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
) )
# Console handler # Console handler

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from fastapi import HTTPException
from datetime import datetime
import logging import logging
from datetime import datetime
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from models.database_models import User, MarketplaceImportJob, Shop from fastapi import HTTPException
from sqlalchemy.orm import Session
from models.api_models import MarketplaceImportJobResponse from models.api_models import MarketplaceImportJobResponse
from models.database_models import MarketplaceImportJob, Shop, User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,7 +18,9 @@ class AdminService:
"""Get paginated list of all users""" """Get paginated list of all users"""
return db.query(User).offset(skip).limit(limit).all() return db.query(User).offset(skip).limit(limit).all()
def toggle_user_status(self, db: Session, user_id: int, current_admin_id: int) -> Tuple[User, str]: def toggle_user_status(
self, db: Session, user_id: int, current_admin_id: int
) -> Tuple[User, str]:
""" """
Toggle user active status Toggle user active status
@@ -37,7 +40,9 @@ class AdminService:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
if user.id == current_admin_id: if user.id == current_admin_id:
raise HTTPException(status_code=400, detail="Cannot deactivate your own account") raise HTTPException(
status_code=400, detail="Cannot deactivate your own account"
)
user.is_active = not user.is_active user.is_active = not user.is_active
user.updated_at = datetime.utcnow() user.updated_at = datetime.utcnow()
@@ -45,10 +50,14 @@ class AdminService:
db.refresh(user) db.refresh(user)
status = "activated" if user.is_active else "deactivated" status = "activated" if user.is_active else "deactivated"
logger.info(f"User {user.username} has been {status} by admin {current_admin_id}") logger.info(
f"User {user.username} has been {status} by admin {current_admin_id}"
)
return user, f"User {user.username} has been {status}" return user, f"User {user.username} has been {status}"
def get_all_shops(self, db: Session, skip: int = 0, limit: int = 100) -> Tuple[List[Shop], int]: def get_all_shops(
self, db: Session, skip: int = 0, limit: int = 100
) -> Tuple[List[Shop], int]:
""" """
Get paginated list of all shops with total count Get paginated list of all shops with total count
@@ -119,13 +128,13 @@ class AdminService:
return shop, f"Shop {shop.shop_code} has been {status}" return shop, f"Shop {shop.shop_code} has been {status}"
def get_marketplace_import_jobs( def get_marketplace_import_jobs(
self, self,
db: Session, db: Session,
marketplace: Optional[str] = None, marketplace: Optional[str] = None,
shop_name: Optional[str] = None, shop_name: Optional[str] = None,
status: Optional[str] = None, status: Optional[str] = None,
skip: int = 0, skip: int = 0,
limit: int = 100 limit: int = 100,
) -> List[MarketplaceImportJobResponse]: ) -> List[MarketplaceImportJobResponse]:
""" """
Get filtered and paginated marketplace import jobs Get filtered and paginated marketplace import jobs
@@ -145,14 +154,21 @@ class AdminService:
# Apply filters # Apply filters
if marketplace: if marketplace:
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if shop_name: if shop_name:
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
if status: if status:
query = query.filter(MarketplaceImportJob.status == status) query = query.filter(MarketplaceImportJob.status == status)
# Order by creation date and apply pagination # Order by creation date and apply pagination
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() jobs = (
query.order_by(MarketplaceImportJob.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [ return [
MarketplaceImportJobResponse( MarketplaceImportJobResponse(
@@ -168,8 +184,9 @@ class AdminService:
error_message=job.error_message, error_message=job.error_message,
created_at=job.created_at, created_at=job.created_at,
started_at=job.started_at, started_at=job.started_at,
completed_at=job.completed_at completed_at=job.completed_at,
) for job in jobs )
for job in jobs
] ]
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]: def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from fastapi import HTTPException
import logging import logging
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
from fastapi import HTTPException
from sqlalchemy.orm import Session
from models.database_models import User
from models.api_models import UserRegister, UserLogin
from middleware.auth import AuthManager from middleware.auth import AuthManager
from models.api_models import UserLogin, UserRegister
from models.database_models import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,7 +37,9 @@ class AuthService:
raise HTTPException(status_code=400, detail="Email already registered") raise HTTPException(status_code=400, detail="Email already registered")
# Check if username already exists # Check if username already exists
existing_username = db.query(User).filter(User.username == user_data.username).first() existing_username = (
db.query(User).filter(User.username == user_data.username).first()
)
if existing_username: if existing_username:
raise HTTPException(status_code=400, detail="Username already taken") raise HTTPException(status_code=400, detail="Username already taken")
@@ -47,7 +50,7 @@ class AuthService:
username=user_data.username, username=user_data.username,
hashed_password=hashed_password, hashed_password=hashed_password,
role="user", role="user",
is_active=True is_active=True,
) )
db.add(new_user) db.add(new_user)
@@ -71,19 +74,20 @@ class AuthService:
Raises: Raises:
HTTPException: If authentication fails HTTPException: If authentication fails
""" """
user = self.auth_manager.authenticate_user(db, user_credentials.username, user_credentials.password) user = self.auth_manager.authenticate_user(
db, user_credentials.username, user_credentials.password
)
if not user: if not user:
raise HTTPException(status_code=401, detail="Incorrect username or password") raise HTTPException(
status_code=401, detail="Incorrect username or password"
)
# Create access token # Create access token
token_data = self.auth_manager.create_access_token(user) token_data = self.auth_manager.create_access_token(user)
logger.info(f"User logged in: {user.username}") logger.info(f"User logged in: {user.username}")
return { return {"token_data": token_data, "user": user}
"token_data": token_data,
"user": user
}
def get_user_by_email(self, db: Session, email: str) -> Optional[User]: def get_user_by_email(self, db: Session, email: str) -> Optional[User]:
"""Get user by email""" """Get user by email"""
@@ -101,7 +105,9 @@ class AuthService:
"""Check if username already exists""" """Check if username already exists"""
return db.query(User).filter(User.username == username).first() is not None return db.query(User).filter(User.username == username).first() is not None
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: def authenticate_user(
self, db: Session, username: str, password: str
) -> Optional[User]:
"""Authenticate user with username/password""" """Authenticate user with username/password"""
return self.auth_manager.authenticate_user(db, username, password) return self.auth_manager.authenticate_user(db, username, password)

View File

@@ -1,10 +1,13 @@
import logging
from datetime import datetime
from typing import List, Optional
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from models.api_models import (MarketplaceImportJobResponse,
MarketplaceImportRequest)
from models.database_models import MarketplaceImportJob, Shop, User from models.database_models import MarketplaceImportJob, Shop, User
from models.api_models import MarketplaceImportRequest, MarketplaceImportJobResponse
from typing import Optional, List
from datetime import datetime
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,9 +20,11 @@ class MarketplaceService:
"""Validate that the shop exists and user has access to it""" """Validate that the shop exists and user has access to it"""
# Explicit type hint to help type checker shop: Optional[Shop] # Explicit type hint to help type checker shop: Optional[Shop]
# Use case-insensitive query to handle both uppercase and lowercase codes # Use case-insensitive query to handle both uppercase and lowercase codes
shop: Optional[Shop] = db.query(Shop).filter( shop: Optional[Shop] = (
func.upper(Shop.shop_code) == shop_code.upper() db.query(Shop)
).first() .filter(func.upper(Shop.shop_code) == shop_code.upper())
.first()
)
if not shop: if not shop:
raise ValueError("Shop not found") raise ValueError("Shop not found")
@@ -30,10 +35,7 @@ class MarketplaceService:
return shop return shop
def create_import_job( def create_import_job(
self, self, db: Session, request: MarketplaceImportRequest, user: User
db: Session,
request: MarketplaceImportRequest,
user: User
) -> MarketplaceImportJob: ) -> MarketplaceImportJob:
"""Create a new marketplace import job""" """Create a new marketplace import job"""
# Validate shop access first # Validate shop access first
@@ -47,7 +49,7 @@ class MarketplaceService:
shop_id=shop.id, # Foreign key to shops table shop_id=shop.id, # Foreign key to shops table
shop_name=shop.shop_name, # Use shop.shop_name (the display name) shop_name=shop.shop_name, # Use shop.shop_name (the display name)
user_id=user.id, user_id=user.id,
created_at=datetime.utcnow() created_at=datetime.utcnow(),
) )
db.add(import_job) db.add(import_job)
@@ -55,13 +57,20 @@ class MarketplaceService:
db.refresh(import_job) db.refresh(import_job)
logger.info( logger.info(
f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}") f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}"
)
return import_job return import_job
def get_import_job_by_id(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob: def get_import_job_by_id(
self, db: Session, job_id: int, user: User
) -> MarketplaceImportJob:
"""Get a marketplace import job by ID with access control""" """Get a marketplace import job by ID with access control"""
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job: if not job:
raise ValueError("Marketplace import job not found") raise ValueError("Marketplace import job not found")
@@ -72,13 +81,13 @@ class MarketplaceService:
return job return job
def get_import_jobs( def get_import_jobs(
self, self,
db: Session, db: Session,
user: User, user: User,
marketplace: Optional[str] = None, marketplace: Optional[str] = None,
shop_name: Optional[str] = None, shop_name: Optional[str] = None,
skip: int = 0, skip: int = 0,
limit: int = 50 limit: int = 50,
) -> List[MarketplaceImportJob]: ) -> List[MarketplaceImportJob]:
"""Get marketplace import jobs with filtering and access control""" """Get marketplace import jobs with filtering and access control"""
query = db.query(MarketplaceImportJob) query = db.query(MarketplaceImportJob)
@@ -89,44 +98,51 @@ class MarketplaceService:
# Apply filters # Apply filters
if marketplace: if marketplace:
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if shop_name: if shop_name:
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
# Order by creation date (newest first) and apply pagination # Order by creation date (newest first) and apply pagination
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all() jobs = (
query.order_by(MarketplaceImportJob.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return jobs return jobs
def update_job_status( def update_job_status(
self, self, db: Session, job_id: int, status: str, **kwargs
db: Session,
job_id: int,
status: str,
**kwargs
) -> MarketplaceImportJob: ) -> MarketplaceImportJob:
"""Update marketplace import job status and other fields""" """Update marketplace import job status and other fields"""
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job: if not job:
raise ValueError("Marketplace import job not found") raise ValueError("Marketplace import job not found")
job.status = status job.status = status
# Update optional fields if provided # Update optional fields if provided
if 'imported_count' in kwargs: if "imported_count" in kwargs:
job.imported_count = kwargs['imported_count'] job.imported_count = kwargs["imported_count"]
if 'updated_count' in kwargs: if "updated_count" in kwargs:
job.updated_count = kwargs['updated_count'] job.updated_count = kwargs["updated_count"]
if 'total_processed' in kwargs: if "total_processed" in kwargs:
job.total_processed = kwargs['total_processed'] job.total_processed = kwargs["total_processed"]
if 'error_count' in kwargs: if "error_count" in kwargs:
job.error_count = kwargs['error_count'] job.error_count = kwargs["error_count"]
if 'error_message' in kwargs: if "error_message" in kwargs:
job.error_message = kwargs['error_message'] job.error_message = kwargs["error_message"]
if 'started_at' in kwargs: if "started_at" in kwargs:
job.started_at = kwargs['started_at'] job.started_at = kwargs["started_at"]
if 'completed_at' in kwargs: if "completed_at" in kwargs:
job.completed_at = kwargs['completed_at'] job.completed_at = kwargs["completed_at"]
db.commit() db.commit()
db.refresh(job) db.refresh(job)
@@ -145,7 +161,9 @@ class MarketplaceService:
total_jobs = query.count() total_jobs = query.count()
pending_jobs = query.filter(MarketplaceImportJob.status == "pending").count() pending_jobs = query.filter(MarketplaceImportJob.status == "pending").count()
running_jobs = query.filter(MarketplaceImportJob.status == "running").count() running_jobs = query.filter(MarketplaceImportJob.status == "running").count()
completed_jobs = query.filter(MarketplaceImportJob.status == "completed").count() completed_jobs = query.filter(
MarketplaceImportJob.status == "completed"
).count()
failed_jobs = query.filter(MarketplaceImportJob.status == "failed").count() failed_jobs = query.filter(MarketplaceImportJob.status == "failed").count()
return { return {
@@ -153,17 +171,21 @@ class MarketplaceService:
"pending_jobs": pending_jobs, "pending_jobs": pending_jobs,
"running_jobs": running_jobs, "running_jobs": running_jobs,
"completed_jobs": completed_jobs, "completed_jobs": completed_jobs,
"failed_jobs": failed_jobs "failed_jobs": failed_jobs,
} }
def convert_to_response_model(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse: def convert_to_response_model(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to API response model""" """Convert database model to API response model"""
return MarketplaceImportJobResponse( return MarketplaceImportJobResponse(
job_id=job.id, job_id=job.id,
status=job.status, status=job.status,
marketplace=job.marketplace, marketplace=job.marketplace,
shop_id=job.shop_id, shop_id=job.shop_id,
shop_code=job.shop.shop_code if job.shop else None, # Add this optional field via relationship shop_code=(
job.shop.shop_code if job.shop else None
), # Add this optional field via relationship
shop_name=job.shop_name, shop_name=job.shop_name,
imported=job.imported_count or 0, imported=job.imported_count or 0,
updated=job.updated_count or 0, updated=job.updated_count or 0,
@@ -172,10 +194,12 @@ class MarketplaceService:
error_message=job.error_message, error_message=job.error_message,
created_at=job.created_at, created_at=job.created_at,
started_at=job.started_at, started_at=job.started_at,
completed_at=job.completed_at completed_at=job.completed_at,
) )
def cancel_import_job(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob: def cancel_import_job(
self, db: Session, job_id: int, user: User
) -> MarketplaceImportJob:
"""Cancel a pending or running import job""" """Cancel a pending or running import job"""
job = self.get_import_job_by_id(db, job_id, user) job = self.get_import_job_by_id(db, job_id, user)
@@ -197,7 +221,9 @@ class MarketplaceService:
# Only allow deletion of completed, failed, or cancelled jobs # Only allow deletion of completed, failed, or cancelled jobs
if job.status in ["pending", "running"]: if job.status in ["pending", "running"]:
raise ValueError(f"Cannot delete job with status: {job.status}. Cancel it first.") raise ValueError(
f"Cannot delete job with status: {job.status}. Cancel it first."
)
db.delete(job) db.delete(job)
db.commit() db.commit()

View File

@@ -1,11 +1,14 @@
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from models.database_models import Product, Stock
from models.api_models import ProductCreate, ProductUpdate, StockLocationResponse, StockSummaryResponse
from utils.data_processing import GTINProcessor, PriceProcessor
from typing import Optional, List, Generator
from datetime import datetime
import logging import logging
from datetime import datetime
from typing import Generator, List, Optional
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from models.api_models import (ProductCreate, ProductUpdate,
StockLocationResponse, StockSummaryResponse)
from models.database_models import Product, Stock
from utils.data_processing import GTINProcessor, PriceProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,7 +30,9 @@ class ProductService:
# Process price if provided # Process price if provided
if product_data.price: if product_data.price:
parsed_price, currency = self.price_processor.parse_price_currency(product_data.price) parsed_price, currency = self.price_processor.parse_price_currency(
product_data.price
)
if parsed_price: if parsed_price:
product_data.price = parsed_price product_data.price = parsed_price
product_data.currency = currency product_data.currency = currency
@@ -58,16 +63,16 @@ class ProductService:
return db.query(Product).filter(Product.product_id == product_id).first() return db.query(Product).filter(Product.product_id == product_id).first()
def get_products_with_filters( def get_products_with_filters(
self, self,
db: Session, db: Session,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
brand: Optional[str] = None, brand: Optional[str] = None,
category: Optional[str] = None, category: Optional[str] = None,
availability: Optional[str] = None, availability: Optional[str] = None,
marketplace: Optional[str] = None, marketplace: Optional[str] = None,
shop_name: Optional[str] = None, shop_name: Optional[str] = None,
search: Optional[str] = None search: Optional[str] = None,
) -> tuple[List[Product], int]: ) -> tuple[List[Product], int]:
"""Get products with filtering and pagination""" """Get products with filtering and pagination"""
query = db.query(Product) query = db.query(Product)
@@ -87,10 +92,10 @@ class ProductService:
# Search in title, description, marketplace, and shop_name # Search in title, description, marketplace, and shop_name
search_term = f"%{search}%" search_term = f"%{search}%"
query = query.filter( query = query.filter(
(Product.title.ilike(search_term)) | (Product.title.ilike(search_term))
(Product.description.ilike(search_term)) | | (Product.description.ilike(search_term))
(Product.marketplace.ilike(search_term)) | | (Product.marketplace.ilike(search_term))
(Product.shop_name.ilike(search_term)) | (Product.shop_name.ilike(search_term))
) )
total = query.count() total = query.count()
@@ -98,7 +103,9 @@ class ProductService:
return products, total return products, total
def update_product(self, db: Session, product_id: str, product_update: ProductUpdate) -> Product: def update_product(
self, db: Session, product_id: str, product_update: ProductUpdate
) -> Product:
"""Update product with validation""" """Update product with validation"""
product = db.query(Product).filter(Product.product_id == product_id).first() product = db.query(Product).filter(Product.product_id == product_id).first()
if not product: if not product:
@@ -116,7 +123,9 @@ class ProductService:
# Process price if being updated # Process price if being updated
if "price" in update_data and update_data["price"]: if "price" in update_data and update_data["price"]:
parsed_price, currency = self.price_processor.parse_price_currency(update_data["price"]) parsed_price, currency = self.price_processor.parse_price_currency(
update_data["price"]
)
if parsed_price: if parsed_price:
update_data["price"] = parsed_price update_data["price"] = parsed_price
update_data["currency"] = currency update_data["currency"] = currency
@@ -160,21 +169,21 @@ class ProductService:
] ]
return StockSummaryResponse( return StockSummaryResponse(
gtin=gtin, gtin=gtin, total_quantity=total_quantity, locations=locations
total_quantity=total_quantity,
locations=locations
) )
def generate_csv_export( def generate_csv_export(
self, self,
db: Session, db: Session,
marketplace: Optional[str] = None, marketplace: Optional[str] = None,
shop_name: Optional[str] = None shop_name: Optional[str] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
"""Generate CSV export with streaming for memory efficiency""" """Generate CSV export with streaming for memory efficiency"""
# CSV header # CSV header
yield ("product_id,title,description,link,image_link,availability,price,currency,brand," yield (
"gtin,marketplace,shop_name\n") "product_id,title,description,link,image_link,availability,price,currency,brand,"
"gtin,marketplace,shop_name\n"
)
batch_size = 1000 batch_size = 1000
offset = 0 offset = 0
@@ -194,17 +203,22 @@ class ProductService:
for product in products: for product in products:
# Create CSV row with marketplace fields # Create CSV row with marketplace fields
row = (f'"{product.product_id}","{product.title or ""}","{product.description or ""}",' row = (
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",' f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",' f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n') f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n'
)
yield row yield row
offset += batch_size offset += batch_size
def product_exists(self, db: Session, product_id: str) -> bool: def product_exists(self, db: Session, product_id: str) -> bool:
"""Check if product exists by ID""" """Check if product exists by ID"""
return db.query(Product).filter(Product.product_id == product_id).first() is not None return (
db.query(Product).filter(Product.product_id == product_id).first()
is not None
)
# Create service instance # Create service instance

View File

@@ -1,12 +1,13 @@
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from fastapi import HTTPException
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException
from datetime import datetime
import logging
from typing import List, Optional, Tuple, Dict, Any
from models.database_models import User, Shop, Product, ShopProduct
from models.api_models import ShopCreate, ShopProductCreate from models.api_models import ShopCreate, ShopProductCreate
from models.database_models import Product, Shop, ShopProduct, User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,7 +15,9 @@ logger = logging.getLogger(__name__)
class ShopService: class ShopService:
"""Service class for shop operations following the application's service pattern""" """Service class for shop operations following the application's service pattern"""
def create_shop(self, db: Session, shop_data: ShopCreate, current_user: User) -> Shop: def create_shop(
self, db: Session, shop_data: ShopCreate, current_user: User
) -> Shop:
""" """
Create a new shop Create a new shop
@@ -33,39 +36,43 @@ class ShopService:
normalized_shop_code = shop_data.shop_code.upper() normalized_shop_code = shop_data.shop_code.upper()
# Check if shop code already exists (case-insensitive check against existing data) # Check if shop code already exists (case-insensitive check against existing data)
existing_shop = db.query(Shop).filter( existing_shop = (
func.upper(Shop.shop_code) == normalized_shop_code db.query(Shop)
).first() .filter(func.upper(Shop.shop_code) == normalized_shop_code)
.first()
)
if existing_shop: if existing_shop:
raise HTTPException(status_code=400, detail="Shop code already exists") raise HTTPException(status_code=400, detail="Shop code already exists")
# Create shop with uppercase code # Create shop with uppercase code
shop_dict = shop_data.model_dump() # Fixed deprecated .dict() method shop_dict = shop_data.model_dump() # Fixed deprecated .dict() method
shop_dict['shop_code'] = normalized_shop_code # Store as uppercase shop_dict["shop_code"] = normalized_shop_code # Store as uppercase
new_shop = Shop( new_shop = Shop(
**shop_dict, **shop_dict,
owner_id=current_user.id, owner_id=current_user.id,
is_active=True, is_active=True,
is_verified=(current_user.role == "admin") is_verified=(current_user.role == "admin"),
) )
db.add(new_shop) db.add(new_shop)
db.commit() db.commit()
db.refresh(new_shop) db.refresh(new_shop)
logger.info(f"New shop created: {new_shop.shop_code} by {current_user.username}") logger.info(
f"New shop created: {new_shop.shop_code} by {current_user.username}"
)
return new_shop return new_shop
def get_shops( def get_shops(
self, self,
db: Session, db: Session,
current_user: User, current_user: User,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
active_only: bool = True, active_only: bool = True,
verified_only: bool = False verified_only: bool = False,
) -> Tuple[List[Shop], int]: ) -> Tuple[List[Shop], int]:
""" """
Get shops with filtering Get shops with filtering
@@ -86,8 +93,8 @@ class ShopService:
# Non-admin users can only see active and verified shops, plus their own # Non-admin users can only see active and verified shops, plus their own
if current_user.role != "admin": if current_user.role != "admin":
query = query.filter( query = query.filter(
(Shop.is_active == True) & (Shop.is_active == True)
((Shop.is_verified == True) | (Shop.owner_id == current_user.id)) & ((Shop.is_verified == True) | (Shop.owner_id == current_user.id))
) )
else: else:
# Admin can apply filters # Admin can apply filters
@@ -117,22 +124,25 @@ class ShopService:
HTTPException: If shop not found or access denied HTTPException: If shop not found or access denied
""" """
# Explicit type hint to help type checker shop: Optional[Shop] # Explicit type hint to help type checker shop: Optional[Shop]
shop: Optional[Shop] = db.query(Shop).filter(func.upper(Shop.shop_code) == shop_code.upper()).first() shop: Optional[Shop] = (
db.query(Shop)
.filter(func.upper(Shop.shop_code) == shop_code.upper())
.first()
)
if not shop: if not shop:
raise HTTPException(status_code=404, detail="Shop not found") raise HTTPException(status_code=404, detail="Shop not found")
# Non-admin users can only see active verified shops or their own shops # Non-admin users can only see active verified shops or their own shops
if current_user.role != "admin": if current_user.role != "admin":
if not shop.is_active or (not shop.is_verified and shop.owner_id != current_user.id): if not shop.is_active or (
not shop.is_verified and shop.owner_id != current_user.id
):
raise HTTPException(status_code=404, detail="Shop not found") raise HTTPException(status_code=404, detail="Shop not found")
return shop return shop
def add_product_to_shop( def add_product_to_shop(
self, self, db: Session, shop: Shop, shop_product: ShopProductCreate
db: Session,
shop: Shop,
shop_product: ShopProductCreate
) -> ShopProduct: ) -> ShopProduct:
""" """
Add existing product to shop catalog with shop-specific settings Add existing product to shop catalog with shop-specific settings
@@ -149,24 +159,35 @@ class ShopService:
HTTPException: If product not found or already in shop HTTPException: If product not found or already in shop
""" """
# Check if product exists # Check if product exists
product = db.query(Product).filter(Product.product_id == shop_product.product_id).first() product = (
db.query(Product)
.filter(Product.product_id == shop_product.product_id)
.first()
)
if not product: if not product:
raise HTTPException(status_code=404, detail="Product not found in marketplace catalog") raise HTTPException(
status_code=404, detail="Product not found in marketplace catalog"
)
# Check if product already in shop # Check if product already in shop
existing_shop_product = db.query(ShopProduct).filter( existing_shop_product = (
ShopProduct.shop_id == shop.id, db.query(ShopProduct)
ShopProduct.product_id == product.id .filter(
).first() ShopProduct.shop_id == shop.id, ShopProduct.product_id == product.id
)
.first()
)
if existing_shop_product: if existing_shop_product:
raise HTTPException(status_code=400, detail="Product already in shop catalog") raise HTTPException(
status_code=400, detail="Product already in shop catalog"
)
# Create shop-product association # Create shop-product association
new_shop_product = ShopProduct( new_shop_product = ShopProduct(
shop_id=shop.id, shop_id=shop.id,
product_id=product.id, product_id=product.id,
**shop_product.model_dump(exclude={'product_id'}) **shop_product.model_dump(exclude={"product_id"}),
) )
db.add(new_shop_product) db.add(new_shop_product)
@@ -180,14 +201,14 @@ class ShopService:
return new_shop_product return new_shop_product
def get_shop_products( def get_shop_products(
self, self,
db: Session, db: Session,
shop: Shop, shop: Shop,
current_user: User, current_user: User,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
active_only: bool = True, active_only: bool = True,
featured_only: bool = False featured_only: bool = False,
) -> Tuple[List[ShopProduct], int]: ) -> Tuple[List[ShopProduct], int]:
""" """
Get products in shop catalog with filtering Get products in shop catalog with filtering
@@ -239,10 +260,14 @@ class ShopService:
def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool: def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool:
"""Check if product is already in shop""" """Check if product is already in shop"""
return db.query(ShopProduct).filter( return (
ShopProduct.shop_id == shop_id, db.query(ShopProduct)
ShopProduct.product_id == product_id .filter(
).first() is not None ShopProduct.shop_id == shop_id, ShopProduct.product_id == product_id
)
.first()
is not None
)
def is_shop_owner(self, shop: Shop, user: User) -> bool: def is_shop_owner(self, shop: Shop, user: User) -> bool:
"""Check if user is shop owner""" """Check if user is shop owner"""

View File

@@ -1,10 +1,11 @@
import logging
from typing import Any, Dict, List
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import logging
from typing import List, Dict, Any
from models.database_models import User, Product, Stock from models.api_models import MarketplaceStatsResponse, StatsResponse
from models.api_models import StatsResponse, MarketplaceStatsResponse from models.database_models import Product, Stock, User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,26 +26,37 @@ class StatsService:
# Use more efficient queries with proper indexes # Use more efficient queries with proper indexes
total_products = db.query(Product).count() total_products = db.query(Product).count()
unique_brands = db.query(Product.brand).filter( unique_brands = (
Product.brand.isnot(None), db.query(Product.brand)
Product.brand != "" .filter(Product.brand.isnot(None), Product.brand != "")
).distinct().count() .distinct()
.count()
)
unique_categories = db.query(Product.google_product_category).filter( unique_categories = (
Product.google_product_category.isnot(None), db.query(Product.google_product_category)
Product.google_product_category != "" .filter(
).distinct().count() Product.google_product_category.isnot(None),
Product.google_product_category != "",
)
.distinct()
.count()
)
# New marketplace statistics # New marketplace statistics
unique_marketplaces = db.query(Product.marketplace).filter( unique_marketplaces = (
Product.marketplace.isnot(None), db.query(Product.marketplace)
Product.marketplace != "" .filter(Product.marketplace.isnot(None), Product.marketplace != "")
).distinct().count() .distinct()
.count()
)
unique_shops = db.query(Product.shop_name).filter( unique_shops = (
Product.shop_name.isnot(None), db.query(Product.shop_name)
Product.shop_name != "" .filter(Product.shop_name.isnot(None), Product.shop_name != "")
).distinct().count() .distinct()
.count()
)
# Stock statistics # Stock statistics
total_stock_entries = db.query(Stock).count() total_stock_entries = db.query(Stock).count()
@@ -57,10 +69,12 @@ class StatsService:
"unique_marketplaces": unique_marketplaces, "unique_marketplaces": unique_marketplaces,
"unique_shops": unique_shops, "unique_shops": unique_shops,
"total_stock_entries": total_stock_entries, "total_stock_entries": total_stock_entries,
"total_inventory_quantity": total_inventory "total_inventory_quantity": total_inventory,
} }
logger.info(f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces") logger.info(
f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces"
)
return stats_data return stats_data
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]: def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
@@ -74,25 +88,31 @@ class StatsService:
List of dictionaries containing marketplace statistics List of dictionaries containing marketplace statistics
""" """
# Query to get stats per marketplace # Query to get stats per marketplace
marketplace_stats = db.query( marketplace_stats = (
Product.marketplace, db.query(
func.count(Product.id).label('total_products'), Product.marketplace,
func.count(func.distinct(Product.shop_name)).label('unique_shops'), func.count(Product.id).label("total_products"),
func.count(func.distinct(Product.brand)).label('unique_brands') func.count(func.distinct(Product.shop_name)).label("unique_shops"),
).filter( func.count(func.distinct(Product.brand)).label("unique_brands"),
Product.marketplace.isnot(None) )
).group_by(Product.marketplace).all() .filter(Product.marketplace.isnot(None))
.group_by(Product.marketplace)
.all()
)
stats_list = [ stats_list = [
{ {
"marketplace": stat.marketplace, "marketplace": stat.marketplace,
"total_products": stat.total_products, "total_products": stat.total_products,
"unique_shops": stat.unique_shops, "unique_shops": stat.unique_shops,
"unique_brands": stat.unique_brands "unique_brands": stat.unique_brands,
} for stat in marketplace_stats }
for stat in marketplace_stats
] ]
logger.info(f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces") logger.info(
f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces"
)
return stats_list return stats_list
def get_product_count(self, db: Session) -> int: def get_product_count(self, db: Session) -> int:
@@ -101,31 +121,42 @@ class StatsService:
def get_unique_brands_count(self, db: Session) -> int: def get_unique_brands_count(self, db: Session) -> int:
"""Get count of unique brands""" """Get count of unique brands"""
return db.query(Product.brand).filter( return (
Product.brand.isnot(None), db.query(Product.brand)
Product.brand != "" .filter(Product.brand.isnot(None), Product.brand != "")
).distinct().count() .distinct()
.count()
)
def get_unique_categories_count(self, db: Session) -> int: def get_unique_categories_count(self, db: Session) -> int:
"""Get count of unique categories""" """Get count of unique categories"""
return db.query(Product.google_product_category).filter( return (
Product.google_product_category.isnot(None), db.query(Product.google_product_category)
Product.google_product_category != "" .filter(
).distinct().count() Product.google_product_category.isnot(None),
Product.google_product_category != "",
)
.distinct()
.count()
)
def get_unique_marketplaces_count(self, db: Session) -> int: def get_unique_marketplaces_count(self, db: Session) -> int:
"""Get count of unique marketplaces""" """Get count of unique marketplaces"""
return db.query(Product.marketplace).filter( return (
Product.marketplace.isnot(None), db.query(Product.marketplace)
Product.marketplace != "" .filter(Product.marketplace.isnot(None), Product.marketplace != "")
).distinct().count() .distinct()
.count()
)
def get_unique_shops_count(self, db: Session) -> int: def get_unique_shops_count(self, db: Session) -> int:
"""Get count of unique shops""" """Get count of unique shops"""
return db.query(Product.shop_name).filter( return (
Product.shop_name.isnot(None), db.query(Product.shop_name)
Product.shop_name != "" .filter(Product.shop_name.isnot(None), Product.shop_name != "")
).distinct().count() .distinct()
.count()
)
def get_stock_statistics(self, db: Session) -> Dict[str, int]: def get_stock_statistics(self, db: Session) -> Dict[str, int]:
""" """
@@ -142,25 +173,35 @@ class StatsService:
return { return {
"total_stock_entries": total_stock_entries, "total_stock_entries": total_stock_entries,
"total_inventory_quantity": total_inventory "total_inventory_quantity": total_inventory,
} }
def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]: def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
"""Get unique brands for a specific marketplace""" """Get unique brands for a specific marketplace"""
brands = db.query(Product.brand).filter( brands = (
Product.marketplace == marketplace, db.query(Product.brand)
Product.brand.isnot(None), .filter(
Product.brand != "" Product.marketplace == marketplace,
).distinct().all() Product.brand.isnot(None),
Product.brand != "",
)
.distinct()
.all()
)
return [brand[0] for brand in brands] return [brand[0] for brand in brands]
def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]: def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
"""Get unique shops for a specific marketplace""" """Get unique shops for a specific marketplace"""
shops = db.query(Product.shop_name).filter( shops = (
Product.marketplace == marketplace, db.query(Product.shop_name)
Product.shop_name.isnot(None), .filter(
Product.shop_name != "" Product.marketplace == marketplace,
).distinct().all() Product.shop_name.isnot(None),
Product.shop_name != "",
)
.distinct()
.all()
)
return [shop[0] for shop in shops] return [shop[0] for shop in shops]
def get_products_by_marketplace(self, db: Session, marketplace: str) -> int: def get_products_by_marketplace(self, db: Session, marketplace: str) -> int:

View File

@@ -1,10 +1,13 @@
from sqlalchemy.orm import Session
from models.database_models import Stock, Product
from models.api_models import StockCreate, StockAdd, StockUpdate, StockLocationResponse, StockSummaryResponse
from utils.data_processing import GTINProcessor
from typing import Optional, List, Tuple
from datetime import datetime
import logging import logging
from datetime import datetime
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from models.api_models import (StockAdd, StockCreate, StockLocationResponse,
StockSummaryResponse, StockUpdate)
from models.database_models import Product, Stock
from utils.data_processing import GTINProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,10 +29,11 @@ class StockService:
location = stock_data.location.strip().upper() location = stock_data.location.strip().upper()
# Check if stock entry already exists for this GTIN and location # Check if stock entry already exists for this GTIN and location
existing_stock = db.query(Stock).filter( existing_stock = (
Stock.gtin == normalized_gtin, db.query(Stock)
Stock.location == location .filter(Stock.gtin == normalized_gtin, Stock.location == location)
).first() .first()
)
if existing_stock: if existing_stock:
# Update existing stock (SET to exact quantity) # Update existing stock (SET to exact quantity)
@@ -39,19 +43,20 @@ class StockService:
db.commit() db.commit()
db.refresh(existing_stock) db.refresh(existing_stock)
logger.info( logger.info(
f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity}{stock_data.quantity}") f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity}{stock_data.quantity}"
)
return existing_stock return existing_stock
else: else:
# Create new stock entry # Create new stock entry
new_stock = Stock( new_stock = Stock(
gtin=normalized_gtin, gtin=normalized_gtin, location=location, quantity=stock_data.quantity
location=location,
quantity=stock_data.quantity
) )
db.add(new_stock) db.add(new_stock)
db.commit() db.commit()
db.refresh(new_stock) db.refresh(new_stock)
logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}") logger.info(
f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}"
)
return new_stock return new_stock
def add_stock(self, db: Session, stock_data: StockAdd) -> Stock: def add_stock(self, db: Session, stock_data: StockAdd) -> Stock:
@@ -63,10 +68,11 @@ class StockService:
location = stock_data.location.strip().upper() location = stock_data.location.strip().upper()
# Check if stock entry already exists for this GTIN and location # Check if stock entry already exists for this GTIN and location
existing_stock = db.query(Stock).filter( existing_stock = (
Stock.gtin == normalized_gtin, db.query(Stock)
Stock.location == location .filter(Stock.gtin == normalized_gtin, Stock.location == location)
).first() .first()
)
if existing_stock: if existing_stock:
# Add to existing stock # Add to existing stock
@@ -76,19 +82,20 @@ class StockService:
db.commit() db.commit()
db.refresh(existing_stock) db.refresh(existing_stock)
logger.info( logger.info(
f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}") f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}"
)
return existing_stock return existing_stock
else: else:
# Create new stock entry with the quantity # Create new stock entry with the quantity
new_stock = Stock( new_stock = Stock(
gtin=normalized_gtin, gtin=normalized_gtin, location=location, quantity=stock_data.quantity
location=location,
quantity=stock_data.quantity
) )
db.add(new_stock) db.add(new_stock)
db.commit() db.commit()
db.refresh(new_stock) db.refresh(new_stock)
logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}") logger.info(
f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}"
)
return new_stock return new_stock
def remove_stock(self, db: Session, stock_data: StockAdd) -> Stock: def remove_stock(self, db: Session, stock_data: StockAdd) -> Stock:
@@ -100,18 +107,22 @@ class StockService:
location = stock_data.location.strip().upper() location = stock_data.location.strip().upper()
# Find existing stock entry # Find existing stock entry
existing_stock = db.query(Stock).filter( existing_stock = (
Stock.gtin == normalized_gtin, db.query(Stock)
Stock.location == location .filter(Stock.gtin == normalized_gtin, Stock.location == location)
).first() .first()
)
if not existing_stock: if not existing_stock:
raise ValueError(f"No stock found for GTIN {normalized_gtin} at location {location}") raise ValueError(
f"No stock found for GTIN {normalized_gtin} at location {location}"
)
# Check if we have enough stock to remove # Check if we have enough stock to remove
if existing_stock.quantity < stock_data.quantity: if existing_stock.quantity < stock_data.quantity:
raise ValueError( raise ValueError(
f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}") f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}"
)
# Remove from existing stock # Remove from existing stock
old_quantity = existing_stock.quantity old_quantity = existing_stock.quantity
@@ -120,7 +131,8 @@ class StockService:
db.commit() db.commit()
db.refresh(existing_stock) db.refresh(existing_stock)
logger.info( logger.info(
f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}") f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}"
)
return existing_stock return existing_stock
def get_stock_by_gtin(self, db: Session, gtin: str) -> StockSummaryResponse: def get_stock_by_gtin(self, db: Session, gtin: str) -> StockSummaryResponse:
@@ -141,10 +153,9 @@ class StockService:
for entry in stock_entries: for entry in stock_entries:
total_quantity += entry.quantity total_quantity += entry.quantity
locations.append(StockLocationResponse( locations.append(
location=entry.location, StockLocationResponse(location=entry.location, quantity=entry.quantity)
quantity=entry.quantity )
))
# Try to get product title for reference # Try to get product title for reference
product = db.query(Product).filter(Product.gtin == normalized_gtin).first() product = db.query(Product).filter(Product.gtin == normalized_gtin).first()
@@ -154,7 +165,7 @@ class StockService:
gtin=normalized_gtin, gtin=normalized_gtin,
total_quantity=total_quantity, total_quantity=total_quantity,
locations=locations, locations=locations,
product_title=product_title product_title=product_title,
) )
def get_total_stock(self, db: Session, gtin: str) -> dict: def get_total_stock(self, db: Session, gtin: str) -> dict:
@@ -174,16 +185,16 @@ class StockService:
"gtin": normalized_gtin, "gtin": normalized_gtin,
"total_quantity": total_quantity, "total_quantity": total_quantity,
"product_title": product.title if product else None, "product_title": product.title if product else None,
"locations_count": len(total_stock) "locations_count": len(total_stock),
} }
def get_all_stock( def get_all_stock(
self, self,
db: Session, db: Session,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
location: Optional[str] = None, location: Optional[str] = None,
gtin: Optional[str] = None gtin: Optional[str] = None,
) -> List[Stock]: ) -> List[Stock]:
"""Get all stock entries with optional filtering""" """Get all stock entries with optional filtering"""
query = db.query(Stock) query = db.query(Stock)
@@ -198,7 +209,9 @@ class StockService:
return query.offset(skip).limit(limit).all() return query.offset(skip).limit(limit).all()
def update_stock(self, db: Session, stock_id: int, stock_update: StockUpdate) -> Stock: def update_stock(
self, db: Session, stock_id: int, stock_update: StockUpdate
) -> Stock:
"""Update stock quantity for a specific stock entry""" """Update stock quantity for a specific stock entry"""
stock_entry = db.query(Stock).filter(Stock.id == stock_id).first() stock_entry = db.query(Stock).filter(Stock.id == stock_id).first()
if not stock_entry: if not stock_entry:
@@ -209,7 +222,9 @@ class StockService:
db.commit() db.commit()
db.refresh(stock_entry) db.refresh(stock_entry)
logger.info(f"Updated stock entry {stock_id} to quantity {stock_update.quantity}") logger.info(
f"Updated stock entry {stock_id} to quantity {stock_update.quantity}"
)
return stock_entry return stock_entry
def delete_stock(self, db: Session, stock_id: int) -> bool: def delete_stock(self, db: Session, stock_id: int) -> bool:

View File

@@ -1,6 +1,7 @@
# app/tasks/background_tasks.py # app/tasks/background_tasks.py
import logging import logging
from datetime import datetime from datetime import datetime
from app.core.database import SessionLocal from app.core.database import SessionLocal
from models.database_models import MarketplaceImportJob from models.database_models import MarketplaceImportJob
from utils.csv_processor import CSVProcessor from utils.csv_processor import CSVProcessor
@@ -9,11 +10,7 @@ logger = logging.getLogger(__name__)
async def process_marketplace_import( async def process_marketplace_import(
job_id: int, job_id: int, url: str, marketplace: str, shop_name: str, batch_size: int = 1000
url: str,
marketplace: str,
shop_name: str,
batch_size: int = 1000
): ):
"""Background task to process marketplace CSV import""" """Background task to process marketplace CSV import"""
db = SessionLocal() db = SessionLocal()
@@ -22,7 +19,11 @@ async def process_marketplace_import(
try: try:
# Update job status # Update job status
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job: if not job:
logger.error(f"Import job {job_id} not found") logger.error(f"Import job {job_id} not found")
return return
@@ -70,7 +71,7 @@ async def process_marketplace_import(
finally: finally:
# Close the database session only if it's not a mock # Close the database session only if it's not a mock
# In tests, we use the same session so we shouldn't close it # In tests, we use the same session so we shouldn't close it
if hasattr(db, 'close') and callable(getattr(db, 'close')): if hasattr(db, "close") and callable(getattr(db, "close")):
try: try:
db.close() db.close()
except Exception as close_error: except Exception as close_error:

View File

@@ -6,21 +6,21 @@ import requests
# API Base URL # API Base URL
BASE_URL = "http://localhost:8000" BASE_URL = "http://localhost:8000"
def register_user(email, username, password): def register_user(email, username, password):
"""Register a new user""" """Register a new user"""
response = requests.post(f"{BASE_URL}/register", json={ response = requests.post(
"email": email, f"{BASE_URL}/register",
"username": username, json={"email": email, "username": username, "password": password},
"password": password )
})
return response.json() return response.json()
def login_user(username, password): def login_user(username, password):
"""Login and get JWT token""" """Login and get JWT token"""
response = requests.post(f"{BASE_URL}/login", json={ response = requests.post(
"username": username, f"{BASE_URL}/login", json={"username": username, "password": password}
"password": password )
})
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
return data["access_token"] return data["access_token"]
@@ -28,24 +28,30 @@ def login_user(username, password):
print(f"Login failed: {response.json()}") print(f"Login failed: {response.json()}")
return None return None
def get_user_info(token): def get_user_info(token):
"""Get current user info""" """Get current user info"""
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get(f"{BASE_URL}/me", headers=headers) response = requests.get(f"{BASE_URL}/me", headers=headers)
return response.json() return response.json()
def get_products(token, skip=0, limit=10): def get_products(token, skip=0, limit=10):
"""Get products (requires authentication)""" """Get products (requires authentication)"""
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get(f"{BASE_URL}/products?skip={skip}&limit={limit}", headers=headers) response = requests.get(
f"{BASE_URL}/products?skip={skip}&limit={limit}", headers=headers
)
return response.json() return response.json()
def create_product(token, product_data): def create_product(token, product_data):
"""Create a new product (requires authentication)""" """Create a new product (requires authentication)"""
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.post(f"{BASE_URL}/products", json=product_data, headers=headers) response = requests.post(f"{BASE_URL}/products", json=product_data, headers=headers)
return response.json() return response.json()
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
# 1. Register a new user # 1. Register a new user
@@ -55,18 +61,18 @@ if __name__ == "__main__":
print(f"User registered: {user_result}") print(f"User registered: {user_result}")
except Exception as e: except Exception as e:
print(f"Registration failed: {e}") print(f"Registration failed: {e}")
# 2. Login with default admin user # 2. Login with default admin user
print("\n2. Logging in as admin...") print("\n2. Logging in as admin...")
admin_token = login_user("admin", "admin123") admin_token = login_user("admin", "admin123")
if admin_token: if admin_token:
print(f"Admin login successful! Token: {admin_token[:50]}...") print(f"Admin login successful! Token: {admin_token[:50]}...")
# 3. Get user info # 3. Get user info
print("\n3. Getting admin user info...") print("\n3. Getting admin user info...")
user_info = get_user_info(admin_token) user_info = get_user_info(admin_token)
print(f"User info: {user_info}") print(f"User info: {user_info}")
# 4. Create a sample product # 4. Create a sample product
print("\n4. Creating a sample product...") print("\n4. Creating a sample product...")
sample_product = { sample_product = {
@@ -75,30 +81,32 @@ if __name__ == "__main__":
"description": "A test product for demonstration", "description": "A test product for demonstration",
"price": "19.99", "price": "19.99",
"brand": "Test Brand", "brand": "Test Brand",
"availability": "in stock" "availability": "in stock",
} }
product_result = create_product(admin_token, sample_product) product_result = create_product(admin_token, sample_product)
print(f"Product created: {product_result}") print(f"Product created: {product_result}")
# 5. Get products list # 5. Get products list
print("\n5. Getting products list...") print("\n5. Getting products list...")
products = get_products(admin_token) products = get_products(admin_token)
print(f"Products: {products}") print(f"Products: {products}")
# 6. Login with regular user # 6. Login with regular user
print("\n6. Logging in as regular user...") print("\n6. Logging in as regular user...")
user_token = login_user("testuser", "password123") user_token = login_user("testuser", "password123")
if user_token: if user_token:
print(f"User login successful! Token: {user_token[:50]}...") print(f"User login successful! Token: {user_token[:50]}...")
# Regular users can also access protected endpoints # Regular users can also access protected endpoints
user_info = get_user_info(user_token) user_info = get_user_info(user_token)
print(f"Regular user info: {user_info}") print(f"Regular user info: {user_info}")
products = get_products(user_token, limit=5) products = get_products(user_token, limit=5)
print(f"Products accessible to regular user: {len(products.get('products', []))} products") print(
f"Products accessible to regular user: {len(products.get('products', []))} products"
)
print("\nAuthentication example completed!") print("\nAuthentication example completed!")
# Example cURL commands: # Example cURL commands:
@@ -126,4 +134,4 @@ curl -X POST "http://localhost:8000/products" \
-H "Authorization: Bearer YOUR_JWT_TOKEN_HERE" \ -H "Authorization: Bearer YOUR_JWT_TOKEN_HERE" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"product_id": "TEST001", "title": "Test Product", "price": "19.99"}' -d '{"product_id": "TEST001", "title": "Test Product", "price": "19.99"}'
""" """

35
main.py
View File

@@ -1,13 +1,15 @@
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.core.config import settings
from app.core.lifespan import lifespan
from app.core.database import get_db
from datetime import datetime
from app.api.main import api_router
from fastapi.middleware.cors import CORSMiddleware
import logging import logging
from datetime import datetime
from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.api.main import api_router
from app.core.config import settings
from app.core.database import get_db
from app.core.lifespan import lifespan
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,7 +18,7 @@ app = FastAPI(
title=settings.project_name, title=settings.project_name,
description=settings.description, description=settings.description,
version=settings.version, version=settings.version,
lifespan=lifespan lifespan=lifespan,
) )
# Add CORS middleware # Add CORS middleware
@@ -44,10 +46,17 @@ def root():
"JWT Authentication", "JWT Authentication",
"Marketplace-aware product import", "Marketplace-aware product import",
"Multi-shop product management", "Multi-shop product management",
"Stock management with location tracking" "Stock management with location tracking",
], ],
"supported_marketplaces": ["Letzshop", "Amazon", "eBay", "Etsy", "Shopify", "Other"], "supported_marketplaces": [
"auth_required": "Most endpoints require Bearer token authentication" "Letzshop",
"Amazon",
"eBay",
"Etsy",
"Shopify",
"Other",
],
"auth_required": "Most endpoints require Bearer token authentication",
} }

View File

@@ -1,14 +1,16 @@
# middleware/auth.py # middleware/auth.py
import logging
import os
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials from fastapi.security import HTTPAuthorizationCredentials
from passlib.context import CryptContext
from jose import jwt from jose import jwt
from datetime import datetime, timedelta from passlib.context import CryptContext
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from models.database_models import User from models.database_models import User
import os
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,7 +22,9 @@ class AuthManager:
"""JWT-based authentication manager with bcrypt password hashing""" """JWT-based authentication manager with bcrypt password hashing"""
def __init__(self): def __init__(self):
self.secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production-please") self.secret_key = os.getenv(
"JWT_SECRET_KEY", "your-secret-key-change-in-production-please"
)
self.algorithm = "HS256" self.algorithm = "HS256"
self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30")) self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30"))
@@ -32,11 +36,15 @@ class AuthManager:
"""Verify password against hash""" """Verify password against hash"""
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: def authenticate_user(
self, db: Session, username: str, password: str
) -> Optional[User]:
"""Authenticate user and return user object if valid""" """Authenticate user and return user object if valid"""
user = db.query(User).filter( user = (
(User.username == username) | (User.email == username) db.query(User)
).first() .filter((User.username == username) | (User.email == username))
.first()
)
if not user: if not user:
return None return None
@@ -65,7 +73,7 @@ class AuthManager:
"email": user.email, "email": user.email,
"role": user.role, "role": user.role,
"exp": expire, "exp": expire,
"iat": datetime.utcnow() "iat": datetime.utcnow(),
} }
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
@@ -73,7 +81,7 @@ class AuthManager:
return { return {
"access_token": token, "access_token": token,
"token_type": "bearer", "token_type": "bearer",
"expires_in": self.token_expire_minutes * 60 # Return in seconds "expires_in": self.token_expire_minutes * 60, # Return in seconds
} }
def verify_token(self, token: str) -> Dict[str, Any]: def verify_token(self, token: str) -> Dict[str, Any]:
@@ -92,25 +100,31 @@ class AuthManager:
# Extract user data # Extract user data
user_id = payload.get("sub") user_id = payload.get("sub")
if user_id is None: if user_id is None:
raise HTTPException(status_code=401, detail="Token missing user identifier") raise HTTPException(
status_code=401, detail="Token missing user identifier"
)
return { return {
"user_id": int(user_id), "user_id": int(user_id),
"username": payload.get("username"), "username": payload.get("username"),
"email": payload.get("email"), "email": payload.get("email"),
"role": payload.get("role", "user") "role": payload.get("role", "user"),
} }
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired") raise HTTPException(status_code=401, detail="Token has expired")
except jwt.JWTError as e: except jwt.JWTError as e:
logger.error(f"JWT decode error: {e}") logger.error(f"JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Could not validate credentials") raise HTTPException(
status_code=401, detail="Could not validate credentials"
)
except Exception as e: except Exception as e:
logger.error(f"Token verification error: {e}") logger.error(f"Token verification error: {e}")
raise HTTPException(status_code=401, detail="Authentication failed") raise HTTPException(status_code=401, detail="Authentication failed")
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User: def get_current_user(
self, db: Session, credentials: HTTPAuthorizationCredentials
) -> User:
"""Get current authenticated user from database""" """Get current authenticated user from database"""
user_data = self.verify_token(credentials.credentials) user_data = self.verify_token(credentials.credentials)
@@ -131,7 +145,7 @@ class AuthManager:
if current_user.role != required_role: if current_user.role != required_role:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'" detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'",
) )
return func(current_user, *args, **kwargs) return func(current_user, *args, **kwargs)
@@ -142,10 +156,7 @@ class AuthManager:
def require_admin(self, current_user: User): def require_admin(self, current_user: User):
"""Require admin role""" """Require admin role"""
if current_user.role != "admin": if current_user.role != "admin":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin privileges required")
status_code=403,
detail="Admin privileges required"
)
return current_user return current_user
def create_default_admin_user(self, db: Session): def create_default_admin_user(self, db: Session):
@@ -159,11 +170,13 @@ class AuthManager:
username="admin", username="admin",
hashed_password=hashed_password, hashed_password=hashed_password,
role="admin", role="admin",
is_active=True is_active=True,
) )
db.add(admin_user) db.add(admin_user)
db.commit() db.commit()
db.refresh(admin_user) db.refresh(admin_user)
logger.info("Default admin user created: username='admin', password='admin123'") logger.info(
"Default admin user created: username='admin', password='admin123'"
)
return admin_user return admin_user

View File

@@ -1,6 +1,8 @@
# middleware/decorators.py # middleware/decorators.py
from functools import wraps from functools import wraps
from fastapi import HTTPException from fastapi import HTTPException
from middleware.rate_limiter import RateLimiter from middleware.rate_limiter import RateLimiter
# Initialize rate limiter instance # Initialize rate limiter instance
@@ -17,10 +19,7 @@ def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
client_id = "anonymous" # In production, extract from request client_id = "anonymous" # In production, extract from request
if not rate_limiter.allow_request(client_id, max_requests, window_seconds): if not rate_limiter.allow_request(client_id, max_requests, window_seconds):
raise HTTPException( raise HTTPException(status_code=429, detail="Rate limit exceeded")
status_code=429,
detail="Rate limit exceeded"
)
return await func(*args, **kwargs) return await func(*args, **kwargs)

View File

@@ -1,31 +1,35 @@
# middleware/error_handler.py # middleware/error_handler.py
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
import logging import logging
from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def custom_http_exception_handler(request: Request, exc: HTTPException): async def custom_http_exception_handler(request: Request, exc: HTTPException):
"""Custom HTTP exception handler""" """Custom HTTP exception handler"""
logger.error(f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}") logger.error(
f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}"
)
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content={ content={
"error": { "error": {
"code": exc.status_code, "code": exc.status_code,
"message": exc.detail, "message": exc.detail,
"type": "http_exception" "type": "http_exception",
} }
} },
) )
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle Pydantic validation errors""" """Handle Pydantic validation errors"""
logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}") logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}")
return JSONResponse( return JSONResponse(
status_code=422, status_code=422,
content={ content={
@@ -33,23 +37,25 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
"code": 422, "code": 422,
"message": "Validation error", "message": "Validation error",
"type": "validation_error", "type": "validation_error",
"details": exc.errors() "details": exc.errors(),
} }
} },
) )
async def general_exception_handler(request: Request, exc: Exception): async def general_exception_handler(request: Request, exc: Exception):
"""Handle unexpected exceptions""" """Handle unexpected exceptions"""
logger.error(f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True) logger.error(
f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True
)
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content={ content={
"error": { "error": {
"code": 500, "code": 500,
"message": "Internal server error", "message": "Internal server error",
"type": "server_error" "type": "server_error",
} }
} },
) )

View File

@@ -1,9 +1,10 @@
# middleware/logging_middleware.py # middleware/logging_middleware.py
import logging import logging
import time import time
from typing import Callable
from fastapi import Request, Response from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,4 +44,4 @@ class LoggingMiddleware(BaseHTTPMiddleware):
f"Error: {str(e)} for {request.method} {request.url.path} " f"Error: {str(e)} for {request.method} {request.url.path} "
f"({duration:.3f}s)" f"({duration:.3f}s)"
) )
raise raise

View File

@@ -1,8 +1,8 @@
# middleware/rate_limiter.py # middleware/rate_limiter.py
from typing import Dict
from datetime import datetime, timedelta
import logging import logging
from collections import defaultdict, deque from collections import defaultdict, deque
from datetime import datetime, timedelta
from typing import Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,7 +16,9 @@ class RateLimiter:
self.cleanup_interval = 3600 # Clean up old entries every hour self.cleanup_interval = 3600 # Clean up old entries every hour
self.last_cleanup = datetime.utcnow() self.last_cleanup = datetime.utcnow()
def allow_request(self, client_id: str, max_requests: int, window_seconds: int) -> bool: def allow_request(
self, client_id: str, max_requests: int, window_seconds: int
) -> bool:
""" """
Check if client is allowed to make a request Check if client is allowed to make a request
Uses sliding window algorithm Uses sliding window algorithm
@@ -41,7 +43,9 @@ class RateLimiter:
client_requests.append(now) client_requests.append(now)
return True return True
logger.warning(f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}") logger.warning(
f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}"
)
return False return False
def _cleanup_old_entries(self): def _cleanup_old_entries(self):
@@ -62,7 +66,9 @@ class RateLimiter:
for client_id in clients_to_remove: for client_id in clients_to_remove:
del self.clients[client_id] del self.clients[client_id]
logger.info(f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients") logger.info(
f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients"
)
def get_client_stats(self, client_id: str) -> Dict[str, int]: def get_client_stats(self, client_id: str) -> Dict[str, int]:
"""Get statistics for a specific client""" """Get statistics for a specific client"""
@@ -72,11 +78,13 @@ class RateLimiter:
hour_ago = now - timedelta(hours=1) hour_ago = now - timedelta(hours=1)
day_ago = now - timedelta(days=1) day_ago = now - timedelta(days=1)
requests_last_hour = sum(1 for req_time in client_requests if req_time > hour_ago) requests_last_hour = sum(
1 for req_time in client_requests if req_time > hour_ago
)
requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago) requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago)
return { return {
"requests_last_hour": requests_last_hour, "requests_last_hour": requests_last_hour,
"requests_last_day": requests_last_day, "requests_last_day": requests_last_day,
"total_tracked_requests": len(client_requests) "total_tracked_requests": len(client_requests),
} }

View File

@@ -1,28 +1,35 @@
# models/api_models.py - Updated with Marketplace Support and Pydantic v2 # models/api_models.py - Updated with Marketplace Support and Pydantic v2
from pydantic import BaseModel, Field, field_validator, EmailStr, ConfigDict
from typing import Optional, List
from datetime import datetime
import re import re
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
# User Authentication Models # User Authentication Models
class UserRegister(BaseModel): class UserRegister(BaseModel):
email: EmailStr = Field(..., description="Valid email address") email: EmailStr = Field(..., description="Valid email address")
username: str = Field(..., min_length=3, max_length=50, description="Username (3-50 characters)") username: str = Field(
password: str = Field(..., min_length=6, description="Password (minimum 6 characters)") ..., min_length=3, max_length=50, description="Username (3-50 characters)"
)
password: str = Field(
..., min_length=6, description="Password (minimum 6 characters)"
)
@field_validator('username') @field_validator("username")
@classmethod @classmethod
def validate_username(cls, v): def validate_username(cls, v):
if not re.match(r'^[a-zA-Z0-9_]+$', v): if not re.match(r"^[a-zA-Z0-9_]+$", v):
raise ValueError('Username must contain only letters, numbers, or underscores') raise ValueError(
"Username must contain only letters, numbers, or underscores"
)
return v.lower().strip() return v.lower().strip()
@field_validator('password') @field_validator("password")
@classmethod @classmethod
def validate_password(cls, v): def validate_password(cls, v):
if len(v) < 6: if len(v) < 6:
raise ValueError('Password must be at least 6 characters long') raise ValueError("Password must be at least 6 characters long")
return v return v
@@ -30,7 +37,7 @@ class UserLogin(BaseModel):
username: str = Field(..., description="Username") username: str = Field(..., description="Username")
password: str = Field(..., description="Password") password: str = Field(..., description="Password")
@field_validator('username') @field_validator("username")
@classmethod @classmethod
def validate_username(cls, v): def validate_username(cls, v):
return v.strip() return v.strip()
@@ -58,27 +65,38 @@ class LoginResponse(BaseModel):
# NEW: Shop models # NEW: Shop models
class ShopCreate(BaseModel): class ShopCreate(BaseModel):
shop_code: str = Field(..., min_length=3, max_length=50, description="Unique shop code (e.g., TECHSTORE)") shop_code: str = Field(
shop_name: str = Field(..., min_length=1, max_length=200, description="Display name of the shop") ...,
description: Optional[str] = Field(None, max_length=2000, description="Shop description") min_length=3,
max_length=50,
description="Unique shop code (e.g., TECHSTORE)",
)
shop_name: str = Field(
..., min_length=1, max_length=200, description="Display name of the shop"
)
description: Optional[str] = Field(
None, max_length=2000, description="Shop description"
)
contact_email: Optional[str] = None contact_email: Optional[str] = None
contact_phone: Optional[str] = None contact_phone: Optional[str] = None
website: Optional[str] = None website: Optional[str] = None
business_address: Optional[str] = None business_address: Optional[str] = None
tax_number: Optional[str] = None tax_number: Optional[str] = None
@field_validator('shop_code') @field_validator("shop_code")
def validate_shop_code(cls, v): def validate_shop_code(cls, v):
# Convert to uppercase and check format # Convert to uppercase and check format
v = v.upper().strip() v = v.upper().strip()
if not v.replace('_', '').replace('-', '').isalnum(): if not v.replace("_", "").replace("-", "").isalnum():
raise ValueError('Shop code must be alphanumeric (underscores and hyphens allowed)') raise ValueError(
"Shop code must be alphanumeric (underscores and hyphens allowed)"
)
return v return v
@field_validator('contact_email') @field_validator("contact_email")
def validate_contact_email(cls, v): def validate_contact_email(cls, v):
if v and ('@' not in v or '.' not in v): if v and ("@" not in v or "." not in v):
raise ValueError('Invalid email format') raise ValueError("Invalid email format")
return v.lower() if v else v return v.lower() if v else v
@@ -91,10 +109,10 @@ class ShopUpdate(BaseModel):
business_address: Optional[str] = None business_address: Optional[str] = None
tax_number: Optional[str] = None tax_number: Optional[str] = None
@field_validator('contact_email') @field_validator("contact_email")
def validate_contact_email(cls, v): def validate_contact_email(cls, v):
if v and ('@' not in v or '.' not in v): if v and ("@" not in v or "." not in v):
raise ValueError('Invalid email format') raise ValueError("Invalid email format")
return v.lower() if v else v return v.lower() if v else v
@@ -172,11 +190,11 @@ class ProductCreate(ProductBase):
product_id: str = Field(..., min_length=1, description="Product ID is required") product_id: str = Field(..., min_length=1, description="Product ID is required")
title: str = Field(..., min_length=1, description="Title is required") title: str = Field(..., min_length=1, description="Title is required")
@field_validator('product_id', 'title') @field_validator("product_id", "title")
@classmethod @classmethod
def validate_required_fields(cls, v): def validate_required_fields(cls, v):
if not v or not v.strip(): if not v or not v.strip():
raise ValueError('Field cannot be empty') raise ValueError("Field cannot be empty")
return v.strip() return v.strip()
@@ -195,15 +213,25 @@ class ProductResponse(ProductBase):
# NEW: Shop Product models # NEW: Shop Product models
class ShopProductCreate(BaseModel): class ShopProductCreate(BaseModel):
product_id: str = Field(..., description="Product ID to add to shop") product_id: str = Field(..., description="Product ID to add to shop")
shop_product_id: Optional[str] = Field(None, description="Shop's internal product ID") shop_product_id: Optional[str] = Field(
shop_price: Optional[float] = Field(None, ge=0, description="Shop-specific price override") None, description="Shop's internal product ID"
shop_sale_price: Optional[float] = Field(None, ge=0, description="Shop-specific sale price") )
shop_price: Optional[float] = Field(
None, ge=0, description="Shop-specific price override"
)
shop_sale_price: Optional[float] = Field(
None, ge=0, description="Shop-specific sale price"
)
shop_currency: Optional[str] = Field(None, description="Shop-specific currency") shop_currency: Optional[str] = Field(None, description="Shop-specific currency")
shop_availability: Optional[str] = Field(None, description="Shop-specific availability") shop_availability: Optional[str] = Field(
None, description="Shop-specific availability"
)
shop_condition: Optional[str] = Field(None, description="Shop-specific condition") shop_condition: Optional[str] = Field(None, description="Shop-specific condition")
is_featured: bool = Field(False, description="Featured product flag") is_featured: bool = Field(False, description="Featured product flag")
min_quantity: int = Field(1, ge=1, description="Minimum order quantity") min_quantity: int = Field(1, ge=1, description="Minimum order quantity")
max_quantity: Optional[int] = Field(None, ge=1, description="Maximum order quantity") max_quantity: Optional[int] = Field(
None, ge=1, description="Maximum order quantity"
)
class ShopProductResponse(BaseModel): class ShopProductResponse(BaseModel):
@@ -270,28 +298,40 @@ class StockSummaryResponse(BaseModel):
# Marketplace Import Models # Marketplace Import Models
class MarketplaceImportRequest(BaseModel): class MarketplaceImportRequest(BaseModel):
url: str = Field(..., description="URL to CSV file from marketplace") url: str = Field(..., description="URL to CSV file from marketplace")
marketplace: str = Field(default="Letzshop", description="Name of the marketplace (e.g., Letzshop, Amazon, eBay)") marketplace: str = Field(
default="Letzshop",
description="Name of the marketplace (e.g., Letzshop, Amazon, eBay)",
)
shop_code: str = Field(..., description="Shop code to associate products with") shop_code: str = Field(..., description="Shop code to associate products with")
batch_size: Optional[int] = Field(1000, gt=0, le=10000, description="Batch size for processing") batch_size: Optional[int] = Field(
1000, gt=0, le=10000, description="Batch size for processing"
)
@field_validator('url') @field_validator("url")
@classmethod @classmethod
def validate_url(cls, v): def validate_url(cls, v):
if not v.startswith(('http://', 'https://')): if not v.startswith(("http://", "https://")):
raise ValueError('URL must start with http:// or https://') raise ValueError("URL must start with http:// or https://")
return v return v
@field_validator('marketplace') @field_validator("marketplace")
@classmethod @classmethod
def validate_marketplace(cls, v): def validate_marketplace(cls, v):
# You can add validation for supported marketplaces here # You can add validation for supported marketplaces here
supported_marketplaces = ['Letzshop', 'Amazon', 'eBay', 'Etsy', 'Shopify', 'Other'] supported_marketplaces = [
"Letzshop",
"Amazon",
"eBay",
"Etsy",
"Shopify",
"Other",
]
if v not in supported_marketplaces: if v not in supported_marketplaces:
# For now, allow any marketplace but log it # For now, allow any marketplace but log it
pass pass
return v.strip() return v.strip()
@field_validator('shop_code') @field_validator("shop_code")
@classmethod @classmethod
def validate_shop_code(cls, v): def validate_shop_code(cls, v):
return v.upper().strip() return v.upper().strip()

View File

@@ -1,8 +1,10 @@
# models/database_models.py - Updated with Marketplace Support # models/database_models.py - Updated with Marketplace Support
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text, ForeignKey, UniqueConstraint, Index
from sqlalchemy.orm import relationship
from datetime import datetime from datetime import datetime
from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Index,
Integer, String, Text, UniqueConstraint)
from sqlalchemy.orm import relationship
# Import Base from the central database module instead of creating a new one # Import Base from the central database module instead of creating a new one
from app.core.database import Base from app.core.database import Base
@@ -18,10 +20,14 @@ class User(Base):
is_active = Column(Boolean, default=True, nullable=False) is_active = Column(Boolean, default=True, nullable=False)
last_login = Column(DateTime, nullable=True) last_login = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) updated_at = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
)
# Relationships # Relationships
marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="user") marketplace_import_jobs = relationship(
"MarketplaceImportJob", back_populates="user"
)
owned_shops = relationship("Shop", back_populates="owner") owned_shops = relationship("Shop", back_populates="owner")
def __repr__(self): def __repr__(self):
@@ -32,7 +38,9 @@ class Shop(Base):
__tablename__ = "shops" __tablename__ = "shops"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
shop_code = Column(String, unique=True, index=True, nullable=False) # e.g., "TECHSTORE", "FASHIONHUB" shop_code = Column(
String, unique=True, index=True, nullable=False
) # e.g., "TECHSTORE", "FASHIONHUB"
shop_name = Column(String, nullable=False) # Display name shop_name = Column(String, nullable=False) # Display name
description = Column(Text) description = Column(Text)
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False) owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
@@ -57,7 +65,9 @@ class Shop(Base):
# Relationships # Relationships
owner = relationship("User", back_populates="owned_shops") owner = relationship("User", back_populates="owned_shops")
shop_products = relationship("ShopProduct", back_populates="shop") shop_products = relationship("ShopProduct", back_populates="shop")
marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="shop") marketplace_import_jobs = relationship(
"MarketplaceImportJob", back_populates="shop"
)
class Product(Base): class Product(Base):
@@ -103,26 +113,40 @@ class Product(Base):
currency = Column(String) currency = Column(String)
# New marketplace fields # New marketplace fields
marketplace = Column(String, index=True, nullable=True, default="Letzshop") # Index for marketplace filtering marketplace = Column(
String, index=True, nullable=True, default="Letzshop"
) # Index for marketplace filtering
shop_name = Column(String, index=True, nullable=True) # Index for shop filtering shop_name = Column(String, index=True, nullable=True) # Index for shop filtering
created_at = Column(DateTime, default=datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) updated_at = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
)
# Relationship to stock (one-to-many via GTIN) # Relationship to stock (one-to-many via GTIN)
stock_entries = relationship("Stock", foreign_keys="Stock.gtin", primaryjoin="Product.gtin == Stock.gtin", stock_entries = relationship(
viewonly=True) "Stock",
foreign_keys="Stock.gtin",
primaryjoin="Product.gtin == Stock.gtin",
viewonly=True,
)
shop_products = relationship("ShopProduct", back_populates="product") shop_products = relationship("ShopProduct", back_populates="product")
# Additional indexes for marketplace queries # Additional indexes for marketplace queries
__table_args__ = ( __table_args__ = (
Index('idx_marketplace_shop', 'marketplace', 'shop_name'), # Composite index for marketplace+shop queries Index(
Index('idx_marketplace_brand', 'marketplace', 'brand'), # Composite index for marketplace+brand queries "idx_marketplace_shop", "marketplace", "shop_name"
), # Composite index for marketplace+shop queries
Index(
"idx_marketplace_brand", "marketplace", "brand"
), # Composite index for marketplace+brand queries
) )
def __repr__(self): def __repr__(self):
return (f"<Product(product_id='{self.product_id}', title='{self.title}', marketplace='{self.marketplace}', " return (
f"shop='{self.shop_name}')>") f"<Product(product_id='{self.product_id}', title='{self.title}', marketplace='{self.marketplace}', "
f"shop='{self.shop_name}')>"
)
class ShopProduct(Base): class ShopProduct(Base):
@@ -159,9 +183,9 @@ class ShopProduct(Base):
# Constraints # Constraints
__table_args__ = ( __table_args__ = (
UniqueConstraint('shop_id', 'product_id', name='uq_shop_product'), UniqueConstraint("shop_id", "product_id", name="uq_shop_product"),
Index('idx_shop_product_active', 'shop_id', 'is_active'), Index("idx_shop_product_active", "shop_id", "is_active"),
Index('idx_shop_product_featured', 'shop_id', 'is_featured'), Index("idx_shop_product_featured", "shop_id", "is_featured"),
) )
@@ -169,22 +193,28 @@ class Stock(Base):
__tablename__ = "stock" __tablename__ = "stock"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
gtin = Column(String, index=True, nullable=False) # Foreign key relationship would be ideal gtin = Column(
String, index=True, nullable=False
) # Foreign key relationship would be ideal
location = Column(String, nullable=False, index=True) location = Column(String, nullable=False, index=True)
quantity = Column(Integer, nullable=False, default=0) quantity = Column(Integer, nullable=False, default=0)
reserved_quantity = Column(Integer, default=0) # For orders being processed reserved_quantity = Column(Integer, default=0) # For orders being processed
shop_id = Column(Integer, ForeignKey("shops.id")) # Optional: shop-specific stock shop_id = Column(Integer, ForeignKey("shops.id")) # Optional: shop-specific stock
created_at = Column(DateTime, default=datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) updated_at = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
)
# Relationships # Relationships
shop = relationship("Shop") shop = relationship("Shop")
# Composite unique constraint to prevent duplicate GTIN-location combinations # Composite unique constraint to prevent duplicate GTIN-location combinations
__table_args__ = ( __table_args__ = (
UniqueConstraint('gtin', 'location', name='uq_stock_gtin_location'), UniqueConstraint("gtin", "location", name="uq_stock_gtin_location"),
Index('idx_stock_gtin_location', 'gtin', 'location'), # Composite index for efficient queries Index(
"idx_stock_gtin_location", "gtin", "location"
), # Composite index for efficient queries
) )
def __repr__(self): def __repr__(self):
@@ -195,13 +225,20 @@ class MarketplaceImportJob(Base):
__tablename__ = "marketplace_import_jobs" __tablename__ = "marketplace_import_jobs"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
status = Column(String, nullable=False, status = Column(
default="pending") # pending, processing, completed, failed, completed_with_errors String, nullable=False, default="pending"
) # pending, processing, completed, failed, completed_with_errors
source_url = Column(String, nullable=False) source_url = Column(String, nullable=False)
marketplace = Column(String, nullable=False, index=True, default="Letzshop") # Index for marketplace filtering marketplace = Column(
String, nullable=False, index=True, default="Letzshop"
) # Index for marketplace filtering
shop_name = Column(String, nullable=False, index=True) # Index for shop filtering shop_name = Column(String, nullable=False, index=True) # Index for shop filtering
shop_id = Column(Integer, ForeignKey("shops.id"), nullable=False) # Add proper foreign key shop_id = Column(
user_id = Column(Integer, ForeignKey('users.id'), nullable=False) # Foreign key to users table Integer, ForeignKey("shops.id"), nullable=False
) # Add proper foreign key
user_id = Column(
Integer, ForeignKey("users.id"), nullable=False
) # Foreign key to users table
# Results # Results
imported_count = Column(Integer, default=0) imported_count = Column(Integer, default=0)
@@ -223,11 +260,15 @@ class MarketplaceImportJob(Base):
# Additional indexes for marketplace import job queries # Additional indexes for marketplace import job queries
__table_args__ = ( __table_args__ = (
Index('idx_marketplace_import_user_marketplace', 'user_id', 'marketplace'), # User's marketplace imports Index(
Index('idx_marketplace_import_shop_status', 'status'), # Shop import status "idx_marketplace_import_user_marketplace", "user_id", "marketplace"
Index('idx_marketplace_import_shop_id', 'shop_id'), ), # User's marketplace imports
Index("idx_marketplace_import_shop_status", "status"), # Shop import status
Index("idx_marketplace_import_shop_id", "shop_id"),
) )
def __repr__(self): def __repr__(self):
return (f"<MarketplaceImportJob(id={self.id}, marketplace='{self.marketplace}', shop='{self.shop_name}', " return (
f"status='{self.status}', imported={self.imported_count})>") f"<MarketplaceImportJob(id={self.id}, marketplace='{self.marketplace}', shop='{self.shop_name}', "
f"status='{self.status}', imported={self.imported_count})>"
)

View File

@@ -3,8 +3,8 @@
"""Development environment setup script""" """Development environment setup script"""
import os import os
import sys
import subprocess import subprocess
import sys
from pathlib import Path from pathlib import Path
@@ -30,6 +30,7 @@ def setup_alembic():
if alembic_dir.exists(): if alembic_dir.exists():
# Remove incomplete alembic directory # Remove incomplete alembic directory
import shutil import shutil
shutil.rmtree(alembic_dir) shutil.rmtree(alembic_dir)
if not run_command("alembic init alembic", "Initializing Alembic"): if not run_command("alembic init alembic", "Initializing Alembic"):
@@ -138,16 +139,23 @@ LOG_LEVEL=INFO
# Set up Alembic # Set up Alembic
if not setup_alembic(): if not setup_alembic():
print("⚠️ Alembic setup failed. You'll need to set up database migrations manually.") print(
"⚠️ Alembic setup failed. You'll need to set up database migrations manually."
)
return False return False
# Create initial migration # Create initial migration
if not run_command("alembic revision --autogenerate -m \"Initial migration\"", "Creating initial migration"): if not run_command(
'alembic revision --autogenerate -m "Initial migration"',
"Creating initial migration",
):
print("⚠️ Initial migration creation failed. Check your database models.") print("⚠️ Initial migration creation failed. Check your database models.")
# Apply migrations # Apply migrations
if not run_command("alembic upgrade head", "Setting up database"): if not run_command("alembic upgrade head", "Setting up database"):
print("⚠️ Database setup failed. Make sure your DATABASE_URL is correct in .env") print(
"⚠️ Database setup failed. Make sure your DATABASE_URL is correct in .env"
)
# Run tests # Run tests
if not run_command("pytest", "Running tests"): if not run_command("pytest", "Running tests"):
@@ -157,7 +165,7 @@ LOG_LEVEL=INFO
print("To start the development server, run:") print("To start the development server, run:")
print(" uvicorn main:app --reload") print(" uvicorn main:app --reload")
print("\nDatabase commands:") print("\nDatabase commands:")
print(" alembic revision --autogenerate -m \"Description\" # Create migration") print(' alembic revision --autogenerate -m "Description" # Create migration')
print(" alembic upgrade head # Apply migrations") print(" alembic upgrade head # Apply migrations")
print(" alembic current # Check status") print(" alembic current # Check status")

View File

@@ -1,16 +1,18 @@
# tests/conftest.py # tests/conftest.py
import uuid
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from app.core.database import Base, get_db
from main import app from main import app
from app.core.database import get_db, Base
# Import all models to ensure they're registered with Base metadata
from models.database_models import User, Product, Stock, Shop, MarketplaceImportJob, ShopProduct
from middleware.auth import AuthManager from middleware.auth import AuthManager
import uuid # Import all models to ensure they're registered with Base metadata
from models.database_models import (MarketplaceImportJob, Product, Shop,
ShopProduct, Stock, User)
# Use in-memory SQLite database for tests # Use in-memory SQLite database for tests
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:" SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"
@@ -23,7 +25,7 @@ def engine():
SQLALCHEMY_TEST_DATABASE_URL, SQLALCHEMY_TEST_DATABASE_URL,
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
poolclass=StaticPool, poolclass=StaticPool,
echo=False # Set to True for SQL debugging echo=False, # Set to True for SQL debugging
) )
@@ -89,7 +91,7 @@ def test_user(db, auth_manager):
username=f"testuser_{unique_id}", username=f"testuser_{unique_id}",
hashed_password=hashed_password, hashed_password=hashed_password,
role="user", role="user",
is_active=True is_active=True,
) )
db.add(user) db.add(user)
db.commit() db.commit()
@@ -107,7 +109,7 @@ def test_admin(db, auth_manager):
username=f"admin_{unique_id}", username=f"admin_{unique_id}",
hashed_password=hashed_password, hashed_password=hashed_password,
role="admin", role="admin",
is_active=True is_active=True,
) )
db.add(admin) db.add(admin)
db.commit() db.commit()
@@ -118,10 +120,10 @@ def test_admin(db, auth_manager):
@pytest.fixture @pytest.fixture
def auth_headers(client, test_user): def auth_headers(client, test_user):
"""Get authentication headers for test user""" """Get authentication headers for test user"""
response = client.post("/api/v1/auth/login", json={ response = client.post(
"username": test_user.username, "/api/v1/auth/login",
"password": "testpass123" json={"username": test_user.username, "password": "testpass123"},
}) )
assert response.status_code == 200, f"Login failed: {response.text}" assert response.status_code == 200, f"Login failed: {response.text}"
token = response.json()["access_token"] token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"} return {"Authorization": f"Bearer {token}"}
@@ -130,10 +132,10 @@ def auth_headers(client, test_user):
@pytest.fixture @pytest.fixture
def admin_headers(client, test_admin): def admin_headers(client, test_admin):
"""Get authentication headers for admin user""" """Get authentication headers for admin user"""
response = client.post("/api/v1/auth/login", json={ response = client.post(
"username": test_admin.username, "/api/v1/auth/login",
"password": "adminpass123" json={"username": test_admin.username, "password": "adminpass123"},
}) )
assert response.status_code == 200, f"Admin login failed: {response.text}" assert response.status_code == 200, f"Admin login failed: {response.text}"
token = response.json()["access_token"] token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"} return {"Authorization": f"Bearer {token}"}
@@ -152,7 +154,7 @@ def test_product(db):
gtin="1234567890123", gtin="1234567890123",
availability="in stock", availability="in stock",
marketplace="Letzshop", marketplace="Letzshop",
shop_name="TestShop" shop_name="TestShop",
) )
db.add(product) db.add(product)
db.commit() db.commit()
@@ -169,7 +171,7 @@ def test_shop(db, test_user):
shop_name=f"Test Shop {unique_id}", shop_name=f"Test Shop {unique_id}",
owner_id=test_user.id, owner_id=test_user.id,
is_active=True, is_active=True,
is_verified=True is_verified=True,
) )
db.add(shop) db.add(shop)
db.commit() db.commit()
@@ -186,7 +188,7 @@ def test_stock(db, test_product, test_shop):
location=f"WAREHOUSE_A_{unique_id}", location=f"WAREHOUSE_A_{unique_id}",
quantity=10, quantity=10,
reserved_quantity=0, reserved_quantity=0,
shop_id=test_shop.id # Add shop_id reference shop_id=test_shop.id, # Add shop_id reference
) )
db.add(stock) db.add(stock)
db.commit() db.commit()
@@ -208,7 +210,7 @@ def test_marketplace_job(db, test_shop, test_user): # Add test_shop dependency
updated_count=3, updated_count=3,
total_processed=8, total_processed=8,
error_count=0, error_count=0,
error_message=None error_message=None,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -219,16 +221,16 @@ def test_marketplace_job(db, test_shop, test_user): # Add test_shop dependency
def create_test_import_job(db, shop_id, **kwargs): # Add shop_id parameter def create_test_import_job(db, shop_id, **kwargs): # Add shop_id parameter
"""Helper function to create MarketplaceImportJob with defaults""" """Helper function to create MarketplaceImportJob with defaults"""
defaults = { defaults = {
'marketplace': 'test', "marketplace": "test",
'shop_name': 'Test Shop', "shop_name": "Test Shop",
'status': 'pending', "status": "pending",
'source_url': 'https://test.example.com/import', "source_url": "https://test.example.com/import",
'shop_id': shop_id, # Add required shop_id "shop_id": shop_id, # Add required shop_id
'imported_count': 0, "imported_count": 0,
'updated_count': 0, "updated_count": 0,
'total_processed': 0, "total_processed": 0,
'error_count': 0, "error_count": 0,
'error_message': None "error_message": None,
} }
defaults.update(kwargs) defaults.update(kwargs)
@@ -250,6 +252,7 @@ def cleanup():
# Add these fixtures to your existing conftest.py # Add these fixtures to your existing conftest.py
@pytest.fixture @pytest.fixture
def unique_product(db): def unique_product(db):
"""Create a unique product for tests that need isolated product data""" """Create a unique product for tests that need isolated product data"""
@@ -265,7 +268,7 @@ def unique_product(db):
availability="in stock", availability="in stock",
marketplace="Letzshop", marketplace="Letzshop",
shop_name=f"UniqueShop_{unique_id}", shop_name=f"UniqueShop_{unique_id}",
google_product_category=f"UniqueCategory_{unique_id}" google_product_category=f"UniqueCategory_{unique_id}",
) )
db.add(product) db.add(product)
db.commit() db.commit()
@@ -283,7 +286,7 @@ def unique_shop(db, test_user):
description=f"A unique test shop {unique_id}", description=f"A unique test shop {unique_id}",
owner_id=test_user.id, owner_id=test_user.id,
is_active=True, is_active=True,
is_verified=True is_verified=True,
) )
db.add(shop) db.add(shop)
db.commit() db.commit()
@@ -301,7 +304,7 @@ def other_user(db, auth_manager):
username=f"otheruser_{unique_id}", username=f"otheruser_{unique_id}",
hashed_password=hashed_password, hashed_password=hashed_password,
role="user", role="user",
is_active=True is_active=True,
) )
db.add(user) db.add(user)
db.commit() db.commit()
@@ -318,7 +321,7 @@ def inactive_shop(db, other_user):
shop_name=f"Inactive Shop {unique_id}", shop_name=f"Inactive Shop {unique_id}",
owner_id=other_user.id, owner_id=other_user.id,
is_active=False, is_active=False,
is_verified=False is_verified=False,
) )
db.add(shop) db.add(shop)
db.commit() db.commit()
@@ -335,7 +338,7 @@ def verified_shop(db, other_user):
shop_name=f"Verified Shop {unique_id}", shop_name=f"Verified Shop {unique_id}",
owner_id=other_user.id, owner_id=other_user.id,
is_active=True, is_active=True,
is_verified=True is_verified=True,
) )
db.add(shop) db.add(shop)
db.commit() db.commit()
@@ -347,16 +350,14 @@ def verified_shop(db, other_user):
def shop_product(db, test_shop, unique_product): def shop_product(db, test_shop, unique_product):
"""Create a shop product relationship""" """Create a shop product relationship"""
shop_product = ShopProduct( shop_product = ShopProduct(
shop_id=test_shop.id, shop_id=test_shop.id, product_id=unique_product.id, is_active=True
product_id=unique_product.id,
is_active=True
) )
# Add optional fields if they exist in your model # Add optional fields if they exist in your model
if hasattr(ShopProduct, 'price'): if hasattr(ShopProduct, "price"):
shop_product.price = "24.99" shop_product.price = "24.99"
if hasattr(ShopProduct, 'is_featured'): if hasattr(ShopProduct, "is_featured"):
shop_product.is_featured = False shop_product.is_featured = False
if hasattr(ShopProduct, 'stock_quantity'): if hasattr(ShopProduct, "stock_quantity"):
shop_product.stock_quantity = 10 shop_product.stock_quantity = 10
db.add(shop_product) db.add(shop_product)
@@ -382,7 +383,7 @@ def multiple_products(db):
marketplace=f"MultiMarket_{i % 2}", # Create 2 different marketplaces marketplace=f"MultiMarket_{i % 2}", # Create 2 different marketplaces
shop_name=f"MultiShop_{i}", shop_name=f"MultiShop_{i}",
google_product_category=f"MultiCategory_{i % 2}", # Create 2 different categories google_product_category=f"MultiCategory_{i % 2}", # Create 2 different categories
gtin=f"1234567890{i}{unique_id[:2]}" gtin=f"1234567890{i}{unique_id[:2]}",
) )
products.append(product) products.append(product)
@@ -404,7 +405,7 @@ def multiple_stocks(db, multiple_products, test_shop):
location=f"LOC_{i}", location=f"LOC_{i}",
quantity=10 + (i * 5), # Different quantities quantity=10 + (i * 5), # Different quantities
reserved_quantity=i, reserved_quantity=i,
shop_id=test_shop.id shop_id=test_shop.id,
) )
stocks.append(stock) stocks.append(stock)
@@ -422,12 +423,12 @@ def create_unique_product_factory():
def _create_product(db, **kwargs): def _create_product(db, **kwargs):
unique_id = str(uuid.uuid4())[:8] unique_id = str(uuid.uuid4())[:8]
defaults = { defaults = {
'product_id': f"FACTORY_{unique_id}", "product_id": f"FACTORY_{unique_id}",
'title': f"Factory Product {unique_id}", "title": f"Factory Product {unique_id}",
'price': "15.99", "price": "15.99",
'currency': "EUR", "currency": "EUR",
'marketplace': "TestMarket", "marketplace": "TestMarket",
'shop_name': "TestShop" "shop_name": "TestShop",
} }
defaults.update(kwargs) defaults.update(kwargs)
@@ -452,11 +453,11 @@ def create_unique_shop_factory():
def _create_shop(db, owner_id, **kwargs): def _create_shop(db, owner_id, **kwargs):
unique_id = str(uuid.uuid4())[:8] unique_id = str(uuid.uuid4())[:8]
defaults = { defaults = {
'shop_code': f"FACTORY_{unique_id}", "shop_code": f"FACTORY_{unique_id}",
'shop_name': f"Factory Shop {unique_id}", "shop_name": f"Factory Shop {unique_id}",
'owner_id': owner_id, "owner_id": owner_id,
'is_active': True, "is_active": True,
'is_verified': False "is_verified": False,
} }
defaults.update(kwargs) defaults.update(kwargs)
@@ -473,4 +474,3 @@ def create_unique_shop_factory():
def shop_factory(): def shop_factory():
"""Fixture that provides a shop factory function""" """Fixture that provides a shop factory function"""
return create_unique_shop_factory() return create_unique_shop_factory()

View File

@@ -20,11 +20,16 @@ class TestAdminAPI:
response = client.get("/api/v1/admin/users", headers=auth_headers) response = client.get("/api/v1/admin/users", headers=auth_headers)
assert response.status_code == 403 assert response.status_code == 403
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() assert (
"Access denied" in response.json()["detail"]
or "admin" in response.json()["detail"].lower()
)
def test_toggle_user_status_admin(self, client, admin_headers, test_user): def test_toggle_user_status_admin(self, client, admin_headers, test_user):
"""Test admin toggling user status""" """Test admin toggling user status"""
response = client.put(f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers) response = client.put(
f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
message = response.json()["message"] message = response.json()["message"]
@@ -39,9 +44,13 @@ class TestAdminAPI:
assert response.status_code == 404 assert response.status_code == 404
assert "User not found" in response.json()["detail"] assert "User not found" in response.json()["detail"]
def test_toggle_user_status_cannot_deactivate_self(self, client, admin_headers, test_admin): def test_toggle_user_status_cannot_deactivate_self(
self, client, admin_headers, test_admin
):
"""Test that admin cannot deactivate their own account""" """Test that admin cannot deactivate their own account"""
response = client.put(f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers) response = client.put(
f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers
)
assert response.status_code == 400 assert response.status_code == 400
assert "Cannot deactivate your own account" in response.json()["detail"] assert "Cannot deactivate your own account" in response.json()["detail"]
@@ -56,7 +65,9 @@ class TestAdminAPI:
assert len(data["shops"]) >= 1 assert len(data["shops"]) >= 1
# Check that test_shop is in the response # Check that test_shop is in the response
shop_codes = [shop["shop_code"] for shop in data["shops"] if "shop_code" in shop] shop_codes = [
shop["shop_code"] for shop in data["shops"] if "shop_code" in shop
]
assert test_shop.shop_code in shop_codes assert test_shop.shop_code in shop_codes
def test_get_all_shops_non_admin(self, client, auth_headers): def test_get_all_shops_non_admin(self, client, auth_headers):
@@ -64,11 +75,16 @@ class TestAdminAPI:
response = client.get("/api/v1/admin/shops", headers=auth_headers) response = client.get("/api/v1/admin/shops", headers=auth_headers)
assert response.status_code == 403 assert response.status_code == 403
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() assert (
"Access denied" in response.json()["detail"]
or "admin" in response.json()["detail"].lower()
)
def test_verify_shop_admin(self, client, admin_headers, test_shop): def test_verify_shop_admin(self, client, admin_headers, test_shop):
"""Test admin verifying/unverifying shop""" """Test admin verifying/unverifying shop"""
response = client.put(f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers) response = client.put(
f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
message = response.json()["message"] message = response.json()["message"]
@@ -84,7 +100,9 @@ class TestAdminAPI:
def test_toggle_shop_status_admin(self, client, admin_headers, test_shop): def test_toggle_shop_status_admin(self, client, admin_headers, test_shop):
"""Test admin toggling shop status""" """Test admin toggling shop status"""
response = client.put(f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers) response = client.put(
f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
message = response.json()["message"] message = response.json()["message"]
@@ -98,9 +116,13 @@ class TestAdminAPI:
assert response.status_code == 404 assert response.status_code == 404
assert "Shop not found" in response.json()["detail"] assert "Shop not found" in response.json()["detail"]
def test_get_marketplace_import_jobs_admin(self, client, admin_headers, test_marketplace_job): def test_get_marketplace_import_jobs_admin(
self, client, admin_headers, test_marketplace_job
):
"""Test admin getting marketplace import jobs""" """Test admin getting marketplace import jobs"""
response = client.get("/api/v1/admin/marketplace-import-jobs", headers=admin_headers) response = client.get(
"/api/v1/admin/marketplace-import-jobs", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -110,43 +132,58 @@ class TestAdminAPI:
job_ids = [job["job_id"] for job in data if "job_id" in job] job_ids = [job["job_id"] for job in data if "job_id" in job]
assert test_marketplace_job.id in job_ids assert test_marketplace_job.id in job_ids
def test_get_marketplace_import_jobs_with_filters(self, client, admin_headers, test_marketplace_job): def test_get_marketplace_import_jobs_with_filters(
self, client, admin_headers, test_marketplace_job
):
"""Test admin getting marketplace import jobs with filters""" """Test admin getting marketplace import jobs with filters"""
response = client.get( response = client.get(
"/api/v1/admin/marketplace-import-jobs", "/api/v1/admin/marketplace-import-jobs",
params={"marketplace": test_marketplace_job.marketplace}, params={"marketplace": test_marketplace_job.marketplace},
headers=admin_headers headers=admin_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) >= 1 assert len(data) >= 1
assert all(job["marketplace"] == test_marketplace_job.marketplace for job in data) assert all(
job["marketplace"] == test_marketplace_job.marketplace for job in data
)
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers): def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):
"""Test non-admin trying to access marketplace import jobs""" """Test non-admin trying to access marketplace import jobs"""
response = client.get("/api/v1/admin/marketplace-import-jobs", headers=auth_headers) response = client.get(
"/api/v1/admin/marketplace-import-jobs", headers=auth_headers
)
assert response.status_code == 403 assert response.status_code == 403
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower() assert (
"Access denied" in response.json()["detail"]
or "admin" in response.json()["detail"].lower()
)
def test_admin_pagination_users(self, client, admin_headers, test_user, test_admin): def test_admin_pagination_users(self, client, admin_headers, test_user, test_admin):
"""Test user pagination works correctly""" """Test user pagination works correctly"""
# Test first page # Test first page
response = client.get("/api/v1/admin/users?skip=0&limit=1", headers=admin_headers) response = client.get(
"/api/v1/admin/users?skip=0&limit=1", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
# Test second page # Test second page
response = client.get("/api/v1/admin/users?skip=1&limit=1", headers=admin_headers) response = client.get(
"/api/v1/admin/users?skip=1&limit=1", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) >= 0 # Could be 1 or 0 depending on total users assert len(data) >= 0 # Could be 1 or 0 depending on total users
def test_admin_pagination_shops(self, client, admin_headers, test_shop): def test_admin_pagination_shops(self, client, admin_headers, test_shop):
"""Test shop pagination works correctly""" """Test shop pagination works correctly"""
response = client.get("/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers) response = client.get(
"/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] >= 1 assert data["total"] >= 1

View File

@@ -1,10 +1,11 @@
# tests/test_admin_service.py # tests/test_admin_service.py
import pytest
from datetime import datetime from datetime import datetime
import pytest
from fastapi import HTTPException from fastapi import HTTPException
from app.services.admin_service import AdminService from app.services.admin_service import AdminService
from models.database_models import User, Shop, MarketplaceImportJob from models.database_models import MarketplaceImportJob, Shop, User
class TestAdminService: class TestAdminService:
@@ -91,7 +92,7 @@ class TestAdminService:
shop_name="Test Shop 2", shop_name="Test Shop 2",
owner_id=test_shop.owner_id, owner_id=test_shop.owner_id,
is_active=True, is_active=True,
is_verified=False is_verified=False,
) )
db.add(additional_shop) db.add(additional_shop)
db.commit() db.commit()
@@ -173,13 +174,17 @@ class TestAdminService:
assert len(result) >= 1 assert len(result) >= 1
# Find our test job in the results # Find our test job in the results
test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None) test_job = next(
(job for job in result if job.job_id == test_marketplace_job.id), None
)
assert test_job is not None assert test_job is not None
assert test_job.marketplace == test_marketplace_job.marketplace assert test_job.marketplace == test_marketplace_job.marketplace
assert test_job.shop_name == test_marketplace_job.shop_name assert test_job.shop_name == test_marketplace_job.shop_name
assert test_job.status == test_marketplace_job.status assert test_job.status == test_marketplace_job.status
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_job, test_user, test_shop): def test_get_marketplace_import_jobs_with_marketplace_filter(
self, db, test_marketplace_job, test_user, test_shop
):
"""Test getting marketplace import jobs filtered by marketplace""" """Test getting marketplace import jobs filtered by marketplace"""
# Create additional job with different marketplace # Create additional job with different marketplace
other_job = MarketplaceImportJob( other_job = MarketplaceImportJob(
@@ -188,20 +193,24 @@ class TestAdminService:
status="completed", status="completed",
source_url="https://ebay.example.com/import", source_url="https://ebay.example.com/import",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id # Fixed: Added missing user_id user_id=test_user.id, # Fixed: Added missing user_id
) )
db.add(other_job) db.add(other_job)
db.commit() db.commit()
# Filter by the test marketplace job's marketplace # Filter by the test marketplace job's marketplace
result = self.service.get_marketplace_import_jobs(db, marketplace=test_marketplace_job.marketplace) result = self.service.get_marketplace_import_jobs(
db, marketplace=test_marketplace_job.marketplace
)
assert len(result) >= 1 assert len(result) >= 1
# All results should match the marketplace filter # All results should match the marketplace filter
for job in result: for job in result:
assert test_marketplace_job.marketplace.lower() in job.marketplace.lower() assert test_marketplace_job.marketplace.lower() in job.marketplace.lower()
def test_get_marketplace_import_jobs_with_shop_filter(self, db, test_marketplace_job, test_user, test_shop): def test_get_marketplace_import_jobs_with_shop_filter(
self, db, test_marketplace_job, test_user, test_shop
):
"""Test getting marketplace import jobs filtered by shop name""" """Test getting marketplace import jobs filtered by shop name"""
# Create additional job with different shop name # Create additional job with different shop name
other_job = MarketplaceImportJob( other_job = MarketplaceImportJob(
@@ -210,20 +219,24 @@ class TestAdminService:
status="completed", status="completed",
source_url="https://different.example.com/import", source_url="https://different.example.com/import",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id # Fixed: Added missing user_id user_id=test_user.id, # Fixed: Added missing user_id
) )
db.add(other_job) db.add(other_job)
db.commit() db.commit()
# Filter by the test marketplace job's shop name # Filter by the test marketplace job's shop name
result = self.service.get_marketplace_import_jobs(db, shop_name=test_marketplace_job.shop_name) result = self.service.get_marketplace_import_jobs(
db, shop_name=test_marketplace_job.shop_name
)
assert len(result) >= 1 assert len(result) >= 1
# All results should match the shop name filter # All results should match the shop name filter
for job in result: for job in result:
assert test_marketplace_job.shop_name.lower() in job.shop_name.lower() assert test_marketplace_job.shop_name.lower() in job.shop_name.lower()
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_job, test_user, test_shop): def test_get_marketplace_import_jobs_with_status_filter(
self, db, test_marketplace_job, test_user, test_shop
):
"""Test getting marketplace import jobs filtered by status""" """Test getting marketplace import jobs filtered by status"""
# Create additional job with different status # Create additional job with different status
other_job = MarketplaceImportJob( other_job = MarketplaceImportJob(
@@ -232,20 +245,24 @@ class TestAdminService:
status="pending", status="pending",
source_url="https://pending.example.com/import", source_url="https://pending.example.com/import",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id # Fixed: Added missing user_id user_id=test_user.id, # Fixed: Added missing user_id
) )
db.add(other_job) db.add(other_job)
db.commit() db.commit()
# Filter by the test marketplace job's status # Filter by the test marketplace job's status
result = self.service.get_marketplace_import_jobs(db, status=test_marketplace_job.status) result = self.service.get_marketplace_import_jobs(
db, status=test_marketplace_job.status
)
assert len(result) >= 1 assert len(result) >= 1
# All results should match the status filter # All results should match the status filter
for job in result: for job in result:
assert job.status == test_marketplace_job.status assert job.status == test_marketplace_job.status
def test_get_marketplace_import_jobs_with_multiple_filters(self, db, test_marketplace_job, test_shop, test_user): def test_get_marketplace_import_jobs_with_multiple_filters(
self, db, test_marketplace_job, test_shop, test_user
):
"""Test getting marketplace import jobs with multiple filters""" """Test getting marketplace import jobs with multiple filters"""
# Create jobs that don't match all filters # Create jobs that don't match all filters
non_matching_job1 = MarketplaceImportJob( non_matching_job1 = MarketplaceImportJob(
@@ -254,7 +271,7 @@ class TestAdminService:
status=test_marketplace_job.status, status=test_marketplace_job.status,
source_url="https://non-matching1.example.com/import", source_url="https://non-matching1.example.com/import",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id # Fixed: Added missing user_id user_id=test_user.id, # Fixed: Added missing user_id
) )
non_matching_job2 = MarketplaceImportJob( non_matching_job2 = MarketplaceImportJob(
marketplace=test_marketplace_job.marketplace, marketplace=test_marketplace_job.marketplace,
@@ -262,7 +279,7 @@ class TestAdminService:
status=test_marketplace_job.status, status=test_marketplace_job.status,
source_url="https://non-matching2.example.com/import", source_url="https://non-matching2.example.com/import",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id # Fixed: Added missing user_id user_id=test_user.id, # Fixed: Added missing user_id
) )
db.add_all([non_matching_job1, non_matching_job2]) db.add_all([non_matching_job1, non_matching_job2])
db.commit() db.commit()
@@ -272,12 +289,14 @@ class TestAdminService:
db, db,
marketplace=test_marketplace_job.marketplace, marketplace=test_marketplace_job.marketplace,
shop_name=test_marketplace_job.shop_name, shop_name=test_marketplace_job.shop_name,
status=test_marketplace_job.status status=test_marketplace_job.status,
) )
assert len(result) >= 1 assert len(result) >= 1
# Find our test job in the results # Find our test job in the results
test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None) test_job = next(
(job for job in result if job.job_id == test_marketplace_job.id), None
)
assert test_job is not None assert test_job is not None
assert test_job.marketplace == test_marketplace_job.marketplace assert test_job.marketplace == test_marketplace_job.marketplace
assert test_job.shop_name == test_marketplace_job.shop_name assert test_job.shop_name == test_marketplace_job.shop_name
@@ -297,7 +316,7 @@ class TestAdminService:
updated_count=None, updated_count=None,
total_processed=None, total_processed=None,
error_count=None, error_count=None,
error_message=None error_message=None,
) )
db.add(job) db.add(job)
db.commit() db.commit()

View File

@@ -6,11 +6,14 @@ from fastapi import HTTPException
class TestAuthenticationAPI: class TestAuthenticationAPI:
def test_register_user_success(self, client, db): def test_register_user_success(self, client, db):
"""Test successful user registration""" """Test successful user registration"""
response = client.post("/api/v1/auth/register", json={ response = client.post(
"email": "newuser@example.com", "/api/v1/auth/register",
"username": "newuser", json={
"password": "securepass123" "email": "newuser@example.com",
}) "username": "newuser",
"password": "securepass123",
},
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -22,32 +25,38 @@ class TestAuthenticationAPI:
def test_register_user_duplicate_email(self, client, test_user): def test_register_user_duplicate_email(self, client, test_user):
"""Test registration with duplicate email""" """Test registration with duplicate email"""
response = client.post("/api/v1/auth/register", json={ response = client.post(
"email": test_user.email, # Same as test_user "/api/v1/auth/register",
"username": "newuser", json={
"password": "securepass123" "email": test_user.email, # Same as test_user
}) "username": "newuser",
"password": "securepass123",
},
)
assert response.status_code == 400 assert response.status_code == 400
assert "Email already registered" in response.json()["detail"] assert "Email already registered" in response.json()["detail"]
def test_register_user_duplicate_username(self, client, test_user): def test_register_user_duplicate_username(self, client, test_user):
"""Test registration with duplicate username""" """Test registration with duplicate username"""
response = client.post("/api/v1/auth/register", json={ response = client.post(
"email": "new@example.com", "/api/v1/auth/register",
"username": test_user.username, # Same as test_user json={
"password": "securepass123" "email": "new@example.com",
}) "username": test_user.username, # Same as test_user
"password": "securepass123",
},
)
assert response.status_code == 400 assert response.status_code == 400
assert "Username already taken" in response.json()["detail"] assert "Username already taken" in response.json()["detail"]
def test_login_success(self, client, test_user): def test_login_success(self, client, test_user):
"""Test successful login""" """Test successful login"""
response = client.post("/api/v1/auth/login", json={ response = client.post(
"username": test_user.username, "/api/v1/auth/login",
"password": "testpass123" json={"username": test_user.username, "password": "testpass123"},
}) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -58,20 +67,20 @@ class TestAuthenticationAPI:
def test_login_wrong_password(self, client, test_user): def test_login_wrong_password(self, client, test_user):
"""Test login with wrong password""" """Test login with wrong password"""
response = client.post("/api/v1/auth/login", json={ response = client.post(
"username": "testuser", "/api/v1/auth/login",
"password": "wrongpassword" json={"username": "testuser", "password": "wrongpassword"},
}) )
assert response.status_code == 401 assert response.status_code == 401
assert "Incorrect username or password" in response.json()["detail"] assert "Incorrect username or password" in response.json()["detail"]
def test_login_nonexistent_user(self, client, db): # Added db fixture def test_login_nonexistent_user(self, client, db): # Added db fixture
"""Test login with nonexistent user""" """Test login with nonexistent user"""
response = client.post("/api/v1/auth/login", json={ response = client.post(
"username": "nonexistent", "/api/v1/auth/login",
"password": "password123" json={"username": "nonexistent", "password": "password123"},
}) )
assert response.status_code == 401 assert response.status_code == 401

View File

@@ -3,8 +3,8 @@ import pytest
from fastapi import HTTPException from fastapi import HTTPException
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from models.api_models import UserLogin, UserRegister
from models.database_models import User from models.database_models import User
from models.api_models import UserRegister, UserLogin
class TestAuthService: class TestAuthService:
@@ -17,9 +17,7 @@ class TestAuthService:
def test_register_user_success(self, db): def test_register_user_success(self, db):
"""Test successful user registration""" """Test successful user registration"""
user_data = UserRegister( user_data = UserRegister(
email="newuser@example.com", email="newuser@example.com", username="newuser123", password="securepass123"
username="newuser123",
password="securepass123"
) )
user = self.service.register_user(db, user_data) user = self.service.register_user(db, user_data)
@@ -36,7 +34,7 @@ class TestAuthService:
user_data = UserRegister( user_data = UserRegister(
email=test_user.email, # Use existing email email=test_user.email, # Use existing email
username="differentuser", username="differentuser",
password="securepass123" password="securepass123",
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -50,7 +48,7 @@ class TestAuthService:
user_data = UserRegister( user_data = UserRegister(
email="different@example.com", email="different@example.com",
username=test_user.username, # Use existing username username=test_user.username, # Use existing username
password="securepass123" password="securepass123",
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -62,8 +60,7 @@ class TestAuthService:
def test_login_user_success(self, db, test_user): def test_login_user_success(self, db, test_user):
"""Test successful user login""" """Test successful user login"""
user_credentials = UserLogin( user_credentials = UserLogin(
username=test_user.username, username=test_user.username, password="testpass123"
password="testpass123"
) )
result = self.service.login_user(db, user_credentials) result = self.service.login_user(db, user_credentials)
@@ -78,10 +75,7 @@ class TestAuthService:
def test_login_user_wrong_username(self, db): def test_login_user_wrong_username(self, db):
"""Test login fails with wrong username""" """Test login fails with wrong username"""
user_credentials = UserLogin( user_credentials = UserLogin(username="nonexistentuser", password="testpass123")
username="nonexistentuser",
password="testpass123"
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
self.service.login_user(db, user_credentials) self.service.login_user(db, user_credentials)
@@ -92,8 +86,7 @@ class TestAuthService:
def test_login_user_wrong_password(self, db, test_user): def test_login_user_wrong_password(self, db, test_user):
"""Test login fails with wrong password""" """Test login fails with wrong password"""
user_credentials = UserLogin( user_credentials = UserLogin(
username=test_user.username, username=test_user.username, password="wrongpassword"
password="wrongpassword"
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -109,8 +102,7 @@ class TestAuthService:
db.commit() db.commit()
user_credentials = UserLogin( user_credentials = UserLogin(
username=test_user.username, username=test_user.username, password="testpass123"
password="testpass123"
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:

View File

@@ -1,9 +1,11 @@
# tests/test_background_tasks.py # tests/test_background_tasks.py
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest import pytest
from unittest.mock import patch, AsyncMock
from app.tasks.background_tasks import process_marketplace_import from app.tasks.background_tasks import process_marketplace_import
from models.database_models import MarketplaceImportJob from models.database_models import MarketplaceImportJob
from datetime import datetime
class TestBackgroundTasks: class TestBackgroundTasks:
@@ -17,7 +19,7 @@ class TestBackgroundTasks:
shop_name="TESTSHOP", shop_name="TESTSHOP",
marketplace="TestMarket", marketplace="TestMarket",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id user_id=test_user.id,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -27,29 +29,30 @@ class TestBackgroundTasks:
job_id = job.id job_id = job.id
# Mock CSV processor and prevent session from closing # Mock CSV processor and prevent session from closing
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
patch('app.tasks.background_tasks.SessionLocal', return_value=db): "app.tasks.background_tasks.SessionLocal", return_value=db
):
mock_instance = mock_processor.return_value mock_instance = mock_processor.return_value
mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ mock_instance.process_marketplace_csv_from_url = AsyncMock(
"imported": 10, return_value={
"updated": 5, "imported": 10,
"total_processed": 15, "updated": 5,
"errors": 0 "total_processed": 15,
}) "errors": 0,
}
)
# Run background task # Run background task
await process_marketplace_import( await process_marketplace_import(
job_id, job_id, "http://example.com/test.csv", "TestMarket", "TESTSHOP", 1000
"http://example.com/test.csv",
"TestMarket",
"TESTSHOP",
1000
) )
# Re-query the job using the stored ID # Re-query the job using the stored ID
updated_job = db.query(MarketplaceImportJob).filter( updated_job = (
MarketplaceImportJob.id == job_id db.query(MarketplaceImportJob)
).first() .filter(MarketplaceImportJob.id == job_id)
.first()
)
assert updated_job is not None assert updated_job is not None
assert updated_job.status == "completed" assert updated_job.status == "completed"
@@ -70,7 +73,7 @@ class TestBackgroundTasks:
shop_name="TESTSHOP", shop_name="TESTSHOP",
marketplace="TestMarket", marketplace="TestMarket",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id user_id=test_user.id,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -80,8 +83,9 @@ class TestBackgroundTasks:
job_id = job.id job_id = job.id
# Mock CSV processor to raise exception # Mock CSV processor to raise exception
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
patch('app.tasks.background_tasks.SessionLocal', return_value=db): "app.tasks.background_tasks.SessionLocal", return_value=db
):
mock_instance = mock_processor.return_value mock_instance = mock_processor.return_value
mock_instance.process_marketplace_csv_from_url = AsyncMock( mock_instance.process_marketplace_csv_from_url = AsyncMock(
@@ -96,7 +100,7 @@ class TestBackgroundTasks:
"http://example.com/test.csv", "http://example.com/test.csv",
"TestMarket", "TestMarket",
"TESTSHOP", "TESTSHOP",
1000 1000,
) )
except Exception: except Exception:
# The background task should handle exceptions internally # The background task should handle exceptions internally
@@ -104,9 +108,11 @@ class TestBackgroundTasks:
pass pass
# Re-query the job using the stored ID # Re-query the job using the stored ID
updated_job = db.query(MarketplaceImportJob).filter( updated_job = (
MarketplaceImportJob.id == job_id db.query(MarketplaceImportJob)
).first() .filter(MarketplaceImportJob.id == job_id)
.first()
)
assert updated_job is not None assert updated_job is not None
assert updated_job.status == "failed" assert updated_job.status == "failed"
@@ -115,15 +121,18 @@ class TestBackgroundTasks:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_marketplace_import_job_not_found(self, db): async def test_marketplace_import_job_not_found(self, db):
"""Test handling when import job doesn't exist""" """Test handling when import job doesn't exist"""
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
patch('app.tasks.background_tasks.SessionLocal', return_value=db): "app.tasks.background_tasks.SessionLocal", return_value=db
):
mock_instance = mock_processor.return_value mock_instance = mock_processor.return_value
mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ mock_instance.process_marketplace_csv_from_url = AsyncMock(
"imported": 10, return_value={
"updated": 5, "imported": 10,
"total_processed": 15, "updated": 5,
"errors": 0 "total_processed": 15,
}) "errors": 0,
}
)
# Run background task with non-existent job ID # Run background task with non-existent job ID
await process_marketplace_import( await process_marketplace_import(
@@ -131,7 +140,7 @@ class TestBackgroundTasks:
"http://example.com/test.csv", "http://example.com/test.csv",
"TestMarket", "TestMarket",
"TESTSHOP", "TESTSHOP",
1000 1000,
) )
# Should not raise an exception, just log and return # Should not raise an exception, just log and return
@@ -148,7 +157,7 @@ class TestBackgroundTasks:
shop_name="TESTSHOP", shop_name="TESTSHOP",
marketplace="TestMarket", marketplace="TestMarket",
shop_id=test_shop.id, shop_id=test_shop.id,
user_id=test_user.id user_id=test_user.id,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -158,29 +167,30 @@ class TestBackgroundTasks:
job_id = job.id job_id = job.id
# Mock CSV processor with some errors # Mock CSV processor with some errors
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \ with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
patch('app.tasks.background_tasks.SessionLocal', return_value=db): "app.tasks.background_tasks.SessionLocal", return_value=db
):
mock_instance = mock_processor.return_value mock_instance = mock_processor.return_value
mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={ mock_instance.process_marketplace_csv_from_url = AsyncMock(
"imported": 8, return_value={
"updated": 5, "imported": 8,
"total_processed": 15, "updated": 5,
"errors": 2 "total_processed": 15,
}) "errors": 2,
}
)
# Run background task # Run background task
await process_marketplace_import( await process_marketplace_import(
job_id, job_id, "http://example.com/test.csv", "TestMarket", "TESTSHOP", 1000
"http://example.com/test.csv",
"TestMarket",
"TESTSHOP",
1000
) )
# Re-query the job using the stored ID # Re-query the job using the stored ID
updated_job = db.query(MarketplaceImportJob).filter( updated_job = (
MarketplaceImportJob.id == job_id db.query(MarketplaceImportJob)
).first() .filter(MarketplaceImportJob.id == job_id)
.first()
)
assert updated_job is not None assert updated_job is not None
assert updated_job.status == "completed_with_errors" assert updated_job.status == "completed_with_errors"

View File

@@ -1,9 +1,11 @@
# tests/test_csv_processor.py # tests/test_csv_processor.py
from unittest.mock import Mock, patch
import pandas as pd
import pytest import pytest
import requests import requests
import requests.exceptions import requests.exceptions
from unittest.mock import Mock, patch
import pandas as pd
from utils.csv_processor import CSVProcessor from utils.csv_processor import CSVProcessor
@@ -11,7 +13,7 @@ class TestCSVProcessor:
def setup_method(self): def setup_method(self):
self.processor = CSVProcessor() self.processor = CSVProcessor()
@patch('requests.get') @patch("requests.get")
def test_download_csv_encoding_fallback(self, mock_get): def test_download_csv_encoding_fallback(self, mock_get):
"""Test CSV download with encoding fallback""" """Test CSV download with encoding fallback"""
# Create content with special characters that would fail UTF-8 if not properly encoded # Create content with special characters that would fail UTF-8 if not properly encoded
@@ -20,7 +22,7 @@ class TestCSVProcessor:
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 200 mock_response.status_code = 200
# Use latin-1 encoding which your method should try # Use latin-1 encoding which your method should try
mock_response.content = special_content.encode('latin-1') mock_response.content = special_content.encode("latin-1")
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
@@ -30,14 +32,16 @@ class TestCSVProcessor:
assert isinstance(csv_content, str) assert isinstance(csv_content, str)
assert "Café Product" in csv_content assert "Café Product" in csv_content
@patch('requests.get') @patch("requests.get")
def test_download_csv_encoding_ignore_fallback(self, mock_get): def test_download_csv_encoding_ignore_fallback(self, mock_get):
"""Test CSV download falls back to UTF-8 with error ignoring""" """Test CSV download falls back to UTF-8 with error ignoring"""
# Create problematic bytes that would fail most encoding attempts # Create problematic bytes that would fail most encoding attempts
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 200 mock_response.status_code = 200
# Create bytes that will fail most encodings # Create bytes that will fail most encodings
mock_response.content = b"product_id,title,price\nTEST001,\xff\xfe Product,10.99" mock_response.content = (
b"product_id,title,price\nTEST001,\xff\xfe Product,10.99"
)
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
@@ -49,7 +53,7 @@ class TestCSVProcessor:
assert "product_id,title,price" in csv_content assert "product_id,title,price" in csv_content
assert "TEST001" in csv_content assert "TEST001" in csv_content
@patch('requests.get') @patch("requests.get")
def test_download_csv_request_exception(self, mock_get): def test_download_csv_request_exception(self, mock_get):
"""Test CSV download with request exception""" """Test CSV download with request exception"""
mock_get.side_effect = requests.exceptions.RequestException("Connection error") mock_get.side_effect = requests.exceptions.RequestException("Connection error")
@@ -57,18 +61,20 @@ class TestCSVProcessor:
with pytest.raises(requests.exceptions.RequestException): with pytest.raises(requests.exceptions.RequestException):
self.processor.download_csv("http://example.com/test.csv") self.processor.download_csv("http://example.com/test.csv")
@patch('requests.get') @patch("requests.get")
def test_download_csv_http_error(self, mock_get): def test_download_csv_http_error(self, mock_get):
"""Test CSV download with HTTP error""" """Test CSV download with HTTP error"""
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 404 mock_response.status_code = 404
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found") mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"404 Not Found"
)
mock_get.return_value = mock_response mock_get.return_value = mock_response
with pytest.raises(requests.exceptions.HTTPError): with pytest.raises(requests.exceptions.HTTPError):
self.processor.download_csv("http://example.com/nonexistent.csv") self.processor.download_csv("http://example.com/nonexistent.csv")
@patch('requests.get') @patch("requests.get")
def test_download_csv_failure(self, mock_get): def test_download_csv_failure(self, mock_get):
"""Test CSV download failure""" """Test CSV download failure"""
# Mock failed HTTP response # Mock failed HTTP response
@@ -95,26 +101,24 @@ TEST002,Test Product 2,15.99,TestMarket"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_marketplace_csv_from_url(self, db): async def test_process_marketplace_csv_from_url(self, db):
"""Test complete marketplace CSV processing""" """Test complete marketplace CSV processing"""
with patch.object(self.processor, 'download_csv') as mock_download, \ with patch.object(
patch.object(self.processor, 'parse_csv') as mock_parse: self.processor, "download_csv"
) as mock_download, patch.object(self.processor, "parse_csv") as mock_parse:
# Mock successful download and parsing # Mock successful download and parsing
mock_download.return_value = "csv_content" mock_download.return_value = "csv_content"
mock_df = pd.DataFrame({ mock_df = pd.DataFrame(
"product_id": ["TEST001", "TEST002"], {
"title": ["Product 1", "Product 2"], "product_id": ["TEST001", "TEST002"],
"price": ["10.99", "15.99"], "title": ["Product 1", "Product 2"],
"marketplace": ["TestMarket", "TestMarket"], "price": ["10.99", "15.99"],
"shop_name": ["TestShop", "TestShop"] "marketplace": ["TestMarket", "TestMarket"],
}) "shop_name": ["TestShop", "TestShop"],
}
)
mock_parse.return_value = mock_df mock_parse.return_value = mock_df
result = await self.processor.process_marketplace_csv_from_url( result = await self.processor.process_marketplace_csv_from_url(
"http://example.com/test.csv", "http://example.com/test.csv", "TestMarket", "TestShop", 1000, db
"TestMarket",
"TestShop",
1000,
db
) )
assert "imported" in result assert "imported" in result

View File

@@ -1,5 +1,6 @@
# tests/test_data_validation.py # tests/test_data_validation.py
import pytest import pytest
from utils.data_processing import GTINProcessor, PriceProcessor from utils.data_processing import GTINProcessor, PriceProcessor

View File

@@ -1,7 +1,8 @@
# tests/test_database.py # tests/test_database.py
import pytest import pytest
from sqlalchemy import text from sqlalchemy import text
from models.database_models import User, Product, Stock, Shop
from models.database_models import Product, Shop, Stock, User
class TestDatabaseModels: class TestDatabaseModels:
@@ -12,7 +13,7 @@ class TestDatabaseModels:
username="dbtest", username="dbtest",
hashed_password="hashed_password_123", hashed_password="hashed_password_123",
role="user", role="user",
is_active=True is_active=True,
) )
db.add(user) db.add(user)
@@ -36,7 +37,7 @@ class TestDatabaseModels:
gtin="1234567890123", gtin="1234567890123",
availability="in stock", availability="in stock",
marketplace="TestDB", marketplace="TestDB",
shop_name="DBTestShop" shop_name="DBTestShop",
) )
db.add(product) db.add(product)
@@ -49,11 +50,7 @@ class TestDatabaseModels:
def test_stock_model(self, db): def test_stock_model(self, db):
"""Test Stock model creation""" """Test Stock model creation"""
stock = Stock( stock = Stock(gtin="1234567890123", location="DB_WAREHOUSE", quantity=150)
gtin="1234567890123",
location="DB_WAREHOUSE",
quantity=150
)
db.add(stock) db.add(stock)
db.commit() db.commit()
@@ -72,7 +69,7 @@ class TestDatabaseModels:
description="Testing shop model", description="Testing shop model",
owner_id=test_user.id, owner_id=test_user.id,
is_active=True, is_active=True,
is_verified=False is_verified=False,
) )
db.add(shop) db.add(shop)

View File

@@ -5,24 +5,25 @@ import pytest
class TestErrorHandling: class TestErrorHandling:
def test_invalid_json(self, client, auth_headers): def test_invalid_json(self, client, auth_headers):
"""Test handling of invalid JSON""" """Test handling of invalid JSON"""
response = client.post("/api/v1/product", response = client.post(
headers=auth_headers, "/api/v1/product", headers=auth_headers, content="invalid json"
content="invalid json") )
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
def test_missing_required_fields(self, client, auth_headers): def test_missing_required_fields(self, client, auth_headers):
"""Test handling of missing required fields""" """Test handling of missing required fields"""
response = client.post("/api/v1/product", response = client.post(
headers=auth_headers, "/api/v1/product", headers=auth_headers, json={"title": "Test"}
json={"title": "Test"}) # Missing product_id ) # Missing product_id
assert response.status_code == 422 assert response.status_code == 422
def test_invalid_authentication(self, client): def test_invalid_authentication(self, client):
"""Test handling of invalid authentication""" """Test handling of invalid authentication"""
response = client.get("/api/v1/product", response = client.get(
headers={"Authorization": "Bearer invalid_token"}) "/api/v1/product", headers={"Authorization": "Bearer invalid_token"}
)
assert response.status_code == 401 # Token is not valid assert response.status_code == 401 # Token is not valid
@@ -38,8 +39,10 @@ class TestErrorHandling:
"""Test handling of duplicate resource creation""" """Test handling of duplicate resource creation"""
product_data = { product_data = {
"product_id": test_product.product_id, # Duplicate ID "product_id": test_product.product_id, # Duplicate ID
"title": "Another Product" "title": "Another Product",
} }
response = client.post("/api/v1/product", headers=auth_headers, json=product_data) response = client.post(
"/api/v1/product", headers=auth_headers, json=product_data
)
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -1,8 +1,9 @@
# tests/test_export.py # tests/test_export.py
import pytest
import csv import csv
from io import StringIO from io import StringIO
import pytest
from models.database_models import Product from models.database_models import Product
@@ -15,7 +16,7 @@ class TestExportFunctionality:
assert response.headers["content-type"] == "text/csv; charset=utf-8" assert response.headers["content-type"] == "text/csv; charset=utf-8"
# Parse CSV content # Parse CSV content
csv_content = response.content.decode('utf-8') csv_content = response.content.decode("utf-8")
csv_reader = csv.reader(StringIO(csv_content)) csv_reader = csv.reader(StringIO(csv_content))
# Check header row # Check header row
@@ -35,10 +36,12 @@ class TestExportFunctionality:
db.add_all(products) db.add_all(products)
db.commit() db.commit()
response = client.get("/api/v1/export-csv?marketplace=Amazon", headers=auth_headers) response = client.get(
"/api/v1/export-csv?marketplace=Amazon", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
csv_content = response.content.decode('utf-8') csv_content = response.content.decode("utf-8")
assert "EXP1" in csv_content assert "EXP1" in csv_content
assert "EXP2" not in csv_content # Should be filtered out assert "EXP2" not in csv_content # Should be filtered out
@@ -50,7 +53,7 @@ class TestExportFunctionality:
product = Product( product = Product(
product_id=f"PERF{i:04d}", product_id=f"PERF{i:04d}",
title=f"Performance Product {i}", title=f"Performance Product {i}",
marketplace="Performance" marketplace="Performance",
) )
products.append(product) products.append(product)
@@ -58,10 +61,10 @@ class TestExportFunctionality:
db.commit() db.commit()
import time import time
start_time = time.time() start_time = time.time()
response = client.get("/api/v1/export-csv", headers=auth_headers) response = client.get("/api/v1/export-csv", headers=auth_headers)
end_time = time.time() end_time = time.time()
assert response.status_code == 200 assert response.status_code == 200
assert end_time - start_time < 10.0 # Should complete within 10 seconds assert end_time - start_time < 10.0 # Should complete within 10 seconds

View File

@@ -1,5 +1,6 @@
# tests/test_filtering.py # tests/test_filtering.py
import pytest import pytest
from models.database_models import Product from models.database_models import Product
@@ -39,7 +40,9 @@ class TestFiltering:
db.add_all(products) db.add_all(products)
db.commit() db.commit()
response = client.get("/api/v1/product?marketplace=Amazon", headers=auth_headers) response = client.get(
"/api/v1/product?marketplace=Amazon", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] == 2 assert data["total"] == 2
@@ -47,9 +50,17 @@ class TestFiltering:
def test_product_search_filter(self, client, auth_headers, db): def test_product_search_filter(self, client, auth_headers, db):
"""Test searching products by text""" """Test searching products by text"""
products = [ products = [
Product(product_id="SEARCH1", title="Apple iPhone", description="Smartphone"), Product(
Product(product_id="SEARCH2", title="Samsung Galaxy", description="Android phone"), product_id="SEARCH1", title="Apple iPhone", description="Smartphone"
Product(product_id="SEARCH3", title="iPad Tablet", description="Apple tablet"), ),
Product(
product_id="SEARCH2",
title="Samsung Galaxy",
description="Android phone",
),
Product(
product_id="SEARCH3", title="iPad Tablet", description="Apple tablet"
),
] ]
db.add_all(products) db.add_all(products)
@@ -70,16 +81,33 @@ class TestFiltering:
def test_combined_filters(self, client, auth_headers, db): def test_combined_filters(self, client, auth_headers, db):
"""Test combining multiple filters""" """Test combining multiple filters"""
products = [ products = [
Product(product_id="COMBO1", title="Apple iPhone", brand="Apple", marketplace="Amazon"), Product(
Product(product_id="COMBO2", title="Apple iPad", brand="Apple", marketplace="eBay"), product_id="COMBO1",
Product(product_id="COMBO3", title="Samsung Phone", brand="Samsung", marketplace="Amazon"), title="Apple iPhone",
brand="Apple",
marketplace="Amazon",
),
Product(
product_id="COMBO2",
title="Apple iPad",
brand="Apple",
marketplace="eBay",
),
Product(
product_id="COMBO3",
title="Samsung Phone",
brand="Samsung",
marketplace="Amazon",
),
] ]
db.add_all(products) db.add_all(products)
db.commit() db.commit()
# Filter by brand AND marketplace # Filter by brand AND marketplace
response = client.get("/api/v1/product?brand=Apple&marketplace=Amazon", headers=auth_headers) response = client.get(
"/api/v1/product?brand=Apple&marketplace=Amazon", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] == 1 # Only iPhone matches both assert data["total"] == 1 # Only iPhone matches both

View File

@@ -14,10 +14,12 @@ class TestIntegrationFlows:
"brand": "FlowBrand", "brand": "FlowBrand",
"gtin": "1111222233334", "gtin": "1111222233334",
"availability": "in stock", "availability": "in stock",
"marketplace": "TestFlow" "marketplace": "TestFlow",
} }
response = client.post("/api/v1/product", headers=auth_headers, json=product_data) response = client.post(
"/api/v1/product", headers=auth_headers, json=product_data
)
assert response.status_code == 200 assert response.status_code == 200
product = response.json() product = response.json()
@@ -25,26 +27,33 @@ class TestIntegrationFlows:
stock_data = { stock_data = {
"gtin": product["gtin"], "gtin": product["gtin"],
"location": "MAIN_WAREHOUSE", "location": "MAIN_WAREHOUSE",
"quantity": 50 "quantity": 50,
} }
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
assert response.status_code == 200 assert response.status_code == 200
# 3. Get product with stock info # 3. Get product with stock info
response = client.get(f"/api/v1/product/{product['product_id']}", headers=auth_headers) response = client.get(
f"/api/v1/product/{product['product_id']}", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
product_detail = response.json() product_detail = response.json()
assert product_detail["stock_info"]["total_quantity"] == 50 assert product_detail["stock_info"]["total_quantity"] == 50
# 4. Update product # 4. Update product
update_data = {"title": "Updated Integration Test Product"} update_data = {"title": "Updated Integration Test Product"}
response = client.put(f"/api/v1/product/{product['product_id']}", response = client.put(
headers=auth_headers, json=update_data) f"/api/v1/product/{product['product_id']}",
headers=auth_headers,
json=update_data,
)
assert response.status_code == 200 assert response.status_code == 200
# 5. Search for product # 5. Search for product
response = client.get("/api/v1/product?search=Updated Integration", headers=auth_headers) response = client.get(
"/api/v1/product?search=Updated Integration", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["total"] == 1 assert response.json()["total"] == 1
@@ -54,7 +63,7 @@ class TestIntegrationFlows:
shop_data = { shop_data = {
"shop_code": "FLOWSHOP", "shop_code": "FLOWSHOP",
"shop_name": "Integration Flow Shop", "shop_name": "Integration Flow Shop",
"description": "Test shop for integration" "description": "Test shop for integration",
} }
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data) response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
@@ -66,10 +75,12 @@ class TestIntegrationFlows:
"product_id": "SHOPFLOW001", "product_id": "SHOPFLOW001",
"title": "Shop Flow Product", "title": "Shop Flow Product",
"price": "15.99", "price": "15.99",
"marketplace": "ShopFlow" "marketplace": "ShopFlow",
} }
response = client.post("/api/v1/product", headers=auth_headers, json=product_data) response = client.post(
"/api/v1/product", headers=auth_headers, json=product_data
)
assert response.status_code == 200 assert response.status_code == 200
product = response.json() product = response.json()
@@ -86,28 +97,28 @@ class TestIntegrationFlows:
location = "TEST_WAREHOUSE" location = "TEST_WAREHOUSE"
# 1. Set initial stock # 1. Set initial stock
response = client.post("/api/v1/stock", headers=auth_headers, json={ response = client.post(
"gtin": gtin, "/api/v1/stock",
"location": location, headers=auth_headers,
"quantity": 100 json={"gtin": gtin, "location": location, "quantity": 100},
}) )
assert response.status_code == 200 assert response.status_code == 200
# 2. Add more stock # 2. Add more stock
response = client.post("/api/v1/stock/add", headers=auth_headers, json={ response = client.post(
"gtin": gtin, "/api/v1/stock/add",
"location": location, headers=auth_headers,
"quantity": 25 json={"gtin": gtin, "location": location, "quantity": 25},
}) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["quantity"] == 125 assert response.json()["quantity"] == 125
# 3. Remove some stock # 3. Remove some stock
response = client.post("/api/v1/stock/remove", headers=auth_headers, json={ response = client.post(
"gtin": gtin, "/api/v1/stock/remove",
"location": location, headers=auth_headers,
"quantity": 30 json={"gtin": gtin, "location": location, "quantity": 30},
}) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["quantity"] == 95 assert response.json()["quantity"] == 95

View File

@@ -1,6 +1,7 @@
# tests/test_marketplace.py # tests/test_marketplace.py
from unittest.mock import AsyncMock, patch
import pytest import pytest
from unittest.mock import patch, AsyncMock
class TestMarketplaceAPI: class TestMarketplaceAPI:
@@ -10,11 +11,12 @@ class TestMarketplaceAPI:
import_data = { import_data = {
"url": "https://example.com/products.csv", "url": "https://example.com/products.csv",
"marketplace": "TestMarket", "marketplace": "TestMarket",
"shop_code": test_shop.shop_code "shop_code": test_shop.shop_code,
} }
response = client.post("/api/v1/marketplace/import-product", response = client.post(
headers=auth_headers, json=import_data) "/api/v1/marketplace/import-product", headers=auth_headers, json=import_data
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -29,11 +31,12 @@ class TestMarketplaceAPI:
import_data = { import_data = {
"url": "https://example.com/products.csv", "url": "https://example.com/products.csv",
"marketplace": "TestMarket", "marketplace": "TestMarket",
"shop_code": "NONEXISTENT" "shop_code": "NONEXISTENT",
} }
response = client.post("/api/v1/marketplace/import-product", response = client.post(
headers=auth_headers, json=import_data) "/api/v1/marketplace/import-product", headers=auth_headers, json=import_data
)
assert response.status_code == 404 assert response.status_code == 404
assert "Shop not found" in response.json()["detail"] assert "Shop not found" in response.json()["detail"]
@@ -49,4 +52,3 @@ class TestMarketplaceAPI:
"""Test that marketplace endpoints require authentication""" """Test that marketplace endpoints require authentication"""
response = client.get("/api/v1/marketplace/import-jobs") response = client.get("/api/v1/marketplace/import-jobs")
assert response.status_code == 401 # No authorization header assert response.status_code == 401 # No authorization header

View File

@@ -1,10 +1,12 @@
# tests/test_marketplace_service.py # tests/test_marketplace_service.py
import pytest
import uuid import uuid
from datetime import datetime
import pytest
from app.services.marketplace_service import MarketplaceService from app.services.marketplace_service import MarketplaceService
from models.api_models import MarketplaceImportRequest from models.api_models import MarketplaceImportRequest
from models.database_models import MarketplaceImportJob, Shop, User from models.database_models import MarketplaceImportJob, Shop, User
from datetime import datetime
class TestMarketplaceService: class TestMarketplaceService:
@@ -22,7 +24,9 @@ class TestMarketplaceService:
assert result.shop_code == test_shop.shop_code assert result.shop_code == test_shop.shop_code
assert result.owner_id == test_user.id assert result.owner_id == test_user.id
def test_validate_shop_access_admin_can_access_any_shop(self, db, test_shop, test_admin): def test_validate_shop_access_admin_can_access_any_shop(
self, db, test_shop, test_admin
):
"""Test that admin users can access any shop""" """Test that admin users can access any shop"""
result = self.service.validate_shop_access(db, test_shop.shop_code, test_admin) result = self.service.validate_shop_access(db, test_shop.shop_code, test_admin)
@@ -33,7 +37,9 @@ class TestMarketplaceService:
with pytest.raises(ValueError, match="Shop not found"): with pytest.raises(ValueError, match="Shop not found"):
self.service.validate_shop_access(db, "NONEXISTENT", test_user) self.service.validate_shop_access(db, "NONEXISTENT", test_user)
def test_validate_shop_access_permission_denied(self, db, test_shop, test_user, other_user): def test_validate_shop_access_permission_denied(
self, db, test_shop, test_user, other_user
):
"""Test shop access validation when user doesn't own the shop""" """Test shop access validation when user doesn't own the shop"""
# Set the shop owner to a different user # Set the shop owner to a different user
test_shop.owner_id = other_user.id test_shop.owner_id = other_user.id
@@ -52,7 +58,7 @@ class TestMarketplaceService:
url="https://example.com/products.csv", url="https://example.com/products.csv",
marketplace="Amazon", marketplace="Amazon",
shop_code=test_shop.shop_code, shop_code=test_shop.shop_code,
batch_size=1000 batch_size=1000,
) )
result = self.service.create_import_job(db, request, test_user) result = self.service.create_import_job(db, request, test_user)
@@ -60,7 +66,7 @@ class TestMarketplaceService:
assert result.marketplace == "Amazon" assert result.marketplace == "Amazon"
# Check the correct field based on your model # Check the correct field based on your model
assert result.shop_id == test_shop.id # Changed from shop_code to shop_id assert result.shop_id == test_shop.id # Changed from shop_code to shop_id
assert result.user_id == test_user.id if hasattr(result, 'user_id') else True assert result.user_id == test_user.id if hasattr(result, "user_id") else True
assert result.status == "pending" assert result.status == "pending"
assert result.source_url == "https://example.com/products.csv" assert result.source_url == "https://example.com/products.csv"
@@ -70,7 +76,7 @@ class TestMarketplaceService:
url="https://example.com/products.csv", url="https://example.com/products.csv",
marketplace="Amazon", marketplace="Amazon",
shop_code="INVALID_SHOP", shop_code="INVALID_SHOP",
batch_size=1000 batch_size=1000,
) )
with pytest.raises(ValueError, match="Shop not found"): with pytest.raises(ValueError, match="Shop not found"):
@@ -78,16 +84,22 @@ class TestMarketplaceService:
def test_get_import_job_by_id_success(self, db, test_marketplace_job, test_user): def test_get_import_job_by_id_success(self, db, test_marketplace_job, test_user):
"""Test getting import job by ID for job owner""" """Test getting import job by ID for job owner"""
result = self.service.get_import_job_by_id(db, test_marketplace_job.id, test_user) result = self.service.get_import_job_by_id(
db, test_marketplace_job.id, test_user
)
assert result.id == test_marketplace_job.id assert result.id == test_marketplace_job.id
# Check user_id if the field exists # Check user_id if the field exists
if hasattr(result, 'user_id'): if hasattr(result, "user_id"):
assert result.user_id == test_user.id assert result.user_id == test_user.id
def test_get_import_job_by_id_admin_access(self, db, test_marketplace_job, test_admin): def test_get_import_job_by_id_admin_access(
self, db, test_marketplace_job, test_admin
):
"""Test that admin can access any import job""" """Test that admin can access any import job"""
result = self.service.get_import_job_by_id(db, test_marketplace_job.id, test_admin) result = self.service.get_import_job_by_id(
db, test_marketplace_job.id, test_admin
)
assert result.id == test_marketplace_job.id assert result.id == test_marketplace_job.id
@@ -96,7 +108,9 @@ class TestMarketplaceService:
with pytest.raises(ValueError, match="Marketplace import job not found"): with pytest.raises(ValueError, match="Marketplace import job not found"):
self.service.get_import_job_by_id(db, 99999, test_user) self.service.get_import_job_by_id(db, 99999, test_user)
def test_get_import_job_by_id_access_denied(self, db, test_marketplace_job, other_user): def test_get_import_job_by_id_access_denied(
self, db, test_marketplace_job, other_user
):
"""Test access denied when user doesn't own the job""" """Test access denied when user doesn't own the job"""
with pytest.raises(PermissionError, match="Access denied to this import job"): with pytest.raises(PermissionError, match="Access denied to this import job"):
self.service.get_import_job_by_id(db, test_marketplace_job.id, other_user) self.service.get_import_job_by_id(db, test_marketplace_job.id, other_user)
@@ -108,7 +122,7 @@ class TestMarketplaceService:
assert len(jobs) >= 1 assert len(jobs) >= 1
assert any(job.id == test_marketplace_job.id for job in jobs) assert any(job.id == test_marketplace_job.id for job in jobs)
# Check user_id if the field exists # Check user_id if the field exists
if hasattr(test_marketplace_job, 'user_id'): if hasattr(test_marketplace_job, "user_id"):
assert test_marketplace_job.user_id == test_user.id assert test_marketplace_job.user_id == test_user.id
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_job, test_admin): def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_job, test_admin):
@@ -118,7 +132,9 @@ class TestMarketplaceService:
assert len(jobs) >= 1 assert len(jobs) >= 1
assert any(job.id == test_marketplace_job.id for job in jobs) assert any(job.id == test_marketplace_job.id for job in jobs)
def test_get_import_jobs_with_marketplace_filter(self, db, test_marketplace_job, test_user): def test_get_import_jobs_with_marketplace_filter(
self, db, test_marketplace_job, test_user
):
"""Test getting import jobs with marketplace filter""" """Test getting import jobs with marketplace filter"""
jobs = self.service.get_import_jobs( jobs = self.service.get_import_jobs(
db, test_user, marketplace=test_marketplace_job.marketplace db, test_user, marketplace=test_marketplace_job.marketplace
@@ -143,7 +159,7 @@ class TestMarketplaceService:
imported_count=0, imported_count=0,
updated_count=0, updated_count=0,
total_processed=0, total_processed=0,
error_count=0 error_count=0,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -159,7 +175,7 @@ class TestMarketplaceService:
test_marketplace_job.id, test_marketplace_job.id,
"completed", "completed",
imported_count=100, imported_count=100,
total_processed=100 total_processed=100,
) )
assert result.status == "completed" assert result.status == "completed"
@@ -211,7 +227,7 @@ class TestMarketplaceService:
imported_count=0, imported_count=0,
updated_count=0, updated_count=0,
total_processed=0, total_processed=0,
error_count=0 error_count=0,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -222,13 +238,17 @@ class TestMarketplaceService:
assert result.status == "cancelled" assert result.status == "cancelled"
assert result.completed_at is not None assert result.completed_at is not None
def test_cancel_import_job_invalid_status(self, db, test_marketplace_job, test_user): def test_cancel_import_job_invalid_status(
self, db, test_marketplace_job, test_user
):
"""Test cancelling a job that can't be cancelled""" """Test cancelling a job that can't be cancelled"""
# Set job status to completed # Set job status to completed
test_marketplace_job.status = "completed" test_marketplace_job.status = "completed"
db.commit() db.commit()
with pytest.raises(ValueError, match="Cannot cancel job with status: completed"): with pytest.raises(
ValueError, match="Cannot cancel job with status: completed"
):
self.service.cancel_import_job(db, test_marketplace_job.id, test_user) self.service.cancel_import_job(db, test_marketplace_job.id, test_user)
def test_delete_import_job_success(self, db, test_user, test_shop): def test_delete_import_job_success(self, db, test_user, test_shop):
@@ -246,7 +266,7 @@ class TestMarketplaceService:
imported_count=0, imported_count=0,
updated_count=0, updated_count=0,
total_processed=0, total_processed=0,
error_count=0 error_count=0,
) )
db.add(job) db.add(job)
db.commit() db.commit()
@@ -258,7 +278,11 @@ class TestMarketplaceService:
assert result is True assert result is True
# Verify the job is actually deleted # Verify the job is actually deleted
deleted_job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first() deleted_job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
assert deleted_job is None assert deleted_job is None
def test_delete_import_job_invalid_status(self, db, test_user, test_shop): def test_delete_import_job_invalid_status(self, db, test_user, test_shop):
@@ -276,7 +300,7 @@ class TestMarketplaceService:
imported_count=0, imported_count=0,
updated_count=0, updated_count=0,
total_processed=0, total_processed=0,
error_count=0 error_count=0,
) )
db.add(job) db.add(job)
db.commit() db.commit()

View File

@@ -1,8 +1,10 @@
# tests/test_middleware.py # tests/test_middleware.py
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from middleware.rate_limiter import RateLimiter
import pytest
from middleware.auth import AuthManager from middleware.auth import AuthManager
from middleware.rate_limiter import RateLimiter
class TestRateLimiter: class TestRateLimiter:
@@ -12,11 +14,17 @@ class TestRateLimiter:
client_id = "test_client" client_id = "test_client"
# Should allow first request # Should allow first request
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
# Should allow subsequent requests within limit # Should allow subsequent requests within limit
for _ in range(5): for _ in range(5):
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
def test_rate_limiter_blocks_excess_requests(self): def test_rate_limiter_blocks_excess_requests(self):
"""Test rate limiter blocks requests exceeding limit""" """Test rate limiter blocks requests exceeding limit"""

View File

@@ -1,5 +1,6 @@
# tests/test_pagination.py # tests/test_pagination.py
import pytest import pytest
from models.database_models import Product from models.database_models import Product
@@ -12,7 +13,7 @@ class TestPagination:
product = Product( product = Product(
product_id=f"PAGE{i:03d}", product_id=f"PAGE{i:03d}",
title=f"Pagination Test Product {i}", title=f"Pagination Test Product {i}",
marketplace="PaginationTest" marketplace="PaginationTest",
) )
products.append(product) products.append(product)

View File

@@ -1,7 +1,8 @@
# tests/test_performance.py # tests/test_performance.py
import pytest
import time import time
import pytest
from models.database_models import Product from models.database_models import Product
@@ -15,7 +16,7 @@ class TestPerformance:
product_id=f"PERF{i:03d}", product_id=f"PERF{i:03d}",
title=f"Performance Test Product {i}", title=f"Performance Test Product {i}",
price=f"{i}.99", price=f"{i}.99",
marketplace="Performance" marketplace="Performance",
) )
products.append(product) products.append(product)
@@ -41,7 +42,7 @@ class TestPerformance:
title=f"Searchable Product {i}", title=f"Searchable Product {i}",
description=f"This is a searchable product number {i}", description=f"This is a searchable product number {i}",
brand="SearchBrand", brand="SearchBrand",
marketplace="SearchMarket" marketplace="SearchMarket",
) )
products.append(product) products.append(product)

View File

@@ -31,7 +31,9 @@ class TestProductsAPI:
assert response.json()["total"] == 1 assert response.json()["total"] == 1
# Test marketplace filter # Test marketplace filter
response = client.get("/api/v1/product?marketplace=Letzshop", headers=auth_headers) response = client.get(
"/api/v1/product?marketplace=Letzshop", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["total"] == 1 assert response.json()["total"] == 1
@@ -50,10 +52,12 @@ class TestProductsAPI:
"brand": "NewBrand", "brand": "NewBrand",
"gtin": "9876543210987", "gtin": "9876543210987",
"availability": "in stock", "availability": "in stock",
"marketplace": "Amazon" "marketplace": "Amazon",
} }
response = client.post("/api/v1/product", headers=auth_headers, json=product_data) response = client.post(
"/api/v1/product", headers=auth_headers, json=product_data
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -71,10 +75,12 @@ class TestProductsAPI:
"brand": "NewBrand", "brand": "NewBrand",
"gtin": "9876543210987", "gtin": "9876543210987",
"availability": "in stock", "availability": "in stock",
"marketplace": "Amazon" "marketplace": "Amazon",
} }
response = client.post("/api/v1/product", headers=auth_headers, json=product_data) response = client.post(
"/api/v1/product", headers=auth_headers, json=product_data
)
# Debug output # Debug output
print(f"Status Code: {response.status_code}") print(f"Status Code: {response.status_code}")
@@ -89,7 +95,9 @@ class TestProductsAPI:
def test_get_product_by_id(self, client, auth_headers, test_product): def test_get_product_by_id(self, client, auth_headers, test_product):
"""Test getting specific product""" """Test getting specific product"""
response = client.get(f"/api/v1/product/{test_product.product_id}", headers=auth_headers) response = client.get(
f"/api/v1/product/{test_product.product_id}", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -104,15 +112,12 @@ class TestProductsAPI:
def test_update_product(self, client, auth_headers, test_product): def test_update_product(self, client, auth_headers, test_product):
"""Test updating product""" """Test updating product"""
update_data = { update_data = {"title": "Updated Product Title", "price": "25.99"}
"title": "Updated Product Title",
"price": "25.99"
}
response = client.put( response = client.put(
f"/api/v1/product/{test_product.product_id}", f"/api/v1/product/{test_product.product_id}",
headers=auth_headers, headers=auth_headers,
json=update_data json=update_data,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -123,8 +128,7 @@ class TestProductsAPI:
def test_delete_product(self, client, auth_headers, test_product): def test_delete_product(self, client, auth_headers, test_product):
"""Test deleting product""" """Test deleting product"""
response = client.delete( response = client.delete(
f"/api/v1/product/{test_product.product_id}", f"/api/v1/product/{test_product.product_id}", headers=auth_headers
headers=auth_headers
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -1,5 +1,6 @@
# tests/test_product_service.py # tests/test_product_service.py
import pytest import pytest
from app.services.product_service import ProductService from app.services.product_service import ProductService
from models.api_models import ProductCreate from models.api_models import ProductCreate
from models.database_models import Product from models.database_models import Product
@@ -16,7 +17,7 @@ class TestProductService:
title="Service Test Product", title="Service Test Product",
gtin="1234567890123", gtin="1234567890123",
price="19.99", price="19.99",
marketplace="TestMarket" marketplace="TestMarket",
) )
product = self.service.create_product(db, product_data) product = self.service.create_product(db, product_data)
@@ -31,7 +32,7 @@ class TestProductService:
product_id="SVC002", product_id="SVC002",
title="Service Test Product", title="Service Test Product",
gtin="invalid_gtin", gtin="invalid_gtin",
price="19.99" price="19.99",
) )
with pytest.raises(ValueError, match="Invalid GTIN format"): with pytest.raises(ValueError, match="Invalid GTIN format"):
@@ -39,10 +40,7 @@ class TestProductService:
def test_get_products_with_filters(self, db, test_product): def test_get_products_with_filters(self, db, test_product):
"""Test getting products with various filters""" """Test getting products with various filters"""
products, total = self.service.get_products_with_filters( products, total = self.service.get_products_with_filters(db, brand="TestBrand")
db,
brand="TestBrand"
)
assert total == 1 assert total == 1
assert len(products) == 1 assert len(products) == 1
@@ -51,10 +49,8 @@ class TestProductService:
def test_get_products_with_search(self, db, test_product): def test_get_products_with_search(self, db, test_product):
"""Test getting products with search""" """Test getting products with search"""
products, total = self.service.get_products_with_filters( products, total = self.service.get_products_with_filters(
db, db, search="Test Product"
search="Test Product"
) )
assert total == 1 assert total == 1
assert len(products) == 1 assert len(products) == 1

View File

@@ -1,7 +1,8 @@
# tests/test_security.py # tests/test_security.py
from unittest.mock import patch
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
from unittest.mock import patch
class TestSecurity: class TestSecurity:
@@ -10,7 +11,9 @@ class TestSecurity:
response = client.get("/api/v1/debug-bearer") response = client.get("/api/v1/debug-bearer")
print(f"Direct Bearer - Status: {response.status_code}") print(f"Direct Bearer - Status: {response.status_code}")
print(f"Direct Bearer - Response: {response.json() if response.content else 'No content'}") print(
f"Direct Bearer - Response: {response.json() if response.content else 'No content'}"
)
def test_debug_dependencies(self, client): def test_debug_dependencies(self, client):
"""Debug the dependency chain step by step""" """Debug the dependency chain step by step"""
@@ -24,7 +27,9 @@ class TestSecurity:
print(f"Admin endpoint - Raw: {response.content}") print(f"Admin endpoint - Raw: {response.content}")
# Test 2: Try a regular endpoint that uses get_current_user # Test 2: Try a regular endpoint that uses get_current_user
response2 = client.get("/api/v1/product") # or any endpoint with get_current_user response2 = client.get(
"/api/v1/product"
) # or any endpoint with get_current_user
print(f"Regular endpoint - Status: {response2.status_code}") print(f"Regular endpoint - Status: {response2.status_code}")
try: try:
print(f"Regular endpoint - Response: {response2.json()}") print(f"Regular endpoint - Response: {response2.json()}")
@@ -35,7 +40,7 @@ class TestSecurity:
"""Debug test to see all available routes""" """Debug test to see all available routes"""
print("\n=== All Available Routes ===") print("\n=== All Available Routes ===")
for route in client.app.routes: for route in client.app.routes:
if hasattr(route, 'path') and hasattr(route, 'methods'): if hasattr(route, "path") and hasattr(route, "methods"):
print(f"{list(route.methods)} {route.path}") print(f"{list(route.methods)} {route.path}")
print("\n=== Testing Product Endpoint Variations ===") print("\n=== Testing Product Endpoint Variations ===")
@@ -59,7 +64,7 @@ class TestSecurity:
"/api/v1/product", "/api/v1/product",
"/api/v1/shop", "/api/v1/shop",
"/api/v1/stats", "/api/v1/stats",
"/api/v1/stock" "/api/v1/stock",
] ]
for endpoint in protected_endpoints: for endpoint in protected_endpoints:
@@ -76,7 +81,9 @@ class TestSecurity:
def test_admin_endpoint_requires_admin_role(self, client, auth_headers): def test_admin_endpoint_requires_admin_role(self, client, auth_headers):
"""Test that admin endpoints require admin role""" """Test that admin endpoints require admin role"""
response = client.get("/api/v1/admin/users", headers=auth_headers) response = client.get("/api/v1/admin/users", headers=auth_headers)
assert response.status_code == 403 # Token is valid but user does not have access. assert (
response.status_code == 403
) # Token is valid but user does not have access.
# Regular user should be denied # Regular user should be denied
def test_sql_injection_prevention(self, client, auth_headers): def test_sql_injection_prevention(self, client, auth_headers):
@@ -84,7 +91,9 @@ class TestSecurity:
# Try SQL injection in search parameter # Try SQL injection in search parameter
malicious_search = "'; DROP TABLE products; --" malicious_search = "'; DROP TABLE products; --"
response = client.get(f"/api/v1/product?search={malicious_search}", headers=auth_headers) response = client.get(
f"/api/v1/product?search={malicious_search}", headers=auth_headers
)
# Should not crash and should return normal response # Should not crash and should return normal response
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -8,7 +8,7 @@ class TestShopsAPI:
shop_data = { shop_data = {
"shop_code": "NEWSHOP", "shop_code": "NEWSHOP",
"shop_name": "New Shop", "shop_name": "New Shop",
"description": "A new test shop" "description": "A new test shop",
} }
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data) response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
@@ -23,7 +23,7 @@ class TestShopsAPI:
"""Test creating shop with duplicate code""" """Test creating shop with duplicate code"""
shop_data = { shop_data = {
"shop_code": test_shop.shop_code, # Same as test_shop "shop_code": test_shop.shop_code, # Same as test_shop
"shop_name": test_shop.shop_name "shop_name": test_shop.shop_name,
} }
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data) response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
@@ -42,7 +42,9 @@ class TestShopsAPI:
def test_get_shop_by_code(self, client, auth_headers, test_shop): def test_get_shop_by_code(self, client, auth_headers, test_shop):
"""Test getting specific shop""" """Test getting specific shop"""
response = client.get(f"/api/v1/shop/{test_shop.shop_code}", headers=auth_headers) response = client.get(
f"/api/v1/shop/{test_shop.shop_code}", headers=auth_headers
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()

View File

@@ -18,7 +18,7 @@ class TestShopService:
shop_data = ShopCreate( shop_data = ShopCreate(
shop_code="NEWSHOP", shop_code="NEWSHOP",
shop_name="New Test Shop", shop_name="New Test Shop",
description="A new test shop" description="A new test shop",
) )
shop = self.service.create_shop(db, shop_data, test_user) shop = self.service.create_shop(db, shop_data, test_user)
@@ -30,10 +30,7 @@ class TestShopService:
def test_create_shop_admin_auto_verify(self, db, test_admin, shop_factory): def test_create_shop_admin_auto_verify(self, db, test_admin, shop_factory):
"""Test admin creates verified shop automatically""" """Test admin creates verified shop automatically"""
shop_data = ShopCreate( shop_data = ShopCreate(shop_code="ADMINSHOP", shop_name="Admin Test Shop")
shop_code="ADMINSHOP",
shop_name="Admin Test Shop"
)
shop = self.service.create_shop(db, shop_data, test_admin) shop = self.service.create_shop(db, shop_data, test_admin)
@@ -42,8 +39,7 @@ class TestShopService:
def test_create_shop_duplicate_code(self, db, test_user, test_shop): def test_create_shop_duplicate_code(self, db, test_user, test_shop):
"""Test shop creation fails with duplicate shop code""" """Test shop creation fails with duplicate shop code"""
shop_data = ShopCreate( shop_data = ShopCreate(
shop_code=test_shop.shop_code, shop_code=test_shop.shop_code, shop_name=test_shop.shop_name
shop_name=test_shop.shop_name
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -60,9 +56,13 @@ class TestShopService:
assert test_shop.shop_code in shop_codes assert test_shop.shop_code in shop_codes
assert inactive_shop.shop_code not in shop_codes assert inactive_shop.shop_code not in shop_codes
def test_get_shops_admin_user(self, db, test_admin, test_shop, inactive_shop, verified_shop): def test_get_shops_admin_user(
self, db, test_admin, test_shop, inactive_shop, verified_shop
):
"""Test admin user can see all shops with filters""" """Test admin user can see all shops with filters"""
shops, total = self.service.get_shops(db, test_admin, active_only=False, verified_only=False) shops, total = self.service.get_shops(
db, test_admin, active_only=False, verified_only=False
)
shop_codes = [shop.shop_code for shop in shops] shop_codes = [shop.shop_code for shop in shops]
assert test_shop.shop_code in shop_codes assert test_shop.shop_code in shop_codes
@@ -78,7 +78,9 @@ class TestShopService:
def test_get_shop_by_code_admin_access(self, db, test_admin, test_shop): def test_get_shop_by_code_admin_access(self, db, test_admin, test_shop):
"""Test admin can access any shop""" """Test admin can access any shop"""
shop = self.service.get_shop_by_code(db, test_shop.shop_code.lower(), test_admin) shop = self.service.get_shop_by_code(
db, test_shop.shop_code.lower(), test_admin
)
assert shop is not None assert shop is not None
assert shop.id == test_shop.id assert shop.id == test_shop.id
@@ -103,10 +105,12 @@ class TestShopService:
product_id=unique_product.product_id, product_id=unique_product.product_id,
price="15.99", price="15.99",
is_featured=True, is_featured=True,
stock_quantity=5 stock_quantity=5,
) )
shop_product = self.service.add_product_to_shop(db, test_shop, shop_product_data) shop_product = self.service.add_product_to_shop(
db, test_shop, shop_product_data
)
assert shop_product is not None assert shop_product is not None
assert shop_product.shop_id == test_shop.id assert shop_product.shop_id == test_shop.id
@@ -114,10 +118,7 @@ class TestShopService:
def test_add_product_to_shop_product_not_found(self, db, test_shop): def test_add_product_to_shop_product_not_found(self, db, test_shop):
"""Test adding non-existent product to shop fails""" """Test adding non-existent product to shop fails"""
shop_product_data = ShopProductCreate( shop_product_data = ShopProductCreate(product_id="NONEXISTENT", price="15.99")
product_id="NONEXISTENT",
price="15.99"
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
self.service.add_product_to_shop(db, test_shop, shop_product_data) self.service.add_product_to_shop(db, test_shop, shop_product_data)
@@ -127,8 +128,7 @@ class TestShopService:
def test_add_product_to_shop_already_exists(self, db, test_shop, shop_product): def test_add_product_to_shop_already_exists(self, db, test_shop, shop_product):
"""Test adding product that's already in shop fails""" """Test adding product that's already in shop fails"""
shop_product_data = ShopProductCreate( shop_product_data = ShopProductCreate(
product_id=shop_product.product.product_id, product_id=shop_product.product.product_id, price="15.99"
price="15.99"
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -136,7 +136,9 @@ class TestShopService:
assert exc_info.value.status_code == 400 assert exc_info.value.status_code == 400
def test_get_shop_products_owner_access(self, db, test_user, test_shop, shop_product): def test_get_shop_products_owner_access(
self, db, test_user, test_shop, shop_product
):
"""Test shop owner can get shop products""" """Test shop owner can get shop products"""
products, total = self.service.get_shop_products(db, test_shop, test_user) products, total = self.service.get_shop_products(db, test_shop, test_user)

View File

@@ -40,7 +40,7 @@ class TestStatsService:
marketplace="Amazon", marketplace="Amazon",
shop_name="AmazonShop", shop_name="AmazonShop",
price="15.99", price="15.99",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="PROD003", product_id="PROD003",
@@ -50,7 +50,7 @@ class TestStatsService:
marketplace="eBay", marketplace="eBay",
shop_name="eBayShop", shop_name="eBayShop",
price="25.99", price="25.99",
currency="USD" currency="USD",
), ),
Product( Product(
product_id="PROD004", product_id="PROD004",
@@ -60,8 +60,8 @@ class TestStatsService:
marketplace="Letzshop", # Same as test_product marketplace="Letzshop", # Same as test_product
shop_name="DifferentShop", shop_name="DifferentShop",
price="35.99", price="35.99",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(additional_products) db.add_all(additional_products)
db.commit() db.commit()
@@ -86,7 +86,7 @@ class TestStatsService:
marketplace=None, # Null marketplace marketplace=None, # Null marketplace
shop_name=None, # Null shop shop_name=None, # Null shop
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="EMPTY001", product_id="EMPTY001",
@@ -96,8 +96,8 @@ class TestStatsService:
marketplace="", # Empty marketplace marketplace="", # Empty marketplace
shop_name="", # Empty shop shop_name="", # Empty shop
price="15.00", price="15.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(products_with_nulls) db.add_all(products_with_nulls)
db.commit() db.commit()
@@ -122,14 +122,16 @@ class TestStatsService:
# Find our test marketplace in the results # Find our test marketplace in the results
test_marketplace_stat = next( test_marketplace_stat = next(
(stat for stat in stats if stat["marketplace"] == test_product.marketplace), (stat for stat in stats if stat["marketplace"] == test_product.marketplace),
None None,
) )
assert test_marketplace_stat is not None assert test_marketplace_stat is not None
assert test_marketplace_stat["total_products"] >= 1 assert test_marketplace_stat["total_products"] >= 1
assert test_marketplace_stat["unique_shops"] >= 1 assert test_marketplace_stat["unique_shops"] >= 1
assert test_marketplace_stat["unique_brands"] >= 1 assert test_marketplace_stat["unique_brands"] >= 1
def test_get_marketplace_breakdown_stats_multiple_marketplaces(self, db, test_product): def test_get_marketplace_breakdown_stats_multiple_marketplaces(
self, db, test_product
):
"""Test marketplace breakdown with multiple marketplaces""" """Test marketplace breakdown with multiple marketplaces"""
# Create products for different marketplaces # Create products for different marketplaces
marketplace_products = [ marketplace_products = [
@@ -140,7 +142,7 @@ class TestStatsService:
marketplace="Amazon", marketplace="Amazon",
shop_name="AmazonShop1", shop_name="AmazonShop1",
price="20.00", price="20.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="AMAZON002", product_id="AMAZON002",
@@ -149,7 +151,7 @@ class TestStatsService:
marketplace="Amazon", marketplace="Amazon",
shop_name="AmazonShop2", shop_name="AmazonShop2",
price="25.00", price="25.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="EBAY001", product_id="EBAY001",
@@ -158,8 +160,8 @@ class TestStatsService:
marketplace="eBay", marketplace="eBay",
shop_name="eBayShop", shop_name="eBayShop",
price="30.00", price="30.00",
currency="USD" currency="USD",
) ),
] ]
db.add_all(marketplace_products) db.add_all(marketplace_products)
db.commit() db.commit()
@@ -194,7 +196,7 @@ class TestStatsService:
shop_name="SomeShop", shop_name="SomeShop",
brand="SomeBrand", brand="SomeBrand",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
) )
db.add(null_marketplace_product) db.add(null_marketplace_product)
db.commit() db.commit()
@@ -202,7 +204,9 @@ class TestStatsService:
stats = self.service.get_marketplace_breakdown_stats(db) stats = self.service.get_marketplace_breakdown_stats(db)
# Should not include any stats for null marketplace # Should not include any stats for null marketplace
marketplace_names = [stat["marketplace"] for stat in stats if stat["marketplace"] is not None] marketplace_names = [
stat["marketplace"] for stat in stats if stat["marketplace"] is not None
]
assert None not in marketplace_names assert None not in marketplace_names
def test_get_product_count(self, db, test_product): def test_get_product_count(self, db, test_product):
@@ -223,7 +227,7 @@ class TestStatsService:
marketplace="Test", marketplace="Test",
shop_name="TestShop", shop_name="TestShop",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="BRAND002", product_id="BRAND002",
@@ -232,15 +236,17 @@ class TestStatsService:
marketplace="Test", marketplace="Test",
shop_name="TestShop", shop_name="TestShop",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(brand_products) db.add_all(brand_products)
db.commit() db.commit()
count = self.service.get_unique_brands_count(db) count = self.service.get_unique_brands_count(db)
assert count >= 2 # At least BrandA and BrandB, plus possibly test_product brand assert (
count >= 2
) # At least BrandA and BrandB, plus possibly test_product brand
assert isinstance(count, int) assert isinstance(count, int)
def test_get_unique_categories_count(self, db, test_product): def test_get_unique_categories_count(self, db, test_product):
@@ -254,7 +260,7 @@ class TestStatsService:
marketplace="Test", marketplace="Test",
shop_name="TestShop", shop_name="TestShop",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="CAT002", product_id="CAT002",
@@ -263,8 +269,8 @@ class TestStatsService:
marketplace="Test", marketplace="Test",
shop_name="TestShop", shop_name="TestShop",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(category_products) db.add_all(category_products)
db.commit() db.commit()
@@ -284,7 +290,7 @@ class TestStatsService:
marketplace="Amazon", marketplace="Amazon",
shop_name="AmazonShop", shop_name="AmazonShop",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="MARKET002", product_id="MARKET002",
@@ -292,8 +298,8 @@ class TestStatsService:
marketplace="eBay", marketplace="eBay",
shop_name="eBayShop", shop_name="eBayShop",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(marketplace_products) db.add_all(marketplace_products)
db.commit() db.commit()
@@ -313,7 +319,7 @@ class TestStatsService:
marketplace="Test", marketplace="Test",
shop_name="ShopA", shop_name="ShopA",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="SHOP002", product_id="SHOP002",
@@ -321,8 +327,8 @@ class TestStatsService:
marketplace="Test", marketplace="Test",
shop_name="ShopB", shop_name="ShopB",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(shop_products) db.add_all(shop_products)
db.commit() db.commit()
@@ -341,15 +347,15 @@ class TestStatsService:
location="LOCATION2", location="LOCATION2",
quantity=25, quantity=25,
reserved_quantity=5, reserved_quantity=5,
shop_id=test_stock.shop_id shop_id=test_stock.shop_id,
), ),
Stock( Stock(
gtin="1234567890125", gtin="1234567890125",
location="LOCATION3", location="LOCATION3",
quantity=0, # Out of stock quantity=0, # Out of stock
reserved_quantity=0, reserved_quantity=0,
shop_id=test_stock.shop_id shop_id=test_stock.shop_id,
) ),
] ]
db.add_all(additional_stocks) db.add_all(additional_stocks)
db.commit() db.commit()
@@ -372,7 +378,7 @@ class TestStatsService:
marketplace="SpecificMarket", marketplace="SpecificMarket",
shop_name="SpecificShop1", shop_name="SpecificShop1",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="SPECIFIC002", product_id="SPECIFIC002",
@@ -381,7 +387,7 @@ class TestStatsService:
marketplace="SpecificMarket", marketplace="SpecificMarket",
shop_name="SpecificShop2", shop_name="SpecificShop2",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="OTHER001", product_id="OTHER001",
@@ -390,8 +396,8 @@ class TestStatsService:
marketplace="OtherMarket", marketplace="OtherMarket",
shop_name="OtherShop", shop_name="OtherShop",
price="20.00", price="20.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(marketplace_products) db.add_all(marketplace_products)
db.commit() db.commit()
@@ -414,7 +420,7 @@ class TestStatsService:
marketplace="TestMarketplace", marketplace="TestMarketplace",
shop_name="TestShop1", shop_name="TestShop1",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="SHOPTEST002", product_id="SHOPTEST002",
@@ -423,8 +429,8 @@ class TestStatsService:
marketplace="TestMarketplace", marketplace="TestMarketplace",
shop_name="TestShop2", shop_name="TestShop2",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(marketplace_products) db.add_all(marketplace_products)
db.commit() db.commit()
@@ -445,7 +451,7 @@ class TestStatsService:
marketplace="CountMarketplace", marketplace="CountMarketplace",
shop_name="CountShop", shop_name="CountShop",
price="10.00", price="10.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="COUNT002", product_id="COUNT002",
@@ -453,7 +459,7 @@ class TestStatsService:
marketplace="CountMarketplace", marketplace="CountMarketplace",
shop_name="CountShop", shop_name="CountShop",
price="15.00", price="15.00",
currency="EUR" currency="EUR",
), ),
Product( Product(
product_id="COUNT003", product_id="COUNT003",
@@ -461,8 +467,8 @@ class TestStatsService:
marketplace="CountMarketplace", marketplace="CountMarketplace",
shop_name="CountShop", shop_name="CountShop",
price="20.00", price="20.00",
currency="EUR" currency="EUR",
) ),
] ]
db.add_all(marketplace_products) db.add_all(marketplace_products)
db.commit() db.commit()

View File

@@ -1,5 +1,6 @@
# tests/test_stock.py # tests/test_stock.py
import pytest import pytest
from models.database_models import Stock from models.database_models import Stock
@@ -9,7 +10,7 @@ class TestStockAPI:
stock_data = { stock_data = {
"gtin": "1234567890123", "gtin": "1234567890123",
"location": "WAREHOUSE_A", "location": "WAREHOUSE_A",
"quantity": 100 "quantity": 100,
} }
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
@@ -30,7 +31,7 @@ class TestStockAPI:
stock_data = { stock_data = {
"gtin": "1234567890123", "gtin": "1234567890123",
"location": "WAREHOUSE_A", "location": "WAREHOUSE_A",
"quantity": 75 "quantity": 75,
} }
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data) response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
@@ -49,10 +50,12 @@ class TestStockAPI:
stock_data = { stock_data = {
"gtin": "1234567890123", "gtin": "1234567890123",
"location": "WAREHOUSE_A", "location": "WAREHOUSE_A",
"quantity": 25 "quantity": 25,
} }
response = client.post("/api/v1/stock/add", headers=auth_headers, json=stock_data) response = client.post(
"/api/v1/stock/add", headers=auth_headers, json=stock_data
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -68,10 +71,12 @@ class TestStockAPI:
stock_data = { stock_data = {
"gtin": "1234567890123", "gtin": "1234567890123",
"location": "WAREHOUSE_A", "location": "WAREHOUSE_A",
"quantity": 15 "quantity": 15,
} }
response = client.post("/api/v1/stock/remove", headers=auth_headers, json=stock_data) response = client.post(
"/api/v1/stock/remove", headers=auth_headers, json=stock_data
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -87,10 +92,12 @@ class TestStockAPI:
stock_data = { stock_data = {
"gtin": "1234567890123", "gtin": "1234567890123",
"location": "WAREHOUSE_A", "location": "WAREHOUSE_A",
"quantity": 20 "quantity": 20,
} }
response = client.post("/api/v1/stock/remove", headers=auth_headers, json=stock_data) response = client.post(
"/api/v1/stock/remove", headers=auth_headers, json=stock_data
)
assert response.status_code == 400 assert response.status_code == 400
assert "Insufficient stock" in response.json()["detail"] assert "Insufficient stock" in response.json()["detail"]

View File

@@ -1,9 +1,11 @@
# tests/test_stock_service.py # tests/test_stock_service.py
import pytest
import uuid import uuid
import pytest
from app.services.stock_service import StockService from app.services.stock_service import StockService
from models.api_models import StockCreate, StockAdd, StockUpdate from models.api_models import StockAdd, StockCreate, StockUpdate
from models.database_models import Stock, Product from models.database_models import Product, Stock
class TestStockService: class TestStockService:
@@ -39,7 +41,9 @@ class TestStockService:
assert self.service.normalize_gtin("1234567890123") == "1234567890123" # EAN-13 assert self.service.normalize_gtin("1234567890123") == "1234567890123" # EAN-13
assert self.service.normalize_gtin("123456789012") == "123456789012" # UPC-A assert self.service.normalize_gtin("123456789012") == "123456789012" # UPC-A
assert self.service.normalize_gtin("12345678") == "12345678" # EAN-8 assert self.service.normalize_gtin("12345678") == "12345678" # EAN-8
assert self.service.normalize_gtin("12345678901234") == "12345678901234" # GTIN-14 assert (
self.service.normalize_gtin("12345678901234") == "12345678901234"
) # GTIN-14
# Test with decimal points (should be removed) # Test with decimal points (should be removed)
assert self.service.normalize_gtin("1234567890123.0") == "1234567890123" assert self.service.normalize_gtin("1234567890123.0") == "1234567890123"
@@ -49,10 +53,14 @@ class TestStockService:
# Test short GTINs being padded # Test short GTINs being padded
assert self.service.normalize_gtin("123") == "0000000000123" # Padded to EAN-13 assert self.service.normalize_gtin("123") == "0000000000123" # Padded to EAN-13
assert self.service.normalize_gtin("12345") == "0000000012345" # Padded to EAN-13 assert (
self.service.normalize_gtin("12345") == "0000000012345"
) # Padded to EAN-13
# Test long GTINs being truncated # Test long GTINs being truncated
assert self.service.normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13 assert (
self.service.normalize_gtin("123456789012345") == "3456789012345"
) # Truncated to 13
def test_normalize_gtin_edge_cases(self): def test_normalize_gtin_edge_cases(self):
"""Test GTIN normalization edge cases""" """Test GTIN normalization edge cases"""
@@ -61,17 +69,21 @@ class TestStockService:
assert self.service.normalize_gtin(123) == "0000000000123" assert self.service.normalize_gtin(123) == "0000000000123"
# Test mixed valid/invalid characters # Test mixed valid/invalid characters
assert self.service.normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed assert (
assert self.service.normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed self.service.normalize_gtin("123-456-789-012") == "123456789012"
assert self.service.normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed ) # Dashes removed
assert (
self.service.normalize_gtin("123 456 789 012") == "123456789012"
) # Spaces removed
assert (
self.service.normalize_gtin("ABC123456789012DEF") == "123456789012"
) # Letters removed
def test_set_stock_new_entry(self, db): def test_set_stock_new_entry(self, db):
"""Test setting stock for a new GTIN/location combination""" """Test setting stock for a new GTIN/location combination"""
unique_id = str(uuid.uuid4())[:8] unique_id = str(uuid.uuid4())[:8]
stock_data = StockCreate( stock_data = StockCreate(
gtin="1234567890123", gtin="1234567890123", location=f"WAREHOUSE_A_{unique_id}", quantity=100
location=f"WAREHOUSE_A_{unique_id}",
quantity=100
) )
result = self.service.set_stock(db, stock_data) result = self.service.set_stock(db, stock_data)
@@ -85,7 +97,7 @@ class TestStockService:
stock_data = StockCreate( stock_data = StockCreate(
gtin=test_stock.gtin, gtin=test_stock.gtin,
location=test_stock.location, # Use exact same location as test_stock location=test_stock.location, # Use exact same location as test_stock
quantity=200 quantity=200,
) )
result = self.service.set_stock(db, stock_data) result = self.service.set_stock(db, stock_data)
@@ -98,9 +110,7 @@ class TestStockService:
def test_set_stock_invalid_gtin(self, db): def test_set_stock_invalid_gtin(self, db):
"""Test setting stock with invalid GTIN""" """Test setting stock with invalid GTIN"""
stock_data = StockCreate( stock_data = StockCreate(
gtin="invalid_gtin", gtin="invalid_gtin", location="WAREHOUSE_A", quantity=100
location="WAREHOUSE_A",
quantity=100
) )
with pytest.raises(ValueError, match="Invalid GTIN format"): with pytest.raises(ValueError, match="Invalid GTIN format"):
@@ -110,9 +120,7 @@ class TestStockService:
"""Test adding stock for a new GTIN/location combination""" """Test adding stock for a new GTIN/location combination"""
unique_id = str(uuid.uuid4())[:8] unique_id = str(uuid.uuid4())[:8]
stock_data = StockAdd( stock_data = StockAdd(
gtin="1234567890123", gtin="1234567890123", location=f"WAREHOUSE_B_{unique_id}", quantity=50
location=f"WAREHOUSE_B_{unique_id}",
quantity=50
) )
result = self.service.add_stock(db, stock_data) result = self.service.add_stock(db, stock_data)
@@ -127,7 +135,7 @@ class TestStockService:
stock_data = StockAdd( stock_data = StockAdd(
gtin=test_stock.gtin, gtin=test_stock.gtin,
location=test_stock.location, # Use exact same location as test_stock location=test_stock.location, # Use exact same location as test_stock
quantity=25 quantity=25,
) )
result = self.service.add_stock(db, stock_data) result = self.service.add_stock(db, stock_data)
@@ -138,11 +146,7 @@ class TestStockService:
def test_add_stock_invalid_gtin(self, db): def test_add_stock_invalid_gtin(self, db):
"""Test adding stock with invalid GTIN""" """Test adding stock with invalid GTIN"""
stock_data = StockAdd( stock_data = StockAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
gtin="invalid_gtin",
location="WAREHOUSE_A",
quantity=50
)
with pytest.raises(ValueError, match="Invalid GTIN format"): with pytest.raises(ValueError, match="Invalid GTIN format"):
self.service.add_stock(db, stock_data) self.service.add_stock(db, stock_data)
@@ -150,12 +154,14 @@ class TestStockService:
def test_remove_stock_success(self, db, test_stock): def test_remove_stock_success(self, db, test_stock):
"""Test removing stock successfully""" """Test removing stock successfully"""
original_quantity = test_stock.quantity original_quantity = test_stock.quantity
remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available remove_quantity = min(
10, original_quantity
) # Ensure we don't remove more than available
stock_data = StockAdd( stock_data = StockAdd(
gtin=test_stock.gtin, gtin=test_stock.gtin,
location=test_stock.location, # Use exact same location as test_stock location=test_stock.location, # Use exact same location as test_stock
quantity=remove_quantity quantity=remove_quantity,
) )
result = self.service.remove_stock(db, stock_data) result = self.service.remove_stock(db, stock_data)
@@ -169,20 +175,20 @@ class TestStockService:
stock_data = StockAdd( stock_data = StockAdd(
gtin=test_stock.gtin, gtin=test_stock.gtin,
location=test_stock.location, # Use exact same location as test_stock location=test_stock.location, # Use exact same location as test_stock
quantity=test_stock.quantity + 10 # More than available quantity=test_stock.quantity + 10, # More than available
) )
# Fix: Use more flexible regex pattern # Fix: Use more flexible regex pattern
with pytest.raises(ValueError, match="Insufficient stock|Not enough stock|Cannot remove"): with pytest.raises(
ValueError, match="Insufficient stock|Not enough stock|Cannot remove"
):
self.service.remove_stock(db, stock_data) self.service.remove_stock(db, stock_data)
def test_remove_stock_nonexistent_entry(self, db): def test_remove_stock_nonexistent_entry(self, db):
"""Test removing stock from non-existent GTIN/location""" """Test removing stock from non-existent GTIN/location"""
unique_id = str(uuid.uuid4())[:8] unique_id = str(uuid.uuid4())[:8]
stock_data = StockAdd( stock_data = StockAdd(
gtin="9999999999999", gtin="9999999999999", location=f"NONEXISTENT_{unique_id}", quantity=10
location=f"NONEXISTENT_{unique_id}",
quantity=10
) )
with pytest.raises(ValueError, match="No stock found|Stock not found"): with pytest.raises(ValueError, match="No stock found|Stock not found"):
@@ -190,11 +196,7 @@ class TestStockService:
def test_remove_stock_invalid_gtin(self, db): def test_remove_stock_invalid_gtin(self, db):
"""Test removing stock with invalid GTIN""" """Test removing stock with invalid GTIN"""
stock_data = StockAdd( stock_data = StockAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
gtin="invalid_gtin",
location="WAREHOUSE_A",
quantity=10
)
with pytest.raises(ValueError, match="Invalid GTIN format"): with pytest.raises(ValueError, match="Invalid GTIN format"):
self.service.remove_stock(db, stock_data) self.service.remove_stock(db, stock_data)
@@ -218,14 +220,10 @@ class TestStockService:
# Create multiple stock entries for the same GTIN with unique locations # Create multiple stock entries for the same GTIN with unique locations
stock1 = Stock( stock1 = Stock(
gtin=unique_gtin, gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
location=f"WAREHOUSE_A_{unique_id}",
quantity=50
) )
stock2 = Stock( stock2 = Stock(
gtin=unique_gtin, gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
location=f"WAREHOUSE_B_{unique_id}",
quantity=30
) )
db.add(stock1) db.add(stock1)
@@ -275,7 +273,9 @@ class TestStockService:
assert len(result) >= 1 assert len(result) >= 1
# Fix: Handle case sensitivity in comparison # Fix: Handle case sensitivity in comparison
assert all(stock.location.upper() == test_stock.location.upper() for stock in result) assert all(
stock.location.upper() == test_stock.location.upper() for stock in result
)
def test_get_all_stock_with_gtin_filter(self, db, test_stock): def test_get_all_stock_with_gtin_filter(self, db, test_stock):
"""Test getting all stock with GTIN filter""" """Test getting all stock with GTIN filter"""
@@ -293,14 +293,16 @@ class TestStockService:
stock = Stock( stock = Stock(
gtin=f"1234567890{i:03d}", # Creates valid 13-digit GTINs: 1234567890000, 1234567890001, etc. gtin=f"1234567890{i:03d}", # Creates valid 13-digit GTINs: 1234567890000, 1234567890001, etc.
location=f"WAREHOUSE_{unique_prefix}_{i}", location=f"WAREHOUSE_{unique_prefix}_{i}",
quantity=10 quantity=10,
) )
db.add(stock) db.add(stock)
db.commit() db.commit()
result = self.service.get_all_stock(db, skip=2, limit=2) result = self.service.get_all_stock(db, skip=2, limit=2)
assert len(result) <= 2 # Should be at most 2, might be less if other records exist assert (
len(result) <= 2
) # Should be at most 2, might be less if other records exist
def test_update_stock_success(self, db, test_stock): def test_update_stock_success(self, db, test_stock):
"""Test updating stock quantity""" """Test updating stock quantity"""
@@ -359,7 +361,7 @@ def test_product_with_stock(db, test_stock):
gtin=test_stock.gtin, gtin=test_stock.gtin,
price="29.99", price="29.99",
brand="TestBrand", brand="TestBrand",
marketplace="Letzshop" marketplace="Letzshop",
) )
db.add(product) db.add(product)
db.commit() db.commit()

View File

@@ -1,5 +1,6 @@
# tests/test_utils.py (Enhanced version of your existing file) # tests/test_utils.py (Enhanced version of your existing file)
import pytest import pytest
from utils.data_processing import GTINProcessor, PriceProcessor from utils.data_processing import GTINProcessor, PriceProcessor

View File

@@ -1,14 +1,15 @@
# utils/csv_processor.py # utils/csv_processor.py
import logging
from datetime import datetime
from io import StringIO
from typing import Any, Dict
import pandas as pd import pandas as pd
import requests import requests
from io import StringIO
from typing import Dict, Any
from sqlalchemy import literal from sqlalchemy import literal
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from models.database_models import Product from models.database_models import Product
from datetime import datetime
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,67 +17,66 @@ logger = logging.getLogger(__name__)
class CSVProcessor: class CSVProcessor:
"""Handles CSV import with robust parsing and batching""" """Handles CSV import with robust parsing and batching"""
ENCODINGS = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'utf-8-sig'] ENCODINGS = ["utf-8", "latin-1", "iso-8859-1", "cp1252", "utf-8-sig"]
PARSING_CONFIGS = [ PARSING_CONFIGS = [
# Try auto-detection first # Try auto-detection first
{'sep': None, 'engine': 'python'}, {"sep": None, "engine": "python"},
# Try semicolon (common in European CSVs) # Try semicolon (common in European CSVs)
{'sep': ';', 'engine': 'python'}, {"sep": ";", "engine": "python"},
# Try comma # Try comma
{'sep': ',', 'engine': 'python'}, {"sep": ",", "engine": "python"},
# Try tab # Try tab
{'sep': '\t', 'engine': 'python'}, {"sep": "\t", "engine": "python"},
] ]
COLUMN_MAPPING = { COLUMN_MAPPING = {
# Standard variations # Standard variations
'id': 'product_id', "id": "product_id",
'ID': 'product_id', "ID": "product_id",
'Product ID': 'product_id', "Product ID": "product_id",
'name': 'title', "name": "title",
'Name': 'title', "Name": "title",
'product_name': 'title', "product_name": "title",
'Product Name': 'title', "Product Name": "title",
# Google Shopping feed standard # Google Shopping feed standard
'g:id': 'product_id', "g:id": "product_id",
'g:title': 'title', "g:title": "title",
'g:description': 'description', "g:description": "description",
'g:link': 'link', "g:link": "link",
'g:image_link': 'image_link', "g:image_link": "image_link",
'g:availability': 'availability', "g:availability": "availability",
'g:price': 'price', "g:price": "price",
'g:brand': 'brand', "g:brand": "brand",
'g:gtin': 'gtin', "g:gtin": "gtin",
'g:mpn': 'mpn', "g:mpn": "mpn",
'g:condition': 'condition', "g:condition": "condition",
'g:adult': 'adult', "g:adult": "adult",
'g:multipack': 'multipack', "g:multipack": "multipack",
'g:is_bundle': 'is_bundle', "g:is_bundle": "is_bundle",
'g:age_group': 'age_group', "g:age_group": "age_group",
'g:color': 'color', "g:color": "color",
'g:gender': 'gender', "g:gender": "gender",
'g:material': 'material', "g:material": "material",
'g:pattern': 'pattern', "g:pattern": "pattern",
'g:size': 'size', "g:size": "size",
'g:size_type': 'size_type', "g:size_type": "size_type",
'g:size_system': 'size_system', "g:size_system": "size_system",
'g:item_group_id': 'item_group_id', "g:item_group_id": "item_group_id",
'g:google_product_category': 'google_product_category', "g:google_product_category": "google_product_category",
'g:product_type': 'product_type', "g:product_type": "product_type",
'g:custom_label_0': 'custom_label_0', "g:custom_label_0": "custom_label_0",
'g:custom_label_1': 'custom_label_1', "g:custom_label_1": "custom_label_1",
'g:custom_label_2': 'custom_label_2', "g:custom_label_2": "custom_label_2",
'g:custom_label_3': 'custom_label_3', "g:custom_label_3": "custom_label_3",
'g:custom_label_4': 'custom_label_4', "g:custom_label_4": "custom_label_4",
# Handle complex shipping column # Handle complex shipping column
'shipping(country:price:max_handling_time:min_transit_time:max_transit_time)': 'shipping' "shipping(country:price:max_handling_time:min_transit_time:max_transit_time)": "shipping",
} }
def __init__(self): def __init__(self):
from utils.data_processing import GTINProcessor, PriceProcessor from utils.data_processing import GTINProcessor, PriceProcessor
self.gtin_processor = GTINProcessor() self.gtin_processor = GTINProcessor()
self.price_processor = PriceProcessor() self.price_processor = PriceProcessor()
@@ -98,7 +98,7 @@ class CSVProcessor:
continue continue
# Fallback with error ignoring # Fallback with error ignoring
decoded_content = content.decode('utf-8', errors='ignore') decoded_content = content.decode("utf-8", errors="ignore")
logger.warning("Used UTF-8 with error ignoring for CSV decoding") logger.warning("Used UTF-8 with error ignoring for CSV decoding")
return decoded_content return decoded_content
@@ -113,11 +113,11 @@ class CSVProcessor:
try: try:
df = pd.read_csv( df = pd.read_csv(
StringIO(csv_content), StringIO(csv_content),
on_bad_lines='skip', on_bad_lines="skip",
quotechar='"', quotechar='"',
skip_blank_lines=True, skip_blank_lines=True,
skipinitialspace=True, skipinitialspace=True,
**config **config,
) )
logger.info(f"Successfully parsed CSV with config: {config}") logger.info(f"Successfully parsed CSV with config: {config}")
return df return df
@@ -143,42 +143,43 @@ class CSVProcessor:
processed_data = {k: (v if pd.notna(v) else None) for k, v in row_data.items()} processed_data = {k: (v if pd.notna(v) else None) for k, v in row_data.items()}
# Process GTIN # Process GTIN
if processed_data.get('gtin'): if processed_data.get("gtin"):
processed_data['gtin'] = self.gtin_processor.normalize(processed_data['gtin']) processed_data["gtin"] = self.gtin_processor.normalize(
processed_data["gtin"]
)
# Process price and currency # Process price and currency
if processed_data.get('price'): if processed_data.get("price"):
parsed_price, currency = self.price_processor.parse_price_currency(processed_data['price']) parsed_price, currency = self.price_processor.parse_price_currency(
processed_data['price'] = parsed_price processed_data["price"]
processed_data['currency'] = currency )
processed_data["price"] = parsed_price
processed_data["currency"] = currency
# Process sale_price # Process sale_price
if processed_data.get('sale_price'): if processed_data.get("sale_price"):
parsed_sale_price, _ = self.price_processor.parse_price_currency(processed_data['sale_price']) parsed_sale_price, _ = self.price_processor.parse_price_currency(
processed_data['sale_price'] = parsed_sale_price processed_data["sale_price"]
)
processed_data["sale_price"] = parsed_sale_price
# Clean MPN (remove .0 endings) # Clean MPN (remove .0 endings)
if processed_data.get('mpn'): if processed_data.get("mpn"):
mpn_str = str(processed_data['mpn']).strip() mpn_str = str(processed_data["mpn"]).strip()
if mpn_str.endswith('.0'): if mpn_str.endswith(".0"):
processed_data['mpn'] = mpn_str[:-2] processed_data["mpn"] = mpn_str[:-2]
# Handle multipack type conversion # Handle multipack type conversion
if processed_data.get('multipack') is not None: if processed_data.get("multipack") is not None:
try: try:
processed_data['multipack'] = int(float(processed_data['multipack'])) processed_data["multipack"] = int(float(processed_data["multipack"]))
except (ValueError, TypeError): except (ValueError, TypeError):
processed_data['multipack'] = None processed_data["multipack"] = None
return processed_data return processed_data
async def process_marketplace_csv_from_url( async def process_marketplace_csv_from_url(
self, self, url: str, marketplace: str, shop_name: str, batch_size: int, db: Session
url: str,
marketplace: str,
shop_name: str,
batch_size: int,
db: Session
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Process CSV from URL with marketplace and shop information Process CSV from URL with marketplace and shop information
@@ -194,7 +195,9 @@ class CSVProcessor:
Dictionary with processing results Dictionary with processing results
""" """
logger.info(f"Starting marketplace CSV import from {url} for {marketplace} -> {shop_name}") logger.info(
f"Starting marketplace CSV import from {url} for {marketplace} -> {shop_name}"
)
# Download and parse CSV # Download and parse CSV
csv_content = self.download_csv(url) csv_content = self.download_csv(url)
df = self.parse_csv(csv_content) df = self.parse_csv(csv_content)
@@ -208,40 +211,42 @@ class CSVProcessor:
# Process in batches # Process in batches
for i in range(0, len(df), batch_size): for i in range(0, len(df), batch_size):
batch_df = df.iloc[i:i + batch_size] batch_df = df.iloc[i : i + batch_size]
batch_result = await self._process_marketplace_batch( batch_result = await self._process_marketplace_batch(
batch_df, marketplace, shop_name, db, i // batch_size + 1 batch_df, marketplace, shop_name, db, i // batch_size + 1
) )
imported += batch_result['imported'] imported += batch_result["imported"]
updated += batch_result['updated'] updated += batch_result["updated"]
errors += batch_result['errors'] errors += batch_result["errors"]
logger.info(f"Processed batch {i // batch_size + 1}: {batch_result}") logger.info(f"Processed batch {i // batch_size + 1}: {batch_result}")
return { return {
'total_processed': imported + updated + errors, "total_processed": imported + updated + errors,
'imported': imported, "imported": imported,
'updated': updated, "updated": updated,
'errors': errors, "errors": errors,
'marketplace': marketplace, "marketplace": marketplace,
'shop_name': shop_name "shop_name": shop_name,
} }
async def _process_marketplace_batch( async def _process_marketplace_batch(
self, self,
batch_df: pd.DataFrame, batch_df: pd.DataFrame,
marketplace: str, marketplace: str,
shop_name: str, shop_name: str,
db: Session, db: Session,
batch_num: int batch_num: int,
) -> Dict[str, int]: ) -> Dict[str, int]:
"""Process a batch of CSV rows with marketplace information""" """Process a batch of CSV rows with marketplace information"""
imported = 0 imported = 0
updated = 0 updated = 0
errors = 0 errors = 0
logger.info(f"Processing batch {batch_num} with {len(batch_df)} rows for {marketplace} -> {shop_name}") logger.info(
f"Processing batch {batch_num} with {len(batch_df)} rows for {marketplace} -> {shop_name}"
)
for index, row in batch_df.iterrows(): for index, row in batch_df.iterrows():
try: try:
@@ -249,42 +254,54 @@ class CSVProcessor:
product_data = self._clean_row_data(row.to_dict()) product_data = self._clean_row_data(row.to_dict())
# Add marketplace and shop information # Add marketplace and shop information
product_data['marketplace'] = marketplace product_data["marketplace"] = marketplace
product_data['shop_name'] = shop_name product_data["shop_name"] = shop_name
# Validate required fields # Validate required fields
if not product_data.get('product_id'): if not product_data.get("product_id"):
logger.warning(f"Row {index}: Missing product_id, skipping") logger.warning(f"Row {index}: Missing product_id, skipping")
errors += 1 errors += 1
continue continue
if not product_data.get('title'): if not product_data.get("title"):
logger.warning(f"Row {index}: Missing title, skipping") logger.warning(f"Row {index}: Missing title, skipping")
errors += 1 errors += 1
continue continue
# Check if product exists # Check if product exists
existing_product = db.query(Product).filter( existing_product = (
Product.product_id == literal(product_data['product_id']) db.query(Product)
).first() .filter(Product.product_id == literal(product_data["product_id"]))
.first()
)
if existing_product: if existing_product:
# Update existing product # Update existing product
for key, value in product_data.items(): for key, value in product_data.items():
if key not in ['id', 'created_at'] and hasattr(existing_product, key): if key not in ["id", "created_at"] and hasattr(
existing_product, key
):
setattr(existing_product, key, value) setattr(existing_product, key, value)
existing_product.updated_at = datetime.utcnow() existing_product.updated_at = datetime.utcnow()
updated += 1 updated += 1
logger.debug(f"Updated product {product_data['product_id']} for {marketplace} and shop {shop_name}") logger.debug(
f"Updated product {product_data['product_id']} for {marketplace} and shop {shop_name}"
)
else: else:
# Create new product # Create new product
filtered_data = {k: v for k, v in product_data.items() filtered_data = {
if k not in ['id', 'created_at', 'updated_at'] and hasattr(Product, k)} k: v
for k, v in product_data.items()
if k not in ["id", "created_at", "updated_at"]
and hasattr(Product, k)
}
new_product = Product(**filtered_data) new_product = Product(**filtered_data)
db.add(new_product) db.add(new_product)
imported += 1 imported += 1
logger.debug(f"Imported new product {product_data['product_id']} for {marketplace} and shop " logger.debug(
f"{shop_name}") f"Imported new product {product_data['product_id']} for {marketplace} and shop "
f"{shop_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error processing row: {e}") logger.error(f"Error processing row: {e}")
@@ -303,8 +320,4 @@ class CSVProcessor:
imported = 0 imported = 0
updated = 0 updated = 0
return { return {"imported": imported, "updated": updated, "errors": errors}
'imported': imported,
'updated': updated,
'errors': errors
}

View File

@@ -1,8 +1,9 @@
# utils/data_processing.py # utils/data_processing.py
import re
import pandas as pd
from typing import Tuple, Optional
import logging import logging
import re
from typing import Optional, Tuple
import pandas as pd
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,11 +26,11 @@ class GTINProcessor:
return None return None
# Remove decimal point (e.g., "889698116923.0" -> "889698116923") # Remove decimal point (e.g., "889698116923.0" -> "889698116923")
if '.' in gtin_str: if "." in gtin_str:
gtin_str = gtin_str.split('.')[0] gtin_str = gtin_str.split(".")[0]
# Keep only digits # Keep only digits
gtin_clean = ''.join(filter(str.isdigit, gtin_str)) gtin_clean = "".join(filter(str.isdigit, gtin_str))
if not gtin_clean: if not gtin_clean:
return None return None
@@ -73,23 +74,23 @@ class PriceProcessor:
CURRENCY_PATTERNS = { CURRENCY_PATTERNS = {
# Amount followed by currency # Amount followed by currency
r'([0-9.,]+)\s*(EUR|€)': lambda m: (m.group(1), 'EUR'), r"([0-9.,]+)\s*(EUR|€)": lambda m: (m.group(1), "EUR"),
r'([0-9.,]+)\s*(USD|\$)': lambda m: (m.group(1), 'USD'), r"([0-9.,]+)\s*(USD|\$)": lambda m: (m.group(1), "USD"),
r'([0-9.,]+)\s*(GBP|£)': lambda m: (m.group(1), 'GBP'), r"([0-9.,]+)\s*(GBP|£)": lambda m: (m.group(1), "GBP"),
r'([0-9.,]+)\s*(CHF)': lambda m: (m.group(1), 'CHF'), r"([0-9.,]+)\s*(CHF)": lambda m: (m.group(1), "CHF"),
r'([0-9.,]+)\s*(CAD|AUD|JPY|¥)': lambda m: (m.group(1), m.group(2).upper()), r"([0-9.,]+)\s*(CAD|AUD|JPY|¥)": lambda m: (m.group(1), m.group(2).upper()),
# Currency followed by amount # Currency followed by amount
r'(EUR|€)\s*([0-9.,]+)': lambda m: (m.group(2), 'EUR'), r"(EUR|€)\s*([0-9.,]+)": lambda m: (m.group(2), "EUR"),
r'(USD|\$)\s*([0-9.,]+)': lambda m: (m.group(2), 'USD'), r"(USD|\$)\s*([0-9.,]+)": lambda m: (m.group(2), "USD"),
r'(GBP|£)\s*([0-9.,]+)': lambda m: (m.group(2), 'GBP'), r"(GBP|£)\s*([0-9.,]+)": lambda m: (m.group(2), "GBP"),
# Generic 3-letter currency codes # Generic 3-letter currency codes
r'([0-9.,]+)\s*([A-Z]{3})': lambda m: (m.group(1), m.group(2)), r"([0-9.,]+)\s*([A-Z]{3})": lambda m: (m.group(1), m.group(2)),
r'([A-Z]{3})\s*([0-9.,]+)': lambda m: (m.group(2), m.group(1)), r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)),
} }
def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]: def parse_price_currency(
self, price_str: any
) -> Tuple[Optional[str], Optional[str]]:
""" """
Parse price string into (price, currency) tuple Parse price string into (price, currency) tuple
Returns (None, None) if parsing fails Returns (None, None) if parsing fails
@@ -108,7 +109,7 @@ class PriceProcessor:
try: try:
price_val, currency_val = extract_func(match) price_val, currency_val = extract_func(match)
# Normalize price (remove spaces, handle comma as decimal) # Normalize price (remove spaces, handle comma as decimal)
price_val = price_val.replace(' ', '').replace(',', '.') price_val = price_val.replace(" ", "").replace(",", ".")
# Validate numeric # Validate numeric
float(price_val) float(price_val)
return price_val, currency_val.upper() return price_val, currency_val.upper()
@@ -116,10 +117,10 @@ class PriceProcessor:
continue continue
# Fallback: extract just numbers # Fallback: extract just numbers
number_match = re.search(r'([0-9.,]+)', price_str) number_match = re.search(r"([0-9.,]+)", price_str)
if number_match: if number_match:
try: try:
price_val = number_match.group(1).replace(',', '.') price_val = number_match.group(1).replace(",", ".")
float(price_val) # Validate float(price_val) # Validate
return price_val, None return price_val, None
except ValueError: except ValueError:

View File

@@ -1,20 +1,19 @@
# utils/database.py # utils/database.py
import logging
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_db_engine(database_url: str): def get_db_engine(database_url: str):
"""Create database engine with connection pooling""" """Create database engine with connection pooling"""
if database_url.startswith('sqlite'): if database_url.startswith("sqlite"):
# SQLite configuration # SQLite configuration
engine = create_engine( engine = create_engine(
database_url, database_url, connect_args={"check_same_thread": False}, echo=False
connect_args={"check_same_thread": False},
echo=False
) )
else: else:
# PostgreSQL configuration with connection pooling # PostgreSQL configuration with connection pooling
@@ -24,7 +23,7 @@ def get_db_engine(database_url: str):
pool_size=10, pool_size=10,
max_overflow=20, max_overflow=20,
pool_pre_ping=True, pool_pre_ping=True,
echo=False echo=False,
) )
logger.info(f"Database engine created for: {database_url.split('@')[0]}@...") logger.info(f"Database engine created for: {database_url.split('@')[0]}@...")