test updates to take into account exception management

This commit is contained in:
2025-09-27 13:47:36 +02:00
parent 3e720212d9
commit 6b9817f179
38 changed files with 2951 additions and 871 deletions

View File

@@ -27,6 +27,34 @@ router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/product/export-csv")
async def export_csv(
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Export products as CSV with streaming and marketplace filtering (Protected)."""
def generate_csv():
return product_service.generate_csv_export(
db=db, marketplace=marketplace, shop_name=shop_name
)
filename = "products_export"
if marketplace:
filename += f"_{marketplace}"
if shop_name:
filename += f"_{shop_name}"
filename += ".csv"
return StreamingResponse(
generate_csv(),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"},
)
@router.get("/product", response_model=ProductListResponse)
def get_products(
skip: int = Query(0, ge=0),
@@ -112,30 +140,3 @@ def delete_product(
product_service.delete_product(db, product_id)
return {"message": "Product and associated stock deleted successfully"}
@router.get("/product/export-csv")
async def export_csv(
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Export products as CSV with streaming and marketplace filtering (Protected)."""
def generate_csv():
return product_service.generate_csv_export(
db=db, marketplace=marketplace, shop_name=shop_name
)
filename = "products_export"
if marketplace:
filename += f"_{marketplace}"
if shop_name:
filename += f"_{shop_name}"
filename += ".csv"
return StreamingResponse(
generate_csv(),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"},
)

View File

@@ -90,7 +90,10 @@ def setup_exception_handlers(app):
for error in exc.errors():
clean_error = {}
for key, value in error.items():
if key == 'ctx' and isinstance(value, dict):
if key == 'input' and isinstance(value, bytes):
# Convert bytes to string representation for JSON serialization
clean_error[key] = f"<bytes: {len(value)} bytes>"
elif key == 'ctx' and isinstance(value, dict):
# Handle the 'ctx' field that contains ValueError objects
clean_ctx = {}
for ctx_key, ctx_value in value.items():
@@ -99,6 +102,9 @@ def setup_exception_handlers(app):
else:
clean_ctx[ctx_key] = ctx_value
clean_error[key] = clean_ctx
elif isinstance(value, bytes):
# Handle any other bytes objects
clean_error[key] = f"<bytes: {len(value)} bytes>"
else:
clean_error[key] = value
clean_errors.append(clean_error)
@@ -138,6 +144,23 @@ def setup_exception_handlers(app):
}
)
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
"""Handle all 404 errors with consistent format."""
logger.warning(f"404 Not Found: {request.method} {request.url}")
return JSONResponse(
status_code=404,
content={
"error_code": "ENDPOINT_NOT_FOUND",
"message": f"Endpoint not found: {request.url.path}",
"status_code": 404,
"details": {
"path": request.url.path,
"method": request.method
}
}
)
# Utility functions for common exception scenarios
def raise_not_found(resource_type: str, identifier: str) -> None:

View File

@@ -337,27 +337,6 @@ class AdminService:
completed_at=job.completed_at,
)
# Legacy methods for backward compatibility (mark as deprecated)
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:
"""Get user by ID. DEPRECATED: Use _get_user_by_id_or_raise instead."""
logger.warning("get_user_by_id is deprecated, use proper exception handling")
return db.query(User).filter(User.id == user_id).first()
def get_shop_by_id(self, db: Session, shop_id: int) -> Optional[Shop]:
"""Get shop by ID. DEPRECATED: Use _get_shop_by_id_or_raise instead."""
logger.warning("get_shop_by_id is deprecated, use proper exception handling")
return db.query(Shop).filter(Shop.id == shop_id).first()
def user_exists(self, db: Session, user_id: int) -> bool:
"""Check if user exists by ID. DEPRECATED: Use proper exception handling."""
logger.warning("user_exists is deprecated, use proper exception handling")
return db.query(User).filter(User.id == user_id).first() is not None
def shop_exists(self, db: Session, shop_id: int) -> bool:
"""Check if shop exists by ID. DEPRECATED: Use proper exception handling."""
logger.warning("shop_exists is deprecated, use proper exception handling")
return db.query(Shop).filter(Shop.id == shop_id).first() is not None
# Create service instance following the same pattern as product_service
admin_service = AdminService()

View File

@@ -170,16 +170,6 @@ class AuthService:
"""Check if username already exists."""
return db.query(User).filter(User.username == username).first() is not None
# Legacy methods for backward compatibility (deprecated)
def email_exists(self, db: Session, email: str) -> bool:
"""Check if email already exists. DEPRECATED: Use proper exception handling."""
logger.warning("email_exists is deprecated, use proper exception handling")
return self._email_exists(db, email)
def username_exists(self, db: Session, username: str) -> bool:
"""Check if username already exists. DEPRECATED: Use proper exception handling."""
logger.warning("username_exists is deprecated, use proper exception handling")
return self._username_exists(db, username)
# Create service instance following the same pattern as other services

View File

@@ -8,9 +8,10 @@ This module provides classes and functions for:
- Stock information integration
- CSV export functionality
"""
import csv
import logging
from datetime import datetime, timezone
from io import StringIO
from typing import Generator, List, Optional, Tuple
from sqlalchemy.exc import IntegrityError
@@ -41,21 +42,7 @@ class ProductService:
self.price_processor = PriceProcessor()
def create_product(self, db: Session, product_data: ProductCreate) -> Product:
"""
Create a new product with validation.
Args:
db: Database session
product_data: Product creation data
Returns:
Created Product object
Raises:
ProductAlreadyExistsException: If product with ID already exists
InvalidProductDataException: If product data is invalid
ProductValidationException: If validation fails
"""
"""Create a new product with validation."""
try:
# Process and validate GTIN if provided
if product_data.gtin:
@@ -73,8 +60,9 @@ class ProductService:
if parsed_price:
product_data.price = parsed_price
product_data.currency = currency
except Exception as e:
raise InvalidProductDataException(f"Invalid price format: {str(e)}", field="price")
except ValueError as e:
# Convert ValueError to domain-specific exception
raise InvalidProductDataException(str(e), field="price")
# Set default marketplace if not provided
if not product_data.marketplace:
@@ -199,25 +187,8 @@ class ProductService:
logger.error(f"Error getting products with filters: {str(e)}")
raise ValidationException("Failed to retrieve products")
def update_product(
self, db: Session, product_id: str, product_update: ProductUpdate
) -> Product:
"""
Update product with validation.
Args:
db: Database session
product_id: Product ID to update
product_update: Update data
Returns:
Updated Product object
Raises:
ProductNotFoundException: If product doesn't exist
InvalidProductDataException: If update data is invalid
ProductValidationException: If validation fails
"""
def update_product(self, db: Session, product_id: str, product_update: ProductUpdate) -> Product:
"""Update product with validation."""
try:
product = self.get_product_by_id_or_raise(db, product_id)
@@ -240,8 +211,9 @@ class ProductService:
if parsed_price:
update_data["price"] = parsed_price
update_data["currency"] = currency
except Exception as e:
raise InvalidProductDataException(f"Invalid price format: {str(e)}", field="price")
except ValueError as e:
# Convert ValueError to domain-specific exception
raise InvalidProductDataException(str(e), field="price")
# Validate required fields if being updated
if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()):
@@ -329,6 +301,11 @@ class ProductService:
logger.error(f"Error getting stock info for GTIN {gtin}: {str(e)}")
return None
import csv
from io import StringIO
from typing import Generator, Optional
from sqlalchemy.orm import Session
def generate_csv_export(
self,
db: Session,
@@ -336,7 +313,7 @@ class ProductService:
shop_name: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generate CSV export with streaming for memory efficiency.
Generate CSV export with streaming for memory efficiency and proper CSV escaping.
Args:
db: Database session
@@ -344,14 +321,25 @@ class ProductService:
shop_name: Optional shop name filter
Yields:
CSV content as strings
CSV content as strings with proper escaping
"""
try:
# CSV header
yield (
"product_id,title,description,link,image_link,availability,price,currency,brand,"
"gtin,marketplace,shop_name\n"
)
# Create a StringIO buffer for CSV writing
output = StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
# Write header row
headers = [
"product_id", "title", "description", "link", "image_link",
"availability", "price", "currency", "brand", "gtin",
"marketplace", "shop_name"
]
writer.writerow(headers)
yield output.getvalue()
# Clear buffer for reuse
output.seek(0)
output.truncate(0)
batch_size = 1000
offset = 0
@@ -370,14 +358,28 @@ class ProductService:
break
for product in products:
# Create CSV row with marketplace fields
row = (
f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n'
)
yield row
# Create CSV row with proper escaping
row_data = [
product.product_id or "",
product.title or "",
product.description or "",
product.link or "",
product.image_link or "",
product.availability or "",
product.price or "",
product.currency or "",
product.brand or "",
product.gtin or "",
product.marketplace or "",
product.shop_name or "",
]
writer.writerow(row_data)
yield output.getvalue()
# Clear buffer for next row
output.seek(0)
output.truncate(0)
offset += batch_size

View File

@@ -353,45 +353,5 @@ class ShopService:
"""Check if user is shop owner."""
return shop.owner_id == user.id
# Legacy methods for backward compatibility (deprecated)
def get_shop_by_id(self, db: Session, shop_id: int) -> Optional[Shop]:
"""Get shop by ID. DEPRECATED: Use proper exception handling."""
logger.warning("get_shop_by_id is deprecated, use proper exception handling")
try:
return db.query(Shop).filter(Shop.id == shop_id).first()
except Exception as e:
logger.error(f"Error getting shop by ID: {str(e)}")
return None
def shop_code_exists(self, db: Session, shop_code: str) -> bool:
"""Check if shop code exists. DEPRECATED: Use proper exception handling."""
logger.warning("shop_code_exists is deprecated, use proper exception handling")
return self._shop_code_exists(db, shop_code)
def get_product_by_id(self, db: Session, product_id: str) -> Optional[Product]:
"""Get product by ID. DEPRECATED: Use proper exception handling."""
logger.warning("get_product_by_id is deprecated, use proper exception handling")
try:
return db.query(Product).filter(Product.product_id == product_id).first()
except Exception as e:
logger.error(f"Error getting product by ID: {str(e)}")
return None
def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool:
"""Check if product in shop. DEPRECATED: Use proper exception handling."""
logger.warning("product_in_shop is deprecated, use proper exception handling")
return self._product_in_shop(db, shop_id, product_id)
def is_shop_owner(self, shop: Shop, user: User) -> bool:
"""Check if user is shop owner. DEPRECATED: Use _is_shop_owner."""
logger.warning("is_shop_owner is deprecated, use _is_shop_owner")
return self._is_shop_owner(shop, user)
def can_view_shop(self, shop: Shop, user: User) -> bool:
"""Check if user can view shop. DEPRECATED: Use _can_access_shop."""
logger.warning("can_view_shop is deprecated, use _can_access_shop")
return self._can_access_shop(shop, user)
# Create service instance following the same pattern as other services
shop_service = ShopService()

View File

@@ -294,47 +294,5 @@ class StatsService:
"""Get product count for a specific marketplace."""
return db.query(Product).filter(Product.marketplace == marketplace).count()
# Legacy methods for backward compatibility (deprecated)
def get_product_count(self, db: Session) -> int:
"""Get total product count. DEPRECATED: Use _get_product_count."""
logger.warning("get_product_count is deprecated, use _get_product_count")
return self._get_product_count(db)
def get_unique_brands_count(self, db: Session) -> int:
"""Get unique brands count. DEPRECATED: Use _get_unique_brands_count."""
logger.warning("get_unique_brands_count is deprecated, use _get_unique_brands_count")
return self._get_unique_brands_count(db)
def get_unique_categories_count(self, db: Session) -> int:
"""Get unique categories count. DEPRECATED: Use _get_unique_categories_count."""
logger.warning("get_unique_categories_count is deprecated, use _get_unique_categories_count")
return self._get_unique_categories_count(db)
def get_unique_marketplaces_count(self, db: Session) -> int:
"""Get unique marketplaces count. DEPRECATED: Use _get_unique_marketplaces_count."""
logger.warning("get_unique_marketplaces_count is deprecated, use _get_unique_marketplaces_count")
return self._get_unique_marketplaces_count(db)
def get_unique_shops_count(self, db: Session) -> int:
"""Get unique shops count. DEPRECATED: Use _get_unique_shops_count."""
logger.warning("get_unique_shops_count is deprecated, use _get_unique_shops_count")
return self._get_unique_shops_count(db)
def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
"""Get brands by marketplace. DEPRECATED: Use _get_brands_by_marketplace."""
logger.warning("get_brands_by_marketplace is deprecated, use _get_brands_by_marketplace")
return self._get_brands_by_marketplace(db, marketplace)
def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
"""Get shops by marketplace. DEPRECATED: Use _get_shops_by_marketplace."""
logger.warning("get_shops_by_marketplace is deprecated, use _get_shops_by_marketplace")
return self._get_shops_by_marketplace(db, marketplace)
def get_products_by_marketplace(self, db: Session, marketplace: str) -> int:
"""Get products by marketplace. DEPRECATED: Use _get_products_by_marketplace_count."""
logger.warning("get_products_by_marketplace is deprecated, use _get_products_by_marketplace_count")
return self._get_products_by_marketplace_count(db, marketplace)
# Create service instance following the same pattern as other services
stats_service = StatsService()

View File

@@ -217,7 +217,8 @@ class StockService:
)
return existing_stock
except (StockNotFoundException, InsufficientStockException, InvalidQuantityException, NegativeStockException):
except (StockValidationException, StockNotFoundException, InsufficientStockException, InvalidQuantityException,
NegativeStockException):
db.rollback()
raise # Re-raise custom exceptions
except Exception as e:
@@ -564,21 +565,6 @@ class StockService:
raise StockNotFoundException(str(stock_id))
return stock_entry
# Legacy methods for backward compatibility (deprecated)
def normalize_gtin(self, gtin_value) -> Optional[str]:
"""Normalize GTIN format. DEPRECATED: Use _normalize_gtin."""
logger.warning("normalize_gtin is deprecated, use _normalize_gtin")
return self._normalize_gtin(gtin_value)
def get_stock_by_id(self, db: Session, stock_id: int) -> Optional[Stock]:
"""Get stock by ID. DEPRECATED: Use proper exception handling."""
logger.warning("get_stock_by_id is deprecated, use proper exception handling")
try:
return db.query(Stock).filter(Stock.id == stock_id).first()
except Exception as e:
logger.error(f"Error getting stock by ID: {str(e)}")
return None
# Create service instance
stock_service = StockService()

View File

@@ -103,7 +103,8 @@ class PriceProcessor:
"""
Parse price string into (price, currency) tuple.
Returns (None, None) if parsing fails
Raises ValueError if parsing fails for non-empty input.
Returns (None, None) for empty/null input.
"""
if not price_str or pd.isna(price_str):
return None, None
@@ -136,5 +137,6 @@ class PriceProcessor:
except ValueError:
pass
logger.warning(f"Could not parse price: '{price_str}'")
return price_str, None
# If we get here, parsing failed completely
logger.error(f"Could not parse price: '{price_str}'")
raise ValueError(f"Invalid price format: '{price_str}'")