Refactoring code for modular approach

This commit is contained in:
2025-09-11 20:59:40 +02:00
parent fca389cff4
commit 900229d452
17 changed files with 850 additions and 125 deletions

7
.env
View File

@@ -1,8 +1,9 @@
# .env.example # .env.example
PROJECT_NAME: str = "Ecommerce Backend API with Marketplace Support" # Project information
DESCRIPTION: str = "Advanced product management system with JWT authentication" PROJECT_NAME=Ecommerce Backend API with Marketplace Support
VERSION: str = "0.0.1" DESCRIPTION=Advanced product management system with JWT authentication
VERSION=0.0.1
# Database Configuration # Database Configuration
# DATABASE_URL=postgresql://username:password@localhost:5432/ecommerce_db # DATABASE_URL=postgresql://username:password@localhost:5432/ecommerce_db

View File

@@ -1,13 +1,13 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import auth, products, stock, shops, marketplace, admin, stats from app.api.v1 import auth, product, stock, shop, marketplace, admin, stats
api_router = APIRouter() api_router = APIRouter()
# Include all route modules # Include all route modules
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
api_router.include_router(products.router, prefix="/products", tags=["products"]) api_router.include_router(product.router, prefix="/product", tags=["product"])
api_router.include_router(stock.router, prefix="/stock", tags=["stock"]) api_router.include_router(stock.router, prefix="/stock", tags=["stock"])
api_router.include_router(shops.router, prefix="/shops", tags=["shops"]) api_router.include_router(shop.router, prefix="/shop", tags=["shop"])
api_router.include_router(marketplace.router, prefix="/marketplace", tags=["marketplace"]) api_router.include_router(marketplace.router, prefix="/marketplace", tags=["marketplace"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
api_router.include_router(stats.router, prefix="/stats", tags=["statistics"]) api_router.include_router(stats.router, prefix="/stats", tags=["statistics"])

View File

@@ -1,14 +1,13 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.api.deps import get_current_user, get_current_admin_user from app.api.deps import get_current_admin_user
from app.tasks.background_tasks import process_marketplace_import from app.services.admin_service import admin_service
from middleware.decorators import rate_limit from middleware.decorators import rate_limit
from models.api_models import MarketplaceImportJobResponse, MarketplaceImportRequest, UserResponse, ShopListResponse from models.api_models import MarketplaceImportJobResponse, UserResponse, ShopListResponse
from models.database_models import User, MarketplaceImportJob, Shop from models.database_models import User
from datetime import datetime
import logging import logging
router = APIRouter() router = APIRouter()
@@ -24,8 +23,12 @@ def get_all_users(
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user)
): ):
"""Get all users (Admin only)""" """Get all users (Admin only)"""
users = db.query(User).offset(skip).limit(limit).all() try:
users = admin_service.get_all_users(db=db, skip=skip, limit=limit)
return [UserResponse.model_validate(user) for user in users] return [UserResponse.model_validate(user) for user in users]
except Exception as e:
logger.error(f"Error getting users: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/admin/users/{user_id}/status") @router.put("/admin/users/{user_id}/status")
@@ -35,20 +38,14 @@ def toggle_user_status(
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user)
): ):
"""Toggle user active status (Admin only)""" """Toggle user active status (Admin only)"""
user = db.query(User).filter(User.id == user_id).first() try:
if not user: user, message = admin_service.toggle_user_status(db, user_id, current_admin.id)
raise HTTPException(status_code=404, detail="User not found") return {"message": message}
except HTTPException:
if user.id == current_admin.id: raise
raise HTTPException(status_code=400, detail="Cannot deactivate your own account") except Exception as e:
logger.error(f"Error toggling user {user_id} status: {str(e)}")
user.is_active = not user.is_active raise HTTPException(status_code=500, detail="Internal server error")
user.updated_at = datetime.utcnow()
db.commit()
db.refresh(user)
status = "activated" if user.is_active else "deactivated"
return {"message": f"User {user.username} has been {status}"}
@router.get("/admin/shops", response_model=ShopListResponse) @router.get("/admin/shops", response_model=ShopListResponse)
@@ -59,8 +56,8 @@ def get_all_shops_admin(
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user)
): ):
"""Get all shops with admin view (Admin only)""" """Get all shops with admin view (Admin only)"""
total = db.query(Shop).count() try:
shops = db.query(Shop).offset(skip).limit(limit).all() shops, total = admin_service.get_all_shops(db=db, skip=skip, limit=limit)
return ShopListResponse( return ShopListResponse(
shops=shops, shops=shops,
@@ -68,6 +65,9 @@ def get_all_shops_admin(
skip=skip, skip=skip,
limit=limit limit=limit
) )
except Exception as e:
logger.error(f"Error getting shops: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/admin/shops/{shop_id}/verify") @router.put("/admin/shops/{shop_id}/verify")
@@ -77,17 +77,14 @@ def verify_shop(
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user)
): ):
"""Verify/unverify shop (Admin only)""" """Verify/unverify shop (Admin only)"""
shop = db.query(Shop).filter(Shop.id == shop_id).first() try:
if not shop: shop, message = admin_service.verify_shop(db, shop_id)
raise HTTPException(status_code=404, detail="Shop not found") return {"message": message}
except HTTPException:
shop.is_verified = not shop.is_verified raise
shop.updated_at = datetime.utcnow() except Exception as e:
db.commit() logger.error(f"Error verifying shop {shop_id}: {str(e)}")
db.refresh(shop) raise HTTPException(status_code=500, detail="Internal server error")
status = "verified" if shop.is_verified else "unverified"
return {"message": f"Shop {shop.shop_code} has been {status}"}
@router.put("/admin/shops/{shop_id}/status") @router.put("/admin/shops/{shop_id}/status")
@@ -97,17 +94,14 @@ def toggle_shop_status(
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user)
): ):
"""Toggle shop active status (Admin only)""" """Toggle shop active status (Admin only)"""
shop = db.query(Shop).filter(Shop.id == shop_id).first() try:
if not shop: shop, message = admin_service.toggle_shop_status(db, shop_id)
raise HTTPException(status_code=404, detail="Shop not found") return {"message": message}
except HTTPException:
shop.is_active = not shop.is_active raise
shop.updated_at = datetime.utcnow() except Exception as e:
db.commit() logger.error(f"Error toggling shop {shop_id} status: {str(e)}")
db.refresh(shop) raise HTTPException(status_code=500, detail="Internal server error")
status = "activated" if shop.is_active else "deactivated"
return {"message": f"Shop {shop.shop_code} has been {status}"}
@router.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse]) @router.get("/admin/marketplace-import-jobs", response_model=List[MarketplaceImportJobResponse])
@@ -121,33 +115,15 @@ def get_all_marketplace_import_jobs(
current_admin: User = Depends(get_current_admin_user) current_admin: User = Depends(get_current_admin_user)
): ):
"""Get all marketplace import jobs (Admin only)""" """Get all marketplace import jobs (Admin only)"""
try:
query = db.query(MarketplaceImportJob) return admin_service.get_marketplace_import_jobs(
db=db,
# Apply filters marketplace=marketplace,
if marketplace: shop_name=shop_name,
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")) status=status,
if shop_name: skip=skip,
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%")) limit=limit
if status: )
query = query.filter(MarketplaceImportJob.status == status) except Exception as e:
logger.error(f"Error getting marketplace import jobs: {str(e)}")
# Order by creation date and apply pagination raise HTTPException(status_code=500, detail="Internal server error")
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all()
return [
MarketplaceImportJobResponse(
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
shop_name=job.shop_name,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
total_processed=job.total_processed or 0,
error_count=job.error_count or 0,
error_message=job.error_message,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at
) for job in jobs
]

View File

@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
# Enhanced Product Routes with Marketplace Support # Enhanced Product Routes with Marketplace Support
@router.get("/products", response_model=ProductListResponse) @router.get("/product", response_model=ProductListResponse)
def get_products( def get_products(
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
@@ -58,7 +58,7 @@ def get_products(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/products", response_model=ProductResponse) @router.post("/product", response_model=ProductResponse)
def create_product( def create_product(
product: ProductCreate, product: ProductCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -86,7 +86,7 @@ def create_product(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/products/{product_id}", response_model=ProductDetailResponse) @router.get("/product/{product_id}", response_model=ProductDetailResponse)
def get_product( def get_product(
product_id: str, product_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -116,7 +116,7 @@ def get_product(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/products/{product_id}", response_model=ProductResponse) @router.put("/product/{product_id}", response_model=ProductResponse)
def update_product( def update_product(
product_id: str, product_id: str,
product_update: ProductUpdate, product_update: ProductUpdate,
@@ -142,7 +142,7 @@ def update_product(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/products/{product_id}") @router.delete("/product/{product_id}")
def delete_product( def delete_product(
product_id: str, product_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),

View File

@@ -2,18 +2,21 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from sqlalchemy import text from sqlalchemy import text
from .logging import setup_logging from .logging import setup_logging
from .database import engine from .database import engine, SessionLocal
from models.database_models import Base from models.database_models import Base
import logging import logging
from middleware.auth import AuthManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
auth_manager = AuthManager()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Application lifespan events""" """Application lifespan events"""
# Startup # Startup
logger = setup_logging() # Configure logging first app_logger = setup_logging() # Configure logging first
logger.info("Starting up ecommerce API") app_logger.info("Starting up ecommerce API")
# Create tables # Create tables
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@@ -21,10 +24,19 @@ async def lifespan(app: FastAPI):
# Create indexes # Create indexes
create_indexes() create_indexes()
# Create default admin user
db = SessionLocal()
try:
auth_manager.create_default_admin_user(db)
except Exception as e:
logger.error(f"Failed to create default admin user: {e}")
finally:
db.close()
yield yield
# Shutdown # Shutdown
logger.info("Shutting down ecommerce API") app_logger.info("Shutting down ecommerce API")
def create_indexes(): def create_indexes():

View File

@@ -0,0 +1,193 @@
from sqlalchemy.orm import Session
from sqlalchemy import func
from fastapi import HTTPException
from datetime import datetime
import logging
from typing import List, Optional, Tuple
from models.database_models import User, MarketplaceImportJob, Shop
from models.api_models import MarketplaceImportJobResponse
logger = logging.getLogger(__name__)
class AdminService:
"""Service class for admin operations following the application's service pattern"""
def get_all_users(self, db: Session, skip: int = 0, limit: int = 100) -> List[User]:
"""Get paginated list of all users"""
return db.query(User).offset(skip).limit(limit).all()
def toggle_user_status(self, db: Session, user_id: int, current_admin_id: int) -> Tuple[User, str]:
"""
Toggle user active status
Args:
db: Database session
user_id: ID of user to toggle
current_admin_id: ID of the admin performing the action
Returns:
Tuple of (updated_user, status_message)
Raises:
HTTPException: If user not found or trying to deactivate own account
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user.id == current_admin_id:
raise HTTPException(status_code=400, detail="Cannot deactivate your own account")
user.is_active = not user.is_active
user.updated_at = datetime.utcnow()
db.commit()
db.refresh(user)
status = "activated" if user.is_active else "deactivated"
logger.info(f"User {user.username} has been {status} by admin {current_admin_id}")
return user, f"User {user.username} has been {status}"
def get_all_shops(self, db: Session, skip: int = 0, limit: int = 100) -> Tuple[List[Shop], int]:
"""
Get paginated list of all shops with total count
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
Tuple of (shops_list, total_count)
"""
total = db.query(Shop).count()
shops = db.query(Shop).offset(skip).limit(limit).all()
return shops, total
def verify_shop(self, db: Session, shop_id: int) -> Tuple[Shop, str]:
"""
Toggle shop verification status
Args:
db: Database session
shop_id: ID of shop to verify/unverify
Returns:
Tuple of (updated_shop, status_message)
Raises:
HTTPException: If shop not found
"""
shop = db.query(Shop).filter(Shop.id == shop_id).first()
if not shop:
raise HTTPException(status_code=404, detail="Shop not found")
shop.is_verified = not shop.is_verified
shop.updated_at = datetime.utcnow()
db.commit()
db.refresh(shop)
status = "verified" if shop.is_verified else "unverified"
logger.info(f"Shop {shop.shop_code} has been {status}")
return shop, f"Shop {shop.shop_code} has been {status}"
def toggle_shop_status(self, db: Session, shop_id: int) -> Tuple[Shop, str]:
"""
Toggle shop active status
Args:
db: Database session
shop_id: ID of shop to activate/deactivate
Returns:
Tuple of (updated_shop, status_message)
Raises:
HTTPException: If shop not found
"""
shop = db.query(Shop).filter(Shop.id == shop_id).first()
if not shop:
raise HTTPException(status_code=404, detail="Shop not found")
shop.is_active = not shop.is_active
shop.updated_at = datetime.utcnow()
db.commit()
db.refresh(shop)
status = "activated" if shop.is_active else "deactivated"
logger.info(f"Shop {shop.shop_code} has been {status}")
return shop, f"Shop {shop.shop_code} has been {status}"
def get_marketplace_import_jobs(
self,
db: Session,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100
) -> List[MarketplaceImportJobResponse]:
"""
Get filtered and paginated marketplace import jobs
Args:
db: Database session
marketplace: Filter by marketplace name (case-insensitive partial match)
shop_name: Filter by shop name (case-insensitive partial match)
status: Filter by exact status
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
List of MarketplaceImportJobResponse objects
"""
query = db.query(MarketplaceImportJob)
# Apply filters
if marketplace:
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%"))
if shop_name:
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
if status:
query = query.filter(MarketplaceImportJob.status == status)
# Order by creation date and apply pagination
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all()
return [
MarketplaceImportJobResponse(
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
shop_name=job.shop_name,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
total_processed=job.total_processed or 0,
error_count=job.error_count or 0,
error_message=job.error_message,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at
) for job in jobs
]
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:
"""Get user by ID"""
return db.query(User).filter(User.id == user_id).first()
def get_shop_by_id(self, db: Session, shop_id: int) -> Optional[Shop]:
"""Get shop by ID"""
return db.query(Shop).filter(Shop.id == shop_id).first()
def user_exists(self, db: Session, user_id: int) -> bool:
"""Check if user exists by ID"""
return db.query(User).filter(User.id == user_id).first() is not None
def shop_exists(self, db: Session, shop_id: int) -> bool:
"""Check if shop exists by ID"""
return db.query(Shop).filter(Shop.id == shop_id).first() is not None
# Create service instance following the same pattern as product_service
admin_service = AdminService()

View File

@@ -319,7 +319,7 @@ pytest tests/ -m "not slow" -v # Fast tests only
# Run specific test files # Run specific test files
pytest tests/test_auth.py -v # Authentication tests pytest tests/test_auth.py -v # Authentication tests
pytest tests/test_products.py -v # Product tests pytest tests/test_product.py -v # Product tests
pytest tests/test_stock.py -v # Stock management tests pytest tests/test_stock.py -v # Stock management tests
``` ```

View File

@@ -2,6 +2,7 @@
from pydantic import BaseModel, Field, field_validator, EmailStr, ConfigDict from pydantic import BaseModel, Field, field_validator, EmailStr, ConfigDict
from typing import Optional, List from typing import Optional, List
from datetime import datetime from datetime import datetime
import re
# User Authentication Models # User Authentication Models
@@ -13,8 +14,8 @@ class UserRegister(BaseModel):
@field_validator('username') @field_validator('username')
@classmethod @classmethod
def validate_username(cls, v): def validate_username(cls, v):
if not v.isalnum(): if not re.match(r'^[a-zA-Z0-9_]+$', v):
raise ValueError('Username must contain only alphanumeric characters') raise ValueError('Username must contain only letters, numbers, or underscores')
return v.lower().strip() return v.lower().strip()
@field_validator('password') @field_validator('password')

View File

@@ -206,7 +206,7 @@ product_service = ProductService()
### Step 2: Refactor Router ### Step 2: Refactor Router
```python ```python
# products.py # product.py
from product_service import product_service from product_service import product_service
@router.post("/products", response_model=ProductResponse) @router.post("/products", response_model=ProductResponse)

View File

@@ -12,6 +12,7 @@ from app.core.database import get_db, Base
# Import all models to ensure they're registered with Base metadata # Import all models to ensure they're registered with Base metadata
from models.database_models import User, Product, Stock, Shop, MarketplaceImportJob, ShopProduct from models.database_models import User, Product, Stock, Shop, MarketplaceImportJob, ShopProduct
from middleware.auth import AuthManager from middleware.auth import AuthManager
import uuid
# Use in-memory SQLite database for tests # Use in-memory SQLite database for tests
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:" SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"
@@ -81,11 +82,12 @@ def auth_manager():
@pytest.fixture @pytest.fixture
def test_user(db, auth_manager): def test_user(db, auth_manager):
"""Create a test user""" """Create a test user with unique username"""
unique_id = str(uuid.uuid4())[:8] # Short unique identifier
hashed_password = auth_manager.hash_password("testpass123") hashed_password = auth_manager.hash_password("testpass123")
user = User( user = User(
email="test@example.com", email=f"test_{unique_id}@example.com",
username="testuser", username=f"testuser_{unique_id}",
hashed_password=hashed_password, hashed_password=hashed_password,
role="user", role="user",
is_active=True is_active=True
@@ -98,11 +100,12 @@ def test_user(db, auth_manager):
@pytest.fixture @pytest.fixture
def test_admin(db, auth_manager): def test_admin(db, auth_manager):
"""Create a test admin user""" """Create a test admin user with unique username"""
unique_id = str(uuid.uuid4())[:8] # Short unique identifier
hashed_password = auth_manager.hash_password("adminpass123") hashed_password = auth_manager.hash_password("adminpass123")
admin = User( admin = User(
email="admin@example.com", email=f"admin_{unique_id}@example.com",
username="admin", username=f"admin_{unique_id}",
hashed_password=hashed_password, hashed_password=hashed_password,
role="admin", role="admin",
is_active=True is_active=True
@@ -117,7 +120,7 @@ def test_admin(db, auth_manager):
def auth_headers(client, test_user): def auth_headers(client, test_user):
"""Get authentication headers for test user""" """Get authentication headers for test user"""
response = client.post("/api/v1/auth/login", json={ response = client.post("/api/v1/auth/login", json={
"username": "testuser", "username": test_user.username,
"password": "testpass123" "password": "testpass123"
}) })
assert response.status_code == 200, f"Login failed: {response.text}" assert response.status_code == 200, f"Login failed: {response.text}"
@@ -129,7 +132,7 @@ def auth_headers(client, test_user):
def admin_headers(client, test_admin): def admin_headers(client, test_admin):
"""Get authentication headers for admin user""" """Get authentication headers for admin user"""
response = client.post("/api/v1/auth/login", json={ response = client.post("/api/v1/auth/login", json={
"username": "admin", "username": test_admin.username,
"password": "adminpass123" "password": "adminpass123"
}) })
assert response.status_code == 200, f"Admin login failed: {response.text}" assert response.status_code == 200, f"Admin login failed: {response.text}"
@@ -160,10 +163,11 @@ def test_product(db):
@pytest.fixture @pytest.fixture
def test_shop(db, test_user): def test_shop(db, test_user):
"""Create a test shop""" """Create a test shop with unique shop code"""
unique_id = str(uuid.uuid4())[:8] # Short unique identifier
shop = Shop( shop = Shop(
shop_code="TESTSHOP", shop_code=f"TESTSHOP_{unique_id}",
shop_name="Test Shop", shop_name=f"Test Shop {unique_id}",
owner_id=test_user.id, owner_id=test_user.id,
is_active=True, is_active=True,
is_verified=True is_verified=True
@@ -190,6 +194,50 @@ def test_stock(db, test_product, test_shop):
return stock return stock
@pytest.fixture
def test_marketplace_job(db, test_shop): # Add test_shop dependency
"""Create a test marketplace import job"""
job = MarketplaceImportJob(
marketplace="amazon",
shop_name="Test Import Shop",
status="completed",
source_url="https://test-marketplace.example.com/import",
shop_id=test_shop.id, # Add required shop_id
imported_count=5,
updated_count=3,
total_processed=8,
error_count=0,
error_message=None
)
db.add(job)
db.commit()
db.refresh(job)
return job
def create_test_import_job(db, shop_id, **kwargs): # Add shop_id parameter
"""Helper function to create MarketplaceImportJob with defaults"""
defaults = {
'marketplace': 'test',
'shop_name': 'Test Shop',
'status': 'pending',
'source_url': 'https://test.example.com/import',
'shop_id': shop_id, # Add required shop_id
'imported_count': 0,
'updated_count': 0,
'total_processed': 0,
'error_count': 0,
'error_message': None
}
defaults.update(kwargs)
job = MarketplaceImportJob(**defaults)
db.add(job)
db.commit()
db.refresh(job)
return job
# Cleanup fixture to ensure clean state # Cleanup fixture to ensure clean state
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup(): def cleanup():

View File

@@ -11,6 +11,10 @@ class TestAdminAPI:
data = response.json() data = response.json()
assert len(data) >= 2 # test_user + admin user assert len(data) >= 2 # test_user + admin user
# Check that test_user is in the response
user_ids = [user["id"] for user in data if "id" in user]
assert test_user.id in user_ids
def test_get_all_users_non_admin(self, client, auth_headers): def test_get_all_users_non_admin(self, client, auth_headers):
"""Test non-admin trying to access admin endpoint""" """Test non-admin trying to access admin endpoint"""
response = client.get("/api/v1/admin/users", headers=auth_headers) response = client.get("/api/v1/admin/users", headers=auth_headers)
@@ -23,7 +27,24 @@ class TestAdminAPI:
response = client.put(f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers) response = client.put(f"/api/v1/admin/users/{test_user.id}/status", headers=admin_headers)
assert response.status_code == 200 assert response.status_code == 200
assert "deactivated" in response.json()["message"] or "activated" in response.json()["message"] message = response.json()["message"]
assert "deactivated" in message or "activated" in message
# Verify the username is in the message
assert test_user.username in message
def test_toggle_user_status_user_not_found(self, client, admin_headers):
"""Test admin toggling status for non-existent user"""
response = client.put("/api/v1/admin/users/99999/status", headers=admin_headers)
assert response.status_code == 404
assert "User not found" in response.json()["detail"]
def test_toggle_user_status_cannot_deactivate_self(self, client, admin_headers, test_admin):
"""Test that admin cannot deactivate their own account"""
response = client.put(f"/api/v1/admin/users/{test_admin.id}/status", headers=admin_headers)
assert response.status_code == 400
assert "Cannot deactivate your own account" in response.json()["detail"]
def test_get_all_shops_admin(self, client, admin_headers, test_shop): def test_get_all_shops_admin(self, client, admin_headers, test_shop):
"""Test admin getting all shops""" """Test admin getting all shops"""
@@ -32,3 +53,115 @@ class TestAdminAPI:
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] >= 1 assert data["total"] >= 1
assert len(data["shops"]) >= 1
# Check that test_shop is in the response
shop_codes = [shop["shop_code"] for shop in data["shops"] if "shop_code" in shop]
assert test_shop.shop_code in shop_codes
def test_get_all_shops_non_admin(self, client, auth_headers):
"""Test non-admin trying to access admin shop endpoint"""
response = client.get("/api/v1/admin/shops", headers=auth_headers)
assert response.status_code == 403
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower()
def test_verify_shop_admin(self, client, admin_headers, test_shop):
"""Test admin verifying/unverifying shop"""
response = client.put(f"/api/v1/admin/shops/{test_shop.id}/verify", headers=admin_headers)
assert response.status_code == 200
message = response.json()["message"]
assert "verified" in message or "unverified" in message
assert test_shop.shop_code in message
def test_verify_shop_not_found(self, client, admin_headers):
"""Test admin verifying non-existent shop"""
response = client.put("/api/v1/admin/shops/99999/verify", headers=admin_headers)
assert response.status_code == 404
assert "Shop not found" in response.json()["detail"]
def test_toggle_shop_status_admin(self, client, admin_headers, test_shop):
"""Test admin toggling shop status"""
response = client.put(f"/api/v1/admin/shops/{test_shop.id}/status", headers=admin_headers)
assert response.status_code == 200
message = response.json()["message"]
assert "activated" in message or "deactivated" in message
assert test_shop.shop_code in message
def test_toggle_shop_status_not_found(self, client, admin_headers):
"""Test admin toggling status for non-existent shop"""
response = client.put("/api/v1/admin/shops/99999/status", headers=admin_headers)
assert response.status_code == 404
assert "Shop not found" in response.json()["detail"]
def test_get_marketplace_import_jobs_admin(self, client, admin_headers, test_marketplace_job):
"""Test admin getting marketplace import jobs"""
response = client.get("/api/v1/admin/marketplace-import-jobs", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert len(data) >= 1
# Check that test_marketplace_job is in the response
job_ids = [job["job_id"] for job in data if "job_id" in job]
assert test_marketplace_job.id in job_ids
def test_get_marketplace_import_jobs_with_filters(self, client, admin_headers, test_marketplace_job):
"""Test admin getting marketplace import jobs with filters"""
response = client.get(
"/api/v1/admin/marketplace-import-jobs",
params={"marketplace": test_marketplace_job.marketplace},
headers=admin_headers
)
assert response.status_code == 200
data = response.json()
assert len(data) >= 1
assert all(job["marketplace"] == test_marketplace_job.marketplace for job in data)
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):
"""Test non-admin trying to access marketplace import jobs"""
response = client.get("/api/v1/admin/marketplace-import-jobs", headers=auth_headers)
assert response.status_code == 403
assert "Access denied" in response.json()["detail"] or "admin" in response.json()["detail"].lower()
def test_admin_endpoints_require_authentication(self, client):
"""Test that admin endpoints require authentication"""
endpoints = [
"/api/v1/admin/users",
"/api/v1/admin/shops",
"/api/v1/admin/marketplace-import-jobs"
]
for endpoint in endpoints:
response = client.get(endpoint)
assert response.status_code == 401 # Unauthorized
def test_admin_pagination_users(self, client, admin_headers, test_user, test_admin):
"""Test user pagination works correctly"""
# Test first page
response = client.get("/api/v1/admin/users?skip=0&limit=1", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
# Test second page
response = client.get("/api/v1/admin/users?skip=1&limit=1", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert len(data) >= 0 # Could be 1 or 0 depending on total users
def test_admin_pagination_shops(self, client, admin_headers, test_shop):
"""Test shop pagination works correctly"""
response = client.get("/api/v1/admin/shops?skip=0&limit=1", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
assert len(data["shops"]) >= 0
assert "skip" in data
assert "limit" in data

359
tests/test_admin_service.py Normal file
View File

@@ -0,0 +1,359 @@
# tests/test_admin_service.py
import pytest
from datetime import datetime
from fastapi import HTTPException
from app.services.admin_service import AdminService
from models.database_models import User, Shop, MarketplaceImportJob
class TestAdminService:
"""Test suite for AdminService following the application's testing patterns"""
def setup_method(self):
"""Setup method following the same pattern as product service tests"""
self.service = AdminService()
def test_get_all_users(self, db, test_user, test_admin):
"""Test getting all users with pagination"""
users = self.service.get_all_users(db, skip=0, limit=10)
assert len(users) >= 2 # test_user + test_admin
user_ids = [user.id for user in users]
assert test_user.id in user_ids
assert test_admin.id in user_ids
def test_get_all_users_with_pagination(self, db, test_user, test_admin):
"""Test user pagination works correctly"""
users = self.service.get_all_users(db, skip=0, limit=1)
assert len(users) == 1
users_second_page = self.service.get_all_users(db, skip=1, limit=1)
assert len(users_second_page) == 1
assert users[0].id != users_second_page[0].id
def test_toggle_user_status_deactivate(self, db, test_user, test_admin):
"""Test deactivating a user"""
assert test_user.is_active is True
user, message = self.service.toggle_user_status(db, test_user.id, test_admin.id)
assert user.id == test_user.id
assert user.is_active is False
assert f"{user.username} has been deactivated" in message
def test_toggle_user_status_activate(self, db, test_user, test_admin):
"""Test activating a user"""
# First deactivate the user
test_user.is_active = False
db.commit()
user, message = self.service.toggle_user_status(db, test_user.id, test_admin.id)
assert user.id == test_user.id
assert user.is_active is True
assert f"{user.username} has been activated" in message
def test_toggle_user_status_user_not_found(self, db, test_admin):
"""Test toggle user status when user not found"""
with pytest.raises(HTTPException) as exc_info:
self.service.toggle_user_status(db, 99999, test_admin.id)
assert exc_info.value.status_code == 404
assert "User not found" in str(exc_info.value.detail)
def test_toggle_user_status_cannot_deactivate_self(self, db, test_admin):
"""Test that admin cannot deactivate their own account"""
with pytest.raises(HTTPException) as exc_info:
self.service.toggle_user_status(db, test_admin.id, test_admin.id)
assert exc_info.value.status_code == 400
assert "Cannot deactivate your own account" in str(exc_info.value.detail)
def test_get_all_shops(self, db, test_shop):
"""Test getting all shops with total count"""
shops, total = self.service.get_all_shops(db, skip=0, limit=10)
assert total >= 1
assert len(shops) >= 1
shop_codes = [shop.shop_code for shop in shops]
assert test_shop.shop_code in shop_codes
def test_get_all_shops_with_pagination(self, db, test_shop):
"""Test shop pagination works correctly"""
# Create additional shop for pagination test using the helper function
from conftest import create_test_import_job # If you added the helper function
# Or create directly with proper fields
additional_shop = Shop(
shop_code=f"{test_shop.shop_code}_2",
shop_name="Test Shop 2",
owner_id=test_shop.owner_id,
is_active=True,
is_verified=False
)
db.add(additional_shop)
db.commit()
shops_page_1 = self.service.get_all_shops(db, skip=0, limit=1)
assert len(shops_page_1[0]) == 1
shops_page_2 = self.service.get_all_shops(db, skip=1, limit=1)
assert len(shops_page_2[0]) == 1
# Ensure different shops on different pages
assert shops_page_1[0][0].id != shops_page_2[0][0].id
def test_verify_shop_mark_verified(self, db, test_shop):
"""Test marking shop as verified"""
# Ensure shop starts unverified
test_shop.is_verified = False
db.commit()
shop, message = self.service.verify_shop(db, test_shop.id)
assert shop.id == test_shop.id
assert shop.is_verified is True
assert f"{shop.shop_code} has been verified" in message
def test_verify_shop_mark_unverified(self, db, test_shop):
"""Test marking shop as unverified"""
# Ensure shop starts verified
test_shop.is_verified = True
db.commit()
shop, message = self.service.verify_shop(db, test_shop.id)
assert shop.id == test_shop.id
assert shop.is_verified is False
assert f"{shop.shop_code} has been unverified" in message
def test_verify_shop_not_found(self, db):
"""Test verify shop when shop not found"""
with pytest.raises(HTTPException) as exc_info:
self.service.verify_shop(db, 99999)
assert exc_info.value.status_code == 404
assert "Shop not found" in str(exc_info.value.detail)
def test_toggle_shop_status_deactivate(self, db, test_shop):
"""Test deactivating a shop"""
assert test_shop.is_active is True
shop, message = self.service.toggle_shop_status(db, test_shop.id)
assert shop.id == test_shop.id
assert shop.is_active is False
assert f"{shop.shop_code} has been deactivated" in message
def test_toggle_shop_status_activate(self, db, test_shop):
"""Test activating a shop"""
# First deactivate the shop
test_shop.is_active = False
db.commit()
shop, message = self.service.toggle_shop_status(db, test_shop.id)
assert shop.id == test_shop.id
assert shop.is_active is True
assert f"{shop.shop_code} has been activated" in message
def test_toggle_shop_status_not_found(self, db):
"""Test toggle shop status when shop not found"""
with pytest.raises(HTTPException) as exc_info:
self.service.toggle_shop_status(db, 99999)
assert exc_info.value.status_code == 404
assert "Shop not found" in str(exc_info.value.detail)
def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_job):
"""Test getting marketplace import jobs without filters using fixture"""
result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10)
assert len(result) >= 1
# Find our test job in the results
test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None)
assert test_job is not None
assert test_job.marketplace == test_marketplace_job.marketplace
assert test_job.shop_name == test_marketplace_job.shop_name
assert test_job.status == test_marketplace_job.status
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_job):
"""Test getting marketplace import jobs filtered by marketplace"""
# Create additional job with different marketplace
other_job = MarketplaceImportJob(
marketplace="ebay",
shop_name="eBay Shop",
status="completed",
source_url="https://ebay.example.com/import"
)
db.add(other_job)
db.commit()
# Filter by the test marketplace job's marketplace
result = self.service.get_marketplace_import_jobs(db, marketplace=test_marketplace_job.marketplace)
assert len(result) >= 1
# All results should match the marketplace filter
for job in result:
assert test_marketplace_job.marketplace.lower() in job.marketplace.lower()
def test_get_marketplace_import_jobs_with_shop_filter(self, db, test_marketplace_job):
"""Test getting marketplace import jobs filtered by shop name"""
# Create additional job with different shop name
other_job = MarketplaceImportJob(
marketplace="amazon",
shop_name="Different Shop Name",
status="completed",
source_url="https://different.example.com/import"
)
db.add(other_job)
db.commit()
# Filter by the test marketplace job's shop name
result = self.service.get_marketplace_import_jobs(db, shop_name=test_marketplace_job.shop_name)
assert len(result) >= 1
# All results should match the shop name filter
for job in result:
assert test_marketplace_job.shop_name.lower() in job.shop_name.lower()
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_job):
"""Test getting marketplace import jobs filtered by status"""
# Create additional job with different status
other_job = MarketplaceImportJob(
marketplace="amazon",
shop_name="Test Shop",
status="pending",
source_url="https://pending.example.com/import"
)
db.add(other_job)
db.commit()
# Filter by the test marketplace job's status
result = self.service.get_marketplace_import_jobs(db, status=test_marketplace_job.status)
assert len(result) >= 1
# All results should match the status filter
for job in result:
assert job.status == test_marketplace_job.status
def test_get_marketplace_import_jobs_with_multiple_filters(self, db, test_marketplace_job, test_shop):
"""Test getting marketplace import jobs with multiple filters"""
# Create jobs that don't match all filters
non_matching_job1 = MarketplaceImportJob(
marketplace="ebay", # Different marketplace
shop_name=test_marketplace_job.shop_name,
status=test_marketplace_job.status,
source_url="https://non-matching1.example.com/import",
shop_id=test_shop.id # Add required shop_id
)
non_matching_job2 = MarketplaceImportJob(
marketplace=test_marketplace_job.marketplace,
shop_name="Different Shop", # Different shop name
status=test_marketplace_job.status,
source_url="https://non-matching2.example.com/import",
shop_id=test_shop.id # Add required shop_id
)
db.add_all([non_matching_job1, non_matching_job2])
db.commit()
# Apply all three filters matching the test job
result = self.service.get_marketplace_import_jobs(
db,
marketplace=test_marketplace_job.marketplace,
shop_name=test_marketplace_job.shop_name,
status=test_marketplace_job.status
)
assert len(result) >= 1
# Find our test job in the results
test_job = next((job for job in result if job.job_id == test_marketplace_job.id), None)
assert test_job is not None
assert test_job.marketplace == test_marketplace_job.marketplace
assert test_job.shop_name == test_marketplace_job.shop_name
assert test_job.status == test_marketplace_job.status
def test_get_marketplace_import_jobs_null_values(self, db):
"""Test that marketplace import jobs handle null values correctly"""
# Create job with null values but required fields
job = MarketplaceImportJob(
marketplace="test",
shop_name="Test Shop",
status="pending",
source_url="https://test.example.com/import",
imported_count=None,
updated_count=None,
total_processed=None,
error_count=None,
error_message=None
)
db.add(job)
db.commit()
result = self.service.get_marketplace_import_jobs(db)
assert len(result) >= 1
# Find the job with null values
null_job = next((j for j in result if j.job_id == job.id), None)
assert null_job is not None
assert null_job.imported == 0 # None converted to 0
assert null_job.updated == 0
assert null_job.total_processed == 0
assert null_job.error_count == 0
assert null_job.error_message is None
def test_get_user_by_id(self, db, test_user):
"""Test getting user by ID using fixture"""
user = self.service.get_user_by_id(db, test_user.id)
assert user is not None
assert user.id == test_user.id
assert user.email == test_user.email
assert user.username == test_user.username
def test_get_user_by_id_not_found(self, db):
"""Test getting user by ID when user doesn't exist"""
user = self.service.get_user_by_id(db, 99999)
assert user is None
def test_get_shop_by_id(self, db, test_shop):
"""Test getting shop by ID using fixture"""
shop = self.service.get_shop_by_id(db, test_shop.id)
assert shop is not None
assert shop.id == test_shop.id
assert shop.shop_code == test_shop.shop_code
assert shop.shop_name == test_shop.shop_name
def test_get_shop_by_id_not_found(self, db):
"""Test getting shop by ID when shop doesn't exist"""
shop = self.service.get_shop_by_id(db, 99999)
assert shop is None
def test_user_exists_true(self, db, test_user):
"""Test user_exists returns True when user exists"""
exists = self.service.user_exists(db, test_user.id)
assert exists is True
def test_user_exists_false(self, db):
"""Test user_exists returns False when user doesn't exist"""
exists = self.service.user_exists(db, 99999)
assert exists is False
def test_shop_exists_true(self, db, test_shop):
"""Test shop_exists returns True when shop exists"""
exists = self.service.shop_exists(db, test_shop.id)
assert exists is True
def test_shop_exists_false(self, db):
"""Test shop_exists returns False when shop doesn't exist"""
exists = self.service.shop_exists(db, 99999)
assert exists is False

View File

@@ -23,7 +23,7 @@ class TestAuthenticationAPI:
def test_register_user_duplicate_email(self, client, test_user): def test_register_user_duplicate_email(self, client, test_user):
"""Test registration with duplicate email""" """Test registration with duplicate email"""
response = client.post("/api/v1/auth/register", json={ response = client.post("/api/v1/auth/register", json={
"email": "test@example.com", # Same as test_user "email": test_user.email, # Same as test_user
"username": "newuser", "username": "newuser",
"password": "securepass123" "password": "securepass123"
}) })
@@ -35,7 +35,7 @@ class TestAuthenticationAPI:
"""Test registration with duplicate username""" """Test registration with duplicate username"""
response = client.post("/api/v1/auth/register", json={ response = client.post("/api/v1/auth/register", json={
"email": "new@example.com", "email": "new@example.com",
"username": "testuser", # Same as test_user "username": test_user.username, # Same as test_user
"password": "securepass123" "password": "securepass123"
}) })
@@ -45,7 +45,7 @@ class TestAuthenticationAPI:
def test_login_success(self, client, test_user): def test_login_success(self, client, test_user):
"""Test successful login""" """Test successful login"""
response = client.post("/api/v1/auth/login", json={ response = client.post("/api/v1/auth/login", json={
"username": "testuser", "username": test_user.username,
"password": "testpass123" "password": "testpass123"
}) })
@@ -54,7 +54,7 @@ class TestAuthenticationAPI:
assert "access_token" in data assert "access_token" in data
assert data["token_type"] == "bearer" assert data["token_type"] == "bearer"
assert "expires_in" in data assert "expires_in" in data
assert data["user"]["username"] == "testuser" assert data["user"]["username"] == test_user.username
def test_login_wrong_password(self, client, test_user): def test_login_wrong_password(self, client, test_user):
"""Test login with wrong password""" """Test login with wrong password"""
@@ -75,14 +75,14 @@ class TestAuthenticationAPI:
assert response.status_code == 401 assert response.status_code == 401
def test_get_current_user_info(self, client, auth_headers): def test_get_current_user_info(self, client, auth_headers, test_user):
"""Test getting current user info""" """Test getting current user info"""
response = client.get("/api/v1/auth/me", headers=auth_headers) response = client.get("/api/v1/auth/me", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["username"] == "testuser" assert data["username"] == test_user.username
assert data["email"] == "test@example.com" assert data["email"] == test_user.email
def test_get_current_user_no_auth(self, client): def test_get_current_user_no_auth(self, client):
"""Test getting current user without authentication""" """Test getting current user without authentication"""

View File

@@ -2,6 +2,8 @@
import pytest import pytest
import time import time
from models.database_models import Product
class TestPerformance: class TestPerformance:
def test_product_list_performance(self, client, auth_headers, db): def test_product_list_performance(self, client, auth_headers, db):
@@ -22,7 +24,7 @@ class TestPerformance:
# Time the request # Time the request
start_time = time.time() start_time = time.time()
response = client.get("/api/v1/products?limit=100", headers=auth_headers) response = client.get("/api/v1/product?limit=100", headers=auth_headers)
end_time = time.time() end_time = time.time()
assert response.status_code == 200 assert response.status_code == 200
@@ -48,7 +50,7 @@ class TestPerformance:
# Time search request # Time search request
start_time = time.time() start_time = time.time()
response = client.get("/api/v1/products?search=Searchable", headers=auth_headers) response = client.get("/api/v1/product?search=Searchable", headers=auth_headers)
end_time = time.time() end_time = time.time()
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -1,11 +1,11 @@
# tests/test_products.py # tests/test_product.py
import pytest import pytest
class TestProductsAPI: class TestProductsAPI:
def test_get_products_empty(self, client, auth_headers): def test_get_products_empty(self, client, auth_headers):
"""Test getting products when none exist""" """Test getting products when none exist"""
response = client.get("/api/v1/products", headers=auth_headers) response = client.get("/api/v1/product", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()

View File

@@ -1,4 +1,4 @@
# tests/test_shops.py # tests/test_shop.py
import pytest import pytest