code quality run

This commit is contained in:
2025-09-13 21:58:54 +02:00
parent 0dfd885847
commit 3eb18ef91e
63 changed files with 1802 additions and 1289 deletions

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from fastapi import HTTPException
from datetime import datetime
import logging
from datetime import datetime
from typing import List, Optional, Tuple
from models.database_models import User, MarketplaceImportJob, Shop
from fastapi import HTTPException
from sqlalchemy.orm import Session
from models.api_models import MarketplaceImportJobResponse
from models.database_models import MarketplaceImportJob, Shop, User
logger = logging.getLogger(__name__)
@@ -17,7 +18,9 @@ class AdminService:
"""Get paginated list of all users"""
return db.query(User).offset(skip).limit(limit).all()
def toggle_user_status(self, db: Session, user_id: int, current_admin_id: int) -> Tuple[User, str]:
def toggle_user_status(
self, db: Session, user_id: int, current_admin_id: int
) -> Tuple[User, str]:
"""
Toggle user active status
@@ -37,7 +40,9 @@ class AdminService:
raise HTTPException(status_code=404, detail="User not found")
if user.id == current_admin_id:
raise HTTPException(status_code=400, detail="Cannot deactivate your own account")
raise HTTPException(
status_code=400, detail="Cannot deactivate your own account"
)
user.is_active = not user.is_active
user.updated_at = datetime.utcnow()
@@ -45,10 +50,14 @@ class AdminService:
db.refresh(user)
status = "activated" if user.is_active else "deactivated"
logger.info(f"User {user.username} has been {status} by admin {current_admin_id}")
logger.info(
f"User {user.username} has been {status} by admin {current_admin_id}"
)
return user, f"User {user.username} has been {status}"
def get_all_shops(self, db: Session, skip: int = 0, limit: int = 100) -> Tuple[List[Shop], int]:
def get_all_shops(
self, db: Session, skip: int = 0, limit: int = 100
) -> Tuple[List[Shop], int]:
"""
Get paginated list of all shops with total count
@@ -119,13 +128,13 @@ class AdminService:
return shop, f"Shop {shop.shop_code} has been {status}"
def get_marketplace_import_jobs(
self,
db: Session,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100
self,
db: Session,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100,
) -> List[MarketplaceImportJobResponse]:
"""
Get filtered and paginated marketplace import jobs
@@ -145,14 +154,21 @@ class AdminService:
# Apply filters
if marketplace:
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if shop_name:
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
if status:
query = query.filter(MarketplaceImportJob.status == status)
# Order by creation date and apply pagination
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all()
jobs = (
query.order_by(MarketplaceImportJob.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [
MarketplaceImportJobResponse(
@@ -168,8 +184,9 @@ class AdminService:
error_message=job.error_message,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at
) for job in jobs
completed_at=job.completed_at,
)
for job in jobs
]
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from fastapi import HTTPException
import logging
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
from fastapi import HTTPException
from sqlalchemy.orm import Session
from models.database_models import User
from models.api_models import UserRegister, UserLogin
from middleware.auth import AuthManager
from models.api_models import UserLogin, UserRegister
from models.database_models import User
logger = logging.getLogger(__name__)
@@ -36,7 +37,9 @@ class AuthService:
raise HTTPException(status_code=400, detail="Email already registered")
# Check if username already exists
existing_username = db.query(User).filter(User.username == user_data.username).first()
existing_username = (
db.query(User).filter(User.username == user_data.username).first()
)
if existing_username:
raise HTTPException(status_code=400, detail="Username already taken")
@@ -47,7 +50,7 @@ class AuthService:
username=user_data.username,
hashed_password=hashed_password,
role="user",
is_active=True
is_active=True,
)
db.add(new_user)
@@ -71,19 +74,20 @@ class AuthService:
Raises:
HTTPException: If authentication fails
"""
user = self.auth_manager.authenticate_user(db, user_credentials.username, user_credentials.password)
user = self.auth_manager.authenticate_user(
db, user_credentials.username, user_credentials.password
)
if not user:
raise HTTPException(status_code=401, detail="Incorrect username or password")
raise HTTPException(
status_code=401, detail="Incorrect username or password"
)
# Create access token
token_data = self.auth_manager.create_access_token(user)
logger.info(f"User logged in: {user.username}")
return {
"token_data": token_data,
"user": user
}
return {"token_data": token_data, "user": user}
def get_user_by_email(self, db: Session, email: str) -> Optional[User]:
"""Get user by email"""
@@ -101,7 +105,9 @@ class AuthService:
"""Check if username already exists"""
return db.query(User).filter(User.username == username).first() is not None
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
def authenticate_user(
self, db: Session, username: str, password: str
) -> Optional[User]:
"""Authenticate user with username/password"""
return self.auth_manager.authenticate_user(db, username, password)

View File

@@ -1,10 +1,13 @@
import logging
from datetime import datetime
from typing import List, Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from models.api_models import (MarketplaceImportJobResponse,
MarketplaceImportRequest)
from models.database_models import MarketplaceImportJob, Shop, User
from models.api_models import MarketplaceImportRequest, MarketplaceImportJobResponse
from typing import Optional, List
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
@@ -17,9 +20,11 @@ class MarketplaceService:
"""Validate that the shop exists and user has access to it"""
# Explicit type hint to help type checker shop: Optional[Shop]
# Use case-insensitive query to handle both uppercase and lowercase codes
shop: Optional[Shop] = db.query(Shop).filter(
func.upper(Shop.shop_code) == shop_code.upper()
).first()
shop: Optional[Shop] = (
db.query(Shop)
.filter(func.upper(Shop.shop_code) == shop_code.upper())
.first()
)
if not shop:
raise ValueError("Shop not found")
@@ -30,10 +35,7 @@ class MarketplaceService:
return shop
def create_import_job(
self,
db: Session,
request: MarketplaceImportRequest,
user: User
self, db: Session, request: MarketplaceImportRequest, user: User
) -> MarketplaceImportJob:
"""Create a new marketplace import job"""
# Validate shop access first
@@ -47,7 +49,7 @@ class MarketplaceService:
shop_id=shop.id, # Foreign key to shops table
shop_name=shop.shop_name, # Use shop.shop_name (the display name)
user_id=user.id,
created_at=datetime.utcnow()
created_at=datetime.utcnow(),
)
db.add(import_job)
@@ -55,13 +57,20 @@ class MarketplaceService:
db.refresh(import_job)
logger.info(
f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}")
f"Created marketplace import job {import_job.id}: {request.marketplace} -> {shop.shop_name} (shop_code: {shop.shop_code}) by user {user.username}"
)
return import_job
def get_import_job_by_id(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob:
def get_import_job_by_id(
self, db: Session, job_id: int, user: User
) -> MarketplaceImportJob:
"""Get a marketplace import job by ID with access control"""
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ValueError("Marketplace import job not found")
@@ -72,13 +81,13 @@ class MarketplaceService:
return job
def get_import_jobs(
self,
db: Session,
user: User,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
skip: int = 0,
limit: int = 50
self,
db: Session,
user: User,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
skip: int = 0,
limit: int = 50,
) -> List[MarketplaceImportJob]:
"""Get marketplace import jobs with filtering and access control"""
query = db.query(MarketplaceImportJob)
@@ -89,44 +98,51 @@ class MarketplaceService:
# Apply filters
if marketplace:
query = query.filter(MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if shop_name:
query = query.filter(MarketplaceImportJob.shop_name.ilike(f"%{shop_name}%"))
# Order by creation date (newest first) and apply pagination
jobs = query.order_by(MarketplaceImportJob.created_at.desc()).offset(skip).limit(limit).all()
jobs = (
query.order_by(MarketplaceImportJob.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return jobs
def update_job_status(
self,
db: Session,
job_id: int,
status: str,
**kwargs
self, db: Session, job_id: int, status: str, **kwargs
) -> MarketplaceImportJob:
"""Update marketplace import job status and other fields"""
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ValueError("Marketplace import job not found")
job.status = status
# Update optional fields if provided
if 'imported_count' in kwargs:
job.imported_count = kwargs['imported_count']
if 'updated_count' in kwargs:
job.updated_count = kwargs['updated_count']
if 'total_processed' in kwargs:
job.total_processed = kwargs['total_processed']
if 'error_count' in kwargs:
job.error_count = kwargs['error_count']
if 'error_message' in kwargs:
job.error_message = kwargs['error_message']
if 'started_at' in kwargs:
job.started_at = kwargs['started_at']
if 'completed_at' in kwargs:
job.completed_at = kwargs['completed_at']
if "imported_count" in kwargs:
job.imported_count = kwargs["imported_count"]
if "updated_count" in kwargs:
job.updated_count = kwargs["updated_count"]
if "total_processed" in kwargs:
job.total_processed = kwargs["total_processed"]
if "error_count" in kwargs:
job.error_count = kwargs["error_count"]
if "error_message" in kwargs:
job.error_message = kwargs["error_message"]
if "started_at" in kwargs:
job.started_at = kwargs["started_at"]
if "completed_at" in kwargs:
job.completed_at = kwargs["completed_at"]
db.commit()
db.refresh(job)
@@ -145,7 +161,9 @@ class MarketplaceService:
total_jobs = query.count()
pending_jobs = query.filter(MarketplaceImportJob.status == "pending").count()
running_jobs = query.filter(MarketplaceImportJob.status == "running").count()
completed_jobs = query.filter(MarketplaceImportJob.status == "completed").count()
completed_jobs = query.filter(
MarketplaceImportJob.status == "completed"
).count()
failed_jobs = query.filter(MarketplaceImportJob.status == "failed").count()
return {
@@ -153,17 +171,21 @@ class MarketplaceService:
"pending_jobs": pending_jobs,
"running_jobs": running_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs
"failed_jobs": failed_jobs,
}
def convert_to_response_model(self, job: MarketplaceImportJob) -> MarketplaceImportJobResponse:
def convert_to_response_model(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to API response model"""
return MarketplaceImportJobResponse(
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
shop_id=job.shop_id,
shop_code=job.shop.shop_code if job.shop else None, # Add this optional field via relationship
shop_code=(
job.shop.shop_code if job.shop else None
), # Add this optional field via relationship
shop_name=job.shop_name,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
@@ -172,10 +194,12 @@ class MarketplaceService:
error_message=job.error_message,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at
completed_at=job.completed_at,
)
def cancel_import_job(self, db: Session, job_id: int, user: User) -> MarketplaceImportJob:
def cancel_import_job(
self, db: Session, job_id: int, user: User
) -> MarketplaceImportJob:
"""Cancel a pending or running import job"""
job = self.get_import_job_by_id(db, job_id, user)
@@ -197,7 +221,9 @@ class MarketplaceService:
# Only allow deletion of completed, failed, or cancelled jobs
if job.status in ["pending", "running"]:
raise ValueError(f"Cannot delete job with status: {job.status}. Cancel it first.")
raise ValueError(
f"Cannot delete job with status: {job.status}. Cancel it first."
)
db.delete(job)
db.commit()

View File

@@ -1,11 +1,14 @@
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from models.database_models import Product, Stock
from models.api_models import ProductCreate, ProductUpdate, StockLocationResponse, StockSummaryResponse
from utils.data_processing import GTINProcessor, PriceProcessor
from typing import Optional, List, Generator
from datetime import datetime
import logging
from datetime import datetime
from typing import Generator, List, Optional
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from models.api_models import (ProductCreate, ProductUpdate,
StockLocationResponse, StockSummaryResponse)
from models.database_models import Product, Stock
from utils.data_processing import GTINProcessor, PriceProcessor
logger = logging.getLogger(__name__)
@@ -27,7 +30,9 @@ class ProductService:
# Process price if provided
if product_data.price:
parsed_price, currency = self.price_processor.parse_price_currency(product_data.price)
parsed_price, currency = self.price_processor.parse_price_currency(
product_data.price
)
if parsed_price:
product_data.price = parsed_price
product_data.currency = currency
@@ -58,16 +63,16 @@ class ProductService:
return db.query(Product).filter(Product.product_id == product_id).first()
def get_products_with_filters(
self,
db: Session,
skip: int = 0,
limit: int = 100,
brand: Optional[str] = None,
category: Optional[str] = None,
availability: Optional[str] = None,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
search: Optional[str] = None
self,
db: Session,
skip: int = 0,
limit: int = 100,
brand: Optional[str] = None,
category: Optional[str] = None,
availability: Optional[str] = None,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
search: Optional[str] = None,
) -> tuple[List[Product], int]:
"""Get products with filtering and pagination"""
query = db.query(Product)
@@ -87,10 +92,10 @@ class ProductService:
# Search in title, description, marketplace, and shop_name
search_term = f"%{search}%"
query = query.filter(
(Product.title.ilike(search_term)) |
(Product.description.ilike(search_term)) |
(Product.marketplace.ilike(search_term)) |
(Product.shop_name.ilike(search_term))
(Product.title.ilike(search_term))
| (Product.description.ilike(search_term))
| (Product.marketplace.ilike(search_term))
| (Product.shop_name.ilike(search_term))
)
total = query.count()
@@ -98,7 +103,9 @@ class ProductService:
return products, total
def update_product(self, db: Session, product_id: str, product_update: ProductUpdate) -> Product:
def update_product(
self, db: Session, product_id: str, product_update: ProductUpdate
) -> Product:
"""Update product with validation"""
product = db.query(Product).filter(Product.product_id == product_id).first()
if not product:
@@ -116,7 +123,9 @@ class ProductService:
# Process price if being updated
if "price" in update_data and update_data["price"]:
parsed_price, currency = self.price_processor.parse_price_currency(update_data["price"])
parsed_price, currency = self.price_processor.parse_price_currency(
update_data["price"]
)
if parsed_price:
update_data["price"] = parsed_price
update_data["currency"] = currency
@@ -160,21 +169,21 @@ class ProductService:
]
return StockSummaryResponse(
gtin=gtin,
total_quantity=total_quantity,
locations=locations
gtin=gtin, total_quantity=total_quantity, locations=locations
)
def generate_csv_export(
self,
db: Session,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None
self,
db: Session,
marketplace: Optional[str] = None,
shop_name: Optional[str] = None,
) -> Generator[str, None, None]:
"""Generate CSV export with streaming for memory efficiency"""
# CSV header
yield ("product_id,title,description,link,image_link,availability,price,currency,brand,"
"gtin,marketplace,shop_name\n")
yield (
"product_id,title,description,link,image_link,availability,price,currency,brand,"
"gtin,marketplace,shop_name\n"
)
batch_size = 1000
offset = 0
@@ -194,17 +203,22 @@ class ProductService:
for product in products:
# Create CSV row with marketplace fields
row = (f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n')
row = (
f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n'
)
yield row
offset += batch_size
def product_exists(self, db: Session, product_id: str) -> bool:
"""Check if product exists by ID"""
return db.query(Product).filter(Product.product_id == product_id).first() is not None
return (
db.query(Product).filter(Product.product_id == product_id).first()
is not None
)
# Create service instance

View File

@@ -1,12 +1,13 @@
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy.orm import Session
from fastapi import HTTPException
from datetime import datetime
import logging
from typing import List, Optional, Tuple, Dict, Any
from models.database_models import User, Shop, Product, ShopProduct
from models.api_models import ShopCreate, ShopProductCreate
from models.database_models import Product, Shop, ShopProduct, User
logger = logging.getLogger(__name__)
@@ -14,7 +15,9 @@ logger = logging.getLogger(__name__)
class ShopService:
"""Service class for shop operations following the application's service pattern"""
def create_shop(self, db: Session, shop_data: ShopCreate, current_user: User) -> Shop:
def create_shop(
self, db: Session, shop_data: ShopCreate, current_user: User
) -> Shop:
"""
Create a new shop
@@ -33,39 +36,43 @@ class ShopService:
normalized_shop_code = shop_data.shop_code.upper()
# Check if shop code already exists (case-insensitive check against existing data)
existing_shop = db.query(Shop).filter(
func.upper(Shop.shop_code) == normalized_shop_code
).first()
existing_shop = (
db.query(Shop)
.filter(func.upper(Shop.shop_code) == normalized_shop_code)
.first()
)
if existing_shop:
raise HTTPException(status_code=400, detail="Shop code already exists")
# Create shop with uppercase code
shop_dict = shop_data.model_dump() # Fixed deprecated .dict() method
shop_dict['shop_code'] = normalized_shop_code # Store as uppercase
shop_dict["shop_code"] = normalized_shop_code # Store as uppercase
new_shop = Shop(
**shop_dict,
owner_id=current_user.id,
is_active=True,
is_verified=(current_user.role == "admin")
is_verified=(current_user.role == "admin"),
)
db.add(new_shop)
db.commit()
db.refresh(new_shop)
logger.info(f"New shop created: {new_shop.shop_code} by {current_user.username}")
logger.info(
f"New shop created: {new_shop.shop_code} by {current_user.username}"
)
return new_shop
def get_shops(
self,
db: Session,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
verified_only: bool = False
self,
db: Session,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
verified_only: bool = False,
) -> Tuple[List[Shop], int]:
"""
Get shops with filtering
@@ -86,8 +93,8 @@ class ShopService:
# Non-admin users can only see active and verified shops, plus their own
if current_user.role != "admin":
query = query.filter(
(Shop.is_active == True) &
((Shop.is_verified == True) | (Shop.owner_id == current_user.id))
(Shop.is_active == True)
& ((Shop.is_verified == True) | (Shop.owner_id == current_user.id))
)
else:
# Admin can apply filters
@@ -117,22 +124,25 @@ class ShopService:
HTTPException: If shop not found or access denied
"""
# Explicit type hint to help type checker shop: Optional[Shop]
shop: Optional[Shop] = db.query(Shop).filter(func.upper(Shop.shop_code) == shop_code.upper()).first()
shop: Optional[Shop] = (
db.query(Shop)
.filter(func.upper(Shop.shop_code) == shop_code.upper())
.first()
)
if not shop:
raise HTTPException(status_code=404, detail="Shop not found")
# Non-admin users can only see active verified shops or their own shops
if current_user.role != "admin":
if not shop.is_active or (not shop.is_verified and shop.owner_id != current_user.id):
if not shop.is_active or (
not shop.is_verified and shop.owner_id != current_user.id
):
raise HTTPException(status_code=404, detail="Shop not found")
return shop
def add_product_to_shop(
self,
db: Session,
shop: Shop,
shop_product: ShopProductCreate
self, db: Session, shop: Shop, shop_product: ShopProductCreate
) -> ShopProduct:
"""
Add existing product to shop catalog with shop-specific settings
@@ -149,24 +159,35 @@ class ShopService:
HTTPException: If product not found or already in shop
"""
# Check if product exists
product = db.query(Product).filter(Product.product_id == shop_product.product_id).first()
product = (
db.query(Product)
.filter(Product.product_id == shop_product.product_id)
.first()
)
if not product:
raise HTTPException(status_code=404, detail="Product not found in marketplace catalog")
raise HTTPException(
status_code=404, detail="Product not found in marketplace catalog"
)
# Check if product already in shop
existing_shop_product = db.query(ShopProduct).filter(
ShopProduct.shop_id == shop.id,
ShopProduct.product_id == product.id
).first()
existing_shop_product = (
db.query(ShopProduct)
.filter(
ShopProduct.shop_id == shop.id, ShopProduct.product_id == product.id
)
.first()
)
if existing_shop_product:
raise HTTPException(status_code=400, detail="Product already in shop catalog")
raise HTTPException(
status_code=400, detail="Product already in shop catalog"
)
# Create shop-product association
new_shop_product = ShopProduct(
shop_id=shop.id,
product_id=product.id,
**shop_product.model_dump(exclude={'product_id'})
**shop_product.model_dump(exclude={"product_id"}),
)
db.add(new_shop_product)
@@ -180,14 +201,14 @@ class ShopService:
return new_shop_product
def get_shop_products(
self,
db: Session,
shop: Shop,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
featured_only: bool = False
self,
db: Session,
shop: Shop,
current_user: User,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
featured_only: bool = False,
) -> Tuple[List[ShopProduct], int]:
"""
Get products in shop catalog with filtering
@@ -239,10 +260,14 @@ class ShopService:
def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool:
"""Check if product is already in shop"""
return db.query(ShopProduct).filter(
ShopProduct.shop_id == shop_id,
ShopProduct.product_id == product_id
).first() is not None
return (
db.query(ShopProduct)
.filter(
ShopProduct.shop_id == shop_id, ShopProduct.product_id == product_id
)
.first()
is not None
)
def is_shop_owner(self, shop: Shop, user: User) -> bool:
"""Check if user is shop owner"""

View File

@@ -1,10 +1,11 @@
import logging
from typing import Any, Dict, List
from sqlalchemy import func
from sqlalchemy.orm import Session
import logging
from typing import List, Dict, Any
from models.database_models import User, Product, Stock
from models.api_models import StatsResponse, MarketplaceStatsResponse
from models.api_models import MarketplaceStatsResponse, StatsResponse
from models.database_models import Product, Stock, User
logger = logging.getLogger(__name__)
@@ -25,26 +26,37 @@ class StatsService:
# Use more efficient queries with proper indexes
total_products = db.query(Product).count()
unique_brands = db.query(Product.brand).filter(
Product.brand.isnot(None),
Product.brand != ""
).distinct().count()
unique_brands = (
db.query(Product.brand)
.filter(Product.brand.isnot(None), Product.brand != "")
.distinct()
.count()
)
unique_categories = db.query(Product.google_product_category).filter(
Product.google_product_category.isnot(None),
Product.google_product_category != ""
).distinct().count()
unique_categories = (
db.query(Product.google_product_category)
.filter(
Product.google_product_category.isnot(None),
Product.google_product_category != "",
)
.distinct()
.count()
)
# New marketplace statistics
unique_marketplaces = db.query(Product.marketplace).filter(
Product.marketplace.isnot(None),
Product.marketplace != ""
).distinct().count()
unique_marketplaces = (
db.query(Product.marketplace)
.filter(Product.marketplace.isnot(None), Product.marketplace != "")
.distinct()
.count()
)
unique_shops = db.query(Product.shop_name).filter(
Product.shop_name.isnot(None),
Product.shop_name != ""
).distinct().count()
unique_shops = (
db.query(Product.shop_name)
.filter(Product.shop_name.isnot(None), Product.shop_name != "")
.distinct()
.count()
)
# Stock statistics
total_stock_entries = db.query(Stock).count()
@@ -57,10 +69,12 @@ class StatsService:
"unique_marketplaces": unique_marketplaces,
"unique_shops": unique_shops,
"total_stock_entries": total_stock_entries,
"total_inventory_quantity": total_inventory
"total_inventory_quantity": total_inventory,
}
logger.info(f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces")
logger.info(
f"Generated comprehensive stats: {total_products} products, {unique_marketplaces} marketplaces"
)
return stats_data
def get_marketplace_breakdown_stats(self, db: Session) -> List[Dict[str, Any]]:
@@ -74,25 +88,31 @@ class StatsService:
List of dictionaries containing marketplace statistics
"""
# Query to get stats per marketplace
marketplace_stats = db.query(
Product.marketplace,
func.count(Product.id).label('total_products'),
func.count(func.distinct(Product.shop_name)).label('unique_shops'),
func.count(func.distinct(Product.brand)).label('unique_brands')
).filter(
Product.marketplace.isnot(None)
).group_by(Product.marketplace).all()
marketplace_stats = (
db.query(
Product.marketplace,
func.count(Product.id).label("total_products"),
func.count(func.distinct(Product.shop_name)).label("unique_shops"),
func.count(func.distinct(Product.brand)).label("unique_brands"),
)
.filter(Product.marketplace.isnot(None))
.group_by(Product.marketplace)
.all()
)
stats_list = [
{
"marketplace": stat.marketplace,
"total_products": stat.total_products,
"unique_shops": stat.unique_shops,
"unique_brands": stat.unique_brands
} for stat in marketplace_stats
"unique_brands": stat.unique_brands,
}
for stat in marketplace_stats
]
logger.info(f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces")
logger.info(
f"Generated marketplace breakdown stats for {len(stats_list)} marketplaces"
)
return stats_list
def get_product_count(self, db: Session) -> int:
@@ -101,31 +121,42 @@ class StatsService:
def get_unique_brands_count(self, db: Session) -> int:
"""Get count of unique brands"""
return db.query(Product.brand).filter(
Product.brand.isnot(None),
Product.brand != ""
).distinct().count()
return (
db.query(Product.brand)
.filter(Product.brand.isnot(None), Product.brand != "")
.distinct()
.count()
)
def get_unique_categories_count(self, db: Session) -> int:
"""Get count of unique categories"""
return db.query(Product.google_product_category).filter(
Product.google_product_category.isnot(None),
Product.google_product_category != ""
).distinct().count()
return (
db.query(Product.google_product_category)
.filter(
Product.google_product_category.isnot(None),
Product.google_product_category != "",
)
.distinct()
.count()
)
def get_unique_marketplaces_count(self, db: Session) -> int:
"""Get count of unique marketplaces"""
return db.query(Product.marketplace).filter(
Product.marketplace.isnot(None),
Product.marketplace != ""
).distinct().count()
return (
db.query(Product.marketplace)
.filter(Product.marketplace.isnot(None), Product.marketplace != "")
.distinct()
.count()
)
def get_unique_shops_count(self, db: Session) -> int:
"""Get count of unique shops"""
return db.query(Product.shop_name).filter(
Product.shop_name.isnot(None),
Product.shop_name != ""
).distinct().count()
return (
db.query(Product.shop_name)
.filter(Product.shop_name.isnot(None), Product.shop_name != "")
.distinct()
.count()
)
def get_stock_statistics(self, db: Session) -> Dict[str, int]:
"""
@@ -142,25 +173,35 @@ class StatsService:
return {
"total_stock_entries": total_stock_entries,
"total_inventory_quantity": total_inventory
"total_inventory_quantity": total_inventory,
}
def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
"""Get unique brands for a specific marketplace"""
brands = db.query(Product.brand).filter(
Product.marketplace == marketplace,
Product.brand.isnot(None),
Product.brand != ""
).distinct().all()
brands = (
db.query(Product.brand)
.filter(
Product.marketplace == marketplace,
Product.brand.isnot(None),
Product.brand != "",
)
.distinct()
.all()
)
return [brand[0] for brand in brands]
def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
"""Get unique shops for a specific marketplace"""
shops = db.query(Product.shop_name).filter(
Product.marketplace == marketplace,
Product.shop_name.isnot(None),
Product.shop_name != ""
).distinct().all()
shops = (
db.query(Product.shop_name)
.filter(
Product.marketplace == marketplace,
Product.shop_name.isnot(None),
Product.shop_name != "",
)
.distinct()
.all()
)
return [shop[0] for shop in shops]
def get_products_by_marketplace(self, db: Session, marketplace: str) -> int:

View File

@@ -1,10 +1,13 @@
from sqlalchemy.orm import Session
from models.database_models import Stock, Product
from models.api_models import StockCreate, StockAdd, StockUpdate, StockLocationResponse, StockSummaryResponse
from utils.data_processing import GTINProcessor
from typing import Optional, List, Tuple
from datetime import datetime
import logging
from datetime import datetime
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from models.api_models import (StockAdd, StockCreate, StockLocationResponse,
StockSummaryResponse, StockUpdate)
from models.database_models import Product, Stock
from utils.data_processing import GTINProcessor
logger = logging.getLogger(__name__)
@@ -26,10 +29,11 @@ class StockService:
location = stock_data.location.strip().upper()
# Check if stock entry already exists for this GTIN and location
existing_stock = db.query(Stock).filter(
Stock.gtin == normalized_gtin,
Stock.location == location
).first()
existing_stock = (
db.query(Stock)
.filter(Stock.gtin == normalized_gtin, Stock.location == location)
.first()
)
if existing_stock:
# Update existing stock (SET to exact quantity)
@@ -39,19 +43,20 @@ class StockService:
db.commit()
db.refresh(existing_stock)
logger.info(
f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity}{stock_data.quantity}")
f"Updated stock for GTIN {normalized_gtin} at {location}: {old_quantity}{stock_data.quantity}"
)
return existing_stock
else:
# Create new stock entry
new_stock = Stock(
gtin=normalized_gtin,
location=location,
quantity=stock_data.quantity
gtin=normalized_gtin, location=location, quantity=stock_data.quantity
)
db.add(new_stock)
db.commit()
db.refresh(new_stock)
logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}")
logger.info(
f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}"
)
return new_stock
def add_stock(self, db: Session, stock_data: StockAdd) -> Stock:
@@ -63,10 +68,11 @@ class StockService:
location = stock_data.location.strip().upper()
# Check if stock entry already exists for this GTIN and location
existing_stock = db.query(Stock).filter(
Stock.gtin == normalized_gtin,
Stock.location == location
).first()
existing_stock = (
db.query(Stock)
.filter(Stock.gtin == normalized_gtin, Stock.location == location)
.first()
)
if existing_stock:
# Add to existing stock
@@ -76,19 +82,20 @@ class StockService:
db.commit()
db.refresh(existing_stock)
logger.info(
f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}")
f"Added stock for GTIN {normalized_gtin} at {location}: {old_quantity} + {stock_data.quantity} = {existing_stock.quantity}"
)
return existing_stock
else:
# Create new stock entry with the quantity
new_stock = Stock(
gtin=normalized_gtin,
location=location,
quantity=stock_data.quantity
gtin=normalized_gtin, location=location, quantity=stock_data.quantity
)
db.add(new_stock)
db.commit()
db.refresh(new_stock)
logger.info(f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}")
logger.info(
f"Created new stock for GTIN {normalized_gtin} at {location}: {stock_data.quantity}"
)
return new_stock
def remove_stock(self, db: Session, stock_data: StockAdd) -> Stock:
@@ -100,18 +107,22 @@ class StockService:
location = stock_data.location.strip().upper()
# Find existing stock entry
existing_stock = db.query(Stock).filter(
Stock.gtin == normalized_gtin,
Stock.location == location
).first()
existing_stock = (
db.query(Stock)
.filter(Stock.gtin == normalized_gtin, Stock.location == location)
.first()
)
if not existing_stock:
raise ValueError(f"No stock found for GTIN {normalized_gtin} at location {location}")
raise ValueError(
f"No stock found for GTIN {normalized_gtin} at location {location}"
)
# Check if we have enough stock to remove
if existing_stock.quantity < stock_data.quantity:
raise ValueError(
f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}")
f"Insufficient stock. Available: {existing_stock.quantity}, Requested to remove: {stock_data.quantity}"
)
# Remove from existing stock
old_quantity = existing_stock.quantity
@@ -120,7 +131,8 @@ class StockService:
db.commit()
db.refresh(existing_stock)
logger.info(
f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}")
f"Removed stock for GTIN {normalized_gtin} at {location}: {old_quantity} - {stock_data.quantity} = {existing_stock.quantity}"
)
return existing_stock
def get_stock_by_gtin(self, db: Session, gtin: str) -> StockSummaryResponse:
@@ -141,10 +153,9 @@ class StockService:
for entry in stock_entries:
total_quantity += entry.quantity
locations.append(StockLocationResponse(
location=entry.location,
quantity=entry.quantity
))
locations.append(
StockLocationResponse(location=entry.location, quantity=entry.quantity)
)
# Try to get product title for reference
product = db.query(Product).filter(Product.gtin == normalized_gtin).first()
@@ -154,7 +165,7 @@ class StockService:
gtin=normalized_gtin,
total_quantity=total_quantity,
locations=locations,
product_title=product_title
product_title=product_title,
)
def get_total_stock(self, db: Session, gtin: str) -> dict:
@@ -174,16 +185,16 @@ class StockService:
"gtin": normalized_gtin,
"total_quantity": total_quantity,
"product_title": product.title if product else None,
"locations_count": len(total_stock)
"locations_count": len(total_stock),
}
def get_all_stock(
self,
db: Session,
skip: int = 0,
limit: int = 100,
location: Optional[str] = None,
gtin: Optional[str] = None
self,
db: Session,
skip: int = 0,
limit: int = 100,
location: Optional[str] = None,
gtin: Optional[str] = None,
) -> List[Stock]:
"""Get all stock entries with optional filtering"""
query = db.query(Stock)
@@ -198,7 +209,9 @@ class StockService:
return query.offset(skip).limit(limit).all()
def update_stock(self, db: Session, stock_id: int, stock_update: StockUpdate) -> Stock:
def update_stock(
self, db: Session, stock_id: int, stock_update: StockUpdate
) -> Stock:
"""Update stock quantity for a specific stock entry"""
stock_entry = db.query(Stock).filter(Stock.id == stock_id).first()
if not stock_entry:
@@ -209,7 +222,9 @@ class StockService:
db.commit()
db.refresh(stock_entry)
logger.info(f"Updated stock entry {stock_id} to quantity {stock_update.quantity}")
logger.info(
f"Updated stock entry {stock_id} to quantity {stock_update.quantity}"
)
return stock_entry
def delete_stock(self, db: Session, stock_id: int) -> bool: