test updates to take into account exception management
This commit is contained in:
@@ -27,6 +27,34 @@ router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/product/export-csv")
|
||||
async def export_csv(
|
||||
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
||||
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Export products as CSV with streaming and marketplace filtering (Protected)."""
|
||||
|
||||
def generate_csv():
|
||||
return product_service.generate_csv_export(
|
||||
db=db, marketplace=marketplace, shop_name=shop_name
|
||||
)
|
||||
|
||||
filename = "products_export"
|
||||
if marketplace:
|
||||
filename += f"_{marketplace}"
|
||||
if shop_name:
|
||||
filename += f"_{shop_name}"
|
||||
filename += ".csv"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_csv(),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/product", response_model=ProductListResponse)
|
||||
def get_products(
|
||||
skip: int = Query(0, ge=0),
|
||||
@@ -112,30 +140,3 @@ def delete_product(
|
||||
product_service.delete_product(db, product_id)
|
||||
return {"message": "Product and associated stock deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/product/export-csv")
|
||||
async def export_csv(
|
||||
marketplace: Optional[str] = Query(None, description="Filter by marketplace"),
|
||||
shop_name: Optional[str] = Query(None, description="Filter by shop name"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Export products as CSV with streaming and marketplace filtering (Protected)."""
|
||||
|
||||
def generate_csv():
|
||||
return product_service.generate_csv_export(
|
||||
db=db, marketplace=marketplace, shop_name=shop_name
|
||||
)
|
||||
|
||||
filename = "products_export"
|
||||
if marketplace:
|
||||
filename += f"_{marketplace}"
|
||||
if shop_name:
|
||||
filename += f"_{shop_name}"
|
||||
filename += ".csv"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_csv(),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
|
||||
@@ -90,7 +90,10 @@ def setup_exception_handlers(app):
|
||||
for error in exc.errors():
|
||||
clean_error = {}
|
||||
for key, value in error.items():
|
||||
if key == 'ctx' and isinstance(value, dict):
|
||||
if key == 'input' and isinstance(value, bytes):
|
||||
# Convert bytes to string representation for JSON serialization
|
||||
clean_error[key] = f"<bytes: {len(value)} bytes>"
|
||||
elif key == 'ctx' and isinstance(value, dict):
|
||||
# Handle the 'ctx' field that contains ValueError objects
|
||||
clean_ctx = {}
|
||||
for ctx_key, ctx_value in value.items():
|
||||
@@ -99,6 +102,9 @@ def setup_exception_handlers(app):
|
||||
else:
|
||||
clean_ctx[ctx_key] = ctx_value
|
||||
clean_error[key] = clean_ctx
|
||||
elif isinstance(value, bytes):
|
||||
# Handle any other bytes objects
|
||||
clean_error[key] = f"<bytes: {len(value)} bytes>"
|
||||
else:
|
||||
clean_error[key] = value
|
||||
clean_errors.append(clean_error)
|
||||
@@ -138,6 +144,23 @@ def setup_exception_handlers(app):
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(404)
|
||||
async def not_found_handler(request: Request, exc):
|
||||
"""Handle all 404 errors with consistent format."""
|
||||
logger.warning(f"404 Not Found: {request.method} {request.url}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "ENDPOINT_NOT_FOUND",
|
||||
"message": f"Endpoint not found: {request.url.path}",
|
||||
"status_code": 404,
|
||||
"details": {
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Utility functions for common exception scenarios
|
||||
def raise_not_found(resource_type: str, identifier: str) -> None:
|
||||
|
||||
@@ -337,27 +337,6 @@ class AdminService:
|
||||
completed_at=job.completed_at,
|
||||
)
|
||||
|
||||
# Legacy methods for backward compatibility (mark as deprecated)
|
||||
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:
|
||||
"""Get user by ID. DEPRECATED: Use _get_user_by_id_or_raise instead."""
|
||||
logger.warning("get_user_by_id is deprecated, use proper exception handling")
|
||||
return db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
def get_shop_by_id(self, db: Session, shop_id: int) -> Optional[Shop]:
|
||||
"""Get shop by ID. DEPRECATED: Use _get_shop_by_id_or_raise instead."""
|
||||
logger.warning("get_shop_by_id is deprecated, use proper exception handling")
|
||||
return db.query(Shop).filter(Shop.id == shop_id).first()
|
||||
|
||||
def user_exists(self, db: Session, user_id: int) -> bool:
|
||||
"""Check if user exists by ID. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("user_exists is deprecated, use proper exception handling")
|
||||
return db.query(User).filter(User.id == user_id).first() is not None
|
||||
|
||||
def shop_exists(self, db: Session, shop_id: int) -> bool:
|
||||
"""Check if shop exists by ID. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("shop_exists is deprecated, use proper exception handling")
|
||||
return db.query(Shop).filter(Shop.id == shop_id).first() is not None
|
||||
|
||||
|
||||
# Create service instance following the same pattern as product_service
|
||||
admin_service = AdminService()
|
||||
|
||||
@@ -170,16 +170,6 @@ class AuthService:
|
||||
"""Check if username already exists."""
|
||||
return db.query(User).filter(User.username == username).first() is not None
|
||||
|
||||
# Legacy methods for backward compatibility (deprecated)
|
||||
def email_exists(self, db: Session, email: str) -> bool:
|
||||
"""Check if email already exists. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("email_exists is deprecated, use proper exception handling")
|
||||
return self._email_exists(db, email)
|
||||
|
||||
def username_exists(self, db: Session, username: str) -> bool:
|
||||
"""Check if username already exists. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("username_exists is deprecated, use proper exception handling")
|
||||
return self._username_exists(db, username)
|
||||
|
||||
|
||||
# Create service instance following the same pattern as other services
|
||||
|
||||
@@ -8,9 +8,10 @@ This module provides classes and functions for:
|
||||
- Stock information integration
|
||||
- CSV export functionality
|
||||
"""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from io import StringIO
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -41,21 +42,7 @@ class ProductService:
|
||||
self.price_processor = PriceProcessor()
|
||||
|
||||
def create_product(self, db: Session, product_data: ProductCreate) -> Product:
|
||||
"""
|
||||
Create a new product with validation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
product_data: Product creation data
|
||||
|
||||
Returns:
|
||||
Created Product object
|
||||
|
||||
Raises:
|
||||
ProductAlreadyExistsException: If product with ID already exists
|
||||
InvalidProductDataException: If product data is invalid
|
||||
ProductValidationException: If validation fails
|
||||
"""
|
||||
"""Create a new product with validation."""
|
||||
try:
|
||||
# Process and validate GTIN if provided
|
||||
if product_data.gtin:
|
||||
@@ -73,8 +60,9 @@ class ProductService:
|
||||
if parsed_price:
|
||||
product_data.price = parsed_price
|
||||
product_data.currency = currency
|
||||
except Exception as e:
|
||||
raise InvalidProductDataException(f"Invalid price format: {str(e)}", field="price")
|
||||
except ValueError as e:
|
||||
# Convert ValueError to domain-specific exception
|
||||
raise InvalidProductDataException(str(e), field="price")
|
||||
|
||||
# Set default marketplace if not provided
|
||||
if not product_data.marketplace:
|
||||
@@ -199,25 +187,8 @@ class ProductService:
|
||||
logger.error(f"Error getting products with filters: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve products")
|
||||
|
||||
def update_product(
|
||||
self, db: Session, product_id: str, product_update: ProductUpdate
|
||||
) -> Product:
|
||||
"""
|
||||
Update product with validation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
product_id: Product ID to update
|
||||
product_update: Update data
|
||||
|
||||
Returns:
|
||||
Updated Product object
|
||||
|
||||
Raises:
|
||||
ProductNotFoundException: If product doesn't exist
|
||||
InvalidProductDataException: If update data is invalid
|
||||
ProductValidationException: If validation fails
|
||||
"""
|
||||
def update_product(self, db: Session, product_id: str, product_update: ProductUpdate) -> Product:
|
||||
"""Update product with validation."""
|
||||
try:
|
||||
product = self.get_product_by_id_or_raise(db, product_id)
|
||||
|
||||
@@ -240,8 +211,9 @@ class ProductService:
|
||||
if parsed_price:
|
||||
update_data["price"] = parsed_price
|
||||
update_data["currency"] = currency
|
||||
except Exception as e:
|
||||
raise InvalidProductDataException(f"Invalid price format: {str(e)}", field="price")
|
||||
except ValueError as e:
|
||||
# Convert ValueError to domain-specific exception
|
||||
raise InvalidProductDataException(str(e), field="price")
|
||||
|
||||
# Validate required fields if being updated
|
||||
if "title" in update_data and (not update_data["title"] or not update_data["title"].strip()):
|
||||
@@ -329,6 +301,11 @@ class ProductService:
|
||||
logger.error(f"Error getting stock info for GTIN {gtin}: {str(e)}")
|
||||
return None
|
||||
|
||||
import csv
|
||||
from io import StringIO
|
||||
from typing import Generator, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
def generate_csv_export(
|
||||
self,
|
||||
db: Session,
|
||||
@@ -336,7 +313,7 @@ class ProductService:
|
||||
shop_name: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generate CSV export with streaming for memory efficiency.
|
||||
Generate CSV export with streaming for memory efficiency and proper CSV escaping.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
@@ -344,14 +321,25 @@ class ProductService:
|
||||
shop_name: Optional shop name filter
|
||||
|
||||
Yields:
|
||||
CSV content as strings
|
||||
CSV content as strings with proper escaping
|
||||
"""
|
||||
try:
|
||||
# CSV header
|
||||
yield (
|
||||
"product_id,title,description,link,image_link,availability,price,currency,brand,"
|
||||
"gtin,marketplace,shop_name\n"
|
||||
)
|
||||
# Create a StringIO buffer for CSV writing
|
||||
output = StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
|
||||
# Write header row
|
||||
headers = [
|
||||
"product_id", "title", "description", "link", "image_link",
|
||||
"availability", "price", "currency", "brand", "gtin",
|
||||
"marketplace", "shop_name"
|
||||
]
|
||||
writer.writerow(headers)
|
||||
yield output.getvalue()
|
||||
|
||||
# Clear buffer for reuse
|
||||
output.seek(0)
|
||||
output.truncate(0)
|
||||
|
||||
batch_size = 1000
|
||||
offset = 0
|
||||
@@ -370,14 +358,28 @@ class ProductService:
|
||||
break
|
||||
|
||||
for product in products:
|
||||
# Create CSV row with marketplace fields
|
||||
row = (
|
||||
f'"{product.product_id}","{product.title or ""}","{product.description or ""}",'
|
||||
f'"{product.link or ""}","{product.image_link or ""}","{product.availability or ""}",'
|
||||
f'"{product.price or ""}","{product.currency or ""}","{product.brand or ""}",'
|
||||
f'"{product.gtin or ""}","{product.marketplace or ""}","{product.shop_name or ""}"\n'
|
||||
)
|
||||
yield row
|
||||
# Create CSV row with proper escaping
|
||||
row_data = [
|
||||
product.product_id or "",
|
||||
product.title or "",
|
||||
product.description or "",
|
||||
product.link or "",
|
||||
product.image_link or "",
|
||||
product.availability or "",
|
||||
product.price or "",
|
||||
product.currency or "",
|
||||
product.brand or "",
|
||||
product.gtin or "",
|
||||
product.marketplace or "",
|
||||
product.shop_name or "",
|
||||
]
|
||||
|
||||
writer.writerow(row_data)
|
||||
yield output.getvalue()
|
||||
|
||||
# Clear buffer for next row
|
||||
output.seek(0)
|
||||
output.truncate(0)
|
||||
|
||||
offset += batch_size
|
||||
|
||||
|
||||
@@ -353,45 +353,5 @@ class ShopService:
|
||||
"""Check if user is shop owner."""
|
||||
return shop.owner_id == user.id
|
||||
|
||||
# Legacy methods for backward compatibility (deprecated)
|
||||
def get_shop_by_id(self, db: Session, shop_id: int) -> Optional[Shop]:
|
||||
"""Get shop by ID. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("get_shop_by_id is deprecated, use proper exception handling")
|
||||
try:
|
||||
return db.query(Shop).filter(Shop.id == shop_id).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting shop by ID: {str(e)}")
|
||||
return None
|
||||
|
||||
def shop_code_exists(self, db: Session, shop_code: str) -> bool:
|
||||
"""Check if shop code exists. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("shop_code_exists is deprecated, use proper exception handling")
|
||||
return self._shop_code_exists(db, shop_code)
|
||||
|
||||
def get_product_by_id(self, db: Session, product_id: str) -> Optional[Product]:
|
||||
"""Get product by ID. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("get_product_by_id is deprecated, use proper exception handling")
|
||||
try:
|
||||
return db.query(Product).filter(Product.product_id == product_id).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting product by ID: {str(e)}")
|
||||
return None
|
||||
|
||||
def product_in_shop(self, db: Session, shop_id: int, product_id: int) -> bool:
|
||||
"""Check if product in shop. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("product_in_shop is deprecated, use proper exception handling")
|
||||
return self._product_in_shop(db, shop_id, product_id)
|
||||
|
||||
def is_shop_owner(self, shop: Shop, user: User) -> bool:
|
||||
"""Check if user is shop owner. DEPRECATED: Use _is_shop_owner."""
|
||||
logger.warning("is_shop_owner is deprecated, use _is_shop_owner")
|
||||
return self._is_shop_owner(shop, user)
|
||||
|
||||
def can_view_shop(self, shop: Shop, user: User) -> bool:
|
||||
"""Check if user can view shop. DEPRECATED: Use _can_access_shop."""
|
||||
logger.warning("can_view_shop is deprecated, use _can_access_shop")
|
||||
return self._can_access_shop(shop, user)
|
||||
|
||||
|
||||
# Create service instance following the same pattern as other services
|
||||
shop_service = ShopService()
|
||||
|
||||
@@ -294,47 +294,5 @@ class StatsService:
|
||||
"""Get product count for a specific marketplace."""
|
||||
return db.query(Product).filter(Product.marketplace == marketplace).count()
|
||||
|
||||
# Legacy methods for backward compatibility (deprecated)
|
||||
def get_product_count(self, db: Session) -> int:
|
||||
"""Get total product count. DEPRECATED: Use _get_product_count."""
|
||||
logger.warning("get_product_count is deprecated, use _get_product_count")
|
||||
return self._get_product_count(db)
|
||||
|
||||
def get_unique_brands_count(self, db: Session) -> int:
|
||||
"""Get unique brands count. DEPRECATED: Use _get_unique_brands_count."""
|
||||
logger.warning("get_unique_brands_count is deprecated, use _get_unique_brands_count")
|
||||
return self._get_unique_brands_count(db)
|
||||
|
||||
def get_unique_categories_count(self, db: Session) -> int:
|
||||
"""Get unique categories count. DEPRECATED: Use _get_unique_categories_count."""
|
||||
logger.warning("get_unique_categories_count is deprecated, use _get_unique_categories_count")
|
||||
return self._get_unique_categories_count(db)
|
||||
|
||||
def get_unique_marketplaces_count(self, db: Session) -> int:
|
||||
"""Get unique marketplaces count. DEPRECATED: Use _get_unique_marketplaces_count."""
|
||||
logger.warning("get_unique_marketplaces_count is deprecated, use _get_unique_marketplaces_count")
|
||||
return self._get_unique_marketplaces_count(db)
|
||||
|
||||
def get_unique_shops_count(self, db: Session) -> int:
|
||||
"""Get unique shops count. DEPRECATED: Use _get_unique_shops_count."""
|
||||
logger.warning("get_unique_shops_count is deprecated, use _get_unique_shops_count")
|
||||
return self._get_unique_shops_count(db)
|
||||
|
||||
def get_brands_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
|
||||
"""Get brands by marketplace. DEPRECATED: Use _get_brands_by_marketplace."""
|
||||
logger.warning("get_brands_by_marketplace is deprecated, use _get_brands_by_marketplace")
|
||||
return self._get_brands_by_marketplace(db, marketplace)
|
||||
|
||||
def get_shops_by_marketplace(self, db: Session, marketplace: str) -> List[str]:
|
||||
"""Get shops by marketplace. DEPRECATED: Use _get_shops_by_marketplace."""
|
||||
logger.warning("get_shops_by_marketplace is deprecated, use _get_shops_by_marketplace")
|
||||
return self._get_shops_by_marketplace(db, marketplace)
|
||||
|
||||
def get_products_by_marketplace(self, db: Session, marketplace: str) -> int:
|
||||
"""Get products by marketplace. DEPRECATED: Use _get_products_by_marketplace_count."""
|
||||
logger.warning("get_products_by_marketplace is deprecated, use _get_products_by_marketplace_count")
|
||||
return self._get_products_by_marketplace_count(db, marketplace)
|
||||
|
||||
|
||||
# Create service instance following the same pattern as other services
|
||||
stats_service = StatsService()
|
||||
|
||||
@@ -217,7 +217,8 @@ class StockService:
|
||||
)
|
||||
return existing_stock
|
||||
|
||||
except (StockNotFoundException, InsufficientStockException, InvalidQuantityException, NegativeStockException):
|
||||
except (StockValidationException, StockNotFoundException, InsufficientStockException, InvalidQuantityException,
|
||||
NegativeStockException):
|
||||
db.rollback()
|
||||
raise # Re-raise custom exceptions
|
||||
except Exception as e:
|
||||
@@ -564,21 +565,6 @@ class StockService:
|
||||
raise StockNotFoundException(str(stock_id))
|
||||
return stock_entry
|
||||
|
||||
# Legacy methods for backward compatibility (deprecated)
|
||||
def normalize_gtin(self, gtin_value) -> Optional[str]:
|
||||
"""Normalize GTIN format. DEPRECATED: Use _normalize_gtin."""
|
||||
logger.warning("normalize_gtin is deprecated, use _normalize_gtin")
|
||||
return self._normalize_gtin(gtin_value)
|
||||
|
||||
def get_stock_by_id(self, db: Session, stock_id: int) -> Optional[Stock]:
|
||||
"""Get stock by ID. DEPRECATED: Use proper exception handling."""
|
||||
logger.warning("get_stock_by_id is deprecated, use proper exception handling")
|
||||
try:
|
||||
return db.query(Stock).filter(Stock.id == stock_id).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting stock by ID: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
# Create service instance
|
||||
stock_service = StockService()
|
||||
|
||||
@@ -103,7 +103,8 @@ class PriceProcessor:
|
||||
"""
|
||||
Parse price string into (price, currency) tuple.
|
||||
|
||||
Returns (None, None) if parsing fails
|
||||
Raises ValueError if parsing fails for non-empty input.
|
||||
Returns (None, None) for empty/null input.
|
||||
"""
|
||||
if not price_str or pd.isna(price_str):
|
||||
return None, None
|
||||
@@ -136,5 +137,6 @@ class PriceProcessor:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
logger.warning(f"Could not parse price: '{price_str}'")
|
||||
return price_str, None
|
||||
# If we get here, parsing failed completely
|
||||
logger.error(f"Could not parse price: '{price_str}'")
|
||||
raise ValueError(f"Invalid price format: '{price_str}'")
|
||||
|
||||
Reference in New Issue
Block a user