code quality run
This commit is contained in:
@@ -1,14 +1,16 @@
|
|||||||
from logging.config import fileConfig
|
|
||||||
from sqlalchemy import engine_from_config, pool
|
|
||||||
from alembic import context
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
# Add your project directory to the Python path
|
# Add your project directory to the Python path
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from models.database_models import Base
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from models.database_models import Base
|
||||||
|
|
||||||
# Alembic Config object
|
# Alembic Config object
|
||||||
config = context.config
|
config = context.config
|
||||||
@@ -45,9 +47,7 @@ def run_migrations_online() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from logging.config import fileConfig
|
|
||||||
from sqlalchemy import engine_from_config, pool
|
|
||||||
from alembic import context
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
# Add your project directory to the Python path
|
# Add your project directory to the Python path
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from models.database_models import Base
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from models.database_models import Base
|
||||||
|
|
||||||
# Alembic Config object
|
# Alembic Config object
|
||||||
config = context.config
|
config = context.config
|
||||||
@@ -45,9 +47,7 @@ def run_migrations_online() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from models.database_models import User, Shop
|
|
||||||
from middleware.auth import AuthManager
|
from middleware.auth import AuthManager
|
||||||
from middleware.rate_limiter import RateLimiter
|
from middleware.rate_limiter import RateLimiter
|
||||||
|
from models.database_models import Shop, User
|
||||||
|
|
||||||
# Set auto_error=False to prevent automatic 403 responses
|
# Set auto_error=False to prevent automatic 403 responses
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
@@ -13,8 +14,8 @@ rate_limiter = RateLimiter()
|
|||||||
|
|
||||||
|
|
||||||
def get_current_user(
|
def get_current_user(
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Get current authenticated user"""
|
"""Get current authenticated user"""
|
||||||
# Check if credentials are provided
|
# Check if credentials are provided
|
||||||
@@ -30,9 +31,9 @@ def get_current_admin_user(current_user: User = Depends(get_current_user)):
|
|||||||
|
|
||||||
|
|
||||||
def get_user_shop(
|
def get_user_shop(
|
||||||
shop_code: str,
|
shop_code: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Get shop and verify user ownership"""
|
"""Get shop and verify user ownership"""
|
||||||
shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first()
|
shop = db.query(Shop).filter(Shop.shop_code == shop_code.upper()).first()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from app.api.v1 import auth, product, stock, shop, marketplace, admin, stats
|
|
||||||
|
from app.api.v1 import admin, auth, marketplace, product, shop, stats, stock
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
||||||
@@ -9,6 +10,5 @@ api_router.include_router(auth.router, tags=["authentication"])
|
|||||||
api_router.include_router(marketplace.router, tags=["marketplace"])
|
api_router.include_router(marketplace.router, tags=["marketplace"])
|
||||||
api_router.include_router(product.router, tags=["product"])
|
api_router.include_router(product.router, tags=["product"])
|
||||||
api_router.include_router(shop.router, tags=["shop"])
|
api_router.include_router(shop.router, tags=["shop"])
|
||||||
api_router.include_router(stats.router, tags=["statistics"])
|
api_router.include_router(stats.router, tags=["statistics"])
|
||||||
api_router.include_router(stock.router, tags=["stock"])
|
api_router.include_router(stock.router, tags=["stock"])
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_admin_user
|
from app.api.deps import get_current_admin_user
|
||||||
|
from app.core.database import get_db
|
||||||
from app.services.admin_service import admin_service
|
from app.services.admin_service import admin_service
|
||||||
from models.api_models import MarketplaceImportJobResponse, UserResponse, ShopListResponse
|
from models.api_models import (MarketplaceImportJobResponse, ShopListResponse,
|
||||||
|
UserResponse)
|
||||||
from models.database_models import User
|
from models.database_models import User
|
||||||
import logging
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -16,10 +18,10 @@ logger = logging.getLogger(__name__)
|
|||||||
# Admin-only routes
|
# Admin-only routes
|
||||||
@router.get("/admin/users", response_model=List[UserResponse])
|
@router.get("/admin/users", response_model=List[UserResponse])
|
||||||
def get_all_users(
|
def get_all_users(
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_admin: User = Depends(get_current_admin_user)
|
current_admin: User = Depends(get_current_admin_user),
|
||||||
):
|
):
|
||||||
"""Get all users (Admin only)"""
|
"""Get all users (Admin only)"""
|
||||||
try:
|
try:
|
||||||
@@ -32,9 +34,9 @@ def get_all_users(
|
|||||||
|
|
||||||
@router.put("/admin/users/{user_id}/status")
|
@router.put("/admin/users/{user_id}/status")
|
||||||
def toggle_user_status(
|
def toggle_user_status(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_admin: User = Depends(get_current_admin_user)
|
current_admin: User = Depends(get_current_admin_user),
|
||||||
):
|
):
|
||||||
"""Toggle user active status (Admin only)"""
|
"""Toggle user active status (Admin only)"""
|
||||||
try:
|
try:
|
||||||
@@ -49,21 +51,16 @@ def toggle_user_status(
|
|||||||
|
|
||||||
@router.get("/admin/shops", response_model=ShopListResponse)
|
@router.get("/admin/shops", response_model=ShopListResponse)
|
||||||
def get_all_shops_admin(
|
def get_all_shops_admin(
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_admin: User = Depends(get_current_admin_user)
|
current_admin: User = Depends(get_current_admin_user),
|
||||||
):
|
):
|
||||||
"""Get all shops with admin view (Admin only)"""
|
"""Get all shops with admin view (Admin only)"""
|
||||||
try:
|
try:
|
||||||
shops, total = admin_service.get_all_shops(db=db, skip=skip, limit=limit)
|
shops, total = admin_service.get_all_shops(db=db, skip=skip, limit=limit)
|
||||||
|
|
||||||
return ShopListResponse(
|
return ShopListResponse(shops=shops, total=total, skip=skip, limit=limit)
|
||||||
shops=shops,
|
|
||||||
total=total,
|
|
||||||
skip=skip,
|
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting shops: {str(e)}")
|
logger.error(f"Error getting shops: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
@@ -71,9 +68,9 @@ def get_all_shops_admin(
|
|||||||
|
|
||||||
@router.put("/admin/shops/{shop_id}/verify")
|
@router.put("/admin/shops/{shop_id}/verify")
|
||||||
def verify_shop(
|
def verify_shop(
|
||||||
shop_id: int,
|
shop_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_admin: User = Depends(get_current_admin_user)
|
current_admin: User = Depends(get_current_admin_user),
|
||||||
):
|
):
|
||||||
"""Verify/unverify shop (Admin only)"""
|
"""Verify/unverify shop (Admin only)"""
|
||||||
try:
|
try:
|
||||||
@@ -88,9 +85,9 @@ def verify_shop(
|
|||||||
|
|
||||||
@router.put("/admin/shops/{shop_id}/status")
|
@router.put("/admin/shops/{shop_id}/status")
|
||||||
def toggle_shop_status(
|
def toggle_shop_status(
|
||||||
shop_id: int,
|
shop_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_admin: User = Depends(get_current_admin_user)
|
current_admin: User = Depends(get_current_admin_user),
|
||||||
):
|
):
|
||||||
"""Toggle shop active status (Admin only)"""
|
"""Toggle shop active status (Admin only)"""
|
||||||
try:
|
try:
|
||||||
@@ -103,15 +100,17 @@ def toggle_shop_status(
|
|||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse])
|
@router.get(
|
||||||
|
"/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]
|
||||||
|
)
|
||||||
def get_all_marketplace_import_jobs(
|
def get_all_marketplace_import_jobs(
|
||||||
marketplace: Optional[str] = Query(None),
|
marketplace: Optional[str] = Query(None),
|
||||||
shop_name: Optional[str] = Query(None),
|
shop_name: Optional[str] = Query(None),
|
||||||
status: Optional[str] = Query(None),
|
status: Optional[str] = Query(None),
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=100),
|
limit: int = Query(100, ge=1, le=100),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_admin: User = Depends(get_current_admin_user)
|
current_admin: User = Depends(get_current_admin_user),
|
||||||
):
|
):
|
||||||
"""Get all marketplace import jobs (Admin only)"""
|
"""Get all marketplace import jobs (Admin only)"""
|
||||||
try:
|
try:
|
||||||
@@ -121,7 +120,7 @@ def get_all_marketplace_import_jobs(
|
|||||||
shop_name=shop_name,
|
shop_name=shop_name,
|
||||||
status=status,
|
status=status,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit
|
limit=limit,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting marketplace import jobs: {str(e)}")
|
logger.error(f"Error getting marketplace import jobs: {str(e)}")
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
from app.services.auth_service import auth_service
|
from app.services.auth_service import auth_service
|
||||||
from models.api_models import UserRegister, UserLogin, UserResponse, LoginResponse
|
from models.api_models import (LoginResponse, UserLogin, UserRegister,
|
||||||
|
UserResponse)
|
||||||
from models.database_models import User
|
from models.database_models import User
|
||||||
import logging
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -35,7 +38,7 @@ def login_user(user_credentials: UserLogin, db: Session = Depends(get_db)):
|
|||||||
access_token=login_result["token_data"]["access_token"],
|
access_token=login_result["token_data"]["access_token"],
|
||||||
token_type=login_result["token_data"]["token_type"],
|
token_type=login_result["token_data"]["token_type"],
|
||||||
expires_in=login_result["token_data"]["expires_in"],
|
expires_in=login_result["token_data"]["expires_in"],
|
||||||
user=UserResponse.model_validate(login_result["user"])
|
user=UserResponse.model_validate(login_result["user"]),
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.services.marketplace_service import marketplace_service
|
||||||
from app.tasks.background_tasks import process_marketplace_import
|
from app.tasks.background_tasks import process_marketplace_import
|
||||||
from middleware.decorators import rate_limit
|
from middleware.decorators import rate_limit
|
||||||
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest
|
from models.api_models import (MarketplaceImportJobResponse,
|
||||||
|
MarketplaceImportRequest)
|
||||||
from models.database_models import User
|
from models.database_models import User
|
||||||
from app.services.marketplace_service import marketplace_service
|
|
||||||
import logging
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -19,15 +21,16 @@ logger = logging.getLogger(__name__)
|
|||||||
@router.post("/marketplace/import-product", response_model=MarketplaceImportJobResponse)
|
@router.post("/marketplace/import-product", response_model=MarketplaceImportJobResponse)
|
||||||
@rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports
|
@rate_limit(max_requests=10, window_seconds=3600) # Limit marketplace imports
|
||||||
async def import_products_from_marketplace(
|
async def import_products_from_marketplace(
|
||||||
request: MarketplaceImportRequest,
|
request: MarketplaceImportRequest,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Import products from marketplace CSV with background processing (Protected)"""
|
"""Import products from marketplace CSV with background processing (Protected)"""
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}")
|
f"Starting marketplace import: {request.marketplace} -> {request.shop_code} by user {current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create import job through service
|
# Create import job through service
|
||||||
import_job = marketplace_service.create_import_job(db, request, current_user)
|
import_job = marketplace_service.create_import_job(db, request, current_user)
|
||||||
@@ -39,7 +42,7 @@ async def import_products_from_marketplace(
|
|||||||
request.url,
|
request.url,
|
||||||
request.marketplace,
|
request.marketplace,
|
||||||
request.shop_code,
|
request.shop_code,
|
||||||
request.batch_size or 1000
|
request.batch_size or 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MarketplaceImportJobResponse(
|
return MarketplaceImportJobResponse(
|
||||||
@@ -50,7 +53,7 @@ async def import_products_from_marketplace(
|
|||||||
shop_id=import_job.shop_id,
|
shop_id=import_job.shop_id,
|
||||||
shop_name=import_job.shop_name,
|
shop_name=import_job.shop_name,
|
||||||
message=f"Marketplace import started from {request.marketplace}. Check status with "
|
message=f"Marketplace import started from {request.marketplace}. Check status with "
|
||||||
f"/import-status/{import_job.id}"
|
f"/import-status/{import_job.id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -62,11 +65,13 @@ async def import_products_from_marketplace(
|
|||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/marketplace/import-status/{job_id}", response_model=MarketplaceImportJobResponse)
|
@router.get(
|
||||||
|
"/marketplace/import-status/{job_id}", response_model=MarketplaceImportJobResponse
|
||||||
|
)
|
||||||
def get_marketplace_import_status(
|
def get_marketplace_import_status(
|
||||||
job_id: int,
|
job_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get status of marketplace import job (Protected)"""
|
"""Get status of marketplace import job (Protected)"""
|
||||||
try:
|
try:
|
||||||
@@ -82,14 +87,16 @@ def get_marketplace_import_status(
|
|||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/marketplace/import-jobs", response_model=List[MarketplaceImportJobResponse])
|
@router.get(
|
||||||
|
"/marketplace/import-jobs", response_model=List[MarketplaceImportJobResponse]
|
||||||
|
)
|
||||||
def get_marketplace_import_jobs(
|
def get_marketplace_import_jobs(
|
||||||
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
||||||
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(50, ge=1, le=100),
|
limit: int = Query(50, ge=1, le=100),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get marketplace import jobs with filtering (Protected)"""
|
"""Get marketplace import jobs with filtering (Protected)"""
|
||||||
try:
|
try:
|
||||||
@@ -99,7 +106,7 @@ def get_marketplace_import_jobs(
|
|||||||
marketplace=marketplace,
|
marketplace=marketplace,
|
||||||
shop_name=shop_name,
|
shop_name=shop_name,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [marketplace_service.convert_to_response_model(job) for job in jobs]
|
return [marketplace_service.convert_to_response_model(job) for job in jobs]
|
||||||
@@ -111,8 +118,7 @@ def get_marketplace_import_jobs(
|
|||||||
|
|
||||||
@router.get("/marketplace/marketplace-import-stats")
|
@router.get("/marketplace/marketplace-import-stats")
|
||||||
def get_marketplace_import_stats(
|
def get_marketplace_import_stats(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
):
|
||||||
"""Get statistics about marketplace import jobs (Protected)"""
|
"""Get statistics about marketplace import jobs (Protected)"""
|
||||||
try:
|
try:
|
||||||
@@ -124,11 +130,14 @@ def get_marketplace_import_stats(
|
|||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.put("/marketplace/import-jobs/{job_id}/cancel", response_model=MarketplaceImportJobResponse)
|
@router.put(
|
||||||
|
"/marketplace/import-jobs/{job_id}/cancel",
|
||||||
|
response_model=MarketplaceImportJobResponse,
|
||||||
|
)
|
||||||
def cancel_marketplace_import_job(
|
def cancel_marketplace_import_job(
|
||||||
job_id: int,
|
job_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Cancel a pending or running marketplace import job (Protected)"""
|
"""Cancel a pending or running marketplace import job (Protected)"""
|
||||||
try:
|
try:
|
||||||
@@ -146,9 +155,9 @@ def cancel_marketplace_import_job(
|
|||||||
|
|
||||||
@router.delete("/marketplace/import-jobs/{job_id}")
|
@router.delete("/marketplace/import-jobs/{job_id}")
|
||||||
def delete_marketplace_import_job(
|
def delete_marketplace_import_job(
|
||||||
job_id: int,
|
job_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Delete a completed marketplace import job (Protected)"""
|
"""Delete a completed marketplace import job (Protected)"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from models.api_models import (ProductListResponse, ProductResponse, ProductCreate, ProductDetailResponse,
|
from app.core.database import get_db
|
||||||
|
from app.services.product_service import product_service
|
||||||
|
from models.api_models import (ProductCreate, ProductDetailResponse,
|
||||||
|
ProductListResponse, ProductResponse,
|
||||||
ProductUpdate)
|
ProductUpdate)
|
||||||
from models.database_models import User
|
from models.database_models import User
|
||||||
import logging
|
|
||||||
|
|
||||||
from app.services.product_service import product_service
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,16 +20,16 @@ logger = logging.getLogger(__name__)
|
|||||||
# Enhanced Product Routes with Marketplace Support
|
# Enhanced Product Routes with Marketplace Support
|
||||||
@router.get("/product", response_model=ProductListResponse)
|
@router.get("/product", response_model=ProductListResponse)
|
||||||
def get_products(
|
def get_products(
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
brand: Optional[str] = Query(None),
|
brand: Optional[str] = Query(None),
|
||||||
category: Optional[str] = Query(None),
|
category: Optional[str] = Query(None),
|
||||||
availability: Optional[str] = Query(None),
|
availability: Optional[str] = Query(None),
|
||||||
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
||||||
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
||||||
search: Optional[str] = Query(None),
|
search: Optional[str] = Query(None),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get products with advanced filtering including marketplace and shop (Protected)"""
|
"""Get products with advanced filtering including marketplace and shop (Protected)"""
|
||||||
|
|
||||||
@@ -43,14 +43,11 @@ def get_products(
|
|||||||
availability=availability,
|
availability=availability,
|
||||||
marketplace=marketplace,
|
marketplace=marketplace,
|
||||||
shop_name=shop_name,
|
shop_name=shop_name,
|
||||||
search=search
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ProductListResponse(
|
return ProductListResponse(
|
||||||
products=products,
|
products=products, total=total, skip=skip, limit=limit
|
||||||
total=total,
|
|
||||||
skip=skip,
|
|
||||||
limit=limit
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting products: {str(e)}")
|
logger.error(f"Error getting products: {str(e)}")
|
||||||
@@ -59,9 +56,9 @@ def get_products(
|
|||||||
|
|
||||||
@router.post("/product", response_model=ProductResponse)
|
@router.post("/product", response_model=ProductResponse)
|
||||||
def create_product(
|
def create_product(
|
||||||
product: ProductCreate,
|
product: ProductCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Create a new product with validation and marketplace support (Protected)"""
|
"""Create a new product with validation and marketplace support (Protected)"""
|
||||||
|
|
||||||
@@ -75,7 +72,9 @@ def create_product(
|
|||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
logger.info("Product already exists, raising 400 error")
|
logger.info("Product already exists, raising 400 error")
|
||||||
raise HTTPException(status_code=400, detail="Product with this ID already exists")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Product with this ID already exists"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("No existing product found, proceeding to create...")
|
logger.info("No existing product found, proceeding to create...")
|
||||||
db_product = product_service.create_product(db, product)
|
db_product = product_service.create_product(db, product)
|
||||||
@@ -93,11 +92,12 @@ def create_product(
|
|||||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/product/{product_id}", response_model=ProductDetailResponse)
|
@router.get("/product/{product_id}", response_model=ProductDetailResponse)
|
||||||
def get_product(
|
def get_product(
|
||||||
product_id: str,
|
product_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get product with stock information (Protected)"""
|
"""Get product with stock information (Protected)"""
|
||||||
|
|
||||||
@@ -111,10 +111,7 @@ def get_product(
|
|||||||
if product.gtin:
|
if product.gtin:
|
||||||
stock_info = product_service.get_stock_info(db, product.gtin)
|
stock_info = product_service.get_stock_info(db, product.gtin)
|
||||||
|
|
||||||
return ProductDetailResponse(
|
return ProductDetailResponse(product=product, stock_info=stock_info)
|
||||||
product=product,
|
|
||||||
stock_info=stock_info
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -125,10 +122,10 @@ def get_product(
|
|||||||
|
|
||||||
@router.put("/product/{product_id}", response_model=ProductResponse)
|
@router.put("/product/{product_id}", response_model=ProductResponse)
|
||||||
def update_product(
|
def update_product(
|
||||||
product_id: str,
|
product_id: str,
|
||||||
product_update: ProductUpdate,
|
product_update: ProductUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Update product with validation and marketplace support (Protected)"""
|
"""Update product with validation and marketplace support (Protected)"""
|
||||||
|
|
||||||
@@ -151,9 +148,9 @@ def update_product(
|
|||||||
|
|
||||||
@router.delete("/product/{product_id}")
|
@router.delete("/product/{product_id}")
|
||||||
def delete_product(
|
def delete_product(
|
||||||
product_id: str,
|
product_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Delete product and associated stock (Protected)"""
|
"""Delete product and associated stock (Protected)"""
|
||||||
|
|
||||||
@@ -176,19 +173,18 @@ def delete_product(
|
|||||||
# Export with streaming for large datasets (Protected)
|
# Export with streaming for large datasets (Protected)
|
||||||
@router.get("/export-csv")
|
@router.get("/export-csv")
|
||||||
async def export_csv(
|
async def export_csv(
|
||||||
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
||||||
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Export products as CSV with streaming and marketplace filtering (Protected)"""
|
"""Export products as CSV with streaming and marketplace filtering (Protected)"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
def generate_csv():
|
def generate_csv():
|
||||||
return product_service.generate_csv_export(
|
return product_service.generate_csv_export(
|
||||||
db=db,
|
db=db, marketplace=marketplace, shop_name=shop_name
|
||||||
marketplace=marketplace,
|
|
||||||
shop_name=shop_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
filename = "products_export"
|
filename = "products_export"
|
||||||
@@ -201,7 +197,7 @@ async def export_csv(
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_csv(),
|
generate_csv(),
|
||||||
media_type="text/csv",
|
media_type="text/csv",
|
||||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_user, get_user_shop
|
from app.api.deps import get_current_user, get_user_shop
|
||||||
|
from app.core.database import get_db
|
||||||
from app.services.shop_service import shop_service
|
from app.services.shop_service import shop_service
|
||||||
from app.tasks.background_tasks import process_marketplace_import
|
from app.tasks.background_tasks import process_marketplace_import
|
||||||
from middleware.decorators import rate_limit
|
from middleware.decorators import rate_limit
|
||||||
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, ShopResponse, ShopCreate, \
|
from models.api_models import (MarketplaceImportJobResponse,
|
||||||
ShopListResponse, ShopProductResponse, ShopProductCreate
|
MarketplaceImportRequest, ShopCreate,
|
||||||
from models.database_models import User, MarketplaceImportJob, Shop, Product, ShopProduct
|
ShopListResponse, ShopProductCreate,
|
||||||
from datetime import datetime
|
ShopProductResponse, ShopResponse)
|
||||||
import logging
|
from models.database_models import (MarketplaceImportJob, Product, Shop,
|
||||||
|
ShopProduct, User)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,13 +24,15 @@ logger = logging.getLogger(__name__)
|
|||||||
# Shop Management Routes
|
# Shop Management Routes
|
||||||
@router.post("/shop", response_model=ShopResponse)
|
@router.post("/shop", response_model=ShopResponse)
|
||||||
def create_shop(
|
def create_shop(
|
||||||
shop_data: ShopCreate,
|
shop_data: ShopCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Create a new shop (Protected)"""
|
"""Create a new shop (Protected)"""
|
||||||
try:
|
try:
|
||||||
shop = shop_service.create_shop(db=db, shop_data=shop_data, current_user=current_user)
|
shop = shop_service.create_shop(
|
||||||
|
db=db, shop_data=shop_data, current_user=current_user
|
||||||
|
)
|
||||||
return ShopResponse.model_validate(shop)
|
return ShopResponse.model_validate(shop)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -37,12 +43,12 @@ def create_shop(
|
|||||||
|
|
||||||
@router.get("/shop", response_model=ShopListResponse)
|
@router.get("/shop", response_model=ShopListResponse)
|
||||||
def get_shops(
|
def get_shops(
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
active_only: bool = Query(True),
|
active_only: bool = Query(True),
|
||||||
verified_only: bool = Query(False),
|
verified_only: bool = Query(False),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get shops with filtering (Protected)"""
|
"""Get shops with filtering (Protected)"""
|
||||||
try:
|
try:
|
||||||
@@ -52,15 +58,10 @@ def get_shops(
|
|||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
active_only=active_only,
|
active_only=active_only,
|
||||||
verified_only=verified_only
|
verified_only=verified_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ShopListResponse(
|
return ShopListResponse(shops=shops, total=total, skip=skip, limit=limit)
|
||||||
shops=shops,
|
|
||||||
total=total,
|
|
||||||
skip=skip,
|
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -69,10 +70,16 @@ def get_shops(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/shop/{shop_code}", response_model=ShopResponse)
|
@router.get("/shop/{shop_code}", response_model=ShopResponse)
|
||||||
def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
|
def get_shop(
|
||||||
|
shop_code: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
"""Get shop details (Protected)"""
|
"""Get shop details (Protected)"""
|
||||||
try:
|
try:
|
||||||
shop = shop_service.get_shop_by_code(db=db, shop_code=shop_code, current_user=current_user)
|
shop = shop_service.get_shop_by_code(
|
||||||
|
db=db, shop_code=shop_code, current_user=current_user
|
||||||
|
)
|
||||||
return ShopResponse.model_validate(shop)
|
return ShopResponse.model_validate(shop)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -84,10 +91,10 @@ def get_shop(shop_code: str, db: Session = Depends(get_db), current_user: User =
|
|||||||
# Shop Product Management
|
# Shop Product Management
|
||||||
@router.post("/shop/{shop_code}/products", response_model=ShopProductResponse)
|
@router.post("/shop/{shop_code}/products", response_model=ShopProductResponse)
|
||||||
def add_product_to_shop(
|
def add_product_to_shop(
|
||||||
shop_code: str,
|
shop_code: str,
|
||||||
shop_product: ShopProductCreate,
|
shop_product: ShopProductCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Add existing product to shop catalog with shop-specific settings (Protected)"""
|
"""Add existing product to shop catalog with shop-specific settings (Protected)"""
|
||||||
try:
|
try:
|
||||||
@@ -96,9 +103,7 @@ def add_product_to_shop(
|
|||||||
|
|
||||||
# Add product to shop
|
# Add product to shop
|
||||||
new_shop_product = shop_service.add_product_to_shop(
|
new_shop_product = shop_service.add_product_to_shop(
|
||||||
db=db,
|
db=db, shop=shop, shop_product=shop_product
|
||||||
shop=shop,
|
|
||||||
shop_product=shop_product
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return with product details
|
# Return with product details
|
||||||
@@ -114,18 +119,20 @@ def add_product_to_shop(
|
|||||||
|
|
||||||
@router.get("/shop/{shop_code}/products")
|
@router.get("/shop/{shop_code}/products")
|
||||||
def get_shop_products(
|
def get_shop_products(
|
||||||
shop_code: str,
|
shop_code: str,
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
active_only: bool = Query(True),
|
active_only: bool = Query(True),
|
||||||
featured_only: bool = Query(False),
|
featured_only: bool = Query(False),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get products in shop catalog (Protected)"""
|
"""Get products in shop catalog (Protected)"""
|
||||||
try:
|
try:
|
||||||
# Get shop
|
# Get shop
|
||||||
shop = shop_service.get_shop_by_code(db=db, shop_code=shop_code, current_user=current_user)
|
shop = shop_service.get_shop_by_code(
|
||||||
|
db=db, shop_code=shop_code, current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
# Get shop products
|
# Get shop products
|
||||||
shop_products, total = shop_service.get_shop_products(
|
shop_products, total = shop_service.get_shop_products(
|
||||||
@@ -135,7 +142,7 @@ def get_shop_products(
|
|||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
active_only=active_only,
|
active_only=active_only,
|
||||||
featured_only=featured_only
|
featured_only=featured_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format response
|
# Format response
|
||||||
@@ -150,7 +157,7 @@ def get_shop_products(
|
|||||||
"total": total,
|
"total": total,
|
||||||
"skip": skip,
|
"skip": skip,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"shop": ShopResponse.model_validate(shop)
|
"shop": ShopResponse.model_validate(shop),
|
||||||
}
|
}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
from app.services.stats_service import stats_service
|
from app.services.stats_service import stats_service
|
||||||
from app.tasks.background_tasks import process_marketplace_import
|
from app.tasks.background_tasks import process_marketplace_import
|
||||||
from middleware.decorators import rate_limit
|
from middleware.decorators import rate_limit
|
||||||
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, StatsResponse, \
|
from models.api_models import (MarketplaceImportJobResponse,
|
||||||
MarketplaceStatsResponse
|
MarketplaceImportRequest,
|
||||||
from models.database_models import User, MarketplaceImportJob, Shop, Product, Stock
|
MarketplaceStatsResponse, StatsResponse)
|
||||||
from datetime import datetime
|
from models.database_models import (MarketplaceImportJob, Product, Shop, Stock,
|
||||||
import logging
|
User)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,7 +23,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Enhanced Statistics with Marketplace Support
|
# Enhanced Statistics with Marketplace Support
|
||||||
@router.get("/stats", response_model=StatsResponse)
|
@router.get("/stats", response_model=StatsResponse)
|
||||||
def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
|
def get_stats(
|
||||||
|
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
"""Get comprehensive statistics with marketplace data (Protected)"""
|
"""Get comprehensive statistics with marketplace data (Protected)"""
|
||||||
try:
|
try:
|
||||||
stats_data = stats_service.get_comprehensive_stats(db=db)
|
stats_data = stats_service.get_comprehensive_stats(db=db)
|
||||||
@@ -32,7 +37,7 @@ def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_cu
|
|||||||
unique_marketplaces=stats_data["unique_marketplaces"],
|
unique_marketplaces=stats_data["unique_marketplaces"],
|
||||||
unique_shops=stats_data["unique_shops"],
|
unique_shops=stats_data["unique_shops"],
|
||||||
total_stock_entries=stats_data["total_stock_entries"],
|
total_stock_entries=stats_data["total_stock_entries"],
|
||||||
total_inventory_quantity=stats_data["total_inventory_quantity"]
|
total_inventory_quantity=stats_data["total_inventory_quantity"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting comprehensive stats: {str(e)}")
|
logger.error(f"Error getting comprehensive stats: {str(e)}")
|
||||||
@@ -40,7 +45,9 @@ def get_stats(db: Session = Depends(get_db), current_user: User = Depends(get_cu
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/marketplace", response_model=List[MarketplaceStatsResponse])
|
@router.get("/stats/marketplace", response_model=List[MarketplaceStatsResponse])
|
||||||
def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
|
def get_marketplace_stats(
|
||||||
|
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
"""Get statistics broken down by marketplace (Protected)"""
|
"""Get statistics broken down by marketplace (Protected)"""
|
||||||
try:
|
try:
|
||||||
marketplace_stats = stats_service.get_marketplace_breakdown_stats(db=db)
|
marketplace_stats = stats_service.get_marketplace_breakdown_stats(db=db)
|
||||||
@@ -50,8 +57,9 @@ def get_marketplace_stats(db: Session = Depends(get_db), current_user: User = De
|
|||||||
marketplace=stat["marketplace"],
|
marketplace=stat["marketplace"],
|
||||||
total_products=stat["total_products"],
|
total_products=stat["total_products"],
|
||||||
unique_shops=stat["unique_shops"],
|
unique_shops=stat["unique_shops"],
|
||||||
unique_brands=stat["unique_brands"]
|
unique_brands=stat["unique_brands"],
|
||||||
) for stat in marketplace_stats
|
)
|
||||||
|
for stat in marketplace_stats
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting marketplace stats: {str(e)}")
|
logger.error(f"Error getting marketplace stats: {str(e)}")
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.database import get_db
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.services.stock_service import stock_service
|
||||||
from app.tasks.background_tasks import process_marketplace_import
|
from app.tasks.background_tasks import process_marketplace_import
|
||||||
from middleware.decorators import rate_limit
|
from middleware.decorators import rate_limit
|
||||||
from models.api_models import (MarketplaceImportJobResponse, MarketplaceImportRequest, StockResponse,
|
from models.api_models import (MarketplaceImportJobResponse,
|
||||||
StockSummaryResponse, StockCreate, StockAdd, StockUpdate)
|
MarketplaceImportRequest, StockAdd, StockCreate,
|
||||||
from models.database_models import User, MarketplaceImportJob, Shop
|
StockResponse, StockSummaryResponse,
|
||||||
from app.services.stock_service import stock_service
|
StockUpdate)
|
||||||
import logging
|
from models.database_models import MarketplaceImportJob, Shop, User
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -18,11 +21,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Stock Management Routes (Protected)
|
# Stock Management Routes (Protected)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stock", response_model=StockResponse)
|
@router.post("/stock", response_model=StockResponse)
|
||||||
def set_stock(
|
def set_stock(
|
||||||
stock: StockCreate,
|
stock: StockCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)"""
|
"""Set exact stock quantity for a GTIN at a specific location (replaces existing quantity)"""
|
||||||
try:
|
try:
|
||||||
@@ -37,9 +41,9 @@ def set_stock(
|
|||||||
|
|
||||||
@router.post("/stock/add", response_model=StockResponse)
|
@router.post("/stock/add", response_model=StockResponse)
|
||||||
def add_stock(
|
def add_stock(
|
||||||
stock: StockAdd,
|
stock: StockAdd,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)"""
|
"""Add quantity to existing stock for a GTIN at a specific location (adds to existing quantity)"""
|
||||||
try:
|
try:
|
||||||
@@ -54,9 +58,9 @@ def add_stock(
|
|||||||
|
|
||||||
@router.post("/stock/remove", response_model=StockResponse)
|
@router.post("/stock/remove", response_model=StockResponse)
|
||||||
def remove_stock(
|
def remove_stock(
|
||||||
stock: StockAdd,
|
stock: StockAdd,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Remove quantity from existing stock for a GTIN at a specific location"""
|
"""Remove quantity from existing stock for a GTIN at a specific location"""
|
||||||
try:
|
try:
|
||||||
@@ -71,9 +75,9 @@ def remove_stock(
|
|||||||
|
|
||||||
@router.get("/stock/{gtin}", response_model=StockSummaryResponse)
|
@router.get("/stock/{gtin}", response_model=StockSummaryResponse)
|
||||||
def get_stock_by_gtin(
|
def get_stock_by_gtin(
|
||||||
gtin: str,
|
gtin: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get all stock locations and total quantity for a specific GTIN"""
|
"""Get all stock locations and total quantity for a specific GTIN"""
|
||||||
try:
|
try:
|
||||||
@@ -88,9 +92,9 @@ def get_stock_by_gtin(
|
|||||||
|
|
||||||
@router.get("/stock/{gtin}/total")
|
@router.get("/stock/{gtin}/total")
|
||||||
def get_total_stock(
|
def get_total_stock(
|
||||||
gtin: str,
|
gtin: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get total quantity in stock for a specific GTIN"""
|
"""Get total quantity in stock for a specific GTIN"""
|
||||||
try:
|
try:
|
||||||
@@ -105,21 +109,17 @@ def get_total_stock(
|
|||||||
|
|
||||||
@router.get("/stock", response_model=List[StockResponse])
|
@router.get("/stock", response_model=List[StockResponse])
|
||||||
def get_all_stock(
|
def get_all_stock(
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
location: Optional[str] = Query(None, description="Filter by location"),
|
location: Optional[str] = Query(None, description="Filter by location"),
|
||||||
gtin: Optional[str] = Query(None, description="Filter by GTIN"),
|
gtin: Optional[str] = Query(None, description="Filter by GTIN"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get all stock entries with optional filtering"""
|
"""Get all stock entries with optional filtering"""
|
||||||
try:
|
try:
|
||||||
result = stock_service.get_all_stock(
|
result = stock_service.get_all_stock(
|
||||||
db=db,
|
db=db, skip=skip, limit=limit, location=location, gtin=gtin
|
||||||
skip=skip,
|
|
||||||
limit=limit,
|
|
||||||
location=location,
|
|
||||||
gtin=gtin
|
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -129,10 +129,10 @@ def get_all_stock(
|
|||||||
|
|
||||||
@router.put("/stock/{stock_id}", response_model=StockResponse)
|
@router.put("/stock/{stock_id}", response_model=StockResponse)
|
||||||
def update_stock(
|
def update_stock(
|
||||||
stock_id: int,
|
stock_id: int,
|
||||||
stock_update: StockUpdate,
|
stock_update: StockUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Update stock quantity for a specific stock entry"""
|
"""Update stock quantity for a specific stock entry"""
|
||||||
try:
|
try:
|
||||||
@@ -147,9 +147,9 @@ def update_stock(
|
|||||||
|
|
||||||
@router.delete("/stock/{stock_id}")
|
@router.delete("/stock/{stock_id}")
|
||||||
def delete_stock(
|
def delete_stock(
|
||||||
stock_id: int,
|
stock_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Delete a stock entry"""
|
"""Delete a stock entry"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
# app/core/config.py
|
# app/core/config.py
|
||||||
from pydantic_settings import BaseSettings # This is the correct import for Pydantic v2
|
from typing import List, Optional
|
||||||
from typing import Optional, List
|
|
||||||
|
from pydantic_settings import \
|
||||||
|
BaseSettings # This is the correct import for Pydantic v2
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
|
||||||
engine = create_engine(settings.database_url)
|
engine = create_engine(settings.database_url)
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from .logging import setup_logging
|
|
||||||
from .database import engine, SessionLocal
|
|
||||||
from models.database_models import Base
|
|
||||||
import logging
|
|
||||||
from middleware.auth import AuthManager
|
from middleware.auth import AuthManager
|
||||||
|
from models.database_models import Base
|
||||||
|
|
||||||
|
from .database import SessionLocal, engine
|
||||||
|
from .logging import setup_logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
auth_manager = AuthManager()
|
auth_manager = AuthManager()
|
||||||
@@ -44,15 +47,29 @@ def create_indexes():
|
|||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
try:
|
try:
|
||||||
# User indexes
|
# User indexes
|
||||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)"))
|
conn.execute(
|
||||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)"))
|
text("CREATE INDEX IF NOT EXISTS idx_user_email ON users(email)")
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
text("CREATE INDEX IF NOT EXISTS idx_user_username ON users(username)")
|
||||||
|
)
|
||||||
|
|
||||||
# Product indexes
|
# Product indexes
|
||||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)"))
|
conn.execute(
|
||||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)"))
|
text("CREATE INDEX IF NOT EXISTS idx_product_gtin ON products(gtin)")
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_product_marketplace ON products(marketplace)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Stock indexes
|
# Stock indexes
|
||||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)"))
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_stock_gtin_location ON stock(gtin, location)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info("Database indexes created successfully")
|
logger.info("Database indexes created successfully")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ def setup_logging():
|
|||||||
|
|
||||||
# Create formatters
|
# Create formatters
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Console handler
|
# Console handler
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from models.database_models import User, MarketplaceImportJob, Shop
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from models.api_models import MarketplaceImportJobResponse
|
from models.api_models import MarketplaceImportJobResponse
|
||||||
|
from models.database_models import MarketplaceImportJob, Shop, User
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,7 +18,9 @@ class AdminService:
|
|||||||
"""Get paginated list of all users"""
|
"""Get paginated list of all users"""
|
||||||
return db.query(User).offset(skip).limit(limit).all()
|
return db.query(User).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
def toggle_user_status(self, db: Session, user_id: int, current_admin_id: int) -> Tuple[User, str]:
|
def toggle_user_status(
|
||||||
|
self, db: Session, user_id: int, current_admin_id: int
|
||||||
|
) -> Tuple[User, str]:
|
||||||
"""
|
"""
|
||||||
Toggle user active status
|
Toggle user active status
|
||||||
|
|
||||||
@@ -37,7 +40,9 @@ class AdminService:
|
|||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
if user.id == current_admin_id:
|
if user.id == current_admin_id:
|
||||||
raise HTTPException(status_code=400, detail="Cannot deactivate your own account")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Cannot deactivate your own account"
|
||||||
|
)
|
||||||
|
|
||||||
user.is_active = not user.is_active
|
user.is_active = not user.is_active
|
||||||
user.updated_at = datetime.utcnow()
|
user.updated_at = datetime.utcnow()
|
||||||
@@ -45,10 +50,14 @@ class AdminService:
|
|||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
status = "activated" if user.is_active else "deactivated"
|
status = "activated" if user.is_active else "deactivated"
|
||||||
logger.info(f"User {user.username} has been {status} by admin {current_admin_id}")
|
logger.info(
|
||||||
|
f"User {user.username} has been {status} by admin {current_admin_id}"
|
||||||
|
)
|
||||||
return user, f"User {user.username} has been {status}"
|
return user, f"User {user.username} has been {status}"
|
||||||
|
|
||||||
def get_all_shops(self, db: Session, skip: int = 0, limit: int = 100) -> Tuple[List[Shop], int]:
|
def get_all_shops(
|
||||||
|
self, db: Session, skip: int = 0, limit: int = 100
|
||||||
|
) -> Tuple[List[Shop], int]:
|
||||||
"""
|
"""
|
||||||
Get paginated list of all shops with total count
|
Get paginated list of all shops with total count
|
||||||
|
|
||||||
@@ -119,13 +128,13 @@ class AdminService:
|
|||||||
return shop, f"Shop {shop.shop_code} has been {status}"
|
return shop, f"Shop {shop.shop_code} has been {status}"
|
||||||
|
|
||||||
def get_marketplace_import_jobs(
|
def get_marketplace_import_jobs(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
marketplace: Optional[str] = None,
|
marketplace: Optional[str] = None,
|
||||||
shop_name: Optional[str] = None,
|
shop_name: Optional[str] = None,
|
||||||
status: Optional[str] = None,
|
status: Optional[str] = None,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100
|
limit: int = 100,
|
||||||
) -> List[MarketplaceImportJobResponse]:
|
) -> List[MarketplaceImportJobResponse]:
|
||||||
"""
|
"""
|
||||||
Get filtered and paginated marketplace import jobs
|
Get filtered and paginated marketplace import jobs
|
||||||
@@ -145,14 +154,21 @@ class AdminService:
|
|||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if marketplace:
|
if marketplace:
|
||||||
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%"))
|
query = query.filter(
|
||||||
|
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
|
||||||
|
)
|
||||||
if shop_name:
|
if shop_name:
|
||||||
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
|
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
|
||||||
if status:
|
if status:
|
||||||
query = query.filter(MarketplaceImportJob.status == status)
|
query = query.filter(MarketplaceImportJob.status == status)
|
||||||
|
|
||||||
# Order by creation date and apply pagination
|
# Order by creation date and apply pagination
|
||||||
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all()
|
jobs = (
|
||||||
|
query.order_by(MarketplaceImportJob.created_at.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
MarketplaceImportJobResponse(
|
MarketplaceImportJobResponse(
|
||||||
@@ -168,8 +184,9 @@ class AdminService:
|
|||||||
error_message=job.error_message,
|
error_message=job.error_message,
|
||||||
created_at=job.created_at,
|
created_at=job.created_at,
|
||||||
started_at=job.started_at,
|
started_at=job.started_at,
|
||||||
completed_at=job.completed_at
|
completed_at=job.completed_at,
|
||||||
) for job in jobs
|
)
|
||||||
|
for job in jobs
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:
|
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from fastapi import HTTPException
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from models.database_models import User
|
|
||||||
from models.api_models import UserRegister, UserLogin
|
|
||||||
from middleware.auth import AuthManager
|
from middleware.auth import AuthManager
|
||||||
|
from models.api_models import UserLogin, UserRegister
|
||||||
|
from models.database_models import User
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -36,7 +37,9 @@ class AuthService:
|
|||||||
raise HTTPException(status_code=400, detail="Email already registered")
|
raise HTTPException(status_code=400, detail="Email already registered")
|
||||||
|
|
||||||
# Check if username already exists
|
# Check if username already exists
|
||||||
existing_username = db.query(User).filter(User.username == user_data.username).first()
|
existing_username = (
|
||||||
|
db.query(User).filter(User.username == user_data.username).first()
|
||||||
|
)
|
||||||
if existing_username:
|
if existing_username:
|
||||||
raise HTTPException(status_code=400, detail="Username already taken")
|
raise HTTPException(status_code=400, detail="Username already taken")
|
||||||
|
|
||||||
@@ -47,7 +50,7 @@ class AuthService:
|
|||||||
username=user_data.username,
|
username=user_data.username,
|
||||||
hashed_password=hashed_password,
|
hashed_password=hashed_password,
|
||||||
role="user",
|
role="user",
|
||||||
is_active=True
|
is_active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(new_user)
|
db.add(new_user)
|
||||||
@@ -71,19 +74,20 @@ class AuthService:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: If authentication fails
|
HTTPException: If authentication fails
|
||||||
"""
|
"""
|
||||||
user = self.auth_manager.authenticate_user(db, user_credentials.username, user_credentials.password)
|
user = self.auth_manager.authenticate_user(
|
||||||
|
db, user_credentials.username, user_credentials.password
|
||||||
|
)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Incorrect username or password")
|
raise HTTPException(
|
||||||
|
status_code=401, detail="Incorrect username or password"
|
||||||
|
)
|
||||||
|
|
||||||
# Create access token
|
# Create access token
|
||||||
token_data = self.auth_manager.create_access_token(user)
|
token_data = self.auth_manager.create_access_token(user)
|
||||||
|
|
||||||
logger.info(f"User logged in: {user.username}")
|
logger.info(f"User logged in: {user.username}")
|
||||||
|
|
||||||
return {
|
return {"token_data": token_data, "user": user}
|
||||||
"token_data": token_data,
|
|
||||||
"user": user
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_user_by_email(self, db: Session, email: str) -> Optional[User]:
|
def get_user_by_email(self, db: Session, email: str) -> Optional[User]:
|
||||||
"""Get user by email"""
|
"""Get user by email"""
|
||||||
@@ -101,7 +105,9 @@ class AuthService:
|
|||||||
"""Check if username already exists"""
|
"""Check if username already exists"""
|
||||||
return db.query(User).filter(User.username == username).first() is not None
|
return db.query(User).filter(User.username == username).first() is not None
|
||||||
|
|
||||||
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
|
def authenticate_user(
|
||||||
|
self, db: Session, username: str, password: str
|
||||||
|
) -> Optional[User]:
|
||||||
"""Authenticate user with username/password"""
|
"""Authenticate user with username/password"""
|
||||||
return self.auth_manager.authenticate_user(db, username, password)
|
return self.auth_manager.authenticate_user(db, username, password)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models.api_models import (MarketplaceImportJobResponse,
|
||||||
|
MarketplaceImportRequest)
|
||||||
from models.database_models import MarketplaceImportJob, Shop, User
|
from models.database_models import MarketplaceImportJob, Shop, User
|
||||||
from models.api_models import MarketplaceImportRequest, MarketplaceImportJobResponse
|
|
||||||
from typing import Optional, List
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,9 +20,11 @@ class MarketplaceService:
|
|||||||
"""Validate that the shop exists and user has access to it"""
|
"""Validate that the shop exists and user has access to it"""
|
||||||
# Explicit type hint to help type checker shop: Optional[Shop]
|
# Explicit type hint to help type checker shop: Optional[Shop]
|
||||||
# Use case-insensitive query to handle both uppercase and lowercase codes
|
# Use case-insensitive query to handle both uppercase and lowercase codes
|
||||||
shop: Optional[Shop] = db.query(Shop).filter(
|
shop: Optional[Shop] = (
|
||||||
func.upper(Shop.shop_code) == shop_code.upper()
|
db.query(Shop)
|
||||||
).first()
|
.filter(func.upper(Shop.shop_code) == shop_code.upper())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not shop:
|
if not shop:
|
||||||
raise ValueError("Shop not found")
|
raise ValueError("Shop not found")
|
||||||
|
|
||||||
@@ -30,10 +35,7 @@ class MarketplaceService:
|
|||||||
return shop
|
return shop
|
||||||
|
|
||||||
def create_import_job(
|
def create_import_job(
|
||||||
self,
|
self, db: Session, request: MarketplaceImportRequest, user: User
|
||||||
db: Session,
|
|
||||||
request: MarketplaceImportRequest,
|
|
||||||
user: User
|
|
||||||
) -> MarketplaceImportJob:
|
) -> MarketplaceImportJob:
|
||||||
"""Create a new marketplace import job"""
|
"""Create a new marketplace import job"""
|
||||||
# Validate shop access first
|
# Validate shop access first
|
||||||
@@ -47,7 +49,7 @@ class MarketplaceService:
|
|||||||
shop_id=shop.id, # Foreign key to shops table
|
shop_id=shop.id, # Foreign key to shops table
|
||||||
shop_name=shop.shop_name, # Use shop.shop_name (the display name)
|
shop_name=shop.shop_name, # Use shop.shop_name (the display name)
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
created_at=datetime.utcnow()
|
created_at=datetime.utcnow(),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(import_job)
|
db.add(import_job)
|
||||||
@@ -55,13 +57,20 @@ class MarketplaceService:
|
|||||||
db.refresh(import_job)
|
db.refresh(import_job)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}")
|
f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
return import_job
|
return import_job
|
||||||
|
|
||||||
def get_import_job_by_id(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob:
|
def get_import_job_by_id(
|
||||||
|
self, db: Session, job_id: int, user: User
|
||||||
|
) -> MarketplaceImportJob:
|
||||||
"""Get a marketplace import job by ID with access control"""
|
"""Get a marketplace import job by ID with access control"""
|
||||||
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
|
job = (
|
||||||
|
db.query(MarketplaceImportJob)
|
||||||
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not job:
|
if not job:
|
||||||
raise ValueError("Marketplace import job not found")
|
raise ValueError("Marketplace import job not found")
|
||||||
|
|
||||||
@@ -72,13 +81,13 @@ class MarketplaceService:
|
|||||||
return job
|
return job
|
||||||
|
|
||||||
def get_import_jobs(
|
def get_import_jobs(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
user: User,
|
user: User,
|
||||||
marketplace: Optional[str] = None,
|
marketplace: Optional[str] = None,
|
||||||
shop_name: Optional[str] = None,
|
shop_name: Optional[str] = None,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 50
|
limit: int = 50,
|
||||||
) -> List[MarketplaceImportJob]:
|
) -> List[MarketplaceImportJob]:
|
||||||
"""Get marketplace import jobs with filtering and access control"""
|
"""Get marketplace import jobs with filtering and access control"""
|
||||||
query = db.query(MarketplaceImportJob)
|
query = db.query(MarketplaceImportJob)
|
||||||
@@ -89,44 +98,51 @@ class MarketplaceService:
|
|||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if marketplace:
|
if marketplace:
|
||||||
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%"))
|
query = query.filter(
|
||||||
|
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
|
||||||
|
)
|
||||||
if shop_name:
|
if shop_name:
|
||||||
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
|
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
|
||||||
|
|
||||||
# Order by creation date (newest first) and apply pagination
|
# Order by creation date (newest first) and apply pagination
|
||||||
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all()
|
jobs = (
|
||||||
|
query.order_by(MarketplaceImportJob.created_at.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
def update_job_status(
|
def update_job_status(
|
||||||
self,
|
self, db: Session, job_id: int, status: str, **kwargs
|
||||||
db: Session,
|
|
||||||
job_id: int,
|
|
||||||
status: str,
|
|
||||||
**kwargs
|
|
||||||
) -> MarketplaceImportJob:
|
) -> MarketplaceImportJob:
|
||||||
"""Update marketplace import job status and other fields"""
|
"""Update marketplace import job status and other fields"""
|
||||||
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
|
job = (
|
||||||
|
db.query(MarketplaceImportJob)
|
||||||
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not job:
|
if not job:
|
||||||
raise ValueError("Marketplace import job not found")
|
raise ValueError("Marketplace import job not found")
|
||||||
|
|
||||||
job.status = status
|
job.status = status
|
||||||
|
|
||||||
# Update optional fields if provided
|
# Update optional fields if provided
|
||||||
if 'imported_count' in kwargs:
|
if "imported_count" in kwargs:
|
||||||
job.imported_count = kwargs['imported_count']
|
job.imported_count = kwargs["imported_count"]
|
||||||
if 'updated_count' in kwargs:
|
if "updated_count" in kwargs:
|
||||||
job.updated_count = kwargs['updated_count']
|
job.updated_count = kwargs["updated_count"]
|
||||||
if 'total_processed' in kwargs:
|
if "total_processed" in kwargs:
|
||||||
job.total_processed = kwargs['total_processed']
|
job.total_processed = kwargs["total_processed"]
|
||||||
if 'error_count' in kwargs:
|
if "error_count" in kwargs:
|
||||||
job.error_count = kwargs['error_count']
|
job.error_count = kwargs["error_count"]
|
||||||
if 'error_message' in kwargs:
|
if "error_message" in kwargs:
|
||||||
job.error_message = kwargs['error_message']
|
job.error_message = kwargs["error_message"]
|
||||||
if 'started_at' in kwargs:
|
if "started_at" in kwargs:
|
||||||
job.started_at = kwargs['started_at']
|
job.started_at = kwargs["started_at"]
|
||||||
if 'completed_at' in kwargs:
|
if "completed_at" in kwargs:
|
||||||
job.completed_at = kwargs['completed_at']
|
job.completed_at = kwargs["completed_at"]
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(job)
|
db.refresh(job)
|
||||||
@@ -145,7 +161,9 @@ class MarketplaceService:
|
|||||||
total_jobs = query.count()
|
total_jobs = query.count()
|
||||||
pending_jobs = query.filter(MarketplaceImportJob.status == "pending").count()
|
pending_jobs = query.filter(MarketplaceImportJob.status == "pending").count()
|
||||||
running_jobs = query.filter(MarketplaceImportJob.status == "running").count()
|
running_jobs = query.filter(MarketplaceImportJob.status == "running").count()
|
||||||
completed_jobs = query.filter(MarketplaceImportJob.status == "completed").count()
|
completed_jobs = query.filter(
|
||||||
|
MarketplaceImportJob.status == "completed"
|
||||||
|
).count()
|
||||||
failed_jobs = query.filter(MarketplaceImportJob.status == "failed").count()
|
failed_jobs = query.filter(MarketplaceImportJob.status == "failed").count()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -153,17 +171,21 @@ class MarketplaceService:
|
|||||||
"pending_jobs": pending_jobs,
|
"pending_jobs": pending_jobs,
|
||||||
"running_jobs": running_jobs,
|
"running_jobs": running_jobs,
|
||||||
"completed_jobs": completed_jobs,
|
"completed_jobs": completed_jobs,
|
||||||
"failed_jobs": failed_jobs
|
"failed_jobs": failed_jobs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_to_response_model(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse:
|
def convert_to_response_model(
|
||||||
|
self, job: MarketplaceImportJob
|
||||||
|
) -> MarketplaceImportJobResponse:
|
||||||
"""Convert database model to API response model"""
|
"""Convert database model to API response model"""
|
||||||
return MarketplaceImportJobResponse(
|
return MarketplaceImportJobResponse(
|
||||||
job_id=job.id,
|
job_id=job.id,
|
||||||
status=job.status,
|
status=job.status,
|
||||||
marketplace=job.marketplace,
|
marketplace=job.marketplace,
|
||||||
shop_id=job.shop_id,
|
shop_id=job.shop_id,
|
||||||
shop_code=job.shop.shop_code if job.shop else None, # Add this optional field via relationship
|
shop_code=(
|
||||||
|
job.shop.shop_code if job.shop else None
|
||||||
|
), # Add this optional field via relationship
|
||||||
shop_name=job.shop_name,
|
shop_name=job.shop_name,
|
||||||
imported=job.imported_count or 0,
|
imported=job.imported_count or 0,
|
||||||
updated=job.updated_count or 0,
|
updated=job.updated_count or 0,
|
||||||
@@ -172,10 +194,12 @@ class MarketplaceService:
|
|||||||
error_message=job.error_message,
|
error_message=job.error_message,
|
||||||
created_at=job.created_at,
|
created_at=job.created_at,
|
||||||
started_at=job.started_at,
|
started_at=job.started_at,
|
||||||
completed_at=job.completed_at
|
completed_at=job.completed_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cancel_import_job(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob:
|
def cancel_import_job(
|
||||||
|
self, db: Session, job_id: int, user: User
|
||||||
|
) -> MarketplaceImportJob:
|
||||||
"""Cancel a pending or running import job"""
|
"""Cancel a pending or running import job"""
|
||||||
job = self.get_import_job_by_id(db, job_id, user)
|
job = self.get_import_job_by_id(db, job_id, user)
|
||||||
|
|
||||||
@@ -197,7 +221,9 @@ class MarketplaceService:
|
|||||||
|
|
||||||
# Only allow deletion of completed, failed, or cancelled jobs
|
# Only allow deletion of completed, failed, or cancelled jobs
|
||||||
if job.status in ["pending", "running"]:
|
if job.status in ["pending", "running"]:
|
||||||
raise ValueError(f"Cannot delete job with status: {job.status}. Cancel it first.")
|
raise ValueError(
|
||||||
|
f"Cannot delete job with status: {job.status}. Cancel it first."
|
||||||
|
)
|
||||||
|
|
||||||
db.delete(job)
|
db.delete(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from models.database_models import Product, Stock
|
|
||||||
from models.api_models import ProductCreate, ProductUpdate, StockLocationResponse, StockSummaryResponse
|
|
||||||
from utils.data_processing import GTINProcessor, PriceProcessor
|
|
||||||
from typing import Optional, List, Generator
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models.api_models import (ProductCreate, ProductUpdate,
|
||||||
|
StockLocationResponse, StockSummaryResponse)
|
||||||
|
from models.database_models import Product, Stock
|
||||||
|
from utils.data_processing import GTINProcessor, PriceProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +30,9 @@ class ProductService:
|
|||||||
|
|
||||||
# Process price if provided
|
# Process price if provided
|
||||||
if product_data.price:
|
if product_data.price:
|
||||||
parsed_price, currency = self.price_processor.parse_price_currency(product_data.price)
|
parsed_price, currency = self.price_processor.parse_price_currency(
|
||||||
|
product_data.price
|
||||||
|
)
|
||||||
if parsed_price:
|
if parsed_price:
|
||||||
product_data.price = parsed_price
|
product_data.price = parsed_price
|
||||||
product_data.currency = currency
|
product_data.currency = currency
|
||||||
@@ -58,16 +63,16 @@ class ProductService:
|
|||||||
return db.query(Product).filter(Product.product_id == product_id).first()
|
return db.query(Product).filter(Product.product_id == product_id).first()
|
||||||
|
|
||||||
def get_products_with_filters(
|
def get_products_with_filters(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
brand: Optional[str] = None,
|
brand: Optional[str] = None,
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
availability: Optional[str] = None,
|
availability: Optional[str] = None,
|
||||||
marketplace: Optional[str] = None,
|
marketplace: Optional[str] = None,
|
||||||
shop_name: Optional[str] = None,
|
shop_name: Optional[str] = None,
|
||||||
search: Optional[str] = None
|
search: Optional[str] = None,
|
||||||
) -> tuple[List[Product], int]:
|
) -> tuple[List[Product], int]:
|
||||||
"""Get products with filtering and pagination"""
|
"""Get products with filtering and pagination"""
|
||||||
query = db.query(Product)
|
query = db.query(Product)
|
||||||
@@ -87,10 +92,10 @@ class ProductService:
|
|||||||
# Search in title, description, marketplace, and shop_name
|
# Search in title, description, marketplace, and shop_name
|
||||||
search_term = f"%{search}%"
|
search_term = f"%{search}%"
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
(Product.title.ilike(search_term)) |
|
(Product.title.ilike(search_term))
|
||||||
(Product.description.ilike(search_term)) |
|
| (Product.description.ilike(search_term))
|
||||||
(Product.marketplace.ilike(search_term)) |
|
| (Product.marketplace.ilike(search_term))
|
||||||
(Product.shop_name.ilike(search_term))
|
| (Product.shop_name.ilike(search_term))
|
||||||
)
|
)
|
||||||
|
|
||||||
total = query.count()
|
total = query.count()
|
||||||
@@ -98,7 +103,9 @@ class ProductService:
|
|||||||
|
|
||||||
return products, total
|
return products, total
|
||||||
|
|
||||||
def update_product(self, db: Session, product_id: str, product_update: ProductUpdate) -> Product:
|
def update_product(
|
||||||
|
self, db: Session, product_id: str, product_update: ProductUpdate
|
||||||
|
) -> Product:
|
||||||
"""Update product with validation"""
|
"""Update product with validation"""
|
||||||
product = db.query(Product).filter(Product.product_id == product_id).first()
|
product = db.query(Product).filter(Product.product_id == product_id).first()
|
||||||
if not product:
|
if not product:
|
||||||
@@ -116,7 +123,9 @@ class ProductService:
|
|||||||
|
|
||||||
# Process price if being updated
|
# Process price if being updated
|
||||||
if "price" in update_data and update_data["price"]:
|
if "price" in update_data and update_data["price"]:
|
||||||
parsed_price, currency = self.price_processor.parse_price_currency(update_data["price"])
|
parsed_price, currency = self.price_processor.parse_price_currency(
|
||||||
|
update_data["price"]
|
||||||
|
)
|
||||||
if parsed_price:
|
if parsed_price:
|
||||||
update_data["price"] = parsed_price
|
update_data["price"] = parsed_price
|
||||||
update_data["currency"] = currency
|
update_data["currency"] = currency
|
||||||
@@ -160,21 +169,21 @@ class ProductService:
|
|||||||
]
|
]
|
||||||
|
|
||||||
return StockSummaryResponse(
|
return StockSummaryResponse(
|
||||||
gtin=gtin,
|
gtin=gtin, total_quantity=total_quantity, locations=locations
|
||||||
total_quantity=total_quantity,
|
|
||||||
locations=locations
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_csv_export(
|
def generate_csv_export(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
marketplace: Optional[str] = None,
|
marketplace: Optional[str] = None,
|
||||||
shop_name: Optional[str] = None
|
shop_name: Optional[str] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""Generate CSV export with streaming for memory efficiency"""
|
"""Generate CSV export with streaming for memory efficiency"""
|
||||||
# CSV header
|
# CSV header
|
||||||
yield ("product_id,title,description,link,image_link,availability,price,currency,brand,"
|
yield (
|
||||||
"gtin,marketplace,shop_name\n")
|
"product_id,title,description,link,image_link,availability,price,currency,brand,"
|
||||||
|
"gtin,marketplace,shop_name\n"
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -194,17 +203,22 @@ class ProductService:
|
|||||||
|
|
||||||
for product in products:
|
for product in products:
|
||||||
# Create CSV row with marketplace fields
|
# Create CSV row with marketplace fields
|
||||||
row = (f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
|
row = (
|
||||||
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
|
f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
|
||||||
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
|
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
|
||||||
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n')
|
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
|
||||||
|
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n'
|
||||||
|
)
|
||||||
yield row
|
yield row
|
||||||
|
|
||||||
offset += batch_size
|
offset += batch_size
|
||||||
|
|
||||||
def product_exists(self, db: Session, product_id: str) -> bool:
|
def product_exists(self, db: Session, product_id: str) -> bool:
|
||||||
"""Check if product exists by ID"""
|
"""Check if product exists by ID"""
|
||||||
return db.query(Product).filter(Product.product_id == product_id).first() is not None
|
return (
|
||||||
|
db.query(Product).filter(Product.product_id == product_id).first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Create service instance
|
# Create service instance
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import HTTPException
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Tuple, Dict, Any
|
|
||||||
|
|
||||||
from models.database_models import User, Shop, Product, ShopProduct
|
|
||||||
from models.api_models import ShopCreate, ShopProductCreate
|
from models.api_models import ShopCreate, ShopProductCreate
|
||||||
|
from models.database_models import Product, Shop, ShopProduct, User
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -14,7 +15,9 @@ logger = logging.getLogger(__name__)
|
|||||||
class ShopService:
|
class ShopService:
|
||||||
"""Service class for shop operations following the application's service pattern"""
|
"""Service class for shop operations following the application's service pattern"""
|
||||||
|
|
||||||
def create_shop(self, db: Session, shop_data: ShopCreate, current_user: User) -> Shop:
|
def create_shop(
|
||||||
|
self, db: Session, shop_data: ShopCreate, current_user: User
|
||||||
|
) -> Shop:
|
||||||
"""
|
"""
|
||||||
Create a new shop
|
Create a new shop
|
||||||
|
|
||||||
@@ -33,39 +36,43 @@ class ShopService:
|
|||||||
normalized_shop_code = shop_data.shop_code.upper()
|
normalized_shop_code = shop_data.shop_code.upper()
|
||||||
|
|
||||||
# Check if shop code already exists (case-insensitive check against existing data)
|
# Check if shop code already exists (case-insensitive check against existing data)
|
||||||
existing_shop = db.query(Shop).filter(
|
existing_shop = (
|
||||||
func.upper(Shop.shop_code) == normalized_shop_code
|
db.query(Shop)
|
||||||
).first()
|
.filter(func.upper(Shop.shop_code) == normalized_shop_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if existing_shop:
|
if existing_shop:
|
||||||
raise HTTPException(status_code=400, detail="Shop code already exists")
|
raise HTTPException(status_code=400, detail="Shop code already exists")
|
||||||
|
|
||||||
# Create shop with uppercase code
|
# Create shop with uppercase code
|
||||||
shop_dict = shop_data.model_dump() # Fixed deprecated .dict() method
|
shop_dict = shop_data.model_dump() # Fixed deprecated .dict() method
|
||||||
shop_dict['shop_code'] = normalized_shop_code # Store as uppercase
|
shop_dict["shop_code"] = normalized_shop_code # Store as uppercase
|
||||||
|
|
||||||
new_shop = Shop(
|
new_shop = Shop(
|
||||||
**shop_dict,
|
**shop_dict,
|
||||||
owner_id=current_user.id,
|
owner_id=current_user.id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=(current_user.role == "admin")
|
is_verified=(current_user.role == "admin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(new_shop)
|
db.add(new_shop)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(new_shop)
|
db.refresh(new_shop)
|
||||||
|
|
||||||
logger.info(f"New shop created: {new_shop.shop_code} by {current_user.username}")
|
logger.info(
|
||||||
|
f"New shop created: {new_shop.shop_code} by {current_user.username}"
|
||||||
|
)
|
||||||
return new_shop
|
return new_shop
|
||||||
|
|
||||||
def get_shops(
|
def get_shops(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
current_user: User,
|
current_user: User,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
active_only: bool = True,
|
active_only: bool = True,
|
||||||
verified_only: bool = False
|
verified_only: bool = False,
|
||||||
) -> Tuple[List[Shop], int]:
|
) -> Tuple[List[Shop], int]:
|
||||||
"""
|
"""
|
||||||
Get shops with filtering
|
Get shops with filtering
|
||||||
@@ -86,8 +93,8 @@ class ShopService:
|
|||||||
# Non-admin users can only see active and verified shops, plus their own
|
# Non-admin users can only see active and verified shops, plus their own
|
||||||
if current_user.role != "admin":
|
if current_user.role != "admin":
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
(Shop.is_active == True) &
|
(Shop.is_active == True)
|
||||||
((Shop.is_verified == True) | (Shop.owner_id == current_user.id))
|
& ((Shop.is_verified == True) | (Shop.owner_id == current_user.id))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Admin can apply filters
|
# Admin can apply filters
|
||||||
@@ -117,22 +124,25 @@ class ShopService:
|
|||||||
HTTPException: If shop not found or access denied
|
HTTPException: If shop not found or access denied
|
||||||
"""
|
"""
|
||||||
# Explicit type hint to help type checker shop: Optional[Shop]
|
# Explicit type hint to help type checker shop: Optional[Shop]
|
||||||
shop: Optional[Shop] = db.query(Shop).filter(func.upper(Shop.shop_code) == shop_code.upper()).first()
|
shop: Optional[Shop] = (
|
||||||
|
db.query(Shop)
|
||||||
|
.filter(func.upper(Shop.shop_code) == shop_code.upper())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not shop:
|
if not shop:
|
||||||
raise HTTPException(status_code=404, detail="Shop not found")
|
raise HTTPException(status_code=404, detail="Shop not found")
|
||||||
|
|
||||||
# Non-admin users can only see active verified shops or their own shops
|
# Non-admin users can only see active verified shops or their own shops
|
||||||
if current_user.role != "admin":
|
if current_user.role != "admin":
|
||||||
if not shop.is_active or (not shop.is_verified and shop.owner_id != current_user.id):
|
if not shop.is_active or (
|
||||||
|
not shop.is_verified and shop.owner_id != current_user.id
|
||||||
|
):
|
||||||
raise HTTPException(status_code=404, detail="Shop not found")
|
raise HTTPException(status_code=404, detail="Shop not found")
|
||||||
|
|
||||||
return shop
|
return shop
|
||||||
|
|
||||||
def add_product_to_shop(
|
def add_product_to_shop(
|
||||||
self,
|
self, db: Session, shop: Shop, shop_product: ShopProductCreate
|
||||||
db: Session,
|
|
||||||
shop: Shop,
|
|
||||||
shop_product: ShopProductCreate
|
|
||||||
) -> ShopProduct:
|
) -> ShopProduct:
|
||||||
"""
|
"""
|
||||||
Add existing product to shop catalog with shop-specific settings
|
Add existing product to shop catalog with shop-specific settings
|
||||||
@@ -149,24 +159,35 @@ class ShopService:
|
|||||||
HTTPException: If product not found or already in shop
|
HTTPException: If product not found or already in shop
|
||||||
"""
|
"""
|
||||||
# Check if product exists
|
# Check if product exists
|
||||||
product = db.query(Product).filter(Product.product_id == shop_product.product_id).first()
|
product = (
|
||||||
|
db.query(Product)
|
||||||
|
.filter(Product.product_id == shop_product.product_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not product:
|
if not product:
|
||||||
raise HTTPException(status_code=404, detail="Product not found in marketplace catalog")
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Product not found in marketplace catalog"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if product already in shop
|
# Check if product already in shop
|
||||||
existing_shop_product = db.query(ShopProduct).filter(
|
existing_shop_product = (
|
||||||
ShopProduct.shop_id == shop.id,
|
db.query(ShopProduct)
|
||||||
ShopProduct.product_id == product.id
|
.filter(
|
||||||
).first()
|
ShopProduct.shop_id == shop.id, ShopProduct.product_id == product.id
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if existing_shop_product:
|
if existing_shop_product:
|
||||||
raise HTTPException(status_code=400, detail="Product already in shop catalog")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Product already in shop catalog"
|
||||||
|
)
|
||||||
|
|
||||||
# Create shop-product association
|
# Create shop-product association
|
||||||
new_shop_product = ShopProduct(
|
new_shop_product = ShopProduct(
|
||||||
shop_id=shop.id,
|
shop_id=shop.id,
|
||||||
product_id=product.id,
|
product_id=product.id,
|
||||||
**shop_product.model_dump(exclude={'product_id'})
|
**shop_product.model_dump(exclude={"product_id"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(new_shop_product)
|
db.add(new_shop_product)
|
||||||
@@ -180,14 +201,14 @@ class ShopService:
|
|||||||
return new_shop_product
|
return new_shop_product
|
||||||
|
|
||||||
def get_shop_products(
|
def get_shop_products(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
shop: Shop,
|
shop: Shop,
|
||||||
current_user: User,
|
current_user: User,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
active_only: bool = True,
|
active_only: bool = True,
|
||||||
featured_only: bool = False
|
featured_only: bool = False,
|
||||||
) -> Tuple[List[ShopProduct], int]:
|
) -> Tuple[List[ShopProduct], int]:
|
||||||
"""
|
"""
|
||||||
Get products in shop catalog with filtering
|
Get products in shop catalog with filtering
|
||||||
@@ -239,10 +260,14 @@ class ShopService:
|
|||||||
|
|
||||||
def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool:
|
def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool:
|
||||||
"""Check if product is already in shop"""
|
"""Check if product is already in shop"""
|
||||||
return db.query(ShopProduct).filter(
|
return (
|
||||||
ShopProduct.shop_id == shop_id,
|
db.query(ShopProduct)
|
||||||
ShopProduct.product_id == product_id
|
.filter(
|
||||||
).first() is not None
|
ShopProduct.shop_id == shop_id, ShopProduct.product_id == product_id
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
def is_shop_owner(self, shop: Shop, user: User) -> bool:
|
def is_shop_owner(self, shop: Shop, user: User) -> bool:
|
||||||
"""Check if user is shop owner"""
|
"""Check if user is shop owner"""
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
import logging
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
from models.database_models import User, Product, Stock
|
from models.api_models import MarketplaceStatsResponse, StatsResponse
|
||||||
from models.api_models import StatsResponse, MarketplaceStatsResponse
|
from models.database_models import Product, Stock, User
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,26 +26,37 @@ class StatsService:
|
|||||||
# Use more efficient queries with proper indexes
|
# Use more efficient queries with proper indexes
|
||||||
total_products = db.query(Product).count()
|
total_products = db.query(Product).count()
|
||||||
|
|
||||||
unique_brands = db.query(Product.brand).filter(
|
unique_brands = (
|
||||||
Product.brand.isnot(None),
|
db.query(Product.brand)
|
||||||
Product.brand != ""
|
.filter(Product.brand.isnot(None), Product.brand != "")
|
||||||
).distinct().count()
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
unique_categories = db.query(Product.google_product_category).filter(
|
unique_categories = (
|
||||||
Product.google_product_category.isnot(None),
|
db.query(Product.google_product_category)
|
||||||
Product.google_product_category != ""
|
.filter(
|
||||||
).distinct().count()
|
Product.google_product_category.isnot(None),
|
||||||
|
Product.google_product_category != "",
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
# New marketplace statistics
|
# New marketplace statistics
|
||||||
unique_marketplaces = db.query(Product.marketplace).filter(
|
unique_marketplaces = (
|
||||||
Product.marketplace.isnot(None),
|
db.query(Product.marketplace)
|
||||||
Product.marketplace != ""
|
.filter(Product.marketplace.isnot(None), Product.marketplace != "")
|
||||||
).distinct().count()
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
unique_shops = db.query(Product.shop_name).filter(
|
unique_shops = (
|
||||||
Product.shop_name.isnot(None),
|
db.query(Product.shop_name)
|
||||||
Product.shop_name != ""
|
.filter(Product.shop_name.isnot(None), Product.shop_name != "")
|
||||||
).distinct().count()
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
# Stock statistics
|
# Stock statistics
|
||||||
total_stock_entries = db.query(Stock).count()
|
total_stock_entries = db.query(Stock).count()
|
||||||
@@ -57,10 +69,12 @@ class StatsService:
|
|||||||
"unique_marketplaces": unique_marketplaces,
|
"unique_marketplaces": unique_marketplaces,
|
||||||
"unique_shops": unique_shops,
|
"unique_shops": unique_shops,
|
||||||
"total_stock_entries": total_stock_entries,
|
"total_stock_entries": total_stock_entries,
|
||||||
"total_inventory_quantity": total_inventory
|
"total_inventory_quantity": total_inventory,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces")
|
logger.info(
|
||||||
|
f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces"
|
||||||
|
)
|
||||||
return stats_data
|
return stats_data
|
||||||
|
|
||||||
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
|
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
|
||||||
@@ -74,25 +88,31 @@ class StatsService:
|
|||||||
List of dictionaries containing marketplace statistics
|
List of dictionaries containing marketplace statistics
|
||||||
"""
|
"""
|
||||||
# Query to get stats per marketplace
|
# Query to get stats per marketplace
|
||||||
marketplace_stats = db.query(
|
marketplace_stats = (
|
||||||
Product.marketplace,
|
db.query(
|
||||||
func.count(Product.id).label('total_products'),
|
Product.marketplace,
|
||||||
func.count(func.distinct(Product.shop_name)).label('unique_shops'),
|
func.count(Product.id).label("total_products"),
|
||||||
func.count(func.distinct(Product.brand)).label('unique_brands')
|
func.count(func.distinct(Product.shop_name)).label("unique_shops"),
|
||||||
).filter(
|
func.count(func.distinct(Product.brand)).label("unique_brands"),
|
||||||
Product.marketplace.isnot(None)
|
)
|
||||||
).group_by(Product.marketplace).all()
|
.filter(Product.marketplace.isnot(None))
|
||||||
|
.group_by(Product.marketplace)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
stats_list = [
|
stats_list = [
|
||||||
{
|
{
|
||||||
"marketplace": stat.marketplace,
|
"marketplace": stat.marketplace,
|
||||||
"total_products": stat.total_products,
|
"total_products": stat.total_products,
|
||||||
"unique_shops": stat.unique_shops,
|
"unique_shops": stat.unique_shops,
|
||||||
"unique_brands": stat.unique_brands
|
"unique_brands": stat.unique_brands,
|
||||||
} for stat in marketplace_stats
|
}
|
||||||
|
for stat in marketplace_stats
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info(f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces")
|
logger.info(
|
||||||
|
f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces"
|
||||||
|
)
|
||||||
return stats_list
|
return stats_list
|
||||||
|
|
||||||
def get_product_count(self, db: Session) -> int:
|
def get_product_count(self, db: Session) -> int:
|
||||||
@@ -101,31 +121,42 @@ class StatsService:
|
|||||||
|
|
||||||
def get_unique_brands_count(self, db: Session) -> int:
|
def get_unique_brands_count(self, db: Session) -> int:
|
||||||
"""Get count of unique brands"""
|
"""Get count of unique brands"""
|
||||||
return db.query(Product.brand).filter(
|
return (
|
||||||
Product.brand.isnot(None),
|
db.query(Product.brand)
|
||||||
Product.brand != ""
|
.filter(Product.brand.isnot(None), Product.brand != "")
|
||||||
).distinct().count()
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
def get_unique_categories_count(self, db: Session) -> int:
|
def get_unique_categories_count(self, db: Session) -> int:
|
||||||
"""Get count of unique categories"""
|
"""Get count of unique categories"""
|
||||||
return db.query(Product.google_product_category).filter(
|
return (
|
||||||
Product.google_product_category.isnot(None),
|
db.query(Product.google_product_category)
|
||||||
Product.google_product_category != ""
|
.filter(
|
||||||
).distinct().count()
|
Product.google_product_category.isnot(None),
|
||||||
|
Product.google_product_category != "",
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
def get_unique_marketplaces_count(self, db: Session) -> int:
|
def get_unique_marketplaces_count(self, db: Session) -> int:
|
||||||
"""Get count of unique marketplaces"""
|
"""Get count of unique marketplaces"""
|
||||||
return db.query(Product.marketplace).filter(
|
return (
|
||||||
Product.marketplace.isnot(None),
|
db.query(Product.marketplace)
|
||||||
Product.marketplace != ""
|
.filter(Product.marketplace.isnot(None), Product.marketplace != "")
|
||||||
).distinct().count()
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
def get_unique_shops_count(self, db: Session) -> int:
|
def get_unique_shops_count(self, db: Session) -> int:
|
||||||
"""Get count of unique shops"""
|
"""Get count of unique shops"""
|
||||||
return db.query(Product.shop_name).filter(
|
return (
|
||||||
Product.shop_name.isnot(None),
|
db.query(Product.shop_name)
|
||||||
Product.shop_name != ""
|
.filter(Product.shop_name.isnot(None), Product.shop_name != "")
|
||||||
).distinct().count()
|
.distinct()
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
def get_stock_statistics(self, db: Session) -> Dict[str, int]:
|
def get_stock_statistics(self, db: Session) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
@@ -142,25 +173,35 @@ class StatsService:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"total_stock_entries": total_stock_entries,
|
"total_stock_entries": total_stock_entries,
|
||||||
"total_inventory_quantity": total_inventory
|
"total_inventory_quantity": total_inventory,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
|
def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
|
||||||
"""Get unique brands for a specific marketplace"""
|
"""Get unique brands for a specific marketplace"""
|
||||||
brands = db.query(Product.brand).filter(
|
brands = (
|
||||||
Product.marketplace == marketplace,
|
db.query(Product.brand)
|
||||||
Product.brand.isnot(None),
|
.filter(
|
||||||
Product.brand != ""
|
Product.marketplace == marketplace,
|
||||||
).distinct().all()
|
Product.brand.isnot(None),
|
||||||
|
Product.brand != "",
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
return [brand[0] for brand in brands]
|
return [brand[0] for brand in brands]
|
||||||
|
|
||||||
def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
|
def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
|
||||||
"""Get unique shops for a specific marketplace"""
|
"""Get unique shops for a specific marketplace"""
|
||||||
shops = db.query(Product.shop_name).filter(
|
shops = (
|
||||||
Product.marketplace == marketplace,
|
db.query(Product.shop_name)
|
||||||
Product.shop_name.isnot(None),
|
.filter(
|
||||||
Product.shop_name != ""
|
Product.marketplace == marketplace,
|
||||||
).distinct().all()
|
Product.shop_name.isnot(None),
|
||||||
|
Product.shop_name != "",
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
return [shop[0] for shop in shops]
|
return [shop[0] for shop in shops]
|
||||||
|
|
||||||
def get_products_by_marketplace(self, db: Session, marketplace: str) -> int:
|
def get_products_by_marketplace(self, db: Session, marketplace: str) -> int:
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from models.database_models import Stock, Product
|
|
||||||
from models.api_models import StockCreate, StockAdd, StockUpdate, StockLocationResponse, StockSummaryResponse
|
|
||||||
from utils.data_processing import GTINProcessor
|
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models.api_models import (StockAdd, StockCreate, StockLocationResponse,
|
||||||
|
StockSummaryResponse, StockUpdate)
|
||||||
|
from models.database_models import Product, Stock
|
||||||
|
from utils.data_processing import GTINProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -26,10 +29,11 @@ class StockService:
|
|||||||
location = stock_data.location.strip().upper()
|
location = stock_data.location.strip().upper()
|
||||||
|
|
||||||
# Check if stock entry already exists for this GTIN and location
|
# Check if stock entry already exists for this GTIN and location
|
||||||
existing_stock = db.query(Stock).filter(
|
existing_stock = (
|
||||||
Stock.gtin == normalized_gtin,
|
db.query(Stock)
|
||||||
Stock.location == location
|
.filter(Stock.gtin == normalized_gtin, Stock.location == location)
|
||||||
).first()
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if existing_stock:
|
if existing_stock:
|
||||||
# Update existing stock (SET to exact quantity)
|
# Update existing stock (SET to exact quantity)
|
||||||
@@ -39,19 +43,20 @@ class StockService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_stock)
|
db.refresh(existing_stock)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity} → {stock_data.quantity}")
|
f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity} → {stock_data.quantity}"
|
||||||
|
)
|
||||||
return existing_stock
|
return existing_stock
|
||||||
else:
|
else:
|
||||||
# Create new stock entry
|
# Create new stock entry
|
||||||
new_stock = Stock(
|
new_stock = Stock(
|
||||||
gtin=normalized_gtin,
|
gtin=normalized_gtin, location=location, quantity=stock_data.quantity
|
||||||
location=location,
|
|
||||||
quantity=stock_data.quantity
|
|
||||||
)
|
)
|
||||||
db.add(new_stock)
|
db.add(new_stock)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(new_stock)
|
db.refresh(new_stock)
|
||||||
logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}")
|
logger.info(
|
||||||
|
f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}"
|
||||||
|
)
|
||||||
return new_stock
|
return new_stock
|
||||||
|
|
||||||
def add_stock(self, db: Session, stock_data: StockAdd) -> Stock:
|
def add_stock(self, db: Session, stock_data: StockAdd) -> Stock:
|
||||||
@@ -63,10 +68,11 @@ class StockService:
|
|||||||
location = stock_data.location.strip().upper()
|
location = stock_data.location.strip().upper()
|
||||||
|
|
||||||
# Check if stock entry already exists for this GTIN and location
|
# Check if stock entry already exists for this GTIN and location
|
||||||
existing_stock = db.query(Stock).filter(
|
existing_stock = (
|
||||||
Stock.gtin == normalized_gtin,
|
db.query(Stock)
|
||||||
Stock.location == location
|
.filter(Stock.gtin == normalized_gtin, Stock.location == location)
|
||||||
).first()
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if existing_stock:
|
if existing_stock:
|
||||||
# Add to existing stock
|
# Add to existing stock
|
||||||
@@ -76,19 +82,20 @@ class StockService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_stock)
|
db.refresh(existing_stock)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}")
|
f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}"
|
||||||
|
)
|
||||||
return existing_stock
|
return existing_stock
|
||||||
else:
|
else:
|
||||||
# Create new stock entry with the quantity
|
# Create new stock entry with the quantity
|
||||||
new_stock = Stock(
|
new_stock = Stock(
|
||||||
gtin=normalized_gtin,
|
gtin=normalized_gtin, location=location, quantity=stock_data.quantity
|
||||||
location=location,
|
|
||||||
quantity=stock_data.quantity
|
|
||||||
)
|
)
|
||||||
db.add(new_stock)
|
db.add(new_stock)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(new_stock)
|
db.refresh(new_stock)
|
||||||
logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}")
|
logger.info(
|
||||||
|
f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}"
|
||||||
|
)
|
||||||
return new_stock
|
return new_stock
|
||||||
|
|
||||||
def remove_stock(self, db: Session, stock_data: StockAdd) -> Stock:
|
def remove_stock(self, db: Session, stock_data: StockAdd) -> Stock:
|
||||||
@@ -100,18 +107,22 @@ class StockService:
|
|||||||
location = stock_data.location.strip().upper()
|
location = stock_data.location.strip().upper()
|
||||||
|
|
||||||
# Find existing stock entry
|
# Find existing stock entry
|
||||||
existing_stock = db.query(Stock).filter(
|
existing_stock = (
|
||||||
Stock.gtin == normalized_gtin,
|
db.query(Stock)
|
||||||
Stock.location == location
|
.filter(Stock.gtin == normalized_gtin, Stock.location == location)
|
||||||
).first()
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not existing_stock:
|
if not existing_stock:
|
||||||
raise ValueError(f"No stock found for GTIN {normalized_gtin} at location {location}")
|
raise ValueError(
|
||||||
|
f"No stock found for GTIN {normalized_gtin} at location {location}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if we have enough stock to remove
|
# Check if we have enough stock to remove
|
||||||
if existing_stock.quantity < stock_data.quantity:
|
if existing_stock.quantity < stock_data.quantity:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}")
|
f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}"
|
||||||
|
)
|
||||||
|
|
||||||
# Remove from existing stock
|
# Remove from existing stock
|
||||||
old_quantity = existing_stock.quantity
|
old_quantity = existing_stock.quantity
|
||||||
@@ -120,7 +131,8 @@ class StockService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_stock)
|
db.refresh(existing_stock)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}")
|
f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}"
|
||||||
|
)
|
||||||
return existing_stock
|
return existing_stock
|
||||||
|
|
||||||
def get_stock_by_gtin(self, db: Session, gtin: str) -> StockSummaryResponse:
|
def get_stock_by_gtin(self, db: Session, gtin: str) -> StockSummaryResponse:
|
||||||
@@ -141,10 +153,9 @@ class StockService:
|
|||||||
|
|
||||||
for entry in stock_entries:
|
for entry in stock_entries:
|
||||||
total_quantity += entry.quantity
|
total_quantity += entry.quantity
|
||||||
locations.append(StockLocationResponse(
|
locations.append(
|
||||||
location=entry.location,
|
StockLocationResponse(location=entry.location, quantity=entry.quantity)
|
||||||
quantity=entry.quantity
|
)
|
||||||
))
|
|
||||||
|
|
||||||
# Try to get product title for reference
|
# Try to get product title for reference
|
||||||
product = db.query(Product).filter(Product.gtin == normalized_gtin).first()
|
product = db.query(Product).filter(Product.gtin == normalized_gtin).first()
|
||||||
@@ -154,7 +165,7 @@ class StockService:
|
|||||||
gtin=normalized_gtin,
|
gtin=normalized_gtin,
|
||||||
total_quantity=total_quantity,
|
total_quantity=total_quantity,
|
||||||
locations=locations,
|
locations=locations,
|
||||||
product_title=product_title
|
product_title=product_title,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_total_stock(self, db: Session, gtin: str) -> dict:
|
def get_total_stock(self, db: Session, gtin: str) -> dict:
|
||||||
@@ -174,16 +185,16 @@ class StockService:
|
|||||||
"gtin": normalized_gtin,
|
"gtin": normalized_gtin,
|
||||||
"total_quantity": total_quantity,
|
"total_quantity": total_quantity,
|
||||||
"product_title": product.title if product else None,
|
"product_title": product.title if product else None,
|
||||||
"locations_count": len(total_stock)
|
"locations_count": len(total_stock),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_all_stock(
|
def get_all_stock(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
location: Optional[str] = None,
|
location: Optional[str] = None,
|
||||||
gtin: Optional[str] = None
|
gtin: Optional[str] = None,
|
||||||
) -> List[Stock]:
|
) -> List[Stock]:
|
||||||
"""Get all stock entries with optional filtering"""
|
"""Get all stock entries with optional filtering"""
|
||||||
query = db.query(Stock)
|
query = db.query(Stock)
|
||||||
@@ -198,7 +209,9 @@ class StockService:
|
|||||||
|
|
||||||
return query.offset(skip).limit(limit).all()
|
return query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
def update_stock(self, db: Session, stock_id: int, stock_update: StockUpdate) -> Stock:
|
def update_stock(
|
||||||
|
self, db: Session, stock_id: int, stock_update: StockUpdate
|
||||||
|
) -> Stock:
|
||||||
"""Update stock quantity for a specific stock entry"""
|
"""Update stock quantity for a specific stock entry"""
|
||||||
stock_entry = db.query(Stock).filter(Stock.id == stock_id).first()
|
stock_entry = db.query(Stock).filter(Stock.id == stock_id).first()
|
||||||
if not stock_entry:
|
if not stock_entry:
|
||||||
@@ -209,7 +222,9 @@ class StockService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(stock_entry)
|
db.refresh(stock_entry)
|
||||||
|
|
||||||
logger.info(f"Updated stock entry {stock_id} to quantity {stock_update.quantity}")
|
logger.info(
|
||||||
|
f"Updated stock entry {stock_id} to quantity {stock_update.quantity}"
|
||||||
|
)
|
||||||
return stock_entry
|
return stock_entry
|
||||||
|
|
||||||
def delete_stock(self, db: Session, stock_id: int) -> bool:
|
def delete_stock(self, db: Session, stock_id: int) -> bool:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# app/tasks/background_tasks.py
|
# app/tasks/background_tasks.py
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.database import SessionLocal
|
from app.core.database import SessionLocal
|
||||||
from models.database_models import MarketplaceImportJob
|
from models.database_models import MarketplaceImportJob
|
||||||
from utils.csv_processor import CSVProcessor
|
from utils.csv_processor import CSVProcessor
|
||||||
@@ -9,11 +10,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def process_marketplace_import(
|
async def process_marketplace_import(
|
||||||
job_id: int,
|
job_id: int, url: str, marketplace: str, shop_name: str, batch_size: int = 1000
|
||||||
url: str,
|
|
||||||
marketplace: str,
|
|
||||||
shop_name: str,
|
|
||||||
batch_size: int = 1000
|
|
||||||
):
|
):
|
||||||
"""Background task to process marketplace CSV import"""
|
"""Background task to process marketplace CSV import"""
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
@@ -22,7 +19,11 @@ async def process_marketplace_import(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Update job status
|
# Update job status
|
||||||
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
|
job = (
|
||||||
|
db.query(MarketplaceImportJob)
|
||||||
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not job:
|
if not job:
|
||||||
logger.error(f"Import job {job_id} not found")
|
logger.error(f"Import job {job_id} not found")
|
||||||
return
|
return
|
||||||
@@ -70,7 +71,7 @@ async def process_marketplace_import(
|
|||||||
finally:
|
finally:
|
||||||
# Close the database session only if it's not a mock
|
# Close the database session only if it's not a mock
|
||||||
# In tests, we use the same session so we shouldn't close it
|
# In tests, we use the same session so we shouldn't close it
|
||||||
if hasattr(db, 'close') and callable(getattr(db, 'close')):
|
if hasattr(db, "close") and callable(getattr(db, "close")):
|
||||||
try:
|
try:
|
||||||
db.close()
|
db.close()
|
||||||
except Exception as close_error:
|
except Exception as close_error:
|
||||||
|
|||||||
@@ -6,21 +6,21 @@ import requests
|
|||||||
# API Base URL
|
# API Base URL
|
||||||
BASE_URL = "http://localhost:8000"
|
BASE_URL = "http://localhost:8000"
|
||||||
|
|
||||||
|
|
||||||
def register_user(email, username, password):
|
def register_user(email, username, password):
|
||||||
"""Register a new user"""
|
"""Register a new user"""
|
||||||
response = requests.post(f"{BASE_URL}/register", json={
|
response = requests.post(
|
||||||
"email": email,
|
f"{BASE_URL}/register",
|
||||||
"username": username,
|
json={"email": email, "username": username, "password": password},
|
||||||
"password": password
|
)
|
||||||
})
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def login_user(username, password):
|
def login_user(username, password):
|
||||||
"""Login and get JWT token"""
|
"""Login and get JWT token"""
|
||||||
response = requests.post(f"{BASE_URL}/login", json={
|
response = requests.post(
|
||||||
"username": username,
|
f"{BASE_URL}/login", json={"username": username, "password": password}
|
||||||
"password": password
|
)
|
||||||
})
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return data["access_token"]
|
return data["access_token"]
|
||||||
@@ -28,24 +28,30 @@ def login_user(username, password):
|
|||||||
print(f"Login failed: {response.json()}")
|
print(f"Login failed: {response.json()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_user_info(token):
|
def get_user_info(token):
|
||||||
"""Get current user info"""
|
"""Get current user info"""
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
response = requests.get(f"{BASE_URL}/me", headers=headers)
|
response = requests.get(f"{BASE_URL}/me", headers=headers)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def get_products(token, skip=0, limit=10):
|
def get_products(token, skip=0, limit=10):
|
||||||
"""Get products (requires authentication)"""
|
"""Get products (requires authentication)"""
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
response = requests.get(f"{BASE_URL}/products?skip={skip}&limit={limit}", headers=headers)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/products?skip={skip}&limit={limit}", headers=headers
|
||||||
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def create_product(token, product_data):
|
def create_product(token, product_data):
|
||||||
"""Create a new product (requires authentication)"""
|
"""Create a new product (requires authentication)"""
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
response = requests.post(f"{BASE_URL}/products", json=product_data, headers=headers)
|
response = requests.post(f"{BASE_URL}/products", json=product_data, headers=headers)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 1. Register a new user
|
# 1. Register a new user
|
||||||
@@ -75,7 +81,7 @@ if __name__ == "__main__":
|
|||||||
"description": "A test product for demonstration",
|
"description": "A test product for demonstration",
|
||||||
"price": "19.99",
|
"price": "19.99",
|
||||||
"brand": "Test Brand",
|
"brand": "Test Brand",
|
||||||
"availability": "in stock"
|
"availability": "in stock",
|
||||||
}
|
}
|
||||||
|
|
||||||
product_result = create_product(admin_token, sample_product)
|
product_result = create_product(admin_token, sample_product)
|
||||||
@@ -97,7 +103,9 @@ if __name__ == "__main__":
|
|||||||
print(f"Regular user info: {user_info}")
|
print(f"Regular user info: {user_info}")
|
||||||
|
|
||||||
products = get_products(user_token, limit=5)
|
products = get_products(user_token, limit=5)
|
||||||
print(f"Products accessible to regular user: {len(products.get('products', []))} products")
|
print(
|
||||||
|
f"Products accessible to regular user: {len(products.get('products', []))} products"
|
||||||
|
)
|
||||||
|
|
||||||
print("\nAuthentication example completed!")
|
print("\nAuthentication example completed!")
|
||||||
|
|
||||||
|
|||||||
35
main.py
35
main.py
@@ -1,13 +1,15 @@
|
|||||||
from fastapi import FastAPI, Depends, HTTPException
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from sqlalchemy import text
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.lifespan import lifespan
|
|
||||||
from app.core.database import get_db
|
|
||||||
from datetime import datetime
|
|
||||||
from app.api.main import api_router
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.main import api_router
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.lifespan import lifespan
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -16,7 +18,7 @@ app = FastAPI(
|
|||||||
title=settings.project_name,
|
title=settings.project_name,
|
||||||
description=settings.description,
|
description=settings.description,
|
||||||
version=settings.version,
|
version=settings.version,
|
||||||
lifespan=lifespan
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add CORS middleware
|
||||||
@@ -44,10 +46,17 @@ def root():
|
|||||||
"JWT Authentication",
|
"JWT Authentication",
|
||||||
"Marketplace-aware product import",
|
"Marketplace-aware product import",
|
||||||
"Multi-shop product management",
|
"Multi-shop product management",
|
||||||
"Stock management with location tracking"
|
"Stock management with location tracking",
|
||||||
],
|
],
|
||||||
"supported_marketplaces": ["Letzshop", "Amazon", "eBay", "Etsy", "Shopify", "Other"],
|
"supported_marketplaces": [
|
||||||
"auth_required": "Most endpoints require Bearer token authentication"
|
"Letzshop",
|
||||||
|
"Amazon",
|
||||||
|
"eBay",
|
||||||
|
"Etsy",
|
||||||
|
"Shopify",
|
||||||
|
"Other",
|
||||||
|
],
|
||||||
|
"auth_required": "Most endpoints require Bearer token authentication",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
# middleware/auth.py
|
# middleware/auth.py
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.security import HTTPAuthorizationCredentials
|
from fastapi.security import HTTPAuthorizationCredentials
|
||||||
from passlib.context import CryptContext
|
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from datetime import datetime, timedelta
|
from passlib.context import CryptContext
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from models.database_models import User
|
from models.database_models import User
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -20,7 +22,9 @@ class AuthManager:
|
|||||||
"""JWT-based authentication manager with bcrypt password hashing"""
|
"""JWT-based authentication manager with bcrypt password hashing"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production-please")
|
self.secret_key = os.getenv(
|
||||||
|
"JWT_SECRET_KEY", "your-secret-key-change-in-production-please"
|
||||||
|
)
|
||||||
self.algorithm = "HS256"
|
self.algorithm = "HS256"
|
||||||
self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30"))
|
self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30"))
|
||||||
|
|
||||||
@@ -32,11 +36,15 @@ class AuthManager:
|
|||||||
"""Verify password against hash"""
|
"""Verify password against hash"""
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
|
def authenticate_user(
|
||||||
|
self, db: Session, username: str, password: str
|
||||||
|
) -> Optional[User]:
|
||||||
"""Authenticate user and return user object if valid"""
|
"""Authenticate user and return user object if valid"""
|
||||||
user = db.query(User).filter(
|
user = (
|
||||||
(User.username == username) | (User.email == username)
|
db.query(User)
|
||||||
).first()
|
.filter((User.username == username) | (User.email == username))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
@@ -65,7 +73,7 @@ class AuthManager:
|
|||||||
"email": user.email,
|
"email": user.email,
|
||||||
"role": user.role,
|
"role": user.role,
|
||||||
"exp": expire,
|
"exp": expire,
|
||||||
"iat": datetime.utcnow()
|
"iat": datetime.utcnow(),
|
||||||
}
|
}
|
||||||
|
|
||||||
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||||
@@ -73,7 +81,7 @@ class AuthManager:
|
|||||||
return {
|
return {
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
"expires_in": self.token_expire_minutes * 60 # Return in seconds
|
"expires_in": self.token_expire_minutes * 60, # Return in seconds
|
||||||
}
|
}
|
||||||
|
|
||||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||||
@@ -92,25 +100,31 @@ class AuthManager:
|
|||||||
# Extract user data
|
# Extract user data
|
||||||
user_id = payload.get("sub")
|
user_id = payload.get("sub")
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise HTTPException(status_code=401, detail="Token missing user identifier")
|
raise HTTPException(
|
||||||
|
status_code=401, detail="Token missing user identifier"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"user_id": int(user_id),
|
"user_id": int(user_id),
|
||||||
"username": payload.get("username"),
|
"username": payload.get("username"),
|
||||||
"email": payload.get("email"),
|
"email": payload.get("email"),
|
||||||
"role": payload.get("role", "user")
|
"role": payload.get("role", "user"),
|
||||||
}
|
}
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
raise HTTPException(status_code=401, detail="Token has expired")
|
raise HTTPException(status_code=401, detail="Token has expired")
|
||||||
except jwt.JWTError as e:
|
except jwt.JWTError as e:
|
||||||
logger.error(f"JWT decode error: {e}")
|
logger.error(f"JWT decode error: {e}")
|
||||||
raise HTTPException(status_code=401, detail="Could not validate credentials")
|
raise HTTPException(
|
||||||
|
status_code=401, detail="Could not validate credentials"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Token verification error: {e}")
|
logger.error(f"Token verification error: {e}")
|
||||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||||
|
|
||||||
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User:
|
def get_current_user(
|
||||||
|
self, db: Session, credentials: HTTPAuthorizationCredentials
|
||||||
|
) -> User:
|
||||||
"""Get current authenticated user from database"""
|
"""Get current authenticated user from database"""
|
||||||
user_data = self.verify_token(credentials.credentials)
|
user_data = self.verify_token(credentials.credentials)
|
||||||
|
|
||||||
@@ -131,7 +145,7 @@ class AuthManager:
|
|||||||
if current_user.role != required_role:
|
if current_user.role != required_role:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'"
|
detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'",
|
||||||
)
|
)
|
||||||
return func(current_user, *args, **kwargs)
|
return func(current_user, *args, **kwargs)
|
||||||
|
|
||||||
@@ -142,10 +156,7 @@ class AuthManager:
|
|||||||
def require_admin(self, current_user: User):
|
def require_admin(self, current_user: User):
|
||||||
"""Require admin role"""
|
"""Require admin role"""
|
||||||
if current_user.role != "admin":
|
if current_user.role != "admin":
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=403, detail="Admin privileges required")
|
||||||
status_code=403,
|
|
||||||
detail="Admin privileges required"
|
|
||||||
)
|
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
def create_default_admin_user(self, db: Session):
|
def create_default_admin_user(self, db: Session):
|
||||||
@@ -159,11 +170,13 @@ class AuthManager:
|
|||||||
username="admin",
|
username="admin",
|
||||||
hashed_password=hashed_password,
|
hashed_password=hashed_password,
|
||||||
role="admin",
|
role="admin",
|
||||||
is_active=True
|
is_active=True,
|
||||||
)
|
)
|
||||||
db.add(admin_user)
|
db.add(admin_user)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(admin_user)
|
db.refresh(admin_user)
|
||||||
logger.info("Default admin user created: username='admin', password='admin123'")
|
logger.info(
|
||||||
|
"Default admin user created: username='admin', password='admin123'"
|
||||||
|
)
|
||||||
|
|
||||||
return admin_user
|
return admin_user
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
# middleware/decorators.py
|
# middleware/decorators.py
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from middleware.rate_limiter import RateLimiter
|
from middleware.rate_limiter import RateLimiter
|
||||||
|
|
||||||
# Initialize rate limiter instance
|
# Initialize rate limiter instance
|
||||||
@@ -17,10 +19,7 @@ def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
|
|||||||
client_id = "anonymous" # In production, extract from request
|
client_id = "anonymous" # In production, extract from request
|
||||||
|
|
||||||
if not rate_limiter.allow_request(client_id, max_requests, window_seconds):
|
if not rate_limiter.allow_request(client_id, max_requests, window_seconds):
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||||||
status_code=429,
|
|
||||||
detail="Rate limit exceeded"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
# middleware/error_handler.py
|
# middleware/error_handler.py
|
||||||
from fastapi import Request, HTTPException
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def custom_http_exception_handler(request: Request, exc: HTTPException):
|
async def custom_http_exception_handler(request: Request, exc: HTTPException):
|
||||||
"""Custom HTTP exception handler"""
|
"""Custom HTTP exception handler"""
|
||||||
logger.error(f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}")
|
logger.error(
|
||||||
|
f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}"
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=exc.status_code,
|
status_code=exc.status_code,
|
||||||
@@ -17,11 +20,12 @@ async def custom_http_exception_handler(request: Request, exc: HTTPException):
|
|||||||
"error": {
|
"error": {
|
||||||
"code": exc.status_code,
|
"code": exc.status_code,
|
||||||
"message": exc.detail,
|
"message": exc.detail,
|
||||||
"type": "http_exception"
|
"type": "http_exception",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
"""Handle Pydantic validation errors"""
|
"""Handle Pydantic validation errors"""
|
||||||
logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}")
|
logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}")
|
||||||
@@ -33,14 +37,17 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|||||||
"code": 422,
|
"code": 422,
|
||||||
"message": "Validation error",
|
"message": "Validation error",
|
||||||
"type": "validation_error",
|
"type": "validation_error",
|
||||||
"details": exc.errors()
|
"details": exc.errors(),
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def general_exception_handler(request: Request, exc: Exception):
|
async def general_exception_handler(request: Request, exc: Exception):
|
||||||
"""Handle unexpected exceptions"""
|
"""Handle unexpected exceptions"""
|
||||||
logger.error(f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True)
|
logger.error(
|
||||||
|
f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
@@ -48,8 +55,7 @@ async def general_exception_handler(request: Request, exc: Exception):
|
|||||||
"error": {
|
"error": {
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"message": "Internal server error",
|
"message": "Internal server error",
|
||||||
"type": "server_error"
|
"type": "server_error",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# middleware/logging_middleware.py
|
# middleware/logging_middleware.py
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
# middleware/rate_limiter.py
|
# middleware/rate_limiter.py
|
||||||
from typing import Dict
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -16,7 +16,9 @@ class RateLimiter:
|
|||||||
self.cleanup_interval = 3600 # Clean up old entries every hour
|
self.cleanup_interval = 3600 # Clean up old entries every hour
|
||||||
self.last_cleanup = datetime.utcnow()
|
self.last_cleanup = datetime.utcnow()
|
||||||
|
|
||||||
def allow_request(self, client_id: str, max_requests: int, window_seconds: int) -> bool:
|
def allow_request(
|
||||||
|
self, client_id: str, max_requests: int, window_seconds: int
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if client is allowed to make a request
|
Check if client is allowed to make a request
|
||||||
Uses sliding window algorithm
|
Uses sliding window algorithm
|
||||||
@@ -41,7 +43,9 @@ class RateLimiter:
|
|||||||
client_requests.append(now)
|
client_requests.append(now)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.warning(f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}")
|
logger.warning(
|
||||||
|
f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _cleanup_old_entries(self):
|
def _cleanup_old_entries(self):
|
||||||
@@ -62,7 +66,9 @@ class RateLimiter:
|
|||||||
for client_id in clients_to_remove:
|
for client_id in clients_to_remove:
|
||||||
del self.clients[client_id]
|
del self.clients[client_id]
|
||||||
|
|
||||||
logger.info(f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients")
|
logger.info(
|
||||||
|
f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients"
|
||||||
|
)
|
||||||
|
|
||||||
def get_client_stats(self, client_id: str) -> Dict[str, int]:
|
def get_client_stats(self, client_id: str) -> Dict[str, int]:
|
||||||
"""Get statistics for a specific client"""
|
"""Get statistics for a specific client"""
|
||||||
@@ -72,11 +78,13 @@ class RateLimiter:
|
|||||||
hour_ago = now - timedelta(hours=1)
|
hour_ago = now - timedelta(hours=1)
|
||||||
day_ago = now - timedelta(days=1)
|
day_ago = now - timedelta(days=1)
|
||||||
|
|
||||||
requests_last_hour = sum(1 for req_time in client_requests if req_time > hour_ago)
|
requests_last_hour = sum(
|
||||||
|
1 for req_time in client_requests if req_time > hour_ago
|
||||||
|
)
|
||||||
requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago)
|
requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"requests_last_hour": requests_last_hour,
|
"requests_last_hour": requests_last_hour,
|
||||||
"requests_last_day": requests_last_day,
|
"requests_last_day": requests_last_day,
|
||||||
"total_tracked_requests": len(client_requests)
|
"total_tracked_requests": len(client_requests),
|
||||||
}
|
}
|
||||||
@@ -1,28 +1,35 @@
|
|||||||
# models/api_models.py - Updated with Marketplace Support and Pydantic v2
|
# models/api_models.py - Updated with Marketplace Support and Pydantic v2
|
||||||
from pydantic import BaseModel, Field, field_validator, EmailStr, ConfigDict
|
|
||||||
from typing import Optional, List
|
|
||||||
from datetime import datetime
|
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
# User Authentication Models
|
# User Authentication Models
|
||||||
class UserRegister(BaseModel):
|
class UserRegister(BaseModel):
|
||||||
email: EmailStr = Field(..., description="Valid email address")
|
email: EmailStr = Field(..., description="Valid email address")
|
||||||
username: str = Field(..., min_length=3, max_length=50, description="Username (3-50 characters)")
|
username: str = Field(
|
||||||
password: str = Field(..., min_length=6, description="Password (minimum 6 characters)")
|
..., min_length=3, max_length=50, description="Username (3-50 characters)"
|
||||||
|
)
|
||||||
|
password: str = Field(
|
||||||
|
..., min_length=6, description="Password (minimum 6 characters)"
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator('username')
|
@field_validator("username")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_username(cls, v):
|
def validate_username(cls, v):
|
||||||
if not re.match(r'^[a-zA-Z0-9_]+$', v):
|
if not re.match(r"^[a-zA-Z0-9_]+$", v):
|
||||||
raise ValueError('Username must contain only letters, numbers, or underscores')
|
raise ValueError(
|
||||||
|
"Username must contain only letters, numbers, or underscores"
|
||||||
|
)
|
||||||
return v.lower().strip()
|
return v.lower().strip()
|
||||||
|
|
||||||
@field_validator('password')
|
@field_validator("password")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_password(cls, v):
|
def validate_password(cls, v):
|
||||||
if len(v) < 6:
|
if len(v) < 6:
|
||||||
raise ValueError('Password must be at least 6 characters long')
|
raise ValueError("Password must be at least 6 characters long")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +37,7 @@ class UserLogin(BaseModel):
|
|||||||
username: str = Field(..., description="Username")
|
username: str = Field(..., description="Username")
|
||||||
password: str = Field(..., description="Password")
|
password: str = Field(..., description="Password")
|
||||||
|
|
||||||
@field_validator('username')
|
@field_validator("username")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_username(cls, v):
|
def validate_username(cls, v):
|
||||||
return v.strip()
|
return v.strip()
|
||||||
@@ -58,27 +65,38 @@ class LoginResponse(BaseModel):
|
|||||||
|
|
||||||
# NEW: Shop models
|
# NEW: Shop models
|
||||||
class ShopCreate(BaseModel):
|
class ShopCreate(BaseModel):
|
||||||
shop_code: str = Field(..., min_length=3, max_length=50, description="Unique shop code (e.g., TECHSTORE)")
|
shop_code: str = Field(
|
||||||
shop_name: str = Field(..., min_length=1, max_length=200, description="Display name of the shop")
|
...,
|
||||||
description: Optional[str] = Field(None, max_length=2000, description="Shop description")
|
min_length=3,
|
||||||
|
max_length=50,
|
||||||
|
description="Unique shop code (e.g., TECHSTORE)",
|
||||||
|
)
|
||||||
|
shop_name: str = Field(
|
||||||
|
..., min_length=1, max_length=200, description="Display name of the shop"
|
||||||
|
)
|
||||||
|
description: Optional[str] = Field(
|
||||||
|
None, max_length=2000, description="Shop description"
|
||||||
|
)
|
||||||
contact_email: Optional[str] = None
|
contact_email: Optional[str] = None
|
||||||
contact_phone: Optional[str] = None
|
contact_phone: Optional[str] = None
|
||||||
website: Optional[str] = None
|
website: Optional[str] = None
|
||||||
business_address: Optional[str] = None
|
business_address: Optional[str] = None
|
||||||
tax_number: Optional[str] = None
|
tax_number: Optional[str] = None
|
||||||
|
|
||||||
@field_validator('shop_code')
|
@field_validator("shop_code")
|
||||||
def validate_shop_code(cls, v):
|
def validate_shop_code(cls, v):
|
||||||
# Convert to uppercase and check format
|
# Convert to uppercase and check format
|
||||||
v = v.upper().strip()
|
v = v.upper().strip()
|
||||||
if not v.replace('_', '').replace('-', '').isalnum():
|
if not v.replace("_", "").replace("-", "").isalnum():
|
||||||
raise ValueError('Shop code must be alphanumeric (underscores and hyphens allowed)')
|
raise ValueError(
|
||||||
|
"Shop code must be alphanumeric (underscores and hyphens allowed)"
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('contact_email')
|
@field_validator("contact_email")
|
||||||
def validate_contact_email(cls, v):
|
def validate_contact_email(cls, v):
|
||||||
if v and ('@' not in v or '.' not in v):
|
if v and ("@" not in v or "." not in v):
|
||||||
raise ValueError('Invalid email format')
|
raise ValueError("Invalid email format")
|
||||||
return v.lower() if v else v
|
return v.lower() if v else v
|
||||||
|
|
||||||
|
|
||||||
@@ -91,10 +109,10 @@ class ShopUpdate(BaseModel):
|
|||||||
business_address: Optional[str] = None
|
business_address: Optional[str] = None
|
||||||
tax_number: Optional[str] = None
|
tax_number: Optional[str] = None
|
||||||
|
|
||||||
@field_validator('contact_email')
|
@field_validator("contact_email")
|
||||||
def validate_contact_email(cls, v):
|
def validate_contact_email(cls, v):
|
||||||
if v and ('@' not in v or '.' not in v):
|
if v and ("@" not in v or "." not in v):
|
||||||
raise ValueError('Invalid email format')
|
raise ValueError("Invalid email format")
|
||||||
return v.lower() if v else v
|
return v.lower() if v else v
|
||||||
|
|
||||||
|
|
||||||
@@ -172,11 +190,11 @@ class ProductCreate(ProductBase):
|
|||||||
product_id: str = Field(..., min_length=1, description="Product ID is required")
|
product_id: str = Field(..., min_length=1, description="Product ID is required")
|
||||||
title: str = Field(..., min_length=1, description="Title is required")
|
title: str = Field(..., min_length=1, description="Title is required")
|
||||||
|
|
||||||
@field_validator('product_id', 'title')
|
@field_validator("product_id", "title")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_required_fields(cls, v):
|
def validate_required_fields(cls, v):
|
||||||
if not v or not v.strip():
|
if not v or not v.strip():
|
||||||
raise ValueError('Field cannot be empty')
|
raise ValueError("Field cannot be empty")
|
||||||
return v.strip()
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -195,15 +213,25 @@ class ProductResponse(ProductBase):
|
|||||||
# NEW: Shop Product models
|
# NEW: Shop Product models
|
||||||
class ShopProductCreate(BaseModel):
|
class ShopProductCreate(BaseModel):
|
||||||
product_id: str = Field(..., description="Product ID to add to shop")
|
product_id: str = Field(..., description="Product ID to add to shop")
|
||||||
shop_product_id: Optional[str] = Field(None, description="Shop's internal product ID")
|
shop_product_id: Optional[str] = Field(
|
||||||
shop_price: Optional[float] = Field(None, ge=0, description="Shop-specific price override")
|
None, description="Shop's internal product ID"
|
||||||
shop_sale_price: Optional[float] = Field(None, ge=0, description="Shop-specific sale price")
|
)
|
||||||
|
shop_price: Optional[float] = Field(
|
||||||
|
None, ge=0, description="Shop-specific price override"
|
||||||
|
)
|
||||||
|
shop_sale_price: Optional[float] = Field(
|
||||||
|
None, ge=0, description="Shop-specific sale price"
|
||||||
|
)
|
||||||
shop_currency: Optional[str] = Field(None, description="Shop-specific currency")
|
shop_currency: Optional[str] = Field(None, description="Shop-specific currency")
|
||||||
shop_availability: Optional[str] = Field(None, description="Shop-specific availability")
|
shop_availability: Optional[str] = Field(
|
||||||
|
None, description="Shop-specific availability"
|
||||||
|
)
|
||||||
shop_condition: Optional[str] = Field(None, description="Shop-specific condition")
|
shop_condition: Optional[str] = Field(None, description="Shop-specific condition")
|
||||||
is_featured: bool = Field(False, description="Featured product flag")
|
is_featured: bool = Field(False, description="Featured product flag")
|
||||||
min_quantity: int = Field(1, ge=1, description="Minimum order quantity")
|
min_quantity: int = Field(1, ge=1, description="Minimum order quantity")
|
||||||
max_quantity: Optional[int] = Field(None, ge=1, description="Maximum order quantity")
|
max_quantity: Optional[int] = Field(
|
||||||
|
None, ge=1, description="Maximum order quantity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShopProductResponse(BaseModel):
|
class ShopProductResponse(BaseModel):
|
||||||
@@ -270,28 +298,40 @@ class StockSummaryResponse(BaseModel):
|
|||||||
# Marketplace Import Models
|
# Marketplace Import Models
|
||||||
class MarketplaceImportRequest(BaseModel):
|
class MarketplaceImportRequest(BaseModel):
|
||||||
url: str = Field(..., description="URL to CSV file from marketplace")
|
url: str = Field(..., description="URL to CSV file from marketplace")
|
||||||
marketplace: str = Field(default="Letzshop", description="Name of the marketplace (e.g., Letzshop, Amazon, eBay)")
|
marketplace: str = Field(
|
||||||
|
default="Letzshop",
|
||||||
|
description="Name of the marketplace (e.g., Letzshop, Amazon, eBay)",
|
||||||
|
)
|
||||||
shop_code: str = Field(..., description="Shop code to associate products with")
|
shop_code: str = Field(..., description="Shop code to associate products with")
|
||||||
batch_size: Optional[int] = Field(1000, gt=0, le=10000, description="Batch size for processing")
|
batch_size: Optional[int] = Field(
|
||||||
|
1000, gt=0, le=10000, description="Batch size for processing"
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator('url')
|
@field_validator("url")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_url(cls, v):
|
def validate_url(cls, v):
|
||||||
if not v.startswith(('http://', 'https://')):
|
if not v.startswith(("http://", "https://")):
|
||||||
raise ValueError('URL must start with http:// or https://')
|
raise ValueError("URL must start with http:// or https://")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('marketplace')
|
@field_validator("marketplace")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_marketplace(cls, v):
|
def validate_marketplace(cls, v):
|
||||||
# You can add validation for supported marketplaces here
|
# You can add validation for supported marketplaces here
|
||||||
supported_marketplaces = ['Letzshop', 'Amazon', 'eBay', 'Etsy', 'Shopify', 'Other']
|
supported_marketplaces = [
|
||||||
|
"Letzshop",
|
||||||
|
"Amazon",
|
||||||
|
"eBay",
|
||||||
|
"Etsy",
|
||||||
|
"Shopify",
|
||||||
|
"Other",
|
||||||
|
]
|
||||||
if v not in supported_marketplaces:
|
if v not in supported_marketplaces:
|
||||||
# For now, allow any marketplace but log it
|
# For now, allow any marketplace but log it
|
||||||
pass
|
pass
|
||||||
return v.strip()
|
return v.strip()
|
||||||
|
|
||||||
@field_validator('shop_code')
|
@field_validator("shop_code")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_shop_code(cls, v):
|
def validate_shop_code(cls, v):
|
||||||
return v.upper().strip()
|
return v.upper().strip()
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
# models/database_models.py - Updated with Marketplace Support
|
# models/database_models.py - Updated with Marketplace Support
|
||||||
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text, ForeignKey, UniqueConstraint, Index
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Index,
|
||||||
|
Integer, String, Text, UniqueConstraint)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
# Import Base from the central database module instead of creating a new one
|
# Import Base from the central database module instead of creating a new one
|
||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
|
||||||
@@ -18,10 +20,14 @@ class User(Base):
|
|||||||
is_active = Column(Boolean, default=True, nullable=False)
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
last_login = Column(DateTime, nullable=True)
|
last_login = Column(DateTime, nullable=True)
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
updated_at = Column(
|
||||||
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="user")
|
marketplace_import_jobs = relationship(
|
||||||
|
"MarketplaceImportJob", back_populates="user"
|
||||||
|
)
|
||||||
owned_shops = relationship("Shop", back_populates="owner")
|
owned_shops = relationship("Shop", back_populates="owner")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -32,7 +38,9 @@ class Shop(Base):
|
|||||||
__tablename__ = "shops"
|
__tablename__ = "shops"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
shop_code = Column(String, unique=True, index=True, nullable=False) # e.g., "TECHSTORE", "FASHIONHUB"
|
shop_code = Column(
|
||||||
|
String, unique=True, index=True, nullable=False
|
||||||
|
) # e.g., "TECHSTORE", "FASHIONHUB"
|
||||||
shop_name = Column(String, nullable=False) # Display name
|
shop_name = Column(String, nullable=False) # Display name
|
||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
@@ -57,7 +65,9 @@ class Shop(Base):
|
|||||||
# Relationships
|
# Relationships
|
||||||
owner = relationship("User", back_populates="owned_shops")
|
owner = relationship("User", back_populates="owned_shops")
|
||||||
shop_products = relationship("ShopProduct", back_populates="shop")
|
shop_products = relationship("ShopProduct", back_populates="shop")
|
||||||
marketplace_import_jobs = relationship("MarketplaceImportJob", back_populates="shop")
|
marketplace_import_jobs = relationship(
|
||||||
|
"MarketplaceImportJob", back_populates="shop"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Product(Base):
|
class Product(Base):
|
||||||
@@ -103,26 +113,40 @@ class Product(Base):
|
|||||||
currency = Column(String)
|
currency = Column(String)
|
||||||
|
|
||||||
# New marketplace fields
|
# New marketplace fields
|
||||||
marketplace = Column(String, index=True, nullable=True, default="Letzshop") # Index for marketplace filtering
|
marketplace = Column(
|
||||||
|
String, index=True, nullable=True, default="Letzshop"
|
||||||
|
) # Index for marketplace filtering
|
||||||
shop_name = Column(String, index=True, nullable=True) # Index for shop filtering
|
shop_name = Column(String, index=True, nullable=True) # Index for shop filtering
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
updated_at = Column(
|
||||||
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
# Relationship to stock (one-to-many via GTIN)
|
# Relationship to stock (one-to-many via GTIN)
|
||||||
stock_entries = relationship("Stock", foreign_keys="Stock.gtin", primaryjoin="Product.gtin == Stock.gtin",
|
stock_entries = relationship(
|
||||||
viewonly=True)
|
"Stock",
|
||||||
|
foreign_keys="Stock.gtin",
|
||||||
|
primaryjoin="Product.gtin == Stock.gtin",
|
||||||
|
viewonly=True,
|
||||||
|
)
|
||||||
shop_products = relationship("ShopProduct", back_populates="product")
|
shop_products = relationship("ShopProduct", back_populates="product")
|
||||||
|
|
||||||
# Additional indexes for marketplace queries
|
# Additional indexes for marketplace queries
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('idx_marketplace_shop', 'marketplace', 'shop_name'), # Composite index for marketplace+shop queries
|
Index(
|
||||||
Index('idx_marketplace_brand', 'marketplace', 'brand'), # Composite index for marketplace+brand queries
|
"idx_marketplace_shop", "marketplace", "shop_name"
|
||||||
|
), # Composite index for marketplace+shop queries
|
||||||
|
Index(
|
||||||
|
"idx_marketplace_brand", "marketplace", "brand"
|
||||||
|
), # Composite index for marketplace+brand queries
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"<Product(product_id='{self.product_id}', title='{self.title}', marketplace='{self.marketplace}', "
|
return (
|
||||||
f"shop='{self.shop_name}')>")
|
f"<Product(product_id='{self.product_id}', title='{self.title}', marketplace='{self.marketplace}', "
|
||||||
|
f"shop='{self.shop_name}')>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShopProduct(Base):
|
class ShopProduct(Base):
|
||||||
@@ -159,9 +183,9 @@ class ShopProduct(Base):
|
|||||||
|
|
||||||
# Constraints
|
# Constraints
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint('shop_id', 'product_id', name='uq_shop_product'),
|
UniqueConstraint("shop_id", "product_id", name="uq_shop_product"),
|
||||||
Index('idx_shop_product_active', 'shop_id', 'is_active'),
|
Index("idx_shop_product_active", "shop_id", "is_active"),
|
||||||
Index('idx_shop_product_featured', 'shop_id', 'is_featured'),
|
Index("idx_shop_product_featured", "shop_id", "is_featured"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,22 +193,28 @@ class Stock(Base):
|
|||||||
__tablename__ = "stock"
|
__tablename__ = "stock"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
gtin = Column(String, index=True, nullable=False) # Foreign key relationship would be ideal
|
gtin = Column(
|
||||||
|
String, index=True, nullable=False
|
||||||
|
) # Foreign key relationship would be ideal
|
||||||
location = Column(String, nullable=False, index=True)
|
location = Column(String, nullable=False, index=True)
|
||||||
quantity = Column(Integer, nullable=False, default=0)
|
quantity = Column(Integer, nullable=False, default=0)
|
||||||
reserved_quantity = Column(Integer, default=0) # For orders being processed
|
reserved_quantity = Column(Integer, default=0) # For orders being processed
|
||||||
shop_id = Column(Integer, ForeignKey("shops.id")) # Optional: shop-specific stock
|
shop_id = Column(Integer, ForeignKey("shops.id")) # Optional: shop-specific stock
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
updated_at = Column(
|
||||||
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
shop = relationship("Shop")
|
shop = relationship("Shop")
|
||||||
|
|
||||||
# Composite unique constraint to prevent duplicate GTIN-location combinations
|
# Composite unique constraint to prevent duplicate GTIN-location combinations
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint('gtin', 'location', name='uq_stock_gtin_location'),
|
UniqueConstraint("gtin", "location", name="uq_stock_gtin_location"),
|
||||||
Index('idx_stock_gtin_location', 'gtin', 'location'), # Composite index for efficient queries
|
Index(
|
||||||
|
"idx_stock_gtin_location", "gtin", "location"
|
||||||
|
), # Composite index for efficient queries
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -195,13 +225,20 @@ class MarketplaceImportJob(Base):
|
|||||||
__tablename__ = "marketplace_import_jobs"
|
__tablename__ = "marketplace_import_jobs"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
status = Column(String, nullable=False,
|
status = Column(
|
||||||
default="pending") # pending, processing, completed, failed, completed_with_errors
|
String, nullable=False, default="pending"
|
||||||
|
) # pending, processing, completed, failed, completed_with_errors
|
||||||
source_url = Column(String, nullable=False)
|
source_url = Column(String, nullable=False)
|
||||||
marketplace = Column(String, nullable=False, index=True, default="Letzshop") # Index for marketplace filtering
|
marketplace = Column(
|
||||||
|
String, nullable=False, index=True, default="Letzshop"
|
||||||
|
) # Index for marketplace filtering
|
||||||
shop_name = Column(String, nullable=False, index=True) # Index for shop filtering
|
shop_name = Column(String, nullable=False, index=True) # Index for shop filtering
|
||||||
shop_id = Column(Integer, ForeignKey("shops.id"), nullable=False) # Add proper foreign key
|
shop_id = Column(
|
||||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False) # Foreign key to users table
|
Integer, ForeignKey("shops.id"), nullable=False
|
||||||
|
) # Add proper foreign key
|
||||||
|
user_id = Column(
|
||||||
|
Integer, ForeignKey("users.id"), nullable=False
|
||||||
|
) # Foreign key to users table
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
imported_count = Column(Integer, default=0)
|
imported_count = Column(Integer, default=0)
|
||||||
@@ -223,11 +260,15 @@ class MarketplaceImportJob(Base):
|
|||||||
|
|
||||||
# Additional indexes for marketplace import job queries
|
# Additional indexes for marketplace import job queries
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('idx_marketplace_import_user_marketplace', 'user_id', 'marketplace'), # User's marketplace imports
|
Index(
|
||||||
Index('idx_marketplace_import_shop_status', 'status'), # Shop import status
|
"idx_marketplace_import_user_marketplace", "user_id", "marketplace"
|
||||||
Index('idx_marketplace_import_shop_id', 'shop_id'),
|
), # User's marketplace imports
|
||||||
|
Index("idx_marketplace_import_shop_status", "status"), # Shop import status
|
||||||
|
Index("idx_marketplace_import_shop_id", "shop_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"<MarketplaceImportJob(id={self.id}, marketplace='{self.marketplace}', shop='{self.shop_name}', "
|
return (
|
||||||
f"status='{self.status}', imported={self.imported_count})>")
|
f"<MarketplaceImportJob(id={self.id}, marketplace='{self.marketplace}', shop='{self.shop_name}', "
|
||||||
|
f"status='{self.status}', imported={self.imported_count})>"
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
"""Development environment setup script"""
|
"""Development environment setup script"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -30,6 +30,7 @@ def setup_alembic():
|
|||||||
if alembic_dir.exists():
|
if alembic_dir.exists():
|
||||||
# Remove incomplete alembic directory
|
# Remove incomplete alembic directory
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(alembic_dir)
|
shutil.rmtree(alembic_dir)
|
||||||
|
|
||||||
if not run_command("alembic init alembic", "Initializing Alembic"):
|
if not run_command("alembic init alembic", "Initializing Alembic"):
|
||||||
@@ -138,16 +139,23 @@ LOG_LEVEL=INFO
|
|||||||
|
|
||||||
# Set up Alembic
|
# Set up Alembic
|
||||||
if not setup_alembic():
|
if not setup_alembic():
|
||||||
print("⚠️ Alembic setup failed. You'll need to set up database migrations manually.")
|
print(
|
||||||
|
"⚠️ Alembic setup failed. You'll need to set up database migrations manually."
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Create initial migration
|
# Create initial migration
|
||||||
if not run_command("alembic revision --autogenerate -m \"Initial migration\"", "Creating initial migration"):
|
if not run_command(
|
||||||
|
'alembic revision --autogenerate -m "Initial migration"',
|
||||||
|
"Creating initial migration",
|
||||||
|
):
|
||||||
print("⚠️ Initial migration creation failed. Check your database models.")
|
print("⚠️ Initial migration creation failed. Check your database models.")
|
||||||
|
|
||||||
# Apply migrations
|
# Apply migrations
|
||||||
if not run_command("alembic upgrade head", "Setting up database"):
|
if not run_command("alembic upgrade head", "Setting up database"):
|
||||||
print("⚠️ Database setup failed. Make sure your DATABASE_URL is correct in .env")
|
print(
|
||||||
|
"⚠️ Database setup failed. Make sure your DATABASE_URL is correct in .env"
|
||||||
|
)
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
if not run_command("pytest", "Running tests"):
|
if not run_command("pytest", "Running tests"):
|
||||||
@@ -157,7 +165,7 @@ LOG_LEVEL=INFO
|
|||||||
print("To start the development server, run:")
|
print("To start the development server, run:")
|
||||||
print(" uvicorn main:app --reload")
|
print(" uvicorn main:app --reload")
|
||||||
print("\nDatabase commands:")
|
print("\nDatabase commands:")
|
||||||
print(" alembic revision --autogenerate -m \"Description\" # Create migration")
|
print(' alembic revision --autogenerate -m "Description" # Create migration')
|
||||||
print(" alembic upgrade head # Apply migrations")
|
print(" alembic upgrade head # Apply migrations")
|
||||||
print(" alembic current # Check status")
|
print(" alembic current # Check status")
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
# tests/conftest.py
|
# tests/conftest.py
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.core.database import Base, get_db
|
||||||
from main import app
|
from main import app
|
||||||
from app.core.database import get_db, Base
|
|
||||||
# Import all models to ensure they're registered with Base metadata
|
|
||||||
from models.database_models import User, Product, Stock, Shop, MarketplaceImportJob, ShopProduct
|
|
||||||
from middleware.auth import AuthManager
|
from middleware.auth import AuthManager
|
||||||
import uuid
|
# Import all models to ensure they're registered with Base metadata
|
||||||
|
from models.database_models import (MarketplaceImportJob, Product, Shop,
|
||||||
|
ShopProduct, Stock, User)
|
||||||
|
|
||||||
# Use in-memory SQLite database for tests
|
# Use in-memory SQLite database for tests
|
||||||
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"
|
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"
|
||||||
@@ -23,7 +25,7 @@ def engine():
|
|||||||
SQLALCHEMY_TEST_DATABASE_URL,
|
SQLALCHEMY_TEST_DATABASE_URL,
|
||||||
connect_args={"check_same_thread": False},
|
connect_args={"check_same_thread": False},
|
||||||
poolclass=StaticPool,
|
poolclass=StaticPool,
|
||||||
echo=False # Set to True for SQL debugging
|
echo=False, # Set to True for SQL debugging
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,7 +91,7 @@ def test_user(db, auth_manager):
|
|||||||
username=f"testuser_{unique_id}",
|
username=f"testuser_{unique_id}",
|
||||||
hashed_password=hashed_password,
|
hashed_password=hashed_password,
|
||||||
role="user",
|
role="user",
|
||||||
is_active=True
|
is_active=True,
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -107,7 +109,7 @@ def test_admin(db, auth_manager):
|
|||||||
username=f"admin_{unique_id}",
|
username=f"admin_{unique_id}",
|
||||||
hashed_password=hashed_password,
|
hashed_password=hashed_password,
|
||||||
role="admin",
|
role="admin",
|
||||||
is_active=True
|
is_active=True,
|
||||||
)
|
)
|
||||||
db.add(admin)
|
db.add(admin)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -118,10 +120,10 @@ def test_admin(db, auth_manager):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def auth_headers(client, test_user):
|
def auth_headers(client, test_user):
|
||||||
"""Get authentication headers for test user"""
|
"""Get authentication headers for test user"""
|
||||||
response = client.post("/api/v1/auth/login", json={
|
response = client.post(
|
||||||
"username": test_user.username,
|
"/api/v1/auth/login",
|
||||||
"password": "testpass123"
|
json={"username": test_user.username, "password": "testpass123"},
|
||||||
})
|
)
|
||||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||||
token = response.json()["access_token"]
|
token = response.json()["access_token"]
|
||||||
return {"Authorization": f"Bearer {token}"}
|
return {"Authorization": f"Bearer {token}"}
|
||||||
@@ -130,10 +132,10 @@ def auth_headers(client, test_user):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def admin_headers(client, test_admin):
|
def admin_headers(client, test_admin):
|
||||||
"""Get authentication headers for admin user"""
|
"""Get authentication headers for admin user"""
|
||||||
response = client.post("/api/v1/auth/login", json={
|
response = client.post(
|
||||||
"username": test_admin.username,
|
"/api/v1/auth/login",
|
||||||
"password": "adminpass123"
|
json={"username": test_admin.username, "password": "adminpass123"},
|
||||||
})
|
)
|
||||||
assert response.status_code == 200, f"Admin login failed: {response.text}"
|
assert response.status_code == 200, f"Admin login failed: {response.text}"
|
||||||
token = response.json()["access_token"]
|
token = response.json()["access_token"]
|
||||||
return {"Authorization": f"Bearer {token}"}
|
return {"Authorization": f"Bearer {token}"}
|
||||||
@@ -152,7 +154,7 @@ def test_product(db):
|
|||||||
gtin="1234567890123",
|
gtin="1234567890123",
|
||||||
availability="in stock",
|
availability="in stock",
|
||||||
marketplace="Letzshop",
|
marketplace="Letzshop",
|
||||||
shop_name="TestShop"
|
shop_name="TestShop",
|
||||||
)
|
)
|
||||||
db.add(product)
|
db.add(product)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -169,7 +171,7 @@ def test_shop(db, test_user):
|
|||||||
shop_name=f"Test Shop {unique_id}",
|
shop_name=f"Test Shop {unique_id}",
|
||||||
owner_id=test_user.id,
|
owner_id=test_user.id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=True
|
is_verified=True,
|
||||||
)
|
)
|
||||||
db.add(shop)
|
db.add(shop)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -186,7 +188,7 @@ def test_stock(db, test_product, test_shop):
|
|||||||
location=f"WAREHOUSE_A_{unique_id}",
|
location=f"WAREHOUSE_A_{unique_id}",
|
||||||
quantity=10,
|
quantity=10,
|
||||||
reserved_quantity=0,
|
reserved_quantity=0,
|
||||||
shop_id=test_shop.id # Add shop_id reference
|
shop_id=test_shop.id, # Add shop_id reference
|
||||||
)
|
)
|
||||||
db.add(stock)
|
db.add(stock)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -208,7 +210,7 @@ def test_marketplace_job(db, test_shop, test_user): # Add test_shop dependency
|
|||||||
updated_count=3,
|
updated_count=3,
|
||||||
total_processed=8,
|
total_processed=8,
|
||||||
error_count=0,
|
error_count=0,
|
||||||
error_message=None
|
error_message=None,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -219,16 +221,16 @@ def test_marketplace_job(db, test_shop, test_user): # Add test_shop dependency
|
|||||||
def create_test_import_job(db, shop_id, **kwargs): # Add shop_id parameter
|
def create_test_import_job(db, shop_id, **kwargs): # Add shop_id parameter
|
||||||
"""Helper function to create MarketplaceImportJob with defaults"""
|
"""Helper function to create MarketplaceImportJob with defaults"""
|
||||||
defaults = {
|
defaults = {
|
||||||
'marketplace': 'test',
|
"marketplace": "test",
|
||||||
'shop_name': 'Test Shop',
|
"shop_name": "Test Shop",
|
||||||
'status': 'pending',
|
"status": "pending",
|
||||||
'source_url': 'https://test.example.com/import',
|
"source_url": "https://test.example.com/import",
|
||||||
'shop_id': shop_id, # Add required shop_id
|
"shop_id": shop_id, # Add required shop_id
|
||||||
'imported_count': 0,
|
"imported_count": 0,
|
||||||
'updated_count': 0,
|
"updated_count": 0,
|
||||||
'total_processed': 0,
|
"total_processed": 0,
|
||||||
'error_count': 0,
|
"error_count": 0,
|
||||||
'error_message': None
|
"error_message": None,
|
||||||
}
|
}
|
||||||
defaults.update(kwargs)
|
defaults.update(kwargs)
|
||||||
|
|
||||||
@@ -250,6 +252,7 @@ def cleanup():
|
|||||||
|
|
||||||
# Add these fixtures to your existing conftest.py
|
# Add these fixtures to your existing conftest.py
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def unique_product(db):
|
def unique_product(db):
|
||||||
"""Create a unique product for tests that need isolated product data"""
|
"""Create a unique product for tests that need isolated product data"""
|
||||||
@@ -265,7 +268,7 @@ def unique_product(db):
|
|||||||
availability="in stock",
|
availability="in stock",
|
||||||
marketplace="Letzshop",
|
marketplace="Letzshop",
|
||||||
shop_name=f"UniqueShop_{unique_id}",
|
shop_name=f"UniqueShop_{unique_id}",
|
||||||
google_product_category=f"UniqueCategory_{unique_id}"
|
google_product_category=f"UniqueCategory_{unique_id}",
|
||||||
)
|
)
|
||||||
db.add(product)
|
db.add(product)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -283,7 +286,7 @@ def unique_shop(db, test_user):
|
|||||||
description=f"A unique test shop {unique_id}",
|
description=f"A unique test shop {unique_id}",
|
||||||
owner_id=test_user.id,
|
owner_id=test_user.id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=True
|
is_verified=True,
|
||||||
)
|
)
|
||||||
db.add(shop)
|
db.add(shop)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -301,7 +304,7 @@ def other_user(db, auth_manager):
|
|||||||
username=f"otheruser_{unique_id}",
|
username=f"otheruser_{unique_id}",
|
||||||
hashed_password=hashed_password,
|
hashed_password=hashed_password,
|
||||||
role="user",
|
role="user",
|
||||||
is_active=True
|
is_active=True,
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -318,7 +321,7 @@ def inactive_shop(db, other_user):
|
|||||||
shop_name=f"Inactive Shop {unique_id}",
|
shop_name=f"Inactive Shop {unique_id}",
|
||||||
owner_id=other_user.id,
|
owner_id=other_user.id,
|
||||||
is_active=False,
|
is_active=False,
|
||||||
is_verified=False
|
is_verified=False,
|
||||||
)
|
)
|
||||||
db.add(shop)
|
db.add(shop)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -335,7 +338,7 @@ def verified_shop(db, other_user):
|
|||||||
shop_name=f"Verified Shop {unique_id}",
|
shop_name=f"Verified Shop {unique_id}",
|
||||||
owner_id=other_user.id,
|
owner_id=other_user.id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=True
|
is_verified=True,
|
||||||
)
|
)
|
||||||
db.add(shop)
|
db.add(shop)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -347,16 +350,14 @@ def verified_shop(db, other_user):
|
|||||||
def shop_product(db, test_shop, unique_product):
|
def shop_product(db, test_shop, unique_product):
|
||||||
"""Create a shop product relationship"""
|
"""Create a shop product relationship"""
|
||||||
shop_product = ShopProduct(
|
shop_product = ShopProduct(
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id, product_id=unique_product.id, is_active=True
|
||||||
product_id=unique_product.id,
|
|
||||||
is_active=True
|
|
||||||
)
|
)
|
||||||
# Add optional fields if they exist in your model
|
# Add optional fields if they exist in your model
|
||||||
if hasattr(ShopProduct, 'price'):
|
if hasattr(ShopProduct, "price"):
|
||||||
shop_product.price = "24.99"
|
shop_product.price = "24.99"
|
||||||
if hasattr(ShopProduct, 'is_featured'):
|
if hasattr(ShopProduct, "is_featured"):
|
||||||
shop_product.is_featured = False
|
shop_product.is_featured = False
|
||||||
if hasattr(ShopProduct, 'stock_quantity'):
|
if hasattr(ShopProduct, "stock_quantity"):
|
||||||
shop_product.stock_quantity = 10
|
shop_product.stock_quantity = 10
|
||||||
|
|
||||||
db.add(shop_product)
|
db.add(shop_product)
|
||||||
@@ -382,7 +383,7 @@ def multiple_products(db):
|
|||||||
marketplace=f"MultiMarket_{i % 2}", # Create 2 different marketplaces
|
marketplace=f"MultiMarket_{i % 2}", # Create 2 different marketplaces
|
||||||
shop_name=f"MultiShop_{i}",
|
shop_name=f"MultiShop_{i}",
|
||||||
google_product_category=f"MultiCategory_{i % 2}", # Create 2 different categories
|
google_product_category=f"MultiCategory_{i % 2}", # Create 2 different categories
|
||||||
gtin=f"1234567890{i}{unique_id[:2]}"
|
gtin=f"1234567890{i}{unique_id[:2]}",
|
||||||
)
|
)
|
||||||
products.append(product)
|
products.append(product)
|
||||||
|
|
||||||
@@ -404,7 +405,7 @@ def multiple_stocks(db, multiple_products, test_shop):
|
|||||||
location=f"LOC_{i}",
|
location=f"LOC_{i}",
|
||||||
quantity=10 + (i * 5), # Different quantities
|
quantity=10 + (i * 5), # Different quantities
|
||||||
reserved_quantity=i,
|
reserved_quantity=i,
|
||||||
shop_id=test_shop.id
|
shop_id=test_shop.id,
|
||||||
)
|
)
|
||||||
stocks.append(stock)
|
stocks.append(stock)
|
||||||
|
|
||||||
@@ -422,12 +423,12 @@ def create_unique_product_factory():
|
|||||||
def _create_product(db, **kwargs):
|
def _create_product(db, **kwargs):
|
||||||
unique_id = str(uuid.uuid4())[:8]
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
defaults = {
|
defaults = {
|
||||||
'product_id': f"FACTORY_{unique_id}",
|
"product_id": f"FACTORY_{unique_id}",
|
||||||
'title': f"Factory Product {unique_id}",
|
"title": f"Factory Product {unique_id}",
|
||||||
'price': "15.99",
|
"price": "15.99",
|
||||||
'currency': "EUR",
|
"currency": "EUR",
|
||||||
'marketplace': "TestMarket",
|
"marketplace": "TestMarket",
|
||||||
'shop_name': "TestShop"
|
"shop_name": "TestShop",
|
||||||
}
|
}
|
||||||
defaults.update(kwargs)
|
defaults.update(kwargs)
|
||||||
|
|
||||||
@@ -452,11 +453,11 @@ def create_unique_shop_factory():
|
|||||||
def _create_shop(db, owner_id, **kwargs):
|
def _create_shop(db, owner_id, **kwargs):
|
||||||
unique_id = str(uuid.uuid4())[:8]
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
defaults = {
|
defaults = {
|
||||||
'shop_code': f"FACTORY_{unique_id}",
|
"shop_code": f"FACTORY_{unique_id}",
|
||||||
'shop_name': f"Factory Shop {unique_id}",
|
"shop_name": f"Factory Shop {unique_id}",
|
||||||
'owner_id': owner_id,
|
"owner_id": owner_id,
|
||||||
'is_active': True,
|
"is_active": True,
|
||||||
'is_verified': False
|
"is_verified": False,
|
||||||
}
|
}
|
||||||
defaults.update(kwargs)
|
defaults.update(kwargs)
|
||||||
|
|
||||||
@@ -473,4 +474,3 @@ def create_unique_shop_factory():
|
|||||||
def shop_factory():
|
def shop_factory():
|
||||||
"""Fixture that provides a shop factory function"""
|
"""Fixture that provides a shop factory function"""
|
||||||
return create_unique_shop_factory()
|
return create_unique_shop_factory()
|
||||||
|
|
||||||
|
|||||||
@@ -20,11 +20,16 @@ class TestAdminAPI:
|
|||||||
response = client.get("/api/v1/admin/users", headers=auth_headers)
|
response = client.get("/api/v1/admin/users", headers=auth_headers)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower()
|
assert (
|
||||||
|
"Access denied" in response.json()["detail"]
|
||||||
|
or "admin" in response.json()["detail"].lower()
|
||||||
|
)
|
||||||
|
|
||||||
def test_toggle_user_status_admin(self, client, admin_headers, test_user):
|
def test_toggle_user_status_admin(self, client, admin_headers, test_user):
|
||||||
"""Test admin toggling user status"""
|
"""Test admin toggling user status"""
|
||||||
response = client.put(f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers)
|
response = client.put(
|
||||||
|
f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
message = response.json()["message"]
|
message = response.json()["message"]
|
||||||
@@ -39,9 +44,13 @@ class TestAdminAPI:
|
|||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
assert "User not found" in response.json()["detail"]
|
assert "User not found" in response.json()["detail"]
|
||||||
|
|
||||||
def test_toggle_user_status_cannot_deactivate_self(self, client, admin_headers, test_admin):
|
def test_toggle_user_status_cannot_deactivate_self(
|
||||||
|
self, client, admin_headers, test_admin
|
||||||
|
):
|
||||||
"""Test that admin cannot deactivate their own account"""
|
"""Test that admin cannot deactivate their own account"""
|
||||||
response = client.put(f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers)
|
response = client.put(
|
||||||
|
f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "Cannot deactivate your own account" in response.json()["detail"]
|
assert "Cannot deactivate your own account" in response.json()["detail"]
|
||||||
@@ -56,7 +65,9 @@ class TestAdminAPI:
|
|||||||
assert len(data["shops"]) >= 1
|
assert len(data["shops"]) >= 1
|
||||||
|
|
||||||
# Check that test_shop is in the response
|
# Check that test_shop is in the response
|
||||||
shop_codes = [shop["shop_code"] for shop in data["shops"] if "shop_code" in shop]
|
shop_codes = [
|
||||||
|
shop["shop_code"] for shop in data["shops"] if "shop_code" in shop
|
||||||
|
]
|
||||||
assert test_shop.shop_code in shop_codes
|
assert test_shop.shop_code in shop_codes
|
||||||
|
|
||||||
def test_get_all_shops_non_admin(self, client, auth_headers):
|
def test_get_all_shops_non_admin(self, client, auth_headers):
|
||||||
@@ -64,11 +75,16 @@ class TestAdminAPI:
|
|||||||
response = client.get("/api/v1/admin/shops", headers=auth_headers)
|
response = client.get("/api/v1/admin/shops", headers=auth_headers)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower()
|
assert (
|
||||||
|
"Access denied" in response.json()["detail"]
|
||||||
|
or "admin" in response.json()["detail"].lower()
|
||||||
|
)
|
||||||
|
|
||||||
def test_verify_shop_admin(self, client, admin_headers, test_shop):
|
def test_verify_shop_admin(self, client, admin_headers, test_shop):
|
||||||
"""Test admin verifying/unverifying shop"""
|
"""Test admin verifying/unverifying shop"""
|
||||||
response = client.put(f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers)
|
response = client.put(
|
||||||
|
f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
message = response.json()["message"]
|
message = response.json()["message"]
|
||||||
@@ -84,7 +100,9 @@ class TestAdminAPI:
|
|||||||
|
|
||||||
def test_toggle_shop_status_admin(self, client, admin_headers, test_shop):
|
def test_toggle_shop_status_admin(self, client, admin_headers, test_shop):
|
||||||
"""Test admin toggling shop status"""
|
"""Test admin toggling shop status"""
|
||||||
response = client.put(f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers)
|
response = client.put(
|
||||||
|
f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
message = response.json()["message"]
|
message = response.json()["message"]
|
||||||
@@ -98,9 +116,13 @@ class TestAdminAPI:
|
|||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
assert "Shop not found" in response.json()["detail"]
|
assert "Shop not found" in response.json()["detail"]
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_admin(self, client, admin_headers, test_marketplace_job):
|
def test_get_marketplace_import_jobs_admin(
|
||||||
|
self, client, admin_headers, test_marketplace_job
|
||||||
|
):
|
||||||
"""Test admin getting marketplace import jobs"""
|
"""Test admin getting marketplace import jobs"""
|
||||||
response = client.get("/api/v1/admin/marketplace-import-jobs", headers=admin_headers)
|
response = client.get(
|
||||||
|
"/api/v1/admin/marketplace-import-jobs", headers=admin_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -110,43 +132,58 @@ class TestAdminAPI:
|
|||||||
job_ids = [job["job_id"] for job in data if "job_id" in job]
|
job_ids = [job["job_id"] for job in data if "job_id" in job]
|
||||||
assert test_marketplace_job.id in job_ids
|
assert test_marketplace_job.id in job_ids
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_with_filters(self, client, admin_headers, test_marketplace_job):
|
def test_get_marketplace_import_jobs_with_filters(
|
||||||
|
self, client, admin_headers, test_marketplace_job
|
||||||
|
):
|
||||||
"""Test admin getting marketplace import jobs with filters"""
|
"""Test admin getting marketplace import jobs with filters"""
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/v1/admin/marketplace-import-jobs",
|
"/api/v1/admin/marketplace-import-jobs",
|
||||||
params={"marketplace": test_marketplace_job.marketplace},
|
params={"marketplace": test_marketplace_job.marketplace},
|
||||||
headers=admin_headers
|
headers=admin_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) >= 1
|
assert len(data) >= 1
|
||||||
assert all(job["marketplace"] == test_marketplace_job.marketplace for job in data)
|
assert all(
|
||||||
|
job["marketplace"] == test_marketplace_job.marketplace for job in data
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):
|
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):
|
||||||
"""Test non-admin trying to access marketplace import jobs"""
|
"""Test non-admin trying to access marketplace import jobs"""
|
||||||
response = client.get("/api/v1/admin/marketplace-import-jobs", headers=auth_headers)
|
response = client.get(
|
||||||
|
"/api/v1/admin/marketplace-import-jobs", headers=auth_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower()
|
assert (
|
||||||
|
"Access denied" in response.json()["detail"]
|
||||||
|
or "admin" in response.json()["detail"].lower()
|
||||||
|
)
|
||||||
|
|
||||||
def test_admin_pagination_users(self, client, admin_headers, test_user, test_admin):
|
def test_admin_pagination_users(self, client, admin_headers, test_user, test_admin):
|
||||||
"""Test user pagination works correctly"""
|
"""Test user pagination works correctly"""
|
||||||
# Test first page
|
# Test first page
|
||||||
response = client.get("/api/v1/admin/users?skip=0&limit=1", headers=admin_headers)
|
response = client.get(
|
||||||
|
"/api/v1/admin/users?skip=0&limit=1", headers=admin_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) == 1
|
assert len(data) == 1
|
||||||
|
|
||||||
# Test second page
|
# Test second page
|
||||||
response = client.get("/api/v1/admin/users?skip=1&limit=1", headers=admin_headers)
|
response = client.get(
|
||||||
|
"/api/v1/admin/users?skip=1&limit=1", headers=admin_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) >= 0 # Could be 1 or 0 depending on total users
|
assert len(data) >= 0 # Could be 1 or 0 depending on total users
|
||||||
|
|
||||||
def test_admin_pagination_shops(self, client, admin_headers, test_shop):
|
def test_admin_pagination_shops(self, client, admin_headers, test_shop):
|
||||||
"""Test shop pagination works correctly"""
|
"""Test shop pagination works correctly"""
|
||||||
response = client.get("/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers)
|
response = client.get(
|
||||||
|
"/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["total"] >= 1
|
assert data["total"] >= 1
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
# tests/test_admin_service.py
|
# tests/test_admin_service.py
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from app.services.admin_service import AdminService
|
from app.services.admin_service import AdminService
|
||||||
from models.database_models import User, Shop, MarketplaceImportJob
|
from models.database_models import MarketplaceImportJob, Shop, User
|
||||||
|
|
||||||
|
|
||||||
class TestAdminService:
|
class TestAdminService:
|
||||||
@@ -91,7 +92,7 @@ class TestAdminService:
|
|||||||
shop_name="Test Shop 2",
|
shop_name="Test Shop 2",
|
||||||
owner_id=test_shop.owner_id,
|
owner_id=test_shop.owner_id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=False
|
is_verified=False,
|
||||||
)
|
)
|
||||||
db.add(additional_shop)
|
db.add(additional_shop)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -173,13 +174,17 @@ class TestAdminService:
|
|||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
# Find our test job in the results
|
# Find our test job in the results
|
||||||
test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None)
|
test_job = next(
|
||||||
|
(job for job in result if job.job_id == test_marketplace_job.id), None
|
||||||
|
)
|
||||||
assert test_job is not None
|
assert test_job is not None
|
||||||
assert test_job.marketplace == test_marketplace_job.marketplace
|
assert test_job.marketplace == test_marketplace_job.marketplace
|
||||||
assert test_job.shop_name == test_marketplace_job.shop_name
|
assert test_job.shop_name == test_marketplace_job.shop_name
|
||||||
assert test_job.status == test_marketplace_job.status
|
assert test_job.status == test_marketplace_job.status
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_job, test_user, test_shop):
|
def test_get_marketplace_import_jobs_with_marketplace_filter(
|
||||||
|
self, db, test_marketplace_job, test_user, test_shop
|
||||||
|
):
|
||||||
"""Test getting marketplace import jobs filtered by marketplace"""
|
"""Test getting marketplace import jobs filtered by marketplace"""
|
||||||
# Create additional job with different marketplace
|
# Create additional job with different marketplace
|
||||||
other_job = MarketplaceImportJob(
|
other_job = MarketplaceImportJob(
|
||||||
@@ -188,20 +193,24 @@ class TestAdminService:
|
|||||||
status="completed",
|
status="completed",
|
||||||
source_url="https://ebay.example.com/import",
|
source_url="https://ebay.example.com/import",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id # Fixed: Added missing user_id
|
user_id=test_user.id, # Fixed: Added missing user_id
|
||||||
)
|
)
|
||||||
db.add(other_job)
|
db.add(other_job)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Filter by the test marketplace job's marketplace
|
# Filter by the test marketplace job's marketplace
|
||||||
result = self.service.get_marketplace_import_jobs(db, marketplace=test_marketplace_job.marketplace)
|
result = self.service.get_marketplace_import_jobs(
|
||||||
|
db, marketplace=test_marketplace_job.marketplace
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
# All results should match the marketplace filter
|
# All results should match the marketplace filter
|
||||||
for job in result:
|
for job in result:
|
||||||
assert test_marketplace_job.marketplace.lower() in job.marketplace.lower()
|
assert test_marketplace_job.marketplace.lower() in job.marketplace.lower()
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_with_shop_filter(self, db, test_marketplace_job, test_user, test_shop):
|
def test_get_marketplace_import_jobs_with_shop_filter(
|
||||||
|
self, db, test_marketplace_job, test_user, test_shop
|
||||||
|
):
|
||||||
"""Test getting marketplace import jobs filtered by shop name"""
|
"""Test getting marketplace import jobs filtered by shop name"""
|
||||||
# Create additional job with different shop name
|
# Create additional job with different shop name
|
||||||
other_job = MarketplaceImportJob(
|
other_job = MarketplaceImportJob(
|
||||||
@@ -210,20 +219,24 @@ class TestAdminService:
|
|||||||
status="completed",
|
status="completed",
|
||||||
source_url="https://different.example.com/import",
|
source_url="https://different.example.com/import",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id # Fixed: Added missing user_id
|
user_id=test_user.id, # Fixed: Added missing user_id
|
||||||
)
|
)
|
||||||
db.add(other_job)
|
db.add(other_job)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Filter by the test marketplace job's shop name
|
# Filter by the test marketplace job's shop name
|
||||||
result = self.service.get_marketplace_import_jobs(db, shop_name=test_marketplace_job.shop_name)
|
result = self.service.get_marketplace_import_jobs(
|
||||||
|
db, shop_name=test_marketplace_job.shop_name
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
# All results should match the shop name filter
|
# All results should match the shop name filter
|
||||||
for job in result:
|
for job in result:
|
||||||
assert test_marketplace_job.shop_name.lower() in job.shop_name.lower()
|
assert test_marketplace_job.shop_name.lower() in job.shop_name.lower()
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_job, test_user, test_shop):
|
def test_get_marketplace_import_jobs_with_status_filter(
|
||||||
|
self, db, test_marketplace_job, test_user, test_shop
|
||||||
|
):
|
||||||
"""Test getting marketplace import jobs filtered by status"""
|
"""Test getting marketplace import jobs filtered by status"""
|
||||||
# Create additional job with different status
|
# Create additional job with different status
|
||||||
other_job = MarketplaceImportJob(
|
other_job = MarketplaceImportJob(
|
||||||
@@ -232,20 +245,24 @@ class TestAdminService:
|
|||||||
status="pending",
|
status="pending",
|
||||||
source_url="https://pending.example.com/import",
|
source_url="https://pending.example.com/import",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id # Fixed: Added missing user_id
|
user_id=test_user.id, # Fixed: Added missing user_id
|
||||||
)
|
)
|
||||||
db.add(other_job)
|
db.add(other_job)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Filter by the test marketplace job's status
|
# Filter by the test marketplace job's status
|
||||||
result = self.service.get_marketplace_import_jobs(db, status=test_marketplace_job.status)
|
result = self.service.get_marketplace_import_jobs(
|
||||||
|
db, status=test_marketplace_job.status
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
# All results should match the status filter
|
# All results should match the status filter
|
||||||
for job in result:
|
for job in result:
|
||||||
assert job.status == test_marketplace_job.status
|
assert job.status == test_marketplace_job.status
|
||||||
|
|
||||||
def test_get_marketplace_import_jobs_with_multiple_filters(self, db, test_marketplace_job, test_shop, test_user):
|
def test_get_marketplace_import_jobs_with_multiple_filters(
|
||||||
|
self, db, test_marketplace_job, test_shop, test_user
|
||||||
|
):
|
||||||
"""Test getting marketplace import jobs with multiple filters"""
|
"""Test getting marketplace import jobs with multiple filters"""
|
||||||
# Create jobs that don't match all filters
|
# Create jobs that don't match all filters
|
||||||
non_matching_job1 = MarketplaceImportJob(
|
non_matching_job1 = MarketplaceImportJob(
|
||||||
@@ -254,7 +271,7 @@ class TestAdminService:
|
|||||||
status=test_marketplace_job.status,
|
status=test_marketplace_job.status,
|
||||||
source_url="https://non-matching1.example.com/import",
|
source_url="https://non-matching1.example.com/import",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id # Fixed: Added missing user_id
|
user_id=test_user.id, # Fixed: Added missing user_id
|
||||||
)
|
)
|
||||||
non_matching_job2 = MarketplaceImportJob(
|
non_matching_job2 = MarketplaceImportJob(
|
||||||
marketplace=test_marketplace_job.marketplace,
|
marketplace=test_marketplace_job.marketplace,
|
||||||
@@ -262,7 +279,7 @@ class TestAdminService:
|
|||||||
status=test_marketplace_job.status,
|
status=test_marketplace_job.status,
|
||||||
source_url="https://non-matching2.example.com/import",
|
source_url="https://non-matching2.example.com/import",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id # Fixed: Added missing user_id
|
user_id=test_user.id, # Fixed: Added missing user_id
|
||||||
)
|
)
|
||||||
db.add_all([non_matching_job1, non_matching_job2])
|
db.add_all([non_matching_job1, non_matching_job2])
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -272,12 +289,14 @@ class TestAdminService:
|
|||||||
db,
|
db,
|
||||||
marketplace=test_marketplace_job.marketplace,
|
marketplace=test_marketplace_job.marketplace,
|
||||||
shop_name=test_marketplace_job.shop_name,
|
shop_name=test_marketplace_job.shop_name,
|
||||||
status=test_marketplace_job.status
|
status=test_marketplace_job.status,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
# Find our test job in the results
|
# Find our test job in the results
|
||||||
test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None)
|
test_job = next(
|
||||||
|
(job for job in result if job.job_id == test_marketplace_job.id), None
|
||||||
|
)
|
||||||
assert test_job is not None
|
assert test_job is not None
|
||||||
assert test_job.marketplace == test_marketplace_job.marketplace
|
assert test_job.marketplace == test_marketplace_job.marketplace
|
||||||
assert test_job.shop_name == test_marketplace_job.shop_name
|
assert test_job.shop_name == test_marketplace_job.shop_name
|
||||||
@@ -297,7 +316,7 @@ class TestAdminService:
|
|||||||
updated_count=None,
|
updated_count=None,
|
||||||
total_processed=None,
|
total_processed=None,
|
||||||
error_count=None,
|
error_count=None,
|
||||||
error_message=None
|
error_message=None,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ from fastapi import HTTPException
|
|||||||
class TestAuthenticationAPI:
|
class TestAuthenticationAPI:
|
||||||
def test_register_user_success(self, client, db):
|
def test_register_user_success(self, client, db):
|
||||||
"""Test successful user registration"""
|
"""Test successful user registration"""
|
||||||
response = client.post("/api/v1/auth/register", json={
|
response = client.post(
|
||||||
"email": "newuser@example.com",
|
"/api/v1/auth/register",
|
||||||
"username": "newuser",
|
json={
|
||||||
"password": "securepass123"
|
"email": "newuser@example.com",
|
||||||
})
|
"username": "newuser",
|
||||||
|
"password": "securepass123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -22,32 +25,38 @@ class TestAuthenticationAPI:
|
|||||||
|
|
||||||
def test_register_user_duplicate_email(self, client, test_user):
|
def test_register_user_duplicate_email(self, client, test_user):
|
||||||
"""Test registration with duplicate email"""
|
"""Test registration with duplicate email"""
|
||||||
response = client.post("/api/v1/auth/register", json={
|
response = client.post(
|
||||||
"email": test_user.email, # Same as test_user
|
"/api/v1/auth/register",
|
||||||
"username": "newuser",
|
json={
|
||||||
"password": "securepass123"
|
"email": test_user.email, # Same as test_user
|
||||||
})
|
"username": "newuser",
|
||||||
|
"password": "securepass123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "Email already registered" in response.json()["detail"]
|
assert "Email already registered" in response.json()["detail"]
|
||||||
|
|
||||||
def test_register_user_duplicate_username(self, client, test_user):
|
def test_register_user_duplicate_username(self, client, test_user):
|
||||||
"""Test registration with duplicate username"""
|
"""Test registration with duplicate username"""
|
||||||
response = client.post("/api/v1/auth/register", json={
|
response = client.post(
|
||||||
"email": "new@example.com",
|
"/api/v1/auth/register",
|
||||||
"username": test_user.username, # Same as test_user
|
json={
|
||||||
"password": "securepass123"
|
"email": "new@example.com",
|
||||||
})
|
"username": test_user.username, # Same as test_user
|
||||||
|
"password": "securepass123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "Username already taken" in response.json()["detail"]
|
assert "Username already taken" in response.json()["detail"]
|
||||||
|
|
||||||
def test_login_success(self, client, test_user):
|
def test_login_success(self, client, test_user):
|
||||||
"""Test successful login"""
|
"""Test successful login"""
|
||||||
response = client.post("/api/v1/auth/login", json={
|
response = client.post(
|
||||||
"username": test_user.username,
|
"/api/v1/auth/login",
|
||||||
"password": "testpass123"
|
json={"username": test_user.username, "password": "testpass123"},
|
||||||
})
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -58,20 +67,20 @@ class TestAuthenticationAPI:
|
|||||||
|
|
||||||
def test_login_wrong_password(self, client, test_user):
|
def test_login_wrong_password(self, client, test_user):
|
||||||
"""Test login with wrong password"""
|
"""Test login with wrong password"""
|
||||||
response = client.post("/api/v1/auth/login", json={
|
response = client.post(
|
||||||
"username": "testuser",
|
"/api/v1/auth/login",
|
||||||
"password": "wrongpassword"
|
json={"username": "testuser", "password": "wrongpassword"},
|
||||||
})
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Incorrect username or password" in response.json()["detail"]
|
assert "Incorrect username or password" in response.json()["detail"]
|
||||||
|
|
||||||
def test_login_nonexistent_user(self, client, db): # Added db fixture
|
def test_login_nonexistent_user(self, client, db): # Added db fixture
|
||||||
"""Test login with nonexistent user"""
|
"""Test login with nonexistent user"""
|
||||||
response = client.post("/api/v1/auth/login", json={
|
response = client.post(
|
||||||
"username": "nonexistent",
|
"/api/v1/auth/login",
|
||||||
"password": "password123"
|
json={"username": "nonexistent", "password": "password123"},
|
||||||
})
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import pytest
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from app.services.auth_service import AuthService
|
from app.services.auth_service import AuthService
|
||||||
|
from models.api_models import UserLogin, UserRegister
|
||||||
from models.database_models import User
|
from models.database_models import User
|
||||||
from models.api_models import UserRegister, UserLogin
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthService:
|
class TestAuthService:
|
||||||
@@ -17,9 +17,7 @@ class TestAuthService:
|
|||||||
def test_register_user_success(self, db):
|
def test_register_user_success(self, db):
|
||||||
"""Test successful user registration"""
|
"""Test successful user registration"""
|
||||||
user_data = UserRegister(
|
user_data = UserRegister(
|
||||||
email="newuser@example.com",
|
email="newuser@example.com", username="newuser123", password="securepass123"
|
||||||
username="newuser123",
|
|
||||||
password="securepass123"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
user = self.service.register_user(db, user_data)
|
user = self.service.register_user(db, user_data)
|
||||||
@@ -36,7 +34,7 @@ class TestAuthService:
|
|||||||
user_data = UserRegister(
|
user_data = UserRegister(
|
||||||
email=test_user.email, # Use existing email
|
email=test_user.email, # Use existing email
|
||||||
username="differentuser",
|
username="differentuser",
|
||||||
password="securepass123"
|
password="securepass123",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
@@ -50,7 +48,7 @@ class TestAuthService:
|
|||||||
user_data = UserRegister(
|
user_data = UserRegister(
|
||||||
email="different@example.com",
|
email="different@example.com",
|
||||||
username=test_user.username, # Use existing username
|
username=test_user.username, # Use existing username
|
||||||
password="securepass123"
|
password="securepass123",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
@@ -62,8 +60,7 @@ class TestAuthService:
|
|||||||
def test_login_user_success(self, db, test_user):
|
def test_login_user_success(self, db, test_user):
|
||||||
"""Test successful user login"""
|
"""Test successful user login"""
|
||||||
user_credentials = UserLogin(
|
user_credentials = UserLogin(
|
||||||
username=test_user.username,
|
username=test_user.username, password="testpass123"
|
||||||
password="testpass123"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.login_user(db, user_credentials)
|
result = self.service.login_user(db, user_credentials)
|
||||||
@@ -78,10 +75,7 @@ class TestAuthService:
|
|||||||
|
|
||||||
def test_login_user_wrong_username(self, db):
|
def test_login_user_wrong_username(self, db):
|
||||||
"""Test login fails with wrong username"""
|
"""Test login fails with wrong username"""
|
||||||
user_credentials = UserLogin(
|
user_credentials = UserLogin(username="nonexistentuser", password="testpass123")
|
||||||
username="nonexistentuser",
|
|
||||||
password="testpass123"
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
self.service.login_user(db, user_credentials)
|
self.service.login_user(db, user_credentials)
|
||||||
@@ -92,8 +86,7 @@ class TestAuthService:
|
|||||||
def test_login_user_wrong_password(self, db, test_user):
|
def test_login_user_wrong_password(self, db, test_user):
|
||||||
"""Test login fails with wrong password"""
|
"""Test login fails with wrong password"""
|
||||||
user_credentials = UserLogin(
|
user_credentials = UserLogin(
|
||||||
username=test_user.username,
|
username=test_user.username, password="wrongpassword"
|
||||||
password="wrongpassword"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
@@ -109,8 +102,7 @@ class TestAuthService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
user_credentials = UserLogin(
|
user_credentials = UserLogin(
|
||||||
username=test_user.username,
|
username=test_user.username, password="testpass123"
|
||||||
password="testpass123"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# tests/test_background_tasks.py
|
# tests/test_background_tasks.py
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, AsyncMock
|
|
||||||
from app.tasks.background_tasks import process_marketplace_import
|
from app.tasks.background_tasks import process_marketplace_import
|
||||||
from models.database_models import MarketplaceImportJob
|
from models.database_models import MarketplaceImportJob
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackgroundTasks:
|
class TestBackgroundTasks:
|
||||||
@@ -17,7 +19,7 @@ class TestBackgroundTasks:
|
|||||||
shop_name="TESTSHOP",
|
shop_name="TESTSHOP",
|
||||||
marketplace="TestMarket",
|
marketplace="TestMarket",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id
|
user_id=test_user.id,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -27,29 +29,30 @@ class TestBackgroundTasks:
|
|||||||
job_id = job.id
|
job_id = job.id
|
||||||
|
|
||||||
# Mock CSV processor and prevent session from closing
|
# Mock CSV processor and prevent session from closing
|
||||||
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \
|
with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
|
||||||
patch('app.tasks.background_tasks.SessionLocal', return_value=db):
|
"app.tasks.background_tasks.SessionLocal", return_value=db
|
||||||
|
):
|
||||||
mock_instance = mock_processor.return_value
|
mock_instance = mock_processor.return_value
|
||||||
mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={
|
mock_instance.process_marketplace_csv_from_url = AsyncMock(
|
||||||
"imported": 10,
|
return_value={
|
||||||
"updated": 5,
|
"imported": 10,
|
||||||
"total_processed": 15,
|
"updated": 5,
|
||||||
"errors": 0
|
"total_processed": 15,
|
||||||
})
|
"errors": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Run background task
|
# Run background task
|
||||||
await process_marketplace_import(
|
await process_marketplace_import(
|
||||||
job_id,
|
job_id, "http://example.com/test.csv", "TestMarket", "TESTSHOP", 1000
|
||||||
"http://example.com/test.csv",
|
|
||||||
"TestMarket",
|
|
||||||
"TESTSHOP",
|
|
||||||
1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Re-query the job using the stored ID
|
# Re-query the job using the stored ID
|
||||||
updated_job = db.query(MarketplaceImportJob).filter(
|
updated_job = (
|
||||||
MarketplaceImportJob.id == job_id
|
db.query(MarketplaceImportJob)
|
||||||
).first()
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
assert updated_job is not None
|
assert updated_job is not None
|
||||||
assert updated_job.status == "completed"
|
assert updated_job.status == "completed"
|
||||||
@@ -70,7 +73,7 @@ class TestBackgroundTasks:
|
|||||||
shop_name="TESTSHOP",
|
shop_name="TESTSHOP",
|
||||||
marketplace="TestMarket",
|
marketplace="TestMarket",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id
|
user_id=test_user.id,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -80,8 +83,9 @@ class TestBackgroundTasks:
|
|||||||
job_id = job.id
|
job_id = job.id
|
||||||
|
|
||||||
# Mock CSV processor to raise exception
|
# Mock CSV processor to raise exception
|
||||||
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \
|
with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
|
||||||
patch('app.tasks.background_tasks.SessionLocal', return_value=db):
|
"app.tasks.background_tasks.SessionLocal", return_value=db
|
||||||
|
):
|
||||||
|
|
||||||
mock_instance = mock_processor.return_value
|
mock_instance = mock_processor.return_value
|
||||||
mock_instance.process_marketplace_csv_from_url = AsyncMock(
|
mock_instance.process_marketplace_csv_from_url = AsyncMock(
|
||||||
@@ -96,7 +100,7 @@ class TestBackgroundTasks:
|
|||||||
"http://example.com/test.csv",
|
"http://example.com/test.csv",
|
||||||
"TestMarket",
|
"TestMarket",
|
||||||
"TESTSHOP",
|
"TESTSHOP",
|
||||||
1000
|
1000,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# The background task should handle exceptions internally
|
# The background task should handle exceptions internally
|
||||||
@@ -104,9 +108,11 @@ class TestBackgroundTasks:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Re-query the job using the stored ID
|
# Re-query the job using the stored ID
|
||||||
updated_job = db.query(MarketplaceImportJob).filter(
|
updated_job = (
|
||||||
MarketplaceImportJob.id == job_id
|
db.query(MarketplaceImportJob)
|
||||||
).first()
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
assert updated_job is not None
|
assert updated_job is not None
|
||||||
assert updated_job.status == "failed"
|
assert updated_job.status == "failed"
|
||||||
@@ -115,15 +121,18 @@ class TestBackgroundTasks:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_marketplace_import_job_not_found(self, db):
|
async def test_marketplace_import_job_not_found(self, db):
|
||||||
"""Test handling when import job doesn't exist"""
|
"""Test handling when import job doesn't exist"""
|
||||||
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \
|
with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
|
||||||
patch('app.tasks.background_tasks.SessionLocal', return_value=db):
|
"app.tasks.background_tasks.SessionLocal", return_value=db
|
||||||
|
):
|
||||||
mock_instance = mock_processor.return_value
|
mock_instance = mock_processor.return_value
|
||||||
mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={
|
mock_instance.process_marketplace_csv_from_url = AsyncMock(
|
||||||
"imported": 10,
|
return_value={
|
||||||
"updated": 5,
|
"imported": 10,
|
||||||
"total_processed": 15,
|
"updated": 5,
|
||||||
"errors": 0
|
"total_processed": 15,
|
||||||
})
|
"errors": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Run background task with non-existent job ID
|
# Run background task with non-existent job ID
|
||||||
await process_marketplace_import(
|
await process_marketplace_import(
|
||||||
@@ -131,7 +140,7 @@ class TestBackgroundTasks:
|
|||||||
"http://example.com/test.csv",
|
"http://example.com/test.csv",
|
||||||
"TestMarket",
|
"TestMarket",
|
||||||
"TESTSHOP",
|
"TESTSHOP",
|
||||||
1000
|
1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not raise an exception, just log and return
|
# Should not raise an exception, just log and return
|
||||||
@@ -148,7 +157,7 @@ class TestBackgroundTasks:
|
|||||||
shop_name="TESTSHOP",
|
shop_name="TESTSHOP",
|
||||||
marketplace="TestMarket",
|
marketplace="TestMarket",
|
||||||
shop_id=test_shop.id,
|
shop_id=test_shop.id,
|
||||||
user_id=test_user.id
|
user_id=test_user.id,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -158,29 +167,30 @@ class TestBackgroundTasks:
|
|||||||
job_id = job.id
|
job_id = job.id
|
||||||
|
|
||||||
# Mock CSV processor with some errors
|
# Mock CSV processor with some errors
|
||||||
with patch('app.tasks.background_tasks.CSVProcessor') as mock_processor, \
|
with patch("app.tasks.background_tasks.CSVProcessor") as mock_processor, patch(
|
||||||
patch('app.tasks.background_tasks.SessionLocal', return_value=db):
|
"app.tasks.background_tasks.SessionLocal", return_value=db
|
||||||
|
):
|
||||||
mock_instance = mock_processor.return_value
|
mock_instance = mock_processor.return_value
|
||||||
mock_instance.process_marketplace_csv_from_url = AsyncMock(return_value={
|
mock_instance.process_marketplace_csv_from_url = AsyncMock(
|
||||||
"imported": 8,
|
return_value={
|
||||||
"updated": 5,
|
"imported": 8,
|
||||||
"total_processed": 15,
|
"updated": 5,
|
||||||
"errors": 2
|
"total_processed": 15,
|
||||||
})
|
"errors": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Run background task
|
# Run background task
|
||||||
await process_marketplace_import(
|
await process_marketplace_import(
|
||||||
job_id,
|
job_id, "http://example.com/test.csv", "TestMarket", "TESTSHOP", 1000
|
||||||
"http://example.com/test.csv",
|
|
||||||
"TestMarket",
|
|
||||||
"TESTSHOP",
|
|
||||||
1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Re-query the job using the stored ID
|
# Re-query the job using the stored ID
|
||||||
updated_job = db.query(MarketplaceImportJob).filter(
|
updated_job = (
|
||||||
MarketplaceImportJob.id == job_id
|
db.query(MarketplaceImportJob)
|
||||||
).first()
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
assert updated_job is not None
|
assert updated_job is not None
|
||||||
assert updated_job.status == "completed_with_errors"
|
assert updated_job.status == "completed_with_errors"
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# tests/test_csv_processor.py
|
# tests/test_csv_processor.py
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import requests.exceptions
|
import requests.exceptions
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
import pandas as pd
|
|
||||||
from utils.csv_processor import CSVProcessor
|
from utils.csv_processor import CSVProcessor
|
||||||
|
|
||||||
|
|
||||||
@@ -11,7 +13,7 @@ class TestCSVProcessor:
|
|||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
self.processor = CSVProcessor()
|
self.processor = CSVProcessor()
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_download_csv_encoding_fallback(self, mock_get):
|
def test_download_csv_encoding_fallback(self, mock_get):
|
||||||
"""Test CSV download with encoding fallback"""
|
"""Test CSV download with encoding fallback"""
|
||||||
# Create content with special characters that would fail UTF-8 if not properly encoded
|
# Create content with special characters that would fail UTF-8 if not properly encoded
|
||||||
@@ -20,7 +22,7 @@ class TestCSVProcessor:
|
|||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
# Use latin-1 encoding which your method should try
|
# Use latin-1 encoding which your method should try
|
||||||
mock_response.content = special_content.encode('latin-1')
|
mock_response.content = special_content.encode("latin-1")
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
@@ -30,14 +32,16 @@ class TestCSVProcessor:
|
|||||||
assert isinstance(csv_content, str)
|
assert isinstance(csv_content, str)
|
||||||
assert "Café Product" in csv_content
|
assert "Café Product" in csv_content
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_download_csv_encoding_ignore_fallback(self, mock_get):
|
def test_download_csv_encoding_ignore_fallback(self, mock_get):
|
||||||
"""Test CSV download falls back to UTF-8 with error ignoring"""
|
"""Test CSV download falls back to UTF-8 with error ignoring"""
|
||||||
# Create problematic bytes that would fail most encoding attempts
|
# Create problematic bytes that would fail most encoding attempts
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
# Create bytes that will fail most encodings
|
# Create bytes that will fail most encodings
|
||||||
mock_response.content = b"product_id,title,price\nTEST001,\xff\xfe Product,10.99"
|
mock_response.content = (
|
||||||
|
b"product_id,title,price\nTEST001,\xff\xfe Product,10.99"
|
||||||
|
)
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
@@ -49,7 +53,7 @@ class TestCSVProcessor:
|
|||||||
assert "product_id,title,price" in csv_content
|
assert "product_id,title,price" in csv_content
|
||||||
assert "TEST001" in csv_content
|
assert "TEST001" in csv_content
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_download_csv_request_exception(self, mock_get):
|
def test_download_csv_request_exception(self, mock_get):
|
||||||
"""Test CSV download with request exception"""
|
"""Test CSV download with request exception"""
|
||||||
mock_get.side_effect = requests.exceptions.RequestException("Connection error")
|
mock_get.side_effect = requests.exceptions.RequestException("Connection error")
|
||||||
@@ -57,18 +61,20 @@ class TestCSVProcessor:
|
|||||||
with pytest.raises(requests.exceptions.RequestException):
|
with pytest.raises(requests.exceptions.RequestException):
|
||||||
self.processor.download_csv("http://example.com/test.csv")
|
self.processor.download_csv("http://example.com/test.csv")
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_download_csv_http_error(self, mock_get):
|
def test_download_csv_http_error(self, mock_get):
|
||||||
"""Test CSV download with HTTP error"""
|
"""Test CSV download with HTTP error"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 404
|
mock_response.status_code = 404
|
||||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found")
|
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||||
|
"404 Not Found"
|
||||||
|
)
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
with pytest.raises(requests.exceptions.HTTPError):
|
with pytest.raises(requests.exceptions.HTTPError):
|
||||||
self.processor.download_csv("http://example.com/nonexistent.csv")
|
self.processor.download_csv("http://example.com/nonexistent.csv")
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_download_csv_failure(self, mock_get):
|
def test_download_csv_failure(self, mock_get):
|
||||||
"""Test CSV download failure"""
|
"""Test CSV download failure"""
|
||||||
# Mock failed HTTP response
|
# Mock failed HTTP response
|
||||||
@@ -95,26 +101,24 @@ TEST002,Test Product 2,15.99,TestMarket"""
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_marketplace_csv_from_url(self, db):
|
async def test_process_marketplace_csv_from_url(self, db):
|
||||||
"""Test complete marketplace CSV processing"""
|
"""Test complete marketplace CSV processing"""
|
||||||
with patch.object(self.processor, 'download_csv') as mock_download, \
|
with patch.object(
|
||||||
patch.object(self.processor, 'parse_csv') as mock_parse:
|
self.processor, "download_csv"
|
||||||
|
) as mock_download, patch.object(self.processor, "parse_csv") as mock_parse:
|
||||||
# Mock successful download and parsing
|
# Mock successful download and parsing
|
||||||
mock_download.return_value = "csv_content"
|
mock_download.return_value = "csv_content"
|
||||||
mock_df = pd.DataFrame({
|
mock_df = pd.DataFrame(
|
||||||
"product_id": ["TEST001", "TEST002"],
|
{
|
||||||
"title": ["Product 1", "Product 2"],
|
"product_id": ["TEST001", "TEST002"],
|
||||||
"price": ["10.99", "15.99"],
|
"title": ["Product 1", "Product 2"],
|
||||||
"marketplace": ["TestMarket", "TestMarket"],
|
"price": ["10.99", "15.99"],
|
||||||
"shop_name": ["TestShop", "TestShop"]
|
"marketplace": ["TestMarket", "TestMarket"],
|
||||||
})
|
"shop_name": ["TestShop", "TestShop"],
|
||||||
|
}
|
||||||
|
)
|
||||||
mock_parse.return_value = mock_df
|
mock_parse.return_value = mock_df
|
||||||
|
|
||||||
|
|
||||||
result = await self.processor.process_marketplace_csv_from_url(
|
result = await self.processor.process_marketplace_csv_from_url(
|
||||||
"http://example.com/test.csv",
|
"http://example.com/test.csv", "TestMarket", "TestShop", 1000, db
|
||||||
"TestMarket",
|
|
||||||
"TestShop",
|
|
||||||
1000,
|
|
||||||
db
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "imported" in result
|
assert "imported" in result
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_data_validation.py
|
# tests/test_data_validation.py
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from utils.data_processing import GTINProcessor, PriceProcessor
|
from utils.data_processing import GTINProcessor, PriceProcessor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# tests/test_database.py
|
# tests/test_database.py
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from models.database_models import User, Product, Stock, Shop
|
|
||||||
|
from models.database_models import Product, Shop, Stock, User
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseModels:
|
class TestDatabaseModels:
|
||||||
@@ -12,7 +13,7 @@ class TestDatabaseModels:
|
|||||||
username="dbtest",
|
username="dbtest",
|
||||||
hashed_password="hashed_password_123",
|
hashed_password="hashed_password_123",
|
||||||
role="user",
|
role="user",
|
||||||
is_active=True
|
is_active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(user)
|
db.add(user)
|
||||||
@@ -36,7 +37,7 @@ class TestDatabaseModels:
|
|||||||
gtin="1234567890123",
|
gtin="1234567890123",
|
||||||
availability="in stock",
|
availability="in stock",
|
||||||
marketplace="TestDB",
|
marketplace="TestDB",
|
||||||
shop_name="DBTestShop"
|
shop_name="DBTestShop",
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(product)
|
db.add(product)
|
||||||
@@ -49,11 +50,7 @@ class TestDatabaseModels:
|
|||||||
|
|
||||||
def test_stock_model(self, db):
|
def test_stock_model(self, db):
|
||||||
"""Test Stock model creation"""
|
"""Test Stock model creation"""
|
||||||
stock = Stock(
|
stock = Stock(gtin="1234567890123", location="DB_WAREHOUSE", quantity=150)
|
||||||
gtin="1234567890123",
|
|
||||||
location="DB_WAREHOUSE",
|
|
||||||
quantity=150
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(stock)
|
db.add(stock)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -72,7 +69,7 @@ class TestDatabaseModels:
|
|||||||
description="Testing shop model",
|
description="Testing shop model",
|
||||||
owner_id=test_user.id,
|
owner_id=test_user.id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=False
|
is_verified=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(shop)
|
db.add(shop)
|
||||||
|
|||||||
@@ -5,24 +5,25 @@ import pytest
|
|||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
def test_invalid_json(self, client, auth_headers):
|
def test_invalid_json(self, client, auth_headers):
|
||||||
"""Test handling of invalid JSON"""
|
"""Test handling of invalid JSON"""
|
||||||
response = client.post("/api/v1/product",
|
response = client.post(
|
||||||
headers=auth_headers,
|
"/api/v1/product", headers=auth_headers, content="invalid json"
|
||||||
content="invalid json")
|
)
|
||||||
|
|
||||||
assert response.status_code == 422 # Validation error
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
def test_missing_required_fields(self, client, auth_headers):
|
def test_missing_required_fields(self, client, auth_headers):
|
||||||
"""Test handling of missing required fields"""
|
"""Test handling of missing required fields"""
|
||||||
response = client.post("/api/v1/product",
|
response = client.post(
|
||||||
headers=auth_headers,
|
"/api/v1/product", headers=auth_headers, json={"title": "Test"}
|
||||||
json={"title": "Test"}) # Missing product_id
|
) # Missing product_id
|
||||||
|
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
def test_invalid_authentication(self, client):
|
def test_invalid_authentication(self, client):
|
||||||
"""Test handling of invalid authentication"""
|
"""Test handling of invalid authentication"""
|
||||||
response = client.get("/api/v1/product",
|
response = client.get(
|
||||||
headers={"Authorization": "Bearer invalid_token"})
|
"/api/v1/product", headers={"Authorization": "Bearer invalid_token"}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 401 # Token is not valid
|
assert response.status_code == 401 # Token is not valid
|
||||||
|
|
||||||
@@ -38,8 +39,10 @@ class TestErrorHandling:
|
|||||||
"""Test handling of duplicate resource creation"""
|
"""Test handling of duplicate resource creation"""
|
||||||
product_data = {
|
product_data = {
|
||||||
"product_id": test_product.product_id, # Duplicate ID
|
"product_id": test_product.product_id, # Duplicate ID
|
||||||
"title": "Another Product"
|
"title": "Another Product",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/product", headers=auth_headers, json=product_data)
|
response = client.post(
|
||||||
|
"/api/v1/product", headers=auth_headers, json=product_data
|
||||||
|
)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
# tests/test_export.py
|
# tests/test_export.py
|
||||||
import pytest
|
|
||||||
import csv
|
import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from models.database_models import Product
|
from models.database_models import Product
|
||||||
|
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ class TestExportFunctionality:
|
|||||||
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
||||||
|
|
||||||
# Parse CSV content
|
# Parse CSV content
|
||||||
csv_content = response.content.decode('utf-8')
|
csv_content = response.content.decode("utf-8")
|
||||||
csv_reader = csv.reader(StringIO(csv_content))
|
csv_reader = csv.reader(StringIO(csv_content))
|
||||||
|
|
||||||
# Check header row
|
# Check header row
|
||||||
@@ -35,10 +36,12 @@ class TestExportFunctionality:
|
|||||||
db.add_all(products)
|
db.add_all(products)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
response = client.get("/api/v1/export-csv?marketplace=Amazon", headers=auth_headers)
|
response = client.get(
|
||||||
|
"/api/v1/export-csv?marketplace=Amazon", headers=auth_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
csv_content = response.content.decode('utf-8')
|
csv_content = response.content.decode("utf-8")
|
||||||
assert "EXP1" in csv_content
|
assert "EXP1" in csv_content
|
||||||
assert "EXP2" not in csv_content # Should be filtered out
|
assert "EXP2" not in csv_content # Should be filtered out
|
||||||
|
|
||||||
@@ -50,7 +53,7 @@ class TestExportFunctionality:
|
|||||||
product = Product(
|
product = Product(
|
||||||
product_id=f"PERF{i:04d}",
|
product_id=f"PERF{i:04d}",
|
||||||
title=f"Performance Product {i}",
|
title=f"Performance Product {i}",
|
||||||
marketplace="Performance"
|
marketplace="Performance",
|
||||||
)
|
)
|
||||||
products.append(product)
|
products.append(product)
|
||||||
|
|
||||||
@@ -58,10 +61,10 @@ class TestExportFunctionality:
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response = client.get("/api/v1/export-csv", headers=auth_headers)
|
response = client.get("/api/v1/export-csv", headers=auth_headers)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert end_time - start_time < 10.0 # Should complete within 10 seconds
|
assert end_time - start_time < 10.0 # Should complete within 10 seconds
|
||||||
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_filtering.py
|
# tests/test_filtering.py
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from models.database_models import Product
|
from models.database_models import Product
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +40,9 @@ class TestFiltering:
|
|||||||
db.add_all(products)
|
db.add_all(products)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
response = client.get("/api/v1/product?marketplace=Amazon", headers=auth_headers)
|
response = client.get(
|
||||||
|
"/api/v1/product?marketplace=Amazon", headers=auth_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["total"] == 2
|
assert data["total"] == 2
|
||||||
@@ -47,9 +50,17 @@ class TestFiltering:
|
|||||||
def test_product_search_filter(self, client, auth_headers, db):
|
def test_product_search_filter(self, client, auth_headers, db):
|
||||||
"""Test searching products by text"""
|
"""Test searching products by text"""
|
||||||
products = [
|
products = [
|
||||||
Product(product_id="SEARCH1", title="Apple iPhone", description="Smartphone"),
|
Product(
|
||||||
Product(product_id="SEARCH2", title="Samsung Galaxy", description="Android phone"),
|
product_id="SEARCH1", title="Apple iPhone", description="Smartphone"
|
||||||
Product(product_id="SEARCH3", title="iPad Tablet", description="Apple tablet"),
|
),
|
||||||
|
Product(
|
||||||
|
product_id="SEARCH2",
|
||||||
|
title="Samsung Galaxy",
|
||||||
|
description="Android phone",
|
||||||
|
),
|
||||||
|
Product(
|
||||||
|
product_id="SEARCH3", title="iPad Tablet", description="Apple tablet"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
db.add_all(products)
|
db.add_all(products)
|
||||||
@@ -70,16 +81,33 @@ class TestFiltering:
|
|||||||
def test_combined_filters(self, client, auth_headers, db):
|
def test_combined_filters(self, client, auth_headers, db):
|
||||||
"""Test combining multiple filters"""
|
"""Test combining multiple filters"""
|
||||||
products = [
|
products = [
|
||||||
Product(product_id="COMBO1", title="Apple iPhone", brand="Apple", marketplace="Amazon"),
|
Product(
|
||||||
Product(product_id="COMBO2", title="Apple iPad", brand="Apple", marketplace="eBay"),
|
product_id="COMBO1",
|
||||||
Product(product_id="COMBO3", title="Samsung Phone", brand="Samsung", marketplace="Amazon"),
|
title="Apple iPhone",
|
||||||
|
brand="Apple",
|
||||||
|
marketplace="Amazon",
|
||||||
|
),
|
||||||
|
Product(
|
||||||
|
product_id="COMBO2",
|
||||||
|
title="Apple iPad",
|
||||||
|
brand="Apple",
|
||||||
|
marketplace="eBay",
|
||||||
|
),
|
||||||
|
Product(
|
||||||
|
product_id="COMBO3",
|
||||||
|
title="Samsung Phone",
|
||||||
|
brand="Samsung",
|
||||||
|
marketplace="Amazon",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
db.add_all(products)
|
db.add_all(products)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Filter by brand AND marketplace
|
# Filter by brand AND marketplace
|
||||||
response = client.get("/api/v1/product?brand=Apple&marketplace=Amazon", headers=auth_headers)
|
response = client.get(
|
||||||
|
"/api/v1/product?brand=Apple&marketplace=Amazon", headers=auth_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["total"] == 1 # Only iPhone matches both
|
assert data["total"] == 1 # Only iPhone matches both
|
||||||
|
|||||||
@@ -14,10 +14,12 @@ class TestIntegrationFlows:
|
|||||||
"brand": "FlowBrand",
|
"brand": "FlowBrand",
|
||||||
"gtin": "1111222233334",
|
"gtin": "1111222233334",
|
||||||
"availability": "in stock",
|
"availability": "in stock",
|
||||||
"marketplace": "TestFlow"
|
"marketplace": "TestFlow",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/product", headers=auth_headers, json=product_data)
|
response = client.post(
|
||||||
|
"/api/v1/product", headers=auth_headers, json=product_data
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
product = response.json()
|
product = response.json()
|
||||||
|
|
||||||
@@ -25,26 +27,33 @@ class TestIntegrationFlows:
|
|||||||
stock_data = {
|
stock_data = {
|
||||||
"gtin": product["gtin"],
|
"gtin": product["gtin"],
|
||||||
"location": "MAIN_WAREHOUSE",
|
"location": "MAIN_WAREHOUSE",
|
||||||
"quantity": 50
|
"quantity": 50,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
|
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# 3. Get product with stock info
|
# 3. Get product with stock info
|
||||||
response = client.get(f"/api/v1/product/{product['product_id']}", headers=auth_headers)
|
response = client.get(
|
||||||
|
f"/api/v1/product/{product['product_id']}", headers=auth_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
product_detail = response.json()
|
product_detail = response.json()
|
||||||
assert product_detail["stock_info"]["total_quantity"] == 50
|
assert product_detail["stock_info"]["total_quantity"] == 50
|
||||||
|
|
||||||
# 4. Update product
|
# 4. Update product
|
||||||
update_data = {"title": "Updated Integration Test Product"}
|
update_data = {"title": "Updated Integration Test Product"}
|
||||||
response = client.put(f"/api/v1/product/{product['product_id']}",
|
response = client.put(
|
||||||
headers=auth_headers, json=update_data)
|
f"/api/v1/product/{product['product_id']}",
|
||||||
|
headers=auth_headers,
|
||||||
|
json=update_data,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# 5. Search for product
|
# 5. Search for product
|
||||||
response = client.get("/api/v1/product?search=Updated Integration", headers=auth_headers)
|
response = client.get(
|
||||||
|
"/api/v1/product?search=Updated Integration", headers=auth_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["total"] == 1
|
assert response.json()["total"] == 1
|
||||||
|
|
||||||
@@ -54,7 +63,7 @@ class TestIntegrationFlows:
|
|||||||
shop_data = {
|
shop_data = {
|
||||||
"shop_code": "FLOWSHOP",
|
"shop_code": "FLOWSHOP",
|
||||||
"shop_name": "Integration Flow Shop",
|
"shop_name": "Integration Flow Shop",
|
||||||
"description": "Test shop for integration"
|
"description": "Test shop for integration",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
|
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
|
||||||
@@ -66,10 +75,12 @@ class TestIntegrationFlows:
|
|||||||
"product_id": "SHOPFLOW001",
|
"product_id": "SHOPFLOW001",
|
||||||
"title": "Shop Flow Product",
|
"title": "Shop Flow Product",
|
||||||
"price": "15.99",
|
"price": "15.99",
|
||||||
"marketplace": "ShopFlow"
|
"marketplace": "ShopFlow",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/product", headers=auth_headers, json=product_data)
|
response = client.post(
|
||||||
|
"/api/v1/product", headers=auth_headers, json=product_data
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
product = response.json()
|
product = response.json()
|
||||||
|
|
||||||
@@ -86,28 +97,28 @@ class TestIntegrationFlows:
|
|||||||
location = "TEST_WAREHOUSE"
|
location = "TEST_WAREHOUSE"
|
||||||
|
|
||||||
# 1. Set initial stock
|
# 1. Set initial stock
|
||||||
response = client.post("/api/v1/stock", headers=auth_headers, json={
|
response = client.post(
|
||||||
"gtin": gtin,
|
"/api/v1/stock",
|
||||||
"location": location,
|
headers=auth_headers,
|
||||||
"quantity": 100
|
json={"gtin": gtin, "location": location, "quantity": 100},
|
||||||
})
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# 2. Add more stock
|
# 2. Add more stock
|
||||||
response = client.post("/api/v1/stock/add", headers=auth_headers, json={
|
response = client.post(
|
||||||
"gtin": gtin,
|
"/api/v1/stock/add",
|
||||||
"location": location,
|
headers=auth_headers,
|
||||||
"quantity": 25
|
json={"gtin": gtin, "location": location, "quantity": 25},
|
||||||
})
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["quantity"] == 125
|
assert response.json()["quantity"] == 125
|
||||||
|
|
||||||
# 3. Remove some stock
|
# 3. Remove some stock
|
||||||
response = client.post("/api/v1/stock/remove", headers=auth_headers, json={
|
response = client.post(
|
||||||
"gtin": gtin,
|
"/api/v1/stock/remove",
|
||||||
"location": location,
|
headers=auth_headers,
|
||||||
"quantity": 30
|
json={"gtin": gtin, "location": location, "quantity": 30},
|
||||||
})
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["quantity"] == 95
|
assert response.json()["quantity"] == 95
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# tests/test_marketplace.py
|
# tests/test_marketplace.py
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, AsyncMock
|
|
||||||
|
|
||||||
|
|
||||||
class TestMarketplaceAPI:
|
class TestMarketplaceAPI:
|
||||||
@@ -10,11 +11,12 @@ class TestMarketplaceAPI:
|
|||||||
import_data = {
|
import_data = {
|
||||||
"url": "https://example.com/products.csv",
|
"url": "https://example.com/products.csv",
|
||||||
"marketplace": "TestMarket",
|
"marketplace": "TestMarket",
|
||||||
"shop_code": test_shop.shop_code
|
"shop_code": test_shop.shop_code,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/marketplace/import-product",
|
response = client.post(
|
||||||
headers=auth_headers, json=import_data)
|
"/api/v1/marketplace/import-product", headers=auth_headers, json=import_data
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -29,11 +31,12 @@ class TestMarketplaceAPI:
|
|||||||
import_data = {
|
import_data = {
|
||||||
"url": "https://example.com/products.csv",
|
"url": "https://example.com/products.csv",
|
||||||
"marketplace": "TestMarket",
|
"marketplace": "TestMarket",
|
||||||
"shop_code": "NONEXISTENT"
|
"shop_code": "NONEXISTENT",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/marketplace/import-product",
|
response = client.post(
|
||||||
headers=auth_headers, json=import_data)
|
"/api/v1/marketplace/import-product", headers=auth_headers, json=import_data
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
assert "Shop not found" in response.json()["detail"]
|
assert "Shop not found" in response.json()["detail"]
|
||||||
@@ -49,4 +52,3 @@ class TestMarketplaceAPI:
|
|||||||
"""Test that marketplace endpoints require authentication"""
|
"""Test that marketplace endpoints require authentication"""
|
||||||
response = client.get("/api/v1/marketplace/import-jobs")
|
response = client.get("/api/v1/marketplace/import-jobs")
|
||||||
assert response.status_code == 401 # No authorization header
|
assert response.status_code == 401 # No authorization header
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
# tests/test_marketplace_service.py
|
# tests/test_marketplace_service.py
|
||||||
import pytest
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from app.services.marketplace_service import MarketplaceService
|
from app.services.marketplace_service import MarketplaceService
|
||||||
from models.api_models import MarketplaceImportRequest
|
from models.api_models import MarketplaceImportRequest
|
||||||
from models.database_models import MarketplaceImportJob, Shop, User
|
from models.database_models import MarketplaceImportJob, Shop, User
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class TestMarketplaceService:
|
class TestMarketplaceService:
|
||||||
@@ -22,7 +24,9 @@ class TestMarketplaceService:
|
|||||||
assert result.shop_code == test_shop.shop_code
|
assert result.shop_code == test_shop.shop_code
|
||||||
assert result.owner_id == test_user.id
|
assert result.owner_id == test_user.id
|
||||||
|
|
||||||
def test_validate_shop_access_admin_can_access_any_shop(self, db, test_shop, test_admin):
|
def test_validate_shop_access_admin_can_access_any_shop(
|
||||||
|
self, db, test_shop, test_admin
|
||||||
|
):
|
||||||
"""Test that admin users can access any shop"""
|
"""Test that admin users can access any shop"""
|
||||||
result = self.service.validate_shop_access(db, test_shop.shop_code, test_admin)
|
result = self.service.validate_shop_access(db, test_shop.shop_code, test_admin)
|
||||||
|
|
||||||
@@ -33,7 +37,9 @@ class TestMarketplaceService:
|
|||||||
with pytest.raises(ValueError, match="Shop not found"):
|
with pytest.raises(ValueError, match="Shop not found"):
|
||||||
self.service.validate_shop_access(db, "NONEXISTENT", test_user)
|
self.service.validate_shop_access(db, "NONEXISTENT", test_user)
|
||||||
|
|
||||||
def test_validate_shop_access_permission_denied(self, db, test_shop, test_user, other_user):
|
def test_validate_shop_access_permission_denied(
|
||||||
|
self, db, test_shop, test_user, other_user
|
||||||
|
):
|
||||||
"""Test shop access validation when user doesn't own the shop"""
|
"""Test shop access validation when user doesn't own the shop"""
|
||||||
# Set the shop owner to a different user
|
# Set the shop owner to a different user
|
||||||
test_shop.owner_id = other_user.id
|
test_shop.owner_id = other_user.id
|
||||||
@@ -52,7 +58,7 @@ class TestMarketplaceService:
|
|||||||
url="https://example.com/products.csv",
|
url="https://example.com/products.csv",
|
||||||
marketplace="Amazon",
|
marketplace="Amazon",
|
||||||
shop_code=test_shop.shop_code,
|
shop_code=test_shop.shop_code,
|
||||||
batch_size=1000
|
batch_size=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.create_import_job(db, request, test_user)
|
result = self.service.create_import_job(db, request, test_user)
|
||||||
@@ -60,7 +66,7 @@ class TestMarketplaceService:
|
|||||||
assert result.marketplace == "Amazon"
|
assert result.marketplace == "Amazon"
|
||||||
# Check the correct field based on your model
|
# Check the correct field based on your model
|
||||||
assert result.shop_id == test_shop.id # Changed from shop_code to shop_id
|
assert result.shop_id == test_shop.id # Changed from shop_code to shop_id
|
||||||
assert result.user_id == test_user.id if hasattr(result, 'user_id') else True
|
assert result.user_id == test_user.id if hasattr(result, "user_id") else True
|
||||||
assert result.status == "pending"
|
assert result.status == "pending"
|
||||||
assert result.source_url == "https://example.com/products.csv"
|
assert result.source_url == "https://example.com/products.csv"
|
||||||
|
|
||||||
@@ -70,7 +76,7 @@ class TestMarketplaceService:
|
|||||||
url="https://example.com/products.csv",
|
url="https://example.com/products.csv",
|
||||||
marketplace="Amazon",
|
marketplace="Amazon",
|
||||||
shop_code="INVALID_SHOP",
|
shop_code="INVALID_SHOP",
|
||||||
batch_size=1000
|
batch_size=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Shop not found"):
|
with pytest.raises(ValueError, match="Shop not found"):
|
||||||
@@ -78,16 +84,22 @@ class TestMarketplaceService:
|
|||||||
|
|
||||||
def test_get_import_job_by_id_success(self, db, test_marketplace_job, test_user):
|
def test_get_import_job_by_id_success(self, db, test_marketplace_job, test_user):
|
||||||
"""Test getting import job by ID for job owner"""
|
"""Test getting import job by ID for job owner"""
|
||||||
result = self.service.get_import_job_by_id(db, test_marketplace_job.id, test_user)
|
result = self.service.get_import_job_by_id(
|
||||||
|
db, test_marketplace_job.id, test_user
|
||||||
|
)
|
||||||
|
|
||||||
assert result.id == test_marketplace_job.id
|
assert result.id == test_marketplace_job.id
|
||||||
# Check user_id if the field exists
|
# Check user_id if the field exists
|
||||||
if hasattr(result, 'user_id'):
|
if hasattr(result, "user_id"):
|
||||||
assert result.user_id == test_user.id
|
assert result.user_id == test_user.id
|
||||||
|
|
||||||
def test_get_import_job_by_id_admin_access(self, db, test_marketplace_job, test_admin):
|
def test_get_import_job_by_id_admin_access(
|
||||||
|
self, db, test_marketplace_job, test_admin
|
||||||
|
):
|
||||||
"""Test that admin can access any import job"""
|
"""Test that admin can access any import job"""
|
||||||
result = self.service.get_import_job_by_id(db, test_marketplace_job.id, test_admin)
|
result = self.service.get_import_job_by_id(
|
||||||
|
db, test_marketplace_job.id, test_admin
|
||||||
|
)
|
||||||
|
|
||||||
assert result.id == test_marketplace_job.id
|
assert result.id == test_marketplace_job.id
|
||||||
|
|
||||||
@@ -96,7 +108,9 @@ class TestMarketplaceService:
|
|||||||
with pytest.raises(ValueError, match="Marketplace import job not found"):
|
with pytest.raises(ValueError, match="Marketplace import job not found"):
|
||||||
self.service.get_import_job_by_id(db, 99999, test_user)
|
self.service.get_import_job_by_id(db, 99999, test_user)
|
||||||
|
|
||||||
def test_get_import_job_by_id_access_denied(self, db, test_marketplace_job, other_user):
|
def test_get_import_job_by_id_access_denied(
|
||||||
|
self, db, test_marketplace_job, other_user
|
||||||
|
):
|
||||||
"""Test access denied when user doesn't own the job"""
|
"""Test access denied when user doesn't own the job"""
|
||||||
with pytest.raises(PermissionError, match="Access denied to this import job"):
|
with pytest.raises(PermissionError, match="Access denied to this import job"):
|
||||||
self.service.get_import_job_by_id(db, test_marketplace_job.id, other_user)
|
self.service.get_import_job_by_id(db, test_marketplace_job.id, other_user)
|
||||||
@@ -108,7 +122,7 @@ class TestMarketplaceService:
|
|||||||
assert len(jobs) >= 1
|
assert len(jobs) >= 1
|
||||||
assert any(job.id == test_marketplace_job.id for job in jobs)
|
assert any(job.id == test_marketplace_job.id for job in jobs)
|
||||||
# Check user_id if the field exists
|
# Check user_id if the field exists
|
||||||
if hasattr(test_marketplace_job, 'user_id'):
|
if hasattr(test_marketplace_job, "user_id"):
|
||||||
assert test_marketplace_job.user_id == test_user.id
|
assert test_marketplace_job.user_id == test_user.id
|
||||||
|
|
||||||
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_job, test_admin):
|
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_job, test_admin):
|
||||||
@@ -118,7 +132,9 @@ class TestMarketplaceService:
|
|||||||
assert len(jobs) >= 1
|
assert len(jobs) >= 1
|
||||||
assert any(job.id == test_marketplace_job.id for job in jobs)
|
assert any(job.id == test_marketplace_job.id for job in jobs)
|
||||||
|
|
||||||
def test_get_import_jobs_with_marketplace_filter(self, db, test_marketplace_job, test_user):
|
def test_get_import_jobs_with_marketplace_filter(
|
||||||
|
self, db, test_marketplace_job, test_user
|
||||||
|
):
|
||||||
"""Test getting import jobs with marketplace filter"""
|
"""Test getting import jobs with marketplace filter"""
|
||||||
jobs = self.service.get_import_jobs(
|
jobs = self.service.get_import_jobs(
|
||||||
db, test_user, marketplace=test_marketplace_job.marketplace
|
db, test_user, marketplace=test_marketplace_job.marketplace
|
||||||
@@ -143,7 +159,7 @@ class TestMarketplaceService:
|
|||||||
imported_count=0,
|
imported_count=0,
|
||||||
updated_count=0,
|
updated_count=0,
|
||||||
total_processed=0,
|
total_processed=0,
|
||||||
error_count=0
|
error_count=0,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -159,7 +175,7 @@ class TestMarketplaceService:
|
|||||||
test_marketplace_job.id,
|
test_marketplace_job.id,
|
||||||
"completed",
|
"completed",
|
||||||
imported_count=100,
|
imported_count=100,
|
||||||
total_processed=100
|
total_processed=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
@@ -211,7 +227,7 @@ class TestMarketplaceService:
|
|||||||
imported_count=0,
|
imported_count=0,
|
||||||
updated_count=0,
|
updated_count=0,
|
||||||
total_processed=0,
|
total_processed=0,
|
||||||
error_count=0
|
error_count=0,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -222,13 +238,17 @@ class TestMarketplaceService:
|
|||||||
assert result.status == "cancelled"
|
assert result.status == "cancelled"
|
||||||
assert result.completed_at is not None
|
assert result.completed_at is not None
|
||||||
|
|
||||||
def test_cancel_import_job_invalid_status(self, db, test_marketplace_job, test_user):
|
def test_cancel_import_job_invalid_status(
|
||||||
|
self, db, test_marketplace_job, test_user
|
||||||
|
):
|
||||||
"""Test cancelling a job that can't be cancelled"""
|
"""Test cancelling a job that can't be cancelled"""
|
||||||
# Set job status to completed
|
# Set job status to completed
|
||||||
test_marketplace_job.status = "completed"
|
test_marketplace_job.status = "completed"
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Cannot cancel job with status: completed"):
|
with pytest.raises(
|
||||||
|
ValueError, match="Cannot cancel job with status: completed"
|
||||||
|
):
|
||||||
self.service.cancel_import_job(db, test_marketplace_job.id, test_user)
|
self.service.cancel_import_job(db, test_marketplace_job.id, test_user)
|
||||||
|
|
||||||
def test_delete_import_job_success(self, db, test_user, test_shop):
|
def test_delete_import_job_success(self, db, test_user, test_shop):
|
||||||
@@ -246,7 +266,7 @@ class TestMarketplaceService:
|
|||||||
imported_count=0,
|
imported_count=0,
|
||||||
updated_count=0,
|
updated_count=0,
|
||||||
total_processed=0,
|
total_processed=0,
|
||||||
error_count=0
|
error_count=0,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -258,7 +278,11 @@ class TestMarketplaceService:
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify the job is actually deleted
|
# Verify the job is actually deleted
|
||||||
deleted_job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
|
deleted_job = (
|
||||||
|
db.query(MarketplaceImportJob)
|
||||||
|
.filter(MarketplaceImportJob.id == job_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
assert deleted_job is None
|
assert deleted_job is None
|
||||||
|
|
||||||
def test_delete_import_job_invalid_status(self, db, test_user, test_shop):
|
def test_delete_import_job_invalid_status(self, db, test_user, test_shop):
|
||||||
@@ -276,7 +300,7 @@ class TestMarketplaceService:
|
|||||||
imported_count=0,
|
imported_count=0,
|
||||||
updated_count=0,
|
updated_count=0,
|
||||||
total_processed=0,
|
total_processed=0,
|
||||||
error_count=0
|
error_count=0,
|
||||||
)
|
)
|
||||||
db.add(job)
|
db.add(job)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
# tests/test_middleware.py
|
# tests/test_middleware.py
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from middleware.rate_limiter import RateLimiter
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from middleware.auth import AuthManager
|
from middleware.auth import AuthManager
|
||||||
|
from middleware.rate_limiter import RateLimiter
|
||||||
|
|
||||||
|
|
||||||
class TestRateLimiter:
|
class TestRateLimiter:
|
||||||
@@ -12,11 +14,17 @@ class TestRateLimiter:
|
|||||||
client_id = "test_client"
|
client_id = "test_client"
|
||||||
|
|
||||||
# Should allow first request
|
# Should allow first request
|
||||||
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
|
assert (
|
||||||
|
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
# Should allow subsequent requests within limit
|
# Should allow subsequent requests within limit
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
|
assert (
|
||||||
|
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
def test_rate_limiter_blocks_excess_requests(self):
|
def test_rate_limiter_blocks_excess_requests(self):
|
||||||
"""Test rate limiter blocks requests exceeding limit"""
|
"""Test rate limiter blocks requests exceeding limit"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_pagination.py
|
# tests/test_pagination.py
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from models.database_models import Product
|
from models.database_models import Product
|
||||||
|
|
||||||
|
|
||||||
@@ -12,7 +13,7 @@ class TestPagination:
|
|||||||
product = Product(
|
product = Product(
|
||||||
product_id=f"PAGE{i:03d}",
|
product_id=f"PAGE{i:03d}",
|
||||||
title=f"Pagination Test Product {i}",
|
title=f"Pagination Test Product {i}",
|
||||||
marketplace="PaginationTest"
|
marketplace="PaginationTest",
|
||||||
)
|
)
|
||||||
products.append(product)
|
products.append(product)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# tests/test_performance.py
|
# tests/test_performance.py
|
||||||
import pytest
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from models.database_models import Product
|
from models.database_models import Product
|
||||||
|
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ class TestPerformance:
|
|||||||
product_id=f"PERF{i:03d}",
|
product_id=f"PERF{i:03d}",
|
||||||
title=f"Performance Test Product {i}",
|
title=f"Performance Test Product {i}",
|
||||||
price=f"{i}.99",
|
price=f"{i}.99",
|
||||||
marketplace="Performance"
|
marketplace="Performance",
|
||||||
)
|
)
|
||||||
products.append(product)
|
products.append(product)
|
||||||
|
|
||||||
@@ -41,7 +42,7 @@ class TestPerformance:
|
|||||||
title=f"Searchable Product {i}",
|
title=f"Searchable Product {i}",
|
||||||
description=f"This is a searchable product number {i}",
|
description=f"This is a searchable product number {i}",
|
||||||
brand="SearchBrand",
|
brand="SearchBrand",
|
||||||
marketplace="SearchMarket"
|
marketplace="SearchMarket",
|
||||||
)
|
)
|
||||||
products.append(product)
|
products.append(product)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ class TestProductsAPI:
|
|||||||
assert response.json()["total"] == 1
|
assert response.json()["total"] == 1
|
||||||
|
|
||||||
# Test marketplace filter
|
# Test marketplace filter
|
||||||
response = client.get("/api/v1/product?marketplace=Letzshop", headers=auth_headers)
|
response = client.get(
|
||||||
|
"/api/v1/product?marketplace=Letzshop", headers=auth_headers
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["total"] == 1
|
assert response.json()["total"] == 1
|
||||||
|
|
||||||
@@ -50,10 +52,12 @@ class TestProductsAPI:
|
|||||||
"brand": "NewBrand",
|
"brand": "NewBrand",
|
||||||
"gtin": "9876543210987",
|
"gtin": "9876543210987",
|
||||||
"availability": "in stock",
|
"availability": "in stock",
|
||||||
"marketplace": "Amazon"
|
"marketplace": "Amazon",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/product", headers=auth_headers, json=product_data)
|
response = client.post(
|
||||||
|
"/api/v1/product", headers=auth_headers, json=product_data
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -71,10 +75,12 @@ class TestProductsAPI:
|
|||||||
"brand": "NewBrand",
|
"brand": "NewBrand",
|
||||||
"gtin": "9876543210987",
|
"gtin": "9876543210987",
|
||||||
"availability": "in stock",
|
"availability": "in stock",
|
||||||
"marketplace": "Amazon"
|
"marketplace": "Amazon",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/product", headers=auth_headers, json=product_data)
|
response = client.post(
|
||||||
|
"/api/v1/product", headers=auth_headers, json=product_data
|
||||||
|
)
|
||||||
|
|
||||||
# Debug output
|
# Debug output
|
||||||
print(f"Status Code: {response.status_code}")
|
print(f"Status Code: {response.status_code}")
|
||||||
@@ -89,7 +95,9 @@ class TestProductsAPI:
|
|||||||
|
|
||||||
def test_get_product_by_id(self, client, auth_headers, test_product):
|
def test_get_product_by_id(self, client, auth_headers, test_product):
|
||||||
"""Test getting specific product"""
|
"""Test getting specific product"""
|
||||||
response = client.get(f"/api/v1/product/{test_product.product_id}", headers=auth_headers)
|
response = client.get(
|
||||||
|
f"/api/v1/product/{test_product.product_id}", headers=auth_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -104,15 +112,12 @@ class TestProductsAPI:
|
|||||||
|
|
||||||
def test_update_product(self, client, auth_headers, test_product):
|
def test_update_product(self, client, auth_headers, test_product):
|
||||||
"""Test updating product"""
|
"""Test updating product"""
|
||||||
update_data = {
|
update_data = {"title": "Updated Product Title", "price": "25.99"}
|
||||||
"title": "Updated Product Title",
|
|
||||||
"price": "25.99"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.put(
|
response = client.put(
|
||||||
f"/api/v1/product/{test_product.product_id}",
|
f"/api/v1/product/{test_product.product_id}",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
json=update_data
|
json=update_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -123,8 +128,7 @@ class TestProductsAPI:
|
|||||||
def test_delete_product(self, client, auth_headers, test_product):
|
def test_delete_product(self, client, auth_headers, test_product):
|
||||||
"""Test deleting product"""
|
"""Test deleting product"""
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"/api/v1/product/{test_product.product_id}",
|
f"/api/v1/product/{test_product.product_id}", headers=auth_headers
|
||||||
headers=auth_headers
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_product_service.py
|
# tests/test_product_service.py
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.services.product_service import ProductService
|
from app.services.product_service import ProductService
|
||||||
from models.api_models import ProductCreate
|
from models.api_models import ProductCreate
|
||||||
from models.database_models import Product
|
from models.database_models import Product
|
||||||
@@ -16,7 +17,7 @@ class TestProductService:
|
|||||||
title="Service Test Product",
|
title="Service Test Product",
|
||||||
gtin="1234567890123",
|
gtin="1234567890123",
|
||||||
price="19.99",
|
price="19.99",
|
||||||
marketplace="TestMarket"
|
marketplace="TestMarket",
|
||||||
)
|
)
|
||||||
|
|
||||||
product = self.service.create_product(db, product_data)
|
product = self.service.create_product(db, product_data)
|
||||||
@@ -31,7 +32,7 @@ class TestProductService:
|
|||||||
product_id="SVC002",
|
product_id="SVC002",
|
||||||
title="Service Test Product",
|
title="Service Test Product",
|
||||||
gtin="invalid_gtin",
|
gtin="invalid_gtin",
|
||||||
price="19.99"
|
price="19.99",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
||||||
@@ -39,10 +40,7 @@ class TestProductService:
|
|||||||
|
|
||||||
def test_get_products_with_filters(self, db, test_product):
|
def test_get_products_with_filters(self, db, test_product):
|
||||||
"""Test getting products with various filters"""
|
"""Test getting products with various filters"""
|
||||||
products, total = self.service.get_products_with_filters(
|
products, total = self.service.get_products_with_filters(db, brand="TestBrand")
|
||||||
db,
|
|
||||||
brand="TestBrand"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert total == 1
|
assert total == 1
|
||||||
assert len(products) == 1
|
assert len(products) == 1
|
||||||
@@ -51,10 +49,8 @@ class TestProductService:
|
|||||||
def test_get_products_with_search(self, db, test_product):
|
def test_get_products_with_search(self, db, test_product):
|
||||||
"""Test getting products with search"""
|
"""Test getting products with search"""
|
||||||
products, total = self.service.get_products_with_filters(
|
products, total = self.service.get_products_with_filters(
|
||||||
db,
|
db, search="Test Product"
|
||||||
search="Test Product"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert total == 1
|
assert total == 1
|
||||||
assert len(products) == 1
|
assert len(products) == 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# tests/test_security.py
|
# tests/test_security.py
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
|
|
||||||
class TestSecurity:
|
class TestSecurity:
|
||||||
@@ -10,7 +11,9 @@ class TestSecurity:
|
|||||||
|
|
||||||
response = client.get("/api/v1/debug-bearer")
|
response = client.get("/api/v1/debug-bearer")
|
||||||
print(f"Direct Bearer - Status: {response.status_code}")
|
print(f"Direct Bearer - Status: {response.status_code}")
|
||||||
print(f"Direct Bearer - Response: {response.json() if response.content else 'No content'}")
|
print(
|
||||||
|
f"Direct Bearer - Response: {response.json() if response.content else 'No content'}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_debug_dependencies(self, client):
|
def test_debug_dependencies(self, client):
|
||||||
"""Debug the dependency chain step by step"""
|
"""Debug the dependency chain step by step"""
|
||||||
@@ -24,7 +27,9 @@ class TestSecurity:
|
|||||||
print(f"Admin endpoint - Raw: {response.content}")
|
print(f"Admin endpoint - Raw: {response.content}")
|
||||||
|
|
||||||
# Test 2: Try a regular endpoint that uses get_current_user
|
# Test 2: Try a regular endpoint that uses get_current_user
|
||||||
response2 = client.get("/api/v1/product") # or any endpoint with get_current_user
|
response2 = client.get(
|
||||||
|
"/api/v1/product"
|
||||||
|
) # or any endpoint with get_current_user
|
||||||
print(f"Regular endpoint - Status: {response2.status_code}")
|
print(f"Regular endpoint - Status: {response2.status_code}")
|
||||||
try:
|
try:
|
||||||
print(f"Regular endpoint - Response: {response2.json()}")
|
print(f"Regular endpoint - Response: {response2.json()}")
|
||||||
@@ -35,7 +40,7 @@ class TestSecurity:
|
|||||||
"""Debug test to see all available routes"""
|
"""Debug test to see all available routes"""
|
||||||
print("\n=== All Available Routes ===")
|
print("\n=== All Available Routes ===")
|
||||||
for route in client.app.routes:
|
for route in client.app.routes:
|
||||||
if hasattr(route, 'path') and hasattr(route, 'methods'):
|
if hasattr(route, "path") and hasattr(route, "methods"):
|
||||||
print(f"{list(route.methods)} {route.path}")
|
print(f"{list(route.methods)} {route.path}")
|
||||||
|
|
||||||
print("\n=== Testing Product Endpoint Variations ===")
|
print("\n=== Testing Product Endpoint Variations ===")
|
||||||
@@ -59,7 +64,7 @@ class TestSecurity:
|
|||||||
"/api/v1/product",
|
"/api/v1/product",
|
||||||
"/api/v1/shop",
|
"/api/v1/shop",
|
||||||
"/api/v1/stats",
|
"/api/v1/stats",
|
||||||
"/api/v1/stock"
|
"/api/v1/stock",
|
||||||
]
|
]
|
||||||
|
|
||||||
for endpoint in protected_endpoints:
|
for endpoint in protected_endpoints:
|
||||||
@@ -76,7 +81,9 @@ class TestSecurity:
|
|||||||
def test_admin_endpoint_requires_admin_role(self, client, auth_headers):
|
def test_admin_endpoint_requires_admin_role(self, client, auth_headers):
|
||||||
"""Test that admin endpoints require admin role"""
|
"""Test that admin endpoints require admin role"""
|
||||||
response = client.get("/api/v1/admin/users", headers=auth_headers)
|
response = client.get("/api/v1/admin/users", headers=auth_headers)
|
||||||
assert response.status_code == 403 # Token is valid but user does not have access.
|
assert (
|
||||||
|
response.status_code == 403
|
||||||
|
) # Token is valid but user does not have access.
|
||||||
# Regular user should be denied
|
# Regular user should be denied
|
||||||
|
|
||||||
def test_sql_injection_prevention(self, client, auth_headers):
|
def test_sql_injection_prevention(self, client, auth_headers):
|
||||||
@@ -84,7 +91,9 @@ class TestSecurity:
|
|||||||
# Try SQL injection in search parameter
|
# Try SQL injection in search parameter
|
||||||
malicious_search = "'; DROP TABLE products; --"
|
malicious_search = "'; DROP TABLE products; --"
|
||||||
|
|
||||||
response = client.get(f"/api/v1/product?search={malicious_search}", headers=auth_headers)
|
response = client.get(
|
||||||
|
f"/api/v1/product?search={malicious_search}", headers=auth_headers
|
||||||
|
)
|
||||||
|
|
||||||
# Should not crash and should return normal response
|
# Should not crash and should return normal response
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class TestShopsAPI:
|
|||||||
shop_data = {
|
shop_data = {
|
||||||
"shop_code": "NEWSHOP",
|
"shop_code": "NEWSHOP",
|
||||||
"shop_name": "New Shop",
|
"shop_name": "New Shop",
|
||||||
"description": "A new test shop"
|
"description": "A new test shop",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
|
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
|
||||||
@@ -23,7 +23,7 @@ class TestShopsAPI:
|
|||||||
"""Test creating shop with duplicate code"""
|
"""Test creating shop with duplicate code"""
|
||||||
shop_data = {
|
shop_data = {
|
||||||
"shop_code": test_shop.shop_code, # Same as test_shop
|
"shop_code": test_shop.shop_code, # Same as test_shop
|
||||||
"shop_name": test_shop.shop_name
|
"shop_name": test_shop.shop_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
|
response = client.post("/api/v1/shop", headers=auth_headers, json=shop_data)
|
||||||
@@ -42,7 +42,9 @@ class TestShopsAPI:
|
|||||||
|
|
||||||
def test_get_shop_by_code(self, client, auth_headers, test_shop):
|
def test_get_shop_by_code(self, client, auth_headers, test_shop):
|
||||||
"""Test getting specific shop"""
|
"""Test getting specific shop"""
|
||||||
response = client.get(f"/api/v1/shop/{test_shop.shop_code}", headers=auth_headers)
|
response = client.get(
|
||||||
|
f"/api/v1/shop/{test_shop.shop_code}", headers=auth_headers
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class TestShopService:
|
|||||||
shop_data = ShopCreate(
|
shop_data = ShopCreate(
|
||||||
shop_code="NEWSHOP",
|
shop_code="NEWSHOP",
|
||||||
shop_name="New Test Shop",
|
shop_name="New Test Shop",
|
||||||
description="A new test shop"
|
description="A new test shop",
|
||||||
)
|
)
|
||||||
|
|
||||||
shop = self.service.create_shop(db, shop_data, test_user)
|
shop = self.service.create_shop(db, shop_data, test_user)
|
||||||
@@ -30,10 +30,7 @@ class TestShopService:
|
|||||||
|
|
||||||
def test_create_shop_admin_auto_verify(self, db, test_admin, shop_factory):
|
def test_create_shop_admin_auto_verify(self, db, test_admin, shop_factory):
|
||||||
"""Test admin creates verified shop automatically"""
|
"""Test admin creates verified shop automatically"""
|
||||||
shop_data = ShopCreate(
|
shop_data = ShopCreate(shop_code="ADMINSHOP", shop_name="Admin Test Shop")
|
||||||
shop_code="ADMINSHOP",
|
|
||||||
shop_name="Admin Test Shop"
|
|
||||||
)
|
|
||||||
|
|
||||||
shop = self.service.create_shop(db, shop_data, test_admin)
|
shop = self.service.create_shop(db, shop_data, test_admin)
|
||||||
|
|
||||||
@@ -42,8 +39,7 @@ class TestShopService:
|
|||||||
def test_create_shop_duplicate_code(self, db, test_user, test_shop):
|
def test_create_shop_duplicate_code(self, db, test_user, test_shop):
|
||||||
"""Test shop creation fails with duplicate shop code"""
|
"""Test shop creation fails with duplicate shop code"""
|
||||||
shop_data = ShopCreate(
|
shop_data = ShopCreate(
|
||||||
shop_code=test_shop.shop_code,
|
shop_code=test_shop.shop_code, shop_name=test_shop.shop_name
|
||||||
shop_name=test_shop.shop_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
@@ -60,9 +56,13 @@ class TestShopService:
|
|||||||
assert test_shop.shop_code in shop_codes
|
assert test_shop.shop_code in shop_codes
|
||||||
assert inactive_shop.shop_code not in shop_codes
|
assert inactive_shop.shop_code not in shop_codes
|
||||||
|
|
||||||
def test_get_shops_admin_user(self, db, test_admin, test_shop, inactive_shop, verified_shop):
|
def test_get_shops_admin_user(
|
||||||
|
self, db, test_admin, test_shop, inactive_shop, verified_shop
|
||||||
|
):
|
||||||
"""Test admin user can see all shops with filters"""
|
"""Test admin user can see all shops with filters"""
|
||||||
shops, total = self.service.get_shops(db, test_admin, active_only=False, verified_only=False)
|
shops, total = self.service.get_shops(
|
||||||
|
db, test_admin, active_only=False, verified_only=False
|
||||||
|
)
|
||||||
|
|
||||||
shop_codes = [shop.shop_code for shop in shops]
|
shop_codes = [shop.shop_code for shop in shops]
|
||||||
assert test_shop.shop_code in shop_codes
|
assert test_shop.shop_code in shop_codes
|
||||||
@@ -78,7 +78,9 @@ class TestShopService:
|
|||||||
|
|
||||||
def test_get_shop_by_code_admin_access(self, db, test_admin, test_shop):
|
def test_get_shop_by_code_admin_access(self, db, test_admin, test_shop):
|
||||||
"""Test admin can access any shop"""
|
"""Test admin can access any shop"""
|
||||||
shop = self.service.get_shop_by_code(db, test_shop.shop_code.lower(), test_admin)
|
shop = self.service.get_shop_by_code(
|
||||||
|
db, test_shop.shop_code.lower(), test_admin
|
||||||
|
)
|
||||||
|
|
||||||
assert shop is not None
|
assert shop is not None
|
||||||
assert shop.id == test_shop.id
|
assert shop.id == test_shop.id
|
||||||
@@ -103,10 +105,12 @@ class TestShopService:
|
|||||||
product_id=unique_product.product_id,
|
product_id=unique_product.product_id,
|
||||||
price="15.99",
|
price="15.99",
|
||||||
is_featured=True,
|
is_featured=True,
|
||||||
stock_quantity=5
|
stock_quantity=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
shop_product = self.service.add_product_to_shop(db, test_shop, shop_product_data)
|
shop_product = self.service.add_product_to_shop(
|
||||||
|
db, test_shop, shop_product_data
|
||||||
|
)
|
||||||
|
|
||||||
assert shop_product is not None
|
assert shop_product is not None
|
||||||
assert shop_product.shop_id == test_shop.id
|
assert shop_product.shop_id == test_shop.id
|
||||||
@@ -114,10 +118,7 @@ class TestShopService:
|
|||||||
|
|
||||||
def test_add_product_to_shop_product_not_found(self, db, test_shop):
|
def test_add_product_to_shop_product_not_found(self, db, test_shop):
|
||||||
"""Test adding non-existent product to shop fails"""
|
"""Test adding non-existent product to shop fails"""
|
||||||
shop_product_data = ShopProductCreate(
|
shop_product_data = ShopProductCreate(product_id="NONEXISTENT", price="15.99")
|
||||||
product_id="NONEXISTENT",
|
|
||||||
price="15.99"
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
self.service.add_product_to_shop(db, test_shop, shop_product_data)
|
self.service.add_product_to_shop(db, test_shop, shop_product_data)
|
||||||
@@ -127,8 +128,7 @@ class TestShopService:
|
|||||||
def test_add_product_to_shop_already_exists(self, db, test_shop, shop_product):
|
def test_add_product_to_shop_already_exists(self, db, test_shop, shop_product):
|
||||||
"""Test adding product that's already in shop fails"""
|
"""Test adding product that's already in shop fails"""
|
||||||
shop_product_data = ShopProductCreate(
|
shop_product_data = ShopProductCreate(
|
||||||
product_id=shop_product.product.product_id,
|
product_id=shop_product.product.product_id, price="15.99"
|
||||||
price="15.99"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
@@ -136,7 +136,9 @@ class TestShopService:
|
|||||||
|
|
||||||
assert exc_info.value.status_code == 400
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
def test_get_shop_products_owner_access(self, db, test_user, test_shop, shop_product):
|
def test_get_shop_products_owner_access(
|
||||||
|
self, db, test_user, test_shop, shop_product
|
||||||
|
):
|
||||||
"""Test shop owner can get shop products"""
|
"""Test shop owner can get shop products"""
|
||||||
products, total = self.service.get_shop_products(db, test_shop, test_user)
|
products, total = self.service.get_shop_products(db, test_shop, test_user)
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TestStatsService:
|
|||||||
marketplace="Amazon",
|
marketplace="Amazon",
|
||||||
shop_name="AmazonShop",
|
shop_name="AmazonShop",
|
||||||
price="15.99",
|
price="15.99",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="PROD003",
|
product_id="PROD003",
|
||||||
@@ -50,7 +50,7 @@ class TestStatsService:
|
|||||||
marketplace="eBay",
|
marketplace="eBay",
|
||||||
shop_name="eBayShop",
|
shop_name="eBayShop",
|
||||||
price="25.99",
|
price="25.99",
|
||||||
currency="USD"
|
currency="USD",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="PROD004",
|
product_id="PROD004",
|
||||||
@@ -60,8 +60,8 @@ class TestStatsService:
|
|||||||
marketplace="Letzshop", # Same as test_product
|
marketplace="Letzshop", # Same as test_product
|
||||||
shop_name="DifferentShop",
|
shop_name="DifferentShop",
|
||||||
price="35.99",
|
price="35.99",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(additional_products)
|
db.add_all(additional_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -86,7 +86,7 @@ class TestStatsService:
|
|||||||
marketplace=None, # Null marketplace
|
marketplace=None, # Null marketplace
|
||||||
shop_name=None, # Null shop
|
shop_name=None, # Null shop
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="EMPTY001",
|
product_id="EMPTY001",
|
||||||
@@ -96,8 +96,8 @@ class TestStatsService:
|
|||||||
marketplace="", # Empty marketplace
|
marketplace="", # Empty marketplace
|
||||||
shop_name="", # Empty shop
|
shop_name="", # Empty shop
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(products_with_nulls)
|
db.add_all(products_with_nulls)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -122,14 +122,16 @@ class TestStatsService:
|
|||||||
# Find our test marketplace in the results
|
# Find our test marketplace in the results
|
||||||
test_marketplace_stat = next(
|
test_marketplace_stat = next(
|
||||||
(stat for stat in stats if stat["marketplace"] == test_product.marketplace),
|
(stat for stat in stats if stat["marketplace"] == test_product.marketplace),
|
||||||
None
|
None,
|
||||||
)
|
)
|
||||||
assert test_marketplace_stat is not None
|
assert test_marketplace_stat is not None
|
||||||
assert test_marketplace_stat["total_products"] >= 1
|
assert test_marketplace_stat["total_products"] >= 1
|
||||||
assert test_marketplace_stat["unique_shops"] >= 1
|
assert test_marketplace_stat["unique_shops"] >= 1
|
||||||
assert test_marketplace_stat["unique_brands"] >= 1
|
assert test_marketplace_stat["unique_brands"] >= 1
|
||||||
|
|
||||||
def test_get_marketplace_breakdown_stats_multiple_marketplaces(self, db, test_product):
|
def test_get_marketplace_breakdown_stats_multiple_marketplaces(
|
||||||
|
self, db, test_product
|
||||||
|
):
|
||||||
"""Test marketplace breakdown with multiple marketplaces"""
|
"""Test marketplace breakdown with multiple marketplaces"""
|
||||||
# Create products for different marketplaces
|
# Create products for different marketplaces
|
||||||
marketplace_products = [
|
marketplace_products = [
|
||||||
@@ -140,7 +142,7 @@ class TestStatsService:
|
|||||||
marketplace="Amazon",
|
marketplace="Amazon",
|
||||||
shop_name="AmazonShop1",
|
shop_name="AmazonShop1",
|
||||||
price="20.00",
|
price="20.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="AMAZON002",
|
product_id="AMAZON002",
|
||||||
@@ -149,7 +151,7 @@ class TestStatsService:
|
|||||||
marketplace="Amazon",
|
marketplace="Amazon",
|
||||||
shop_name="AmazonShop2",
|
shop_name="AmazonShop2",
|
||||||
price="25.00",
|
price="25.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="EBAY001",
|
product_id="EBAY001",
|
||||||
@@ -158,8 +160,8 @@ class TestStatsService:
|
|||||||
marketplace="eBay",
|
marketplace="eBay",
|
||||||
shop_name="eBayShop",
|
shop_name="eBayShop",
|
||||||
price="30.00",
|
price="30.00",
|
||||||
currency="USD"
|
currency="USD",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(marketplace_products)
|
db.add_all(marketplace_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -194,7 +196,7 @@ class TestStatsService:
|
|||||||
shop_name="SomeShop",
|
shop_name="SomeShop",
|
||||||
brand="SomeBrand",
|
brand="SomeBrand",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
)
|
||||||
db.add(null_marketplace_product)
|
db.add(null_marketplace_product)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -202,7 +204,9 @@ class TestStatsService:
|
|||||||
stats = self.service.get_marketplace_breakdown_stats(db)
|
stats = self.service.get_marketplace_breakdown_stats(db)
|
||||||
|
|
||||||
# Should not include any stats for null marketplace
|
# Should not include any stats for null marketplace
|
||||||
marketplace_names = [stat["marketplace"] for stat in stats if stat["marketplace"] is not None]
|
marketplace_names = [
|
||||||
|
stat["marketplace"] for stat in stats if stat["marketplace"] is not None
|
||||||
|
]
|
||||||
assert None not in marketplace_names
|
assert None not in marketplace_names
|
||||||
|
|
||||||
def test_get_product_count(self, db, test_product):
|
def test_get_product_count(self, db, test_product):
|
||||||
@@ -223,7 +227,7 @@ class TestStatsService:
|
|||||||
marketplace="Test",
|
marketplace="Test",
|
||||||
shop_name="TestShop",
|
shop_name="TestShop",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="BRAND002",
|
product_id="BRAND002",
|
||||||
@@ -232,15 +236,17 @@ class TestStatsService:
|
|||||||
marketplace="Test",
|
marketplace="Test",
|
||||||
shop_name="TestShop",
|
shop_name="TestShop",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(brand_products)
|
db.add_all(brand_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
count = self.service.get_unique_brands_count(db)
|
count = self.service.get_unique_brands_count(db)
|
||||||
|
|
||||||
assert count >= 2 # At least BrandA and BrandB, plus possibly test_product brand
|
assert (
|
||||||
|
count >= 2
|
||||||
|
) # At least BrandA and BrandB, plus possibly test_product brand
|
||||||
assert isinstance(count, int)
|
assert isinstance(count, int)
|
||||||
|
|
||||||
def test_get_unique_categories_count(self, db, test_product):
|
def test_get_unique_categories_count(self, db, test_product):
|
||||||
@@ -254,7 +260,7 @@ class TestStatsService:
|
|||||||
marketplace="Test",
|
marketplace="Test",
|
||||||
shop_name="TestShop",
|
shop_name="TestShop",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="CAT002",
|
product_id="CAT002",
|
||||||
@@ -263,8 +269,8 @@ class TestStatsService:
|
|||||||
marketplace="Test",
|
marketplace="Test",
|
||||||
shop_name="TestShop",
|
shop_name="TestShop",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(category_products)
|
db.add_all(category_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -284,7 +290,7 @@ class TestStatsService:
|
|||||||
marketplace="Amazon",
|
marketplace="Amazon",
|
||||||
shop_name="AmazonShop",
|
shop_name="AmazonShop",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="MARKET002",
|
product_id="MARKET002",
|
||||||
@@ -292,8 +298,8 @@ class TestStatsService:
|
|||||||
marketplace="eBay",
|
marketplace="eBay",
|
||||||
shop_name="eBayShop",
|
shop_name="eBayShop",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(marketplace_products)
|
db.add_all(marketplace_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -313,7 +319,7 @@ class TestStatsService:
|
|||||||
marketplace="Test",
|
marketplace="Test",
|
||||||
shop_name="ShopA",
|
shop_name="ShopA",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="SHOP002",
|
product_id="SHOP002",
|
||||||
@@ -321,8 +327,8 @@ class TestStatsService:
|
|||||||
marketplace="Test",
|
marketplace="Test",
|
||||||
shop_name="ShopB",
|
shop_name="ShopB",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(shop_products)
|
db.add_all(shop_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -341,15 +347,15 @@ class TestStatsService:
|
|||||||
location="LOCATION2",
|
location="LOCATION2",
|
||||||
quantity=25,
|
quantity=25,
|
||||||
reserved_quantity=5,
|
reserved_quantity=5,
|
||||||
shop_id=test_stock.shop_id
|
shop_id=test_stock.shop_id,
|
||||||
),
|
),
|
||||||
Stock(
|
Stock(
|
||||||
gtin="1234567890125",
|
gtin="1234567890125",
|
||||||
location="LOCATION3",
|
location="LOCATION3",
|
||||||
quantity=0, # Out of stock
|
quantity=0, # Out of stock
|
||||||
reserved_quantity=0,
|
reserved_quantity=0,
|
||||||
shop_id=test_stock.shop_id
|
shop_id=test_stock.shop_id,
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(additional_stocks)
|
db.add_all(additional_stocks)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -372,7 +378,7 @@ class TestStatsService:
|
|||||||
marketplace="SpecificMarket",
|
marketplace="SpecificMarket",
|
||||||
shop_name="SpecificShop1",
|
shop_name="SpecificShop1",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="SPECIFIC002",
|
product_id="SPECIFIC002",
|
||||||
@@ -381,7 +387,7 @@ class TestStatsService:
|
|||||||
marketplace="SpecificMarket",
|
marketplace="SpecificMarket",
|
||||||
shop_name="SpecificShop2",
|
shop_name="SpecificShop2",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="OTHER001",
|
product_id="OTHER001",
|
||||||
@@ -390,8 +396,8 @@ class TestStatsService:
|
|||||||
marketplace="OtherMarket",
|
marketplace="OtherMarket",
|
||||||
shop_name="OtherShop",
|
shop_name="OtherShop",
|
||||||
price="20.00",
|
price="20.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(marketplace_products)
|
db.add_all(marketplace_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -414,7 +420,7 @@ class TestStatsService:
|
|||||||
marketplace="TestMarketplace",
|
marketplace="TestMarketplace",
|
||||||
shop_name="TestShop1",
|
shop_name="TestShop1",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="SHOPTEST002",
|
product_id="SHOPTEST002",
|
||||||
@@ -423,8 +429,8 @@ class TestStatsService:
|
|||||||
marketplace="TestMarketplace",
|
marketplace="TestMarketplace",
|
||||||
shop_name="TestShop2",
|
shop_name="TestShop2",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(marketplace_products)
|
db.add_all(marketplace_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -445,7 +451,7 @@ class TestStatsService:
|
|||||||
marketplace="CountMarketplace",
|
marketplace="CountMarketplace",
|
||||||
shop_name="CountShop",
|
shop_name="CountShop",
|
||||||
price="10.00",
|
price="10.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="COUNT002",
|
product_id="COUNT002",
|
||||||
@@ -453,7 +459,7 @@ class TestStatsService:
|
|||||||
marketplace="CountMarketplace",
|
marketplace="CountMarketplace",
|
||||||
shop_name="CountShop",
|
shop_name="CountShop",
|
||||||
price="15.00",
|
price="15.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
),
|
),
|
||||||
Product(
|
Product(
|
||||||
product_id="COUNT003",
|
product_id="COUNT003",
|
||||||
@@ -461,8 +467,8 @@ class TestStatsService:
|
|||||||
marketplace="CountMarketplace",
|
marketplace="CountMarketplace",
|
||||||
shop_name="CountShop",
|
shop_name="CountShop",
|
||||||
price="20.00",
|
price="20.00",
|
||||||
currency="EUR"
|
currency="EUR",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
db.add_all(marketplace_products)
|
db.add_all(marketplace_products)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_stock.py
|
# tests/test_stock.py
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from models.database_models import Stock
|
from models.database_models import Stock
|
||||||
|
|
||||||
|
|
||||||
@@ -9,7 +10,7 @@ class TestStockAPI:
|
|||||||
stock_data = {
|
stock_data = {
|
||||||
"gtin": "1234567890123",
|
"gtin": "1234567890123",
|
||||||
"location": "WAREHOUSE_A",
|
"location": "WAREHOUSE_A",
|
||||||
"quantity": 100
|
"quantity": 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
|
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
|
||||||
@@ -30,7 +31,7 @@ class TestStockAPI:
|
|||||||
stock_data = {
|
stock_data = {
|
||||||
"gtin": "1234567890123",
|
"gtin": "1234567890123",
|
||||||
"location": "WAREHOUSE_A",
|
"location": "WAREHOUSE_A",
|
||||||
"quantity": 75
|
"quantity": 75,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
|
response = client.post("/api/v1/stock", headers=auth_headers, json=stock_data)
|
||||||
@@ -49,10 +50,12 @@ class TestStockAPI:
|
|||||||
stock_data = {
|
stock_data = {
|
||||||
"gtin": "1234567890123",
|
"gtin": "1234567890123",
|
||||||
"location": "WAREHOUSE_A",
|
"location": "WAREHOUSE_A",
|
||||||
"quantity": 25
|
"quantity": 25,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/stock/add", headers=auth_headers, json=stock_data)
|
response = client.post(
|
||||||
|
"/api/v1/stock/add", headers=auth_headers, json=stock_data
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -68,10 +71,12 @@ class TestStockAPI:
|
|||||||
stock_data = {
|
stock_data = {
|
||||||
"gtin": "1234567890123",
|
"gtin": "1234567890123",
|
||||||
"location": "WAREHOUSE_A",
|
"location": "WAREHOUSE_A",
|
||||||
"quantity": 15
|
"quantity": 15,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/stock/remove", headers=auth_headers, json=stock_data)
|
response = client.post(
|
||||||
|
"/api/v1/stock/remove", headers=auth_headers, json=stock_data
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -87,10 +92,12 @@ class TestStockAPI:
|
|||||||
stock_data = {
|
stock_data = {
|
||||||
"gtin": "1234567890123",
|
"gtin": "1234567890123",
|
||||||
"location": "WAREHOUSE_A",
|
"location": "WAREHOUSE_A",
|
||||||
"quantity": 20
|
"quantity": 20,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/stock/remove", headers=auth_headers, json=stock_data)
|
response = client.post(
|
||||||
|
"/api/v1/stock/remove", headers=auth_headers, json=stock_data
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "Insufficient stock" in response.json()["detail"]
|
assert "Insufficient stock" in response.json()["detail"]
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# tests/test_stock_service.py
|
# tests/test_stock_service.py
|
||||||
import pytest
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from app.services.stock_service import StockService
|
from app.services.stock_service import StockService
|
||||||
from models.api_models import StockCreate, StockAdd, StockUpdate
|
from models.api_models import StockAdd, StockCreate, StockUpdate
|
||||||
from models.database_models import Stock, Product
|
from models.database_models import Product, Stock
|
||||||
|
|
||||||
|
|
||||||
class TestStockService:
|
class TestStockService:
|
||||||
@@ -39,7 +41,9 @@ class TestStockService:
|
|||||||
assert self.service.normalize_gtin("1234567890123") == "1234567890123" # EAN-13
|
assert self.service.normalize_gtin("1234567890123") == "1234567890123" # EAN-13
|
||||||
assert self.service.normalize_gtin("123456789012") == "123456789012" # UPC-A
|
assert self.service.normalize_gtin("123456789012") == "123456789012" # UPC-A
|
||||||
assert self.service.normalize_gtin("12345678") == "12345678" # EAN-8
|
assert self.service.normalize_gtin("12345678") == "12345678" # EAN-8
|
||||||
assert self.service.normalize_gtin("12345678901234") == "12345678901234" # GTIN-14
|
assert (
|
||||||
|
self.service.normalize_gtin("12345678901234") == "12345678901234"
|
||||||
|
) # GTIN-14
|
||||||
|
|
||||||
# Test with decimal points (should be removed)
|
# Test with decimal points (should be removed)
|
||||||
assert self.service.normalize_gtin("1234567890123.0") == "1234567890123"
|
assert self.service.normalize_gtin("1234567890123.0") == "1234567890123"
|
||||||
@@ -49,10 +53,14 @@ class TestStockService:
|
|||||||
|
|
||||||
# Test short GTINs being padded
|
# Test short GTINs being padded
|
||||||
assert self.service.normalize_gtin("123") == "0000000000123" # Padded to EAN-13
|
assert self.service.normalize_gtin("123") == "0000000000123" # Padded to EAN-13
|
||||||
assert self.service.normalize_gtin("12345") == "0000000012345" # Padded to EAN-13
|
assert (
|
||||||
|
self.service.normalize_gtin("12345") == "0000000012345"
|
||||||
|
) # Padded to EAN-13
|
||||||
|
|
||||||
# Test long GTINs being truncated
|
# Test long GTINs being truncated
|
||||||
assert self.service.normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13
|
assert (
|
||||||
|
self.service.normalize_gtin("123456789012345") == "3456789012345"
|
||||||
|
) # Truncated to 13
|
||||||
|
|
||||||
def test_normalize_gtin_edge_cases(self):
|
def test_normalize_gtin_edge_cases(self):
|
||||||
"""Test GTIN normalization edge cases"""
|
"""Test GTIN normalization edge cases"""
|
||||||
@@ -61,17 +69,21 @@ class TestStockService:
|
|||||||
assert self.service.normalize_gtin(123) == "0000000000123"
|
assert self.service.normalize_gtin(123) == "0000000000123"
|
||||||
|
|
||||||
# Test mixed valid/invalid characters
|
# Test mixed valid/invalid characters
|
||||||
assert self.service.normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed
|
assert (
|
||||||
assert self.service.normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed
|
self.service.normalize_gtin("123-456-789-012") == "123456789012"
|
||||||
assert self.service.normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed
|
) # Dashes removed
|
||||||
|
assert (
|
||||||
|
self.service.normalize_gtin("123 456 789 012") == "123456789012"
|
||||||
|
) # Spaces removed
|
||||||
|
assert (
|
||||||
|
self.service.normalize_gtin("ABC123456789012DEF") == "123456789012"
|
||||||
|
) # Letters removed
|
||||||
|
|
||||||
def test_set_stock_new_entry(self, db):
|
def test_set_stock_new_entry(self, db):
|
||||||
"""Test setting stock for a new GTIN/location combination"""
|
"""Test setting stock for a new GTIN/location combination"""
|
||||||
unique_id = str(uuid.uuid4())[:8]
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
stock_data = StockCreate(
|
stock_data = StockCreate(
|
||||||
gtin="1234567890123",
|
gtin="1234567890123", location=f"WAREHOUSE_A_{unique_id}", quantity=100
|
||||||
location=f"WAREHOUSE_A_{unique_id}",
|
|
||||||
quantity=100
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.set_stock(db, stock_data)
|
result = self.service.set_stock(db, stock_data)
|
||||||
@@ -85,7 +97,7 @@ class TestStockService:
|
|||||||
stock_data = StockCreate(
|
stock_data = StockCreate(
|
||||||
gtin=test_stock.gtin,
|
gtin=test_stock.gtin,
|
||||||
location=test_stock.location, # Use exact same location as test_stock
|
location=test_stock.location, # Use exact same location as test_stock
|
||||||
quantity=200
|
quantity=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.set_stock(db, stock_data)
|
result = self.service.set_stock(db, stock_data)
|
||||||
@@ -98,9 +110,7 @@ class TestStockService:
|
|||||||
def test_set_stock_invalid_gtin(self, db):
|
def test_set_stock_invalid_gtin(self, db):
|
||||||
"""Test setting stock with invalid GTIN"""
|
"""Test setting stock with invalid GTIN"""
|
||||||
stock_data = StockCreate(
|
stock_data = StockCreate(
|
||||||
gtin="invalid_gtin",
|
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=100
|
||||||
location="WAREHOUSE_A",
|
|
||||||
quantity=100
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
||||||
@@ -110,9 +120,7 @@ class TestStockService:
|
|||||||
"""Test adding stock for a new GTIN/location combination"""
|
"""Test adding stock for a new GTIN/location combination"""
|
||||||
unique_id = str(uuid.uuid4())[:8]
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
stock_data = StockAdd(
|
stock_data = StockAdd(
|
||||||
gtin="1234567890123",
|
gtin="1234567890123", location=f"WAREHOUSE_B_{unique_id}", quantity=50
|
||||||
location=f"WAREHOUSE_B_{unique_id}",
|
|
||||||
quantity=50
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.add_stock(db, stock_data)
|
result = self.service.add_stock(db, stock_data)
|
||||||
@@ -127,7 +135,7 @@ class TestStockService:
|
|||||||
stock_data = StockAdd(
|
stock_data = StockAdd(
|
||||||
gtin=test_stock.gtin,
|
gtin=test_stock.gtin,
|
||||||
location=test_stock.location, # Use exact same location as test_stock
|
location=test_stock.location, # Use exact same location as test_stock
|
||||||
quantity=25
|
quantity=25,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.add_stock(db, stock_data)
|
result = self.service.add_stock(db, stock_data)
|
||||||
@@ -138,11 +146,7 @@ class TestStockService:
|
|||||||
|
|
||||||
def test_add_stock_invalid_gtin(self, db):
|
def test_add_stock_invalid_gtin(self, db):
|
||||||
"""Test adding stock with invalid GTIN"""
|
"""Test adding stock with invalid GTIN"""
|
||||||
stock_data = StockAdd(
|
stock_data = StockAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
|
||||||
gtin="invalid_gtin",
|
|
||||||
location="WAREHOUSE_A",
|
|
||||||
quantity=50
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
||||||
self.service.add_stock(db, stock_data)
|
self.service.add_stock(db, stock_data)
|
||||||
@@ -150,12 +154,14 @@ class TestStockService:
|
|||||||
def test_remove_stock_success(self, db, test_stock):
|
def test_remove_stock_success(self, db, test_stock):
|
||||||
"""Test removing stock successfully"""
|
"""Test removing stock successfully"""
|
||||||
original_quantity = test_stock.quantity
|
original_quantity = test_stock.quantity
|
||||||
remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available
|
remove_quantity = min(
|
||||||
|
10, original_quantity
|
||||||
|
) # Ensure we don't remove more than available
|
||||||
|
|
||||||
stock_data = StockAdd(
|
stock_data = StockAdd(
|
||||||
gtin=test_stock.gtin,
|
gtin=test_stock.gtin,
|
||||||
location=test_stock.location, # Use exact same location as test_stock
|
location=test_stock.location, # Use exact same location as test_stock
|
||||||
quantity=remove_quantity
|
quantity=remove_quantity,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.service.remove_stock(db, stock_data)
|
result = self.service.remove_stock(db, stock_data)
|
||||||
@@ -169,20 +175,20 @@ class TestStockService:
|
|||||||
stock_data = StockAdd(
|
stock_data = StockAdd(
|
||||||
gtin=test_stock.gtin,
|
gtin=test_stock.gtin,
|
||||||
location=test_stock.location, # Use exact same location as test_stock
|
location=test_stock.location, # Use exact same location as test_stock
|
||||||
quantity=test_stock.quantity + 10 # More than available
|
quantity=test_stock.quantity + 10, # More than available
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fix: Use more flexible regex pattern
|
# Fix: Use more flexible regex pattern
|
||||||
with pytest.raises(ValueError, match="Insufficient stock|Not enough stock|Cannot remove"):
|
with pytest.raises(
|
||||||
|
ValueError, match="Insufficient stock|Not enough stock|Cannot remove"
|
||||||
|
):
|
||||||
self.service.remove_stock(db, stock_data)
|
self.service.remove_stock(db, stock_data)
|
||||||
|
|
||||||
def test_remove_stock_nonexistent_entry(self, db):
|
def test_remove_stock_nonexistent_entry(self, db):
|
||||||
"""Test removing stock from non-existent GTIN/location"""
|
"""Test removing stock from non-existent GTIN/location"""
|
||||||
unique_id = str(uuid.uuid4())[:8]
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
stock_data = StockAdd(
|
stock_data = StockAdd(
|
||||||
gtin="9999999999999",
|
gtin="9999999999999", location=f"NONEXISTENT_{unique_id}", quantity=10
|
||||||
location=f"NONEXISTENT_{unique_id}",
|
|
||||||
quantity=10
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No stock found|Stock not found"):
|
with pytest.raises(ValueError, match="No stock found|Stock not found"):
|
||||||
@@ -190,11 +196,7 @@ class TestStockService:
|
|||||||
|
|
||||||
def test_remove_stock_invalid_gtin(self, db):
|
def test_remove_stock_invalid_gtin(self, db):
|
||||||
"""Test removing stock with invalid GTIN"""
|
"""Test removing stock with invalid GTIN"""
|
||||||
stock_data = StockAdd(
|
stock_data = StockAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
|
||||||
gtin="invalid_gtin",
|
|
||||||
location="WAREHOUSE_A",
|
|
||||||
quantity=10
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
with pytest.raises(ValueError, match="Invalid GTIN format"):
|
||||||
self.service.remove_stock(db, stock_data)
|
self.service.remove_stock(db, stock_data)
|
||||||
@@ -218,14 +220,10 @@ class TestStockService:
|
|||||||
|
|
||||||
# Create multiple stock entries for the same GTIN with unique locations
|
# Create multiple stock entries for the same GTIN with unique locations
|
||||||
stock1 = Stock(
|
stock1 = Stock(
|
||||||
gtin=unique_gtin,
|
gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
|
||||||
location=f"WAREHOUSE_A_{unique_id}",
|
|
||||||
quantity=50
|
|
||||||
)
|
)
|
||||||
stock2 = Stock(
|
stock2 = Stock(
|
||||||
gtin=unique_gtin,
|
gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
|
||||||
location=f"WAREHOUSE_B_{unique_id}",
|
|
||||||
quantity=30
|
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(stock1)
|
db.add(stock1)
|
||||||
@@ -275,7 +273,9 @@ class TestStockService:
|
|||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
# Fix: Handle case sensitivity in comparison
|
# Fix: Handle case sensitivity in comparison
|
||||||
assert all(stock.location.upper() == test_stock.location.upper() for stock in result)
|
assert all(
|
||||||
|
stock.location.upper() == test_stock.location.upper() for stock in result
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_all_stock_with_gtin_filter(self, db, test_stock):
|
def test_get_all_stock_with_gtin_filter(self, db, test_stock):
|
||||||
"""Test getting all stock with GTIN filter"""
|
"""Test getting all stock with GTIN filter"""
|
||||||
@@ -293,14 +293,16 @@ class TestStockService:
|
|||||||
stock = Stock(
|
stock = Stock(
|
||||||
gtin=f"1234567890{i:03d}", # Creates valid 13-digit GTINs: 1234567890000, 1234567890001, etc.
|
gtin=f"1234567890{i:03d}", # Creates valid 13-digit GTINs: 1234567890000, 1234567890001, etc.
|
||||||
location=f"WAREHOUSE_{unique_prefix}_{i}",
|
location=f"WAREHOUSE_{unique_prefix}_{i}",
|
||||||
quantity=10
|
quantity=10,
|
||||||
)
|
)
|
||||||
db.add(stock)
|
db.add(stock)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
result = self.service.get_all_stock(db, skip=2, limit=2)
|
result = self.service.get_all_stock(db, skip=2, limit=2)
|
||||||
|
|
||||||
assert len(result) <= 2 # Should be at most 2, might be less if other records exist
|
assert (
|
||||||
|
len(result) <= 2
|
||||||
|
) # Should be at most 2, might be less if other records exist
|
||||||
|
|
||||||
def test_update_stock_success(self, db, test_stock):
|
def test_update_stock_success(self, db, test_stock):
|
||||||
"""Test updating stock quantity"""
|
"""Test updating stock quantity"""
|
||||||
@@ -359,7 +361,7 @@ def test_product_with_stock(db, test_stock):
|
|||||||
gtin=test_stock.gtin,
|
gtin=test_stock.gtin,
|
||||||
price="29.99",
|
price="29.99",
|
||||||
brand="TestBrand",
|
brand="TestBrand",
|
||||||
marketplace="Letzshop"
|
marketplace="Letzshop",
|
||||||
)
|
)
|
||||||
db.add(product)
|
db.add(product)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_utils.py (Enhanced version of your existing file)
|
# tests/test_utils.py (Enhanced version of your existing file)
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from utils.data_processing import GTINProcessor, PriceProcessor
|
from utils.data_processing import GTINProcessor, PriceProcessor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
# utils/csv_processor.py
|
# utils/csv_processor.py
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from io import StringIO
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
from io import StringIO
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
from sqlalchemy import literal
|
from sqlalchemy import literal
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from models.database_models import Product
|
from models.database_models import Product
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -16,67 +17,66 @@ logger = logging.getLogger(__name__)
|
|||||||
class CSVProcessor:
|
class CSVProcessor:
|
||||||
"""Handles CSV import with robust parsing and batching"""
|
"""Handles CSV import with robust parsing and batching"""
|
||||||
|
|
||||||
ENCODINGS = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'utf-8-sig']
|
ENCODINGS = ["utf-8", "latin-1", "iso-8859-1", "cp1252", "utf-8-sig"]
|
||||||
|
|
||||||
PARSING_CONFIGS = [
|
PARSING_CONFIGS = [
|
||||||
# Try auto-detection first
|
# Try auto-detection first
|
||||||
{'sep': None, 'engine': 'python'},
|
{"sep": None, "engine": "python"},
|
||||||
# Try semicolon (common in European CSVs)
|
# Try semicolon (common in European CSVs)
|
||||||
{'sep': ';', 'engine': 'python'},
|
{"sep": ";", "engine": "python"},
|
||||||
# Try comma
|
# Try comma
|
||||||
{'sep': ',', 'engine': 'python'},
|
{"sep": ",", "engine": "python"},
|
||||||
# Try tab
|
# Try tab
|
||||||
{'sep': '\t', 'engine': 'python'},
|
{"sep": "\t", "engine": "python"},
|
||||||
]
|
]
|
||||||
|
|
||||||
COLUMN_MAPPING = {
|
COLUMN_MAPPING = {
|
||||||
# Standard variations
|
# Standard variations
|
||||||
'id': 'product_id',
|
"id": "product_id",
|
||||||
'ID': 'product_id',
|
"ID": "product_id",
|
||||||
'Product ID': 'product_id',
|
"Product ID": "product_id",
|
||||||
'name': 'title',
|
"name": "title",
|
||||||
'Name': 'title',
|
"Name": "title",
|
||||||
'product_name': 'title',
|
"product_name": "title",
|
||||||
'Product Name': 'title',
|
"Product Name": "title",
|
||||||
|
|
||||||
# Google Shopping feed standard
|
# Google Shopping feed standard
|
||||||
'g:id': 'product_id',
|
"g:id": "product_id",
|
||||||
'g:title': 'title',
|
"g:title": "title",
|
||||||
'g:description': 'description',
|
"g:description": "description",
|
||||||
'g:link': 'link',
|
"g:link": "link",
|
||||||
'g:image_link': 'image_link',
|
"g:image_link": "image_link",
|
||||||
'g:availability': 'availability',
|
"g:availability": "availability",
|
||||||
'g:price': 'price',
|
"g:price": "price",
|
||||||
'g:brand': 'brand',
|
"g:brand": "brand",
|
||||||
'g:gtin': 'gtin',
|
"g:gtin": "gtin",
|
||||||
'g:mpn': 'mpn',
|
"g:mpn": "mpn",
|
||||||
'g:condition': 'condition',
|
"g:condition": "condition",
|
||||||
'g:adult': 'adult',
|
"g:adult": "adult",
|
||||||
'g:multipack': 'multipack',
|
"g:multipack": "multipack",
|
||||||
'g:is_bundle': 'is_bundle',
|
"g:is_bundle": "is_bundle",
|
||||||
'g:age_group': 'age_group',
|
"g:age_group": "age_group",
|
||||||
'g:color': 'color',
|
"g:color": "color",
|
||||||
'g:gender': 'gender',
|
"g:gender": "gender",
|
||||||
'g:material': 'material',
|
"g:material": "material",
|
||||||
'g:pattern': 'pattern',
|
"g:pattern": "pattern",
|
||||||
'g:size': 'size',
|
"g:size": "size",
|
||||||
'g:size_type': 'size_type',
|
"g:size_type": "size_type",
|
||||||
'g:size_system': 'size_system',
|
"g:size_system": "size_system",
|
||||||
'g:item_group_id': 'item_group_id',
|
"g:item_group_id": "item_group_id",
|
||||||
'g:google_product_category': 'google_product_category',
|
"g:google_product_category": "google_product_category",
|
||||||
'g:product_type': 'product_type',
|
"g:product_type": "product_type",
|
||||||
'g:custom_label_0': 'custom_label_0',
|
"g:custom_label_0": "custom_label_0",
|
||||||
'g:custom_label_1': 'custom_label_1',
|
"g:custom_label_1": "custom_label_1",
|
||||||
'g:custom_label_2': 'custom_label_2',
|
"g:custom_label_2": "custom_label_2",
|
||||||
'g:custom_label_3': 'custom_label_3',
|
"g:custom_label_3": "custom_label_3",
|
||||||
'g:custom_label_4': 'custom_label_4',
|
"g:custom_label_4": "custom_label_4",
|
||||||
|
|
||||||
# Handle complex shipping column
|
# Handle complex shipping column
|
||||||
'shipping(country:price:max_handling_time:min_transit_time:max_transit_time)': 'shipping'
|
"shipping(country:price:max_handling_time:min_transit_time:max_transit_time)": "shipping",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from utils.data_processing import GTINProcessor, PriceProcessor
|
from utils.data_processing import GTINProcessor, PriceProcessor
|
||||||
|
|
||||||
self.gtin_processor = GTINProcessor()
|
self.gtin_processor = GTINProcessor()
|
||||||
self.price_processor = PriceProcessor()
|
self.price_processor = PriceProcessor()
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ class CSVProcessor:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Fallback with error ignoring
|
# Fallback with error ignoring
|
||||||
decoded_content = content.decode('utf-8', errors='ignore')
|
decoded_content = content.decode("utf-8", errors="ignore")
|
||||||
logger.warning("Used UTF-8 with error ignoring for CSV decoding")
|
logger.warning("Used UTF-8 with error ignoring for CSV decoding")
|
||||||
return decoded_content
|
return decoded_content
|
||||||
|
|
||||||
@@ -113,11 +113,11 @@ class CSVProcessor:
|
|||||||
try:
|
try:
|
||||||
df = pd.read_csv(
|
df = pd.read_csv(
|
||||||
StringIO(csv_content),
|
StringIO(csv_content),
|
||||||
on_bad_lines='skip',
|
on_bad_lines="skip",
|
||||||
quotechar='"',
|
quotechar='"',
|
||||||
skip_blank_lines=True,
|
skip_blank_lines=True,
|
||||||
skipinitialspace=True,
|
skipinitialspace=True,
|
||||||
**config
|
**config,
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully parsed CSV with config: {config}")
|
logger.info(f"Successfully parsed CSV with config: {config}")
|
||||||
return df
|
return df
|
||||||
@@ -143,42 +143,43 @@ class CSVProcessor:
|
|||||||
processed_data = {k: (v if pd.notna(v) else None) for k, v in row_data.items()}
|
processed_data = {k: (v if pd.notna(v) else None) for k, v in row_data.items()}
|
||||||
|
|
||||||
# Process GTIN
|
# Process GTIN
|
||||||
if processed_data.get('gtin'):
|
if processed_data.get("gtin"):
|
||||||
processed_data['gtin'] = self.gtin_processor.normalize(processed_data['gtin'])
|
processed_data["gtin"] = self.gtin_processor.normalize(
|
||||||
|
processed_data["gtin"]
|
||||||
|
)
|
||||||
|
|
||||||
# Process price and currency
|
# Process price and currency
|
||||||
if processed_data.get('price'):
|
if processed_data.get("price"):
|
||||||
parsed_price, currency = self.price_processor.parse_price_currency(processed_data['price'])
|
parsed_price, currency = self.price_processor.parse_price_currency(
|
||||||
processed_data['price'] = parsed_price
|
processed_data["price"]
|
||||||
processed_data['currency'] = currency
|
)
|
||||||
|
processed_data["price"] = parsed_price
|
||||||
|
processed_data["currency"] = currency
|
||||||
|
|
||||||
# Process sale_price
|
# Process sale_price
|
||||||
if processed_data.get('sale_price'):
|
if processed_data.get("sale_price"):
|
||||||
parsed_sale_price, _ = self.price_processor.parse_price_currency(processed_data['sale_price'])
|
parsed_sale_price, _ = self.price_processor.parse_price_currency(
|
||||||
processed_data['sale_price'] = parsed_sale_price
|
processed_data["sale_price"]
|
||||||
|
)
|
||||||
|
processed_data["sale_price"] = parsed_sale_price
|
||||||
|
|
||||||
# Clean MPN (remove .0 endings)
|
# Clean MPN (remove .0 endings)
|
||||||
if processed_data.get('mpn'):
|
if processed_data.get("mpn"):
|
||||||
mpn_str = str(processed_data['mpn']).strip()
|
mpn_str = str(processed_data["mpn"]).strip()
|
||||||
if mpn_str.endswith('.0'):
|
if mpn_str.endswith(".0"):
|
||||||
processed_data['mpn'] = mpn_str[:-2]
|
processed_data["mpn"] = mpn_str[:-2]
|
||||||
|
|
||||||
# Handle multipack type conversion
|
# Handle multipack type conversion
|
||||||
if processed_data.get('multipack') is not None:
|
if processed_data.get("multipack") is not None:
|
||||||
try:
|
try:
|
||||||
processed_data['multipack'] = int(float(processed_data['multipack']))
|
processed_data["multipack"] = int(float(processed_data["multipack"]))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
processed_data['multipack'] = None
|
processed_data["multipack"] = None
|
||||||
|
|
||||||
return processed_data
|
return processed_data
|
||||||
|
|
||||||
async def process_marketplace_csv_from_url(
|
async def process_marketplace_csv_from_url(
|
||||||
self,
|
self, url: str, marketplace: str, shop_name: str, batch_size: int, db: Session
|
||||||
url: str,
|
|
||||||
marketplace: str,
|
|
||||||
shop_name: str,
|
|
||||||
batch_size: int,
|
|
||||||
db: Session
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Process CSV from URL with marketplace and shop information
|
Process CSV from URL with marketplace and shop information
|
||||||
@@ -194,7 +195,9 @@ class CSVProcessor:
|
|||||||
Dictionary with processing results
|
Dictionary with processing results
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info(f"Starting marketplace CSV import from {url} for {marketplace} -> {shop_name}")
|
logger.info(
|
||||||
|
f"Starting marketplace CSV import from {url} for {marketplace} -> {shop_name}"
|
||||||
|
)
|
||||||
# Download and parse CSV
|
# Download and parse CSV
|
||||||
csv_content = self.download_csv(url)
|
csv_content = self.download_csv(url)
|
||||||
df = self.parse_csv(csv_content)
|
df = self.parse_csv(csv_content)
|
||||||
@@ -208,40 +211,42 @@ class CSVProcessor:
|
|||||||
|
|
||||||
# Process in batches
|
# Process in batches
|
||||||
for i in range(0, len(df), batch_size):
|
for i in range(0, len(df), batch_size):
|
||||||
batch_df = df.iloc[i:i + batch_size]
|
batch_df = df.iloc[i : i + batch_size]
|
||||||
batch_result = await self._process_marketplace_batch(
|
batch_result = await self._process_marketplace_batch(
|
||||||
batch_df, marketplace, shop_name, db, i // batch_size + 1
|
batch_df, marketplace, shop_name, db, i // batch_size + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
imported += batch_result['imported']
|
imported += batch_result["imported"]
|
||||||
updated += batch_result['updated']
|
updated += batch_result["updated"]
|
||||||
errors += batch_result['errors']
|
errors += batch_result["errors"]
|
||||||
|
|
||||||
logger.info(f"Processed batch {i // batch_size + 1}: {batch_result}")
|
logger.info(f"Processed batch {i // batch_size + 1}: {batch_result}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'total_processed': imported + updated + errors,
|
"total_processed": imported + updated + errors,
|
||||||
'imported': imported,
|
"imported": imported,
|
||||||
'updated': updated,
|
"updated": updated,
|
||||||
'errors': errors,
|
"errors": errors,
|
||||||
'marketplace': marketplace,
|
"marketplace": marketplace,
|
||||||
'shop_name': shop_name
|
"shop_name": shop_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _process_marketplace_batch(
|
async def _process_marketplace_batch(
|
||||||
self,
|
self,
|
||||||
batch_df: pd.DataFrame,
|
batch_df: pd.DataFrame,
|
||||||
marketplace: str,
|
marketplace: str,
|
||||||
shop_name: str,
|
shop_name: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
batch_num: int
|
batch_num: int,
|
||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
"""Process a batch of CSV rows with marketplace information"""
|
"""Process a batch of CSV rows with marketplace information"""
|
||||||
imported = 0
|
imported = 0
|
||||||
updated = 0
|
updated = 0
|
||||||
errors = 0
|
errors = 0
|
||||||
|
|
||||||
logger.info(f"Processing batch {batch_num} with {len(batch_df)} rows for {marketplace} -> {shop_name}")
|
logger.info(
|
||||||
|
f"Processing batch {batch_num} with {len(batch_df)} rows for {marketplace} -> {shop_name}"
|
||||||
|
)
|
||||||
|
|
||||||
for index, row in batch_df.iterrows():
|
for index, row in batch_df.iterrows():
|
||||||
try:
|
try:
|
||||||
@@ -249,42 +254,54 @@ class CSVProcessor:
|
|||||||
product_data = self._clean_row_data(row.to_dict())
|
product_data = self._clean_row_data(row.to_dict())
|
||||||
|
|
||||||
# Add marketplace and shop information
|
# Add marketplace and shop information
|
||||||
product_data['marketplace'] = marketplace
|
product_data["marketplace"] = marketplace
|
||||||
product_data['shop_name'] = shop_name
|
product_data["shop_name"] = shop_name
|
||||||
|
|
||||||
# Validate required fields
|
# Validate required fields
|
||||||
if not product_data.get('product_id'):
|
if not product_data.get("product_id"):
|
||||||
logger.warning(f"Row {index}: Missing product_id, skipping")
|
logger.warning(f"Row {index}: Missing product_id, skipping")
|
||||||
errors += 1
|
errors += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not product_data.get('title'):
|
if not product_data.get("title"):
|
||||||
logger.warning(f"Row {index}: Missing title, skipping")
|
logger.warning(f"Row {index}: Missing title, skipping")
|
||||||
errors += 1
|
errors += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if product exists
|
# Check if product exists
|
||||||
existing_product = db.query(Product).filter(
|
existing_product = (
|
||||||
Product.product_id == literal(product_data['product_id'])
|
db.query(Product)
|
||||||
).first()
|
.filter(Product.product_id == literal(product_data["product_id"]))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if existing_product:
|
if existing_product:
|
||||||
# Update existing product
|
# Update existing product
|
||||||
for key, value in product_data.items():
|
for key, value in product_data.items():
|
||||||
if key not in ['id', 'created_at'] and hasattr(existing_product, key):
|
if key not in ["id", "created_at"] and hasattr(
|
||||||
|
existing_product, key
|
||||||
|
):
|
||||||
setattr(existing_product, key, value)
|
setattr(existing_product, key, value)
|
||||||
existing_product.updated_at = datetime.utcnow()
|
existing_product.updated_at = datetime.utcnow()
|
||||||
updated += 1
|
updated += 1
|
||||||
logger.debug(f"Updated product {product_data['product_id']} for {marketplace} and shop {shop_name}")
|
logger.debug(
|
||||||
|
f"Updated product {product_data['product_id']} for {marketplace} and shop {shop_name}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Create new product
|
# Create new product
|
||||||
filtered_data = {k: v for k, v in product_data.items()
|
filtered_data = {
|
||||||
if k not in ['id', 'created_at', 'updated_at'] and hasattr(Product, k)}
|
k: v
|
||||||
|
for k, v in product_data.items()
|
||||||
|
if k not in ["id", "created_at", "updated_at"]
|
||||||
|
and hasattr(Product, k)
|
||||||
|
}
|
||||||
new_product = Product(**filtered_data)
|
new_product = Product(**filtered_data)
|
||||||
db.add(new_product)
|
db.add(new_product)
|
||||||
imported += 1
|
imported += 1
|
||||||
logger.debug(f"Imported new product {product_data['product_id']} for {marketplace} and shop "
|
logger.debug(
|
||||||
f"{shop_name}")
|
f"Imported new product {product_data['product_id']} for {marketplace} and shop "
|
||||||
|
f"{shop_name}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing row: {e}")
|
logger.error(f"Error processing row: {e}")
|
||||||
@@ -303,8 +320,4 @@ class CSVProcessor:
|
|||||||
imported = 0
|
imported = 0
|
||||||
updated = 0
|
updated = 0
|
||||||
|
|
||||||
return {
|
return {"imported": imported, "updated": updated, "errors": errors}
|
||||||
'imported': imported,
|
|
||||||
'updated': updated,
|
|
||||||
'errors': errors
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
# utils/data_processing.py
|
# utils/data_processing.py
|
||||||
import re
|
|
||||||
import pandas as pd
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,11 +26,11 @@ class GTINProcessor:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Remove decimal point (e.g., "889698116923.0" -> "889698116923")
|
# Remove decimal point (e.g., "889698116923.0" -> "889698116923")
|
||||||
if '.' in gtin_str:
|
if "." in gtin_str:
|
||||||
gtin_str = gtin_str.split('.')[0]
|
gtin_str = gtin_str.split(".")[0]
|
||||||
|
|
||||||
# Keep only digits
|
# Keep only digits
|
||||||
gtin_clean = ''.join(filter(str.isdigit, gtin_str))
|
gtin_clean = "".join(filter(str.isdigit, gtin_str))
|
||||||
|
|
||||||
if not gtin_clean:
|
if not gtin_clean:
|
||||||
return None
|
return None
|
||||||
@@ -73,23 +74,23 @@ class PriceProcessor:
|
|||||||
|
|
||||||
CURRENCY_PATTERNS = {
|
CURRENCY_PATTERNS = {
|
||||||
# Amount followed by currency
|
# Amount followed by currency
|
||||||
r'([0-9.,]+)\s*(EUR|€)': lambda m: (m.group(1), 'EUR'),
|
r"([0-9.,]+)\s*(EUR|€)": lambda m: (m.group(1), "EUR"),
|
||||||
r'([0-9.,]+)\s*(USD|\$)': lambda m: (m.group(1), 'USD'),
|
r"([0-9.,]+)\s*(USD|\$)": lambda m: (m.group(1), "USD"),
|
||||||
r'([0-9.,]+)\s*(GBP|£)': lambda m: (m.group(1), 'GBP'),
|
r"([0-9.,]+)\s*(GBP|£)": lambda m: (m.group(1), "GBP"),
|
||||||
r'([0-9.,]+)\s*(CHF)': lambda m: (m.group(1), 'CHF'),
|
r"([0-9.,]+)\s*(CHF)": lambda m: (m.group(1), "CHF"),
|
||||||
r'([0-9.,]+)\s*(CAD|AUD|JPY|¥)': lambda m: (m.group(1), m.group(2).upper()),
|
r"([0-9.,]+)\s*(CAD|AUD|JPY|¥)": lambda m: (m.group(1), m.group(2).upper()),
|
||||||
|
|
||||||
# Currency followed by amount
|
# Currency followed by amount
|
||||||
r'(EUR|€)\s*([0-9.,]+)': lambda m: (m.group(2), 'EUR'),
|
r"(EUR|€)\s*([0-9.,]+)": lambda m: (m.group(2), "EUR"),
|
||||||
r'(USD|\$)\s*([0-9.,]+)': lambda m: (m.group(2), 'USD'),
|
r"(USD|\$)\s*([0-9.,]+)": lambda m: (m.group(2), "USD"),
|
||||||
r'(GBP|£)\s*([0-9.,]+)': lambda m: (m.group(2), 'GBP'),
|
r"(GBP|£)\s*([0-9.,]+)": lambda m: (m.group(2), "GBP"),
|
||||||
|
|
||||||
# Generic 3-letter currency codes
|
# Generic 3-letter currency codes
|
||||||
r'([0-9.,]+)\s*([A-Z]{3})': lambda m: (m.group(1), m.group(2)),
|
r"([0-9.,]+)\s*([A-Z]{3})": lambda m: (m.group(1), m.group(2)),
|
||||||
r'([A-Z]{3})\s*([0-9.,]+)': lambda m: (m.group(2), m.group(1)),
|
r"([A-Z]{3})\s*([0-9.,]+)": lambda m: (m.group(2), m.group(1)),
|
||||||
}
|
}
|
||||||
|
|
||||||
def parse_price_currency(self, price_str: any) -> Tuple[Optional[str], Optional[str]]:
|
def parse_price_currency(
|
||||||
|
self, price_str: any
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Parse price string into (price, currency) tuple
|
Parse price string into (price, currency) tuple
|
||||||
Returns (None, None) if parsing fails
|
Returns (None, None) if parsing fails
|
||||||
@@ -108,7 +109,7 @@ class PriceProcessor:
|
|||||||
try:
|
try:
|
||||||
price_val, currency_val = extract_func(match)
|
price_val, currency_val = extract_func(match)
|
||||||
# Normalize price (remove spaces, handle comma as decimal)
|
# Normalize price (remove spaces, handle comma as decimal)
|
||||||
price_val = price_val.replace(' ', '').replace(',', '.')
|
price_val = price_val.replace(" ", "").replace(",", ".")
|
||||||
# Validate numeric
|
# Validate numeric
|
||||||
float(price_val)
|
float(price_val)
|
||||||
return price_val, currency_val.upper()
|
return price_val, currency_val.upper()
|
||||||
@@ -116,10 +117,10 @@ class PriceProcessor:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Fallback: extract just numbers
|
# Fallback: extract just numbers
|
||||||
number_match = re.search(r'([0-9.,]+)', price_str)
|
number_match = re.search(r"([0-9.,]+)", price_str)
|
||||||
if number_match:
|
if number_match:
|
||||||
try:
|
try:
|
||||||
price_val = number_match.group(1).replace(',', '.')
|
price_val = number_match.group(1).replace(",", ".")
|
||||||
float(price_val) # Validate
|
float(price_val) # Validate
|
||||||
return price_val, None
|
return price_val, None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
# utils/database.py
|
# utils/database.py
|
||||||
|
import logging
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.pool import QueuePool
|
from sqlalchemy.pool import QueuePool
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_db_engine(database_url: str):
|
def get_db_engine(database_url: str):
|
||||||
"""Create database engine with connection pooling"""
|
"""Create database engine with connection pooling"""
|
||||||
if database_url.startswith('sqlite'):
|
if database_url.startswith("sqlite"):
|
||||||
# SQLite configuration
|
# SQLite configuration
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
database_url,
|
database_url, connect_args={"check_same_thread": False}, echo=False
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
echo=False
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# PostgreSQL configuration with connection pooling
|
# PostgreSQL configuration with connection pooling
|
||||||
@@ -24,7 +23,7 @@ def get_db_engine(database_url: str):
|
|||||||
pool_size=10,
|
pool_size=10,
|
||||||
max_overflow=20,
|
max_overflow=20,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
echo=False
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Database engine created for: {database_url.split('@')[0]}@...")
|
logger.info(f"Database engine created for: {database_url.split('@')[0]}@...")
|
||||||
|
|||||||
Reference in New Issue
Block a user