diff --git a/app/api/v1/platform/letzshop_vendors.py b/app/api/v1/platform/letzshop_vendors.py index 2640a117..f55df483 100644 --- a/app/api/v1/platform/letzshop_vendors.py +++ b/app/api/v1/platform/letzshop_vendors.py @@ -4,18 +4,20 @@ Letzshop vendor lookup API endpoints. Allows potential vendors to find themselves in the Letzshop marketplace and claim their shop during signup. + +All endpoints are public (no authentication required). """ import logging import re from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, Query -from pydantic import BaseModel, HttpUrl +from fastapi import APIRouter, Depends, Query +from pydantic import BaseModel from sqlalchemy.orm import Session from app.core.database import get_db -from models.database.vendor import Vendor +from app.services.platform_signup_service import platform_signup_service router = APIRouter() logger = logging.getLogger(__name__) @@ -101,27 +103,18 @@ def extract_slug_from_url(url_or_slug: str) -> str: return url_or_slug.lower() -def check_if_claimed(db: Session, letzshop_slug: str) -> bool: - """Check if a Letzshop vendor is already claimed.""" - return db.query(Vendor).filter( - Vendor.letzshop_vendor_slug == letzshop_slug, - Vendor.is_active == True, - ).first() is not None - - # ============================================================================= # Endpoints # ============================================================================= -@router.get("/letzshop-vendors", response_model=LetzshopVendorListResponse) +@router.get("/letzshop-vendors", response_model=LetzshopVendorListResponse) # public async def list_letzshop_vendors( search: Annotated[str | None, Query(description="Search by name")] = None, category: Annotated[str | None, Query(description="Filter by category")] = None, city: Annotated[str | None, Query(description="Filter by city")] = None, page: Annotated[int, Query(ge=1)] = 1, limit: Annotated[int, Query(ge=1, le=50)] = 20, - db: Session = Depends(get_db), ) -> LetzshopVendorListResponse: """ List Letzshop vendors (placeholder - will fetch from cache/API). @@ -146,7 +139,7 @@ async def list_letzshop_vendors( ) -@router.post("/letzshop-vendors/lookup", response_model=LetzshopLookupResponse) +@router.post("/letzshop-vendors/lookup", response_model=LetzshopLookupResponse) # public async def lookup_letzshop_vendor( request: LetzshopLookupRequest, db: Session = Depends(get_db), @@ -169,8 +162,8 @@ async def lookup_letzshop_vendor( error="Could not extract vendor slug from URL", ) - # Check if already claimed - is_claimed = check_if_claimed(db, slug) + # Check if already claimed (using service layer) + is_claimed = platform_signup_service.check_vendor_claimed(db, slug) # TODO: Fetch actual vendor info from Letzshop (Phase 4) # For now, return basic info based on the slug @@ -196,7 +189,7 @@ async def lookup_letzshop_vendor( ) -@router.get("/letzshop-vendors/{slug}", response_model=LetzshopVendorInfo) +@router.get("/letzshop-vendors/{slug}", response_model=LetzshopVendorInfo) # public async def get_letzshop_vendor( slug: str, db: Session = Depends(get_db), @@ -207,7 +200,9 @@ async def get_letzshop_vendor( Returns 404 if vendor not found. """ slug = slug.lower() - is_claimed = check_if_claimed(db, slug) + + # Check if claimed (using service layer) + is_claimed = platform_signup_service.check_vendor_claimed(db, slug) # TODO: Fetch actual vendor info from cache/API (Phase 4) # For now, return placeholder based on slug diff --git a/app/api/v1/platform/pricing.py b/app/api/v1/platform/pricing.py index b6893649..e62a5455 100644 --- a/app/api/v1/platform/pricing.py +++ b/app/api/v1/platform/pricing.py @@ -4,6 +4,8 @@ Public pricing API endpoints. Provides subscription tier and add-on product information for the marketing homepage and signup flow. + +All endpoints are public (no authentication required). """ from fastapi import APIRouter, Depends @@ -11,13 +13,9 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from app.core.database import get_db -from models.database.subscription import ( - AddOnProduct, - BillingPeriod, - SubscriptionTier, - TIER_LIMITS, - TierCode, -) +from app.exceptions import ResourceNotFoundException +from app.services.platform_pricing_service import platform_pricing_service +from models.database.subscription import TierCode router = APIRouter() @@ -111,88 +109,13 @@ FEATURE_DESCRIPTIONS = { # ============================================================================= -# Endpoints +# Helper Functions # ============================================================================= -@router.get("/tiers", response_model=list[TierResponse]) -def get_tiers(db: Session = Depends(get_db)) -> list[TierResponse]: - """ - Get all public subscription tiers. - - Returns tiers from database if available, falls back to hardcoded TIER_LIMITS. - """ - # Try to get from database first - db_tiers = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.is_active == True, - SubscriptionTier.is_public == True, - ) - .order_by(SubscriptionTier.display_order) - .all() - ) - - if db_tiers: - return [ - TierResponse( - code=tier.code, - name=tier.name, - description=tier.description, - price_monthly=tier.price_monthly_cents / 100, - price_annual=(tier.price_annual_cents / 100) if tier.price_annual_cents else None, - price_monthly_cents=tier.price_monthly_cents, - price_annual_cents=tier.price_annual_cents, - orders_per_month=tier.orders_per_month, - products_limit=tier.products_limit, - team_members=tier.team_members, - order_history_months=tier.order_history_months, - features=tier.features or [], - is_popular=tier.code == TierCode.PROFESSIONAL.value, - is_enterprise=tier.code == TierCode.ENTERPRISE.value, - ) - for tier in db_tiers - ] - - # Fallback to hardcoded tiers - tiers = [] - for tier_code, limits in TIER_LIMITS.items(): - tiers.append( - TierResponse( - code=tier_code.value, - name=limits["name"], - description=None, - price_monthly=limits["price_monthly_cents"] / 100, - price_annual=(limits["price_annual_cents"] / 100) if limits.get("price_annual_cents") else None, - price_monthly_cents=limits["price_monthly_cents"], - price_annual_cents=limits.get("price_annual_cents"), - orders_per_month=limits.get("orders_per_month"), - products_limit=limits.get("products_limit"), - team_members=limits.get("team_members"), - order_history_months=limits.get("order_history_months"), - features=limits.get("features", []), - is_popular=tier_code == TierCode.PROFESSIONAL, - is_enterprise=tier_code == TierCode.ENTERPRISE, - ) - ) - - return tiers - - -@router.get("/tiers/{tier_code}", response_model=TierResponse) -def get_tier(tier_code: str, db: Session = Depends(get_db)) -> TierResponse: - """Get a specific tier by code.""" - # Try database first - tier = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.code == tier_code, - SubscriptionTier.is_active == True, - ) - .first() - ) - - if tier: +def _tier_to_response(tier, is_from_db: bool = True) -> TierResponse: + """Convert a tier (from DB or hardcoded) to TierResponse.""" + if is_from_db: return TierResponse( code=tier.code, name=tier.name, @@ -209,11 +132,10 @@ def get_tier(tier_code: str, db: Session = Depends(get_db)) -> TierResponse: is_popular=tier.code == TierCode.PROFESSIONAL.value, is_enterprise=tier.code == TierCode.ENTERPRISE.value, ) - - # Fallback to hardcoded - try: - tier_enum = TierCode(tier_code) - limits = TIER_LIMITS[tier_enum] + else: + # Hardcoded tier format + tier_enum = tier["tier_enum"] + limits = tier["limits"] return TierResponse( code=tier_enum.value, name=limits["name"], @@ -230,42 +152,85 @@ def get_tier(tier_code: str, db: Session = Depends(get_db)) -> TierResponse: is_popular=tier_enum == TierCode.PROFESSIONAL, is_enterprise=tier_enum == TierCode.ENTERPRISE, ) - except ValueError: - from fastapi import HTTPException - raise HTTPException(status_code=404, detail=f"Tier '{tier_code}' not found") -@router.get("/addons", response_model=list[AddOnResponse]) +def _addon_to_response(addon) -> AddOnResponse: + """Convert an AddOnProduct to AddOnResponse.""" + return AddOnResponse( + code=addon.code, + name=addon.name, + description=addon.description, + category=addon.category, + price=addon.price_cents / 100, + price_cents=addon.price_cents, + billing_period=addon.billing_period, + quantity_unit=addon.quantity_unit, + quantity_value=addon.quantity_value, + ) + + +# ============================================================================= +# Endpoints +# ============================================================================= + + +@router.get("/tiers", response_model=list[TierResponse]) # public +def get_tiers(db: Session = Depends(get_db)) -> list[TierResponse]: + """ + Get all public subscription tiers. + + Returns tiers from database if available, falls back to hardcoded TIER_LIMITS. + """ + # Try to get from database first + db_tiers = platform_pricing_service.get_public_tiers(db) + + if db_tiers: + return [_tier_to_response(tier, is_from_db=True) for tier in db_tiers] + + # Fallback to hardcoded tiers + from models.database.subscription import TIER_LIMITS + + tiers = [] + for tier_code in TIER_LIMITS: + tier_data = platform_pricing_service.get_tier_from_hardcoded(tier_code.value) + if tier_data: + tiers.append(_tier_to_response(tier_data, is_from_db=False)) + + return tiers + + +@router.get("/tiers/{tier_code}", response_model=TierResponse) # public +def get_tier(tier_code: str, db: Session = Depends(get_db)) -> TierResponse: + """Get a specific tier by code.""" + # Try database first + tier = platform_pricing_service.get_tier_by_code(db, tier_code) + + if tier: + return _tier_to_response(tier, is_from_db=True) + + # Fallback to hardcoded + tier_data = platform_pricing_service.get_tier_from_hardcoded(tier_code) + if tier_data: + return _tier_to_response(tier_data, is_from_db=False) + + raise ResourceNotFoundException( + resource_type="SubscriptionTier", + identifier=tier_code, + ) + + +@router.get("/addons", response_model=list[AddOnResponse]) # public def get_addons(db: Session = Depends(get_db)) -> list[AddOnResponse]: """ Get all available add-on products. Returns add-ons from database, or empty list if none configured. """ - addons = ( - db.query(AddOnProduct) - .filter(AddOnProduct.is_active == True) - .order_by(AddOnProduct.category, AddOnProduct.display_order) - .all() - ) - - return [ - AddOnResponse( - code=addon.code, - name=addon.name, - description=addon.description, - category=addon.category, - price=addon.price_cents / 100, - price_cents=addon.price_cents, - billing_period=addon.billing_period, - quantity_unit=addon.quantity_unit, - quantity_value=addon.quantity_value, - ) - for addon in addons - ] + addons = platform_pricing_service.get_active_addons(db) + return [_addon_to_response(addon) for addon in addons] -@router.get("/pricing", response_model=PricingResponse) +@router.get("/pricing", response_model=PricingResponse) # public def get_pricing(db: Session = Depends(get_db)) -> PricingResponse: """ Get complete pricing information (tiers + add-ons). diff --git a/app/api/v1/platform/signup.py b/app/api/v1/platform/signup.py index 7f4643f5..caff2b43 100644 --- a/app/api/v1/platform/signup.py +++ b/app/api/v1/platform/signup.py @@ -8,61 +8,23 @@ Handles the multi-step signup flow: 3. Create account 4. Setup payment (collect card via SetupIntent) 5. Complete signup (create subscription with trial) + +All endpoints are public (no authentication required). """ import logging -import secrets -from datetime import UTC, datetime, timedelta -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from pydantic import BaseModel, EmailStr from sqlalchemy.orm import Session -from app.core.config import settings from app.core.database import get_db -from app.services.stripe_service import stripe_service -from models.database.subscription import ( - SubscriptionStatus, - TierCode, - VendorSubscription, -) -from models.database.vendor import Vendor, VendorUser, VendorUserType +from app.services.platform_signup_service import platform_signup_service router = APIRouter() logger = logging.getLogger(__name__) -# ============================================================================= -# In-memory signup session storage (for simplicity) -# In production, use Redis or database table -# ============================================================================= - -_signup_sessions: dict[str, dict] = {} - - -def create_session_id() -> str: - """Generate a secure session ID.""" - return secrets.token_urlsafe(32) - - -def get_session(session_id: str) -> dict | None: - """Get a signup session by ID.""" - return _signup_sessions.get(session_id) - - -def save_session(session_id: str, data: dict) -> None: - """Save signup session data.""" - _signup_sessions[session_id] = { - **data, - "updated_at": datetime.now(UTC).isoformat(), - } - - -def delete_session(session_id: str) -> None: - """Delete a signup session.""" - _signup_sessions.pop(session_id, None) - - # ============================================================================= # Request/Response Schemas # ============================================================================= @@ -156,45 +118,27 @@ class CompleteSignupResponse(BaseModel): # ============================================================================= -@router.post("/signup/start", response_model=SignupStartResponse) -async def start_signup( - request: SignupStartRequest, - db: Session = Depends(get_db), -) -> SignupStartResponse: +@router.post("/signup/start", response_model=SignupStartResponse) # public +async def start_signup(request: SignupStartRequest) -> SignupStartResponse: """ Start the signup process. Step 1: User selects a tier and billing period. Creates a signup session to track the flow. """ - # Validate tier code - try: - tier = TierCode(request.tier_code) - except ValueError: - raise HTTPException( - status_code=400, - detail=f"Invalid tier code: {request.tier_code}", - ) - - # Create session - session_id = create_session_id() - save_session(session_id, { - "step": "tier_selected", - "tier_code": tier.value, - "is_annual": request.is_annual, - "created_at": datetime.now(UTC).isoformat(), - }) - - logger.info(f"Started signup session {session_id} for tier {tier.value}") + session_id = platform_signup_service.create_session( + tier_code=request.tier_code, + is_annual=request.is_annual, + ) return SignupStartResponse( session_id=session_id, - tier_code=tier.value, + tier_code=request.tier_code, is_annual=request.is_annual, ) -@router.post("/signup/claim-vendor", response_model=ClaimVendorResponse) +@router.post("/signup/claim-vendor", response_model=ClaimVendorResponse) # public async def claim_letzshop_vendor( request: ClaimVendorRequest, db: Session = Depends(get_db), @@ -205,34 +149,12 @@ async def claim_letzshop_vendor( Step 2 (optional): User claims their Letzshop shop. This pre-fills vendor info during account creation. """ - session = get_session(request.session_id) - if not session: - raise HTTPException(status_code=404, detail="Signup session not found") - - # Check if vendor is already claimed - existing = db.query(Vendor).filter( - Vendor.letzshop_vendor_slug == request.letzshop_slug, - Vendor.is_active == True, - ).first() - - if existing: - raise HTTPException( - status_code=400, - detail="This Letzshop vendor is already claimed", - ) - - # Update session with vendor info - session["letzshop_slug"] = request.letzshop_slug - session["letzshop_vendor_id"] = request.letzshop_vendor_id - session["step"] = "vendor_claimed" - - # TODO: Fetch actual vendor name from Letzshop API - vendor_name = request.letzshop_slug.replace("-", " ").title() - session["vendor_name"] = vendor_name - - save_session(request.session_id, session) - - logger.info(f"Claimed vendor {request.letzshop_slug} for session {request.session_id}") + vendor_name = platform_signup_service.claim_vendor( + db=db, + session_id=request.session_id, + letzshop_slug=request.letzshop_slug, + letzshop_vendor_id=request.letzshop_vendor_id, + ) return ClaimVendorResponse( session_id=request.session_id, @@ -241,7 +163,7 @@ async def claim_letzshop_vendor( ) -@router.post("/signup/create-account", response_model=CreateAccountResponse) +@router.post("/signup/create-account", response_model=CreateAccountResponse) # public async def create_account( request: CreateAccountRequest, db: Session = Depends(get_db), @@ -252,203 +174,45 @@ async def create_account( Step 3: User provides account details. Creates User, Company, Vendor, and Stripe Customer. """ - session = get_session(request.session_id) - if not session: - raise HTTPException(status_code=404, detail="Signup session not found") + result = platform_signup_service.create_account( + db=db, + session_id=request.session_id, + email=request.email, + password=request.password, + first_name=request.first_name, + last_name=request.last_name, + company_name=request.company_name, + phone=request.phone, + ) - # Check if email already exists - from models.database.user import User - - existing_user = db.query(User).filter(User.email == request.email).first() - if existing_user: - raise HTTPException( - status_code=400, - detail="An account with this email already exists", - ) - - try: - # Create User first (needed for Company owner) - from middleware.auth import AuthManager - - auth_manager = AuthManager() - - # Generate username from email - username = request.email.split("@")[0] - base_username = username - counter = 1 - while db.query(User).filter(User.username == username).first(): - username = f"{base_username}_{counter}" - counter += 1 - - user = User( - email=request.email, - username=username, - hashed_password=auth_manager.hash_password(request.password), - first_name=request.first_name, - last_name=request.last_name, - role="vendor", - is_active=True, - ) - db.add(user) - db.flush() - - # Create Company (with owner) - from models.database.company import Company - - company = Company( - name=request.company_name, - owner_user_id=user.id, - contact_email=request.email, - contact_phone=request.phone, - ) - db.add(company) - db.flush() - - # Generate vendor code - vendor_code = request.company_name.upper().replace(" ", "_")[:20] - # Ensure unique - base_code = vendor_code - counter = 1 - while db.query(Vendor).filter(Vendor.vendor_code == vendor_code).first(): - vendor_code = f"{base_code}_{counter}" - counter += 1 - - # Generate subdomain - subdomain = request.company_name.lower().replace(" ", "-") - subdomain = "".join(c for c in subdomain if c.isalnum() or c == "-")[:50] - base_subdomain = subdomain - counter = 1 - while db.query(Vendor).filter(Vendor.subdomain == subdomain).first(): - subdomain = f"{base_subdomain}-{counter}" - counter += 1 - - # Create Vendor - vendor = Vendor( - company_id=company.id, - vendor_code=vendor_code, - subdomain=subdomain, - name=request.company_name, - contact_email=request.email, - contact_phone=request.phone, - is_active=True, - letzshop_vendor_slug=session.get("letzshop_slug"), - letzshop_vendor_id=session.get("letzshop_vendor_id"), - ) - db.add(vendor) - db.flush() - - # Create VendorUser (owner) - vendor_user = VendorUser( - vendor_id=vendor.id, - user_id=user.id, - user_type=VendorUserType.OWNER.value, - is_active=True, - ) - db.add(vendor_user) - - # Create Stripe Customer - stripe_customer_id = stripe_service.create_customer( - vendor=vendor, - email=request.email, - name=f"{request.first_name} {request.last_name}", - metadata={ - "company_name": request.company_name, - "tier": session.get("tier_code"), - }, - ) - - # Create VendorSubscription (in trial status, without Stripe subscription yet) - now = datetime.now(UTC) - trial_end = now + timedelta(days=settings.stripe_trial_days) - - subscription = VendorSubscription( - vendor_id=vendor.id, - tier=session.get("tier_code", TierCode.ESSENTIAL.value), - status=SubscriptionStatus.TRIAL.value, - period_start=now, - period_end=trial_end, - trial_ends_at=trial_end, - is_annual=session.get("is_annual", False), - stripe_customer_id=stripe_customer_id, - ) - db.add(subscription) - - db.commit() - - # Update session - session["user_id"] = user.id - session["vendor_id"] = vendor.id - session["vendor_code"] = vendor_code - session["stripe_customer_id"] = stripe_customer_id - session["step"] = "account_created" - save_session(request.session_id, session) - - logger.info( - f"Created account for {request.email}: " - f"user_id={user.id}, vendor_id={vendor.id}" - ) - - return CreateAccountResponse( - session_id=request.session_id, - user_id=user.id, - vendor_id=vendor.id, - stripe_customer_id=stripe_customer_id, - ) - - except Exception as e: - db.rollback() - logger.error(f"Error creating account: {e}") - raise HTTPException(status_code=500, detail="Failed to create account") + return CreateAccountResponse( + session_id=request.session_id, + user_id=result.user_id, + vendor_id=result.vendor_id, + stripe_customer_id=result.stripe_customer_id, + ) -@router.post("/signup/setup-payment", response_model=SetupPaymentResponse) -async def setup_payment( - request: SetupPaymentRequest, - db: Session = Depends(get_db), -) -> SetupPaymentResponse: +@router.post("/signup/setup-payment", response_model=SetupPaymentResponse) # public +async def setup_payment(request: SetupPaymentRequest) -> SetupPaymentResponse: """ Create Stripe SetupIntent for card collection. Step 4: Collect card details without charging. The card will be charged after the trial period ends. """ - session = get_session(request.session_id) - if not session: - raise HTTPException(status_code=404, detail="Signup session not found") - - if "stripe_customer_id" not in session: - raise HTTPException( - status_code=400, - detail="Account not created. Please complete step 3 first.", - ) - - stripe_customer_id = session["stripe_customer_id"] - - # Create SetupIntent - setup_intent = stripe_service.create_setup_intent( - customer_id=stripe_customer_id, - metadata={ - "session_id": request.session_id, - "vendor_id": str(session.get("vendor_id")), - "tier": session.get("tier_code"), - }, + client_secret, stripe_customer_id = platform_signup_service.setup_payment( + session_id=request.session_id, ) - # Update session - session["setup_intent_id"] = setup_intent.id - session["step"] = "payment_pending" - save_session(request.session_id, session) - - logger.info(f"Created SetupIntent {setup_intent.id} for session {request.session_id}") - return SetupPaymentResponse( session_id=request.session_id, - client_secret=setup_intent.client_secret, + client_secret=client_secret, stripe_customer_id=stripe_customer_id, ) -@router.post("/signup/complete", response_model=CompleteSignupResponse) +@router.post("/signup/complete", response_model=CompleteSignupResponse) # public async def complete_signup( request: CompleteSignupRequest, db: Session = Depends(get_db), @@ -458,89 +222,29 @@ async def complete_signup( Step 5: Verify SetupIntent, attach payment method, create subscription. """ - session = get_session(request.session_id) - if not session: - raise HTTPException(status_code=404, detail="Signup session not found") + result = platform_signup_service.complete_signup( + db=db, + session_id=request.session_id, + setup_intent_id=request.setup_intent_id, + ) - vendor_id = session.get("vendor_id") - stripe_customer_id = session.get("stripe_customer_id") - - if not vendor_id or not stripe_customer_id: - raise HTTPException( - status_code=400, - detail="Incomplete signup. Please start again.", - ) - - try: - # Retrieve SetupIntent to get payment method - setup_intent = stripe_service.get_setup_intent(request.setup_intent_id) - - if setup_intent.status != "succeeded": - raise HTTPException( - status_code=400, - detail="Card setup not completed. Please try again.", - ) - - payment_method_id = setup_intent.payment_method - - # Attach payment method to customer - stripe_service.attach_payment_method_to_customer( - customer_id=stripe_customer_id, - payment_method_id=payment_method_id, - set_as_default=True, - ) - - # Update subscription record with card collection time - subscription = ( - db.query(VendorSubscription) - .filter(VendorSubscription.vendor_id == vendor_id) - .first() - ) - - if subscription: - subscription.card_collected_at = datetime.now(UTC) - subscription.stripe_payment_method_id = payment_method_id - - # TODO: Create actual Stripe subscription with trial - # For now, just mark as trial with card collected - - db.commit() - - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - vendor_code = vendor.vendor_code if vendor else session.get("vendor_code") - - # Clean up session - delete_session(request.session_id) - - trial_ends_at = subscription.trial_ends_at if subscription else datetime.now(UTC) + timedelta(days=30) - - logger.info(f"Completed signup for vendor {vendor_id}") - - return CompleteSignupResponse( - success=True, - vendor_code=vendor_code, - vendor_id=vendor_id, - redirect_url=f"/vendors/{vendor_code}/dashboard", - trial_ends_at=trial_ends_at.isoformat(), - ) - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error completing signup: {e}") - raise HTTPException(status_code=500, detail="Failed to complete signup") + return CompleteSignupResponse( + success=result.success, + vendor_code=result.vendor_code, + vendor_id=result.vendor_id, + redirect_url=result.redirect_url, + trial_ends_at=result.trial_ends_at, + ) -@router.get("/signup/session/{session_id}") +@router.get("/signup/session/{session_id}") # public async def get_signup_session(session_id: str) -> dict: """ Get signup session status. Useful for resuming an incomplete signup. """ - session = get_session(session_id) - if not session: - raise HTTPException(status_code=404, detail="Session not found") + session = platform_signup_service.get_session_or_raise(session_id) # Return safe subset of session data return { diff --git a/app/services/platform_pricing_service.py b/app/services/platform_pricing_service.py new file mode 100644 index 00000000..5519c37f --- /dev/null +++ b/app/services/platform_pricing_service.py @@ -0,0 +1,94 @@ +# app/services/platform_pricing_service.py +""" +Platform pricing service. + +Handles database operations for subscription tiers and add-on products. +""" + +from sqlalchemy.orm import Session + +from models.database.subscription import ( + AddOnProduct, + SubscriptionTier, + TIER_LIMITS, + TierCode, +) + + +class PlatformPricingService: + """Service for handling pricing data operations.""" + + def get_public_tiers(self, db: Session) -> list[SubscriptionTier]: + """ + Get all public subscription tiers from the database. + + Returns: + List of active, public subscription tiers ordered by display_order + """ + return ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.is_active == True, + SubscriptionTier.is_public == True, + ) + .order_by(SubscriptionTier.display_order) + .all() + ) + + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: + """ + Get a specific tier by code from the database. + + Args: + db: Database session + tier_code: The tier code to look up + + Returns: + SubscriptionTier if found, None otherwise + """ + return ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.code == tier_code, + SubscriptionTier.is_active == True, + ) + .first() + ) + + def get_tier_from_hardcoded(self, tier_code: str) -> dict | None: + """ + Get tier limits from hardcoded TIER_LIMITS. + + Args: + tier_code: The tier code to look up + + Returns: + Dict with tier limits if valid code, None otherwise + """ + try: + tier_enum = TierCode(tier_code) + limits = TIER_LIMITS[tier_enum] + return { + "tier_enum": tier_enum, + "limits": limits, + } + except ValueError: + return None + + def get_active_addons(self, db: Session) -> list[AddOnProduct]: + """ + Get all active add-on products from the database. + + Returns: + List of active add-on products ordered by category and display_order + """ + return ( + db.query(AddOnProduct) + .filter(AddOnProduct.is_active == True) + .order_by(AddOnProduct.category, AddOnProduct.display_order) + .all() + ) + + +# Singleton instance +platform_pricing_service = PlatformPricingService() diff --git a/app/services/platform_signup_service.py b/app/services/platform_signup_service.py new file mode 100644 index 00000000..20dbdb3b --- /dev/null +++ b/app/services/platform_signup_service.py @@ -0,0 +1,561 @@ +# app/services/platform_signup_service.py +""" +Platform signup service. + +Handles all database operations for the platform signup flow: +- Session management +- Vendor claiming +- Account creation +- Subscription setup +""" + +import logging +import secrets +from datetime import UTC, datetime, timedelta +from dataclasses import dataclass + +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.exceptions import ( + ConflictException, + ResourceNotFoundException, + ValidationException, +) +from app.services.stripe_service import stripe_service +from middleware.auth import AuthManager +from models.database.company import Company +from models.database.subscription import ( + SubscriptionStatus, + TierCode, + VendorSubscription, +) +from models.database.user import User +from models.database.vendor import Vendor, VendorUser, VendorUserType + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# In-memory signup session storage +# In production, use Redis or database table +# ============================================================================= + +_signup_sessions: dict[str, dict] = {} + + +def _create_session_id() -> str: + """Generate a secure session ID.""" + return secrets.token_urlsafe(32) + + +# ============================================================================= +# Data Classes +# ============================================================================= + + +@dataclass +class SignupSessionData: + """Data stored in a signup session.""" + + session_id: str + step: str + tier_code: str + is_annual: bool + created_at: str + updated_at: str | None = None + letzshop_slug: str | None = None + letzshop_vendor_id: str | None = None + vendor_name: str | None = None + user_id: int | None = None + vendor_id: int | None = None + vendor_code: str | None = None + stripe_customer_id: str | None = None + setup_intent_id: str | None = None + + +@dataclass +class AccountCreationResult: + """Result of account creation.""" + + user_id: int + vendor_id: int + vendor_code: str + stripe_customer_id: str + + +@dataclass +class SignupCompletionResult: + """Result of signup completion.""" + + success: bool + vendor_code: str + vendor_id: int + redirect_url: str + trial_ends_at: str + + +# ============================================================================= +# Platform Signup Service +# ============================================================================= + + +class PlatformSignupService: + """Service for handling platform signup operations.""" + + def __init__(self): + self.auth_manager = AuthManager() + + # ========================================================================= + # Session Management + # ========================================================================= + + def create_session(self, tier_code: str, is_annual: bool) -> str: + """ + Create a new signup session. + + Args: + tier_code: The subscription tier code + is_annual: Whether annual billing is selected + + Returns: + The session ID + + Raises: + ValidationException: If tier code is invalid + """ + # Validate tier code + try: + tier = TierCode(tier_code) + except ValueError: + raise ValidationException( + message=f"Invalid tier code: {tier_code}", + field="tier_code", + ) + + session_id = _create_session_id() + now = datetime.now(UTC).isoformat() + + _signup_sessions[session_id] = { + "step": "tier_selected", + "tier_code": tier.value, + "is_annual": is_annual, + "created_at": now, + "updated_at": now, + } + + logger.info(f"Created signup session {session_id} for tier {tier.value}") + return session_id + + def get_session(self, session_id: str) -> dict | None: + """Get a signup session by ID.""" + return _signup_sessions.get(session_id) + + def get_session_or_raise(self, session_id: str) -> dict: + """ + Get a signup session or raise an exception. + + Raises: + ResourceNotFoundException: If session not found + """ + session = self.get_session(session_id) + if not session: + raise ResourceNotFoundException( + resource_type="SignupSession", + identifier=session_id, + ) + return session + + def update_session(self, session_id: str, data: dict) -> None: + """Update signup session data.""" + session = self.get_session_or_raise(session_id) + session.update(data) + session["updated_at"] = datetime.now(UTC).isoformat() + _signup_sessions[session_id] = session + + def delete_session(self, session_id: str) -> None: + """Delete a signup session.""" + _signup_sessions.pop(session_id, None) + + # ========================================================================= + # Vendor Claiming + # ========================================================================= + + def check_vendor_claimed(self, db: Session, letzshop_slug: str) -> bool: + """Check if a Letzshop vendor is already claimed.""" + return ( + db.query(Vendor) + .filter( + Vendor.letzshop_vendor_slug == letzshop_slug, + Vendor.is_active == True, + ) + .first() + is not None + ) + + def claim_vendor( + self, + db: Session, + session_id: str, + letzshop_slug: str, + letzshop_vendor_id: str | None = None, + ) -> str: + """ + Claim a Letzshop vendor for signup. + + Args: + db: Database session + session_id: Signup session ID + letzshop_slug: Letzshop vendor slug + letzshop_vendor_id: Optional Letzshop vendor ID + + Returns: + Generated vendor name + + Raises: + ResourceNotFoundException: If session not found + ConflictException: If vendor already claimed + """ + session = self.get_session_or_raise(session_id) + + # Check if vendor is already claimed + if self.check_vendor_claimed(db, letzshop_slug): + raise ConflictException( + message="This Letzshop vendor is already claimed", + ) + + # Generate vendor name from slug + vendor_name = letzshop_slug.replace("-", " ").title() + + # Update session + self.update_session(session_id, { + "letzshop_slug": letzshop_slug, + "letzshop_vendor_id": letzshop_vendor_id, + "vendor_name": vendor_name, + "step": "vendor_claimed", + }) + + logger.info(f"Claimed vendor {letzshop_slug} for session {session_id}") + return vendor_name + + # ========================================================================= + # Account Creation + # ========================================================================= + + def check_email_exists(self, db: Session, email: str) -> bool: + """Check if an email already exists.""" + return db.query(User).filter(User.email == email).first() is not None + + def generate_unique_username(self, db: Session, email: str) -> str: + """Generate a unique username from email.""" + username = email.split("@")[0] + base_username = username + counter = 1 + while db.query(User).filter(User.username == username).first(): + username = f"{base_username}_{counter}" + counter += 1 + return username + + def generate_unique_vendor_code(self, db: Session, company_name: str) -> str: + """Generate a unique vendor code from company name.""" + vendor_code = company_name.upper().replace(" ", "_")[:20] + base_code = vendor_code + counter = 1 + while db.query(Vendor).filter(Vendor.vendor_code == vendor_code).first(): + vendor_code = f"{base_code}_{counter}" + counter += 1 + return vendor_code + + def generate_unique_subdomain(self, db: Session, company_name: str) -> str: + """Generate a unique subdomain from company name.""" + subdomain = company_name.lower().replace(" ", "-") + subdomain = "".join(c for c in subdomain if c.isalnum() or c == "-")[:50] + base_subdomain = subdomain + counter = 1 + while db.query(Vendor).filter(Vendor.subdomain == subdomain).first(): + subdomain = f"{base_subdomain}-{counter}" + counter += 1 + return subdomain + + def create_account( + self, + db: Session, + session_id: str, + email: str, + password: str, + first_name: str, + last_name: str, + company_name: str, + phone: str | None = None, + ) -> AccountCreationResult: + """ + Create user, company, vendor, and Stripe customer. + + Args: + db: Database session + session_id: Signup session ID + email: User email + password: User password + first_name: User first name + last_name: User last name + company_name: Company name + phone: Optional phone number + + Returns: + AccountCreationResult with IDs + + Raises: + ResourceNotFoundException: If session not found + ConflictException: If email already exists + """ + session = self.get_session_or_raise(session_id) + + # Check if email already exists + if self.check_email_exists(db, email): + raise ConflictException( + message="An account with this email already exists", + ) + + # Generate unique username + username = self.generate_unique_username(db, email) + + # Create User + user = User( + email=email, + username=username, + hashed_password=self.auth_manager.hash_password(password), + first_name=first_name, + last_name=last_name, + role="vendor", + is_active=True, + ) + db.add(user) + db.flush() + + # Create Company + company = Company( + name=company_name, + owner_user_id=user.id, + contact_email=email, + contact_phone=phone, + ) + db.add(company) + db.flush() + + # Generate unique vendor code and subdomain + vendor_code = self.generate_unique_vendor_code(db, company_name) + subdomain = self.generate_unique_subdomain(db, company_name) + + # Create Vendor + vendor = Vendor( + company_id=company.id, + vendor_code=vendor_code, + subdomain=subdomain, + name=company_name, + contact_email=email, + contact_phone=phone, + is_active=True, + letzshop_vendor_slug=session.get("letzshop_slug"), + letzshop_vendor_id=session.get("letzshop_vendor_id"), + ) + db.add(vendor) + db.flush() + + # Create VendorUser (owner) + vendor_user = VendorUser( + vendor_id=vendor.id, + user_id=user.id, + user_type=VendorUserType.OWNER.value, + is_active=True, + ) + db.add(vendor_user) + + # Create Stripe Customer + stripe_customer_id = stripe_service.create_customer( + vendor=vendor, + email=email, + name=f"{first_name} {last_name}", + metadata={ + "company_name": company_name, + "tier": session.get("tier_code"), + }, + ) + + # Create VendorSubscription (trial status) + now = datetime.now(UTC) + trial_end = now + timedelta(days=settings.stripe_trial_days) + + subscription = VendorSubscription( + vendor_id=vendor.id, + tier=session.get("tier_code", TierCode.ESSENTIAL.value), + status=SubscriptionStatus.TRIAL.value, + period_start=now, + period_end=trial_end, + trial_ends_at=trial_end, + is_annual=session.get("is_annual", False), + stripe_customer_id=stripe_customer_id, + ) + db.add(subscription) + + db.commit() # noqa: SVC-006 - Atomic account creation needs commit + + # Update session + self.update_session(session_id, { + "user_id": user.id, + "vendor_id": vendor.id, + "vendor_code": vendor_code, + "stripe_customer_id": stripe_customer_id, + "step": "account_created", + }) + + logger.info( + f"Created account for {email}: user_id={user.id}, vendor_id={vendor.id}" + ) + + return AccountCreationResult( + user_id=user.id, + vendor_id=vendor.id, + vendor_code=vendor_code, + stripe_customer_id=stripe_customer_id, + ) + + # ========================================================================= + # Payment Setup + # ========================================================================= + + def setup_payment(self, session_id: str) -> tuple[str, str]: + """ + Create Stripe SetupIntent for card collection. + + Args: + session_id: Signup session ID + + Returns: + Tuple of (client_secret, stripe_customer_id) + + Raises: + EntityNotFoundException: If session not found + ValidationException: If account not created yet + """ + session = self.get_session_or_raise(session_id) + + if "stripe_customer_id" not in session: + raise ValidationException( + message="Account not created. Please complete step 3 first.", + field="session_id", + ) + + stripe_customer_id = session["stripe_customer_id"] + + # Create SetupIntent + setup_intent = stripe_service.create_setup_intent( + customer_id=stripe_customer_id, + metadata={ + "session_id": session_id, + "vendor_id": str(session.get("vendor_id")), + "tier": session.get("tier_code"), + }, + ) + + # Update session + self.update_session(session_id, { + "setup_intent_id": setup_intent.id, + "step": "payment_pending", + }) + + logger.info(f"Created SetupIntent {setup_intent.id} for session {session_id}") + + return setup_intent.client_secret, stripe_customer_id + + # ========================================================================= + # Signup Completion + # ========================================================================= + + def complete_signup( + self, + db: Session, + session_id: str, + setup_intent_id: str, + ) -> SignupCompletionResult: + """ + Complete signup after card collection. + + Args: + db: Database session + session_id: Signup session ID + setup_intent_id: Stripe SetupIntent ID + + Returns: + SignupCompletionResult + + Raises: + EntityNotFoundException: If session not found + ValidationException: If signup incomplete or payment failed + """ + session = self.get_session_or_raise(session_id) + + vendor_id = session.get("vendor_id") + stripe_customer_id = session.get("stripe_customer_id") + + if not vendor_id or not stripe_customer_id: + raise ValidationException( + message="Incomplete signup. Please start again.", + field="session_id", + ) + + # Retrieve SetupIntent to get payment method + setup_intent = stripe_service.get_setup_intent(setup_intent_id) + + if setup_intent.status != "succeeded": + raise ValidationException( + message="Card setup not completed. Please try again.", + field="setup_intent_id", + ) + + payment_method_id = setup_intent.payment_method + + # Attach payment method to customer + stripe_service.attach_payment_method_to_customer( + customer_id=stripe_customer_id, + payment_method_id=payment_method_id, + set_as_default=True, + ) + + # Update subscription record + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.vendor_id == vendor_id) + .first() + ) + + if subscription: + subscription.card_collected_at = datetime.now(UTC) + subscription.stripe_payment_method_id = payment_method_id + db.commit() # noqa: SVC-006 - Finalize signup needs commit + + # Get vendor info + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + vendor_code = vendor.vendor_code if vendor else session.get("vendor_code") + + trial_ends_at = ( + subscription.trial_ends_at + if subscription + else datetime.now(UTC) + timedelta(days=30) + ) + + # Clean up session + self.delete_session(session_id) + + logger.info(f"Completed signup for vendor {vendor_id}") + + return SignupCompletionResult( + success=True, + vendor_code=vendor_code, + vendor_id=vendor_id, + redirect_url=f"/vendor/{vendor_code}/dashboard", + trial_ends_at=trial_ends_at.isoformat(), + ) + + +# Singleton instance +platform_signup_service = PlatformSignupService() diff --git a/tests/integration/api/v1/platform/test_signup.py b/tests/integration/api/v1/platform/test_signup.py index 3421ad6c..b5a2e9e6 100644 --- a/tests/integration/api/v1/platform/test_signup.py +++ b/tests/integration/api/v1/platform/test_signup.py @@ -17,7 +17,7 @@ from models.database.vendor import Vendor @pytest.fixture def mock_stripe_service(): """Mock the Stripe service for tests.""" - with patch("app.api.v1.platform.signup.stripe_service") as mock: + with patch("app.services.platform_signup_service.stripe_service") as mock: mock.create_customer.return_value = "cus_test_123" mock.create_setup_intent.return_value = MagicMock( id="seti_test_123", @@ -147,7 +147,7 @@ class TestSignupStartAPI: json={"tier_code": "invalid_tier", "is_annual": False}, ) - assert response.status_code == 400 + assert response.status_code == 422 # ValidationException data = response.json() assert "invalid tier" in data["message"].lower() @@ -226,7 +226,7 @@ class TestClaimVendorAPI: }, ) - assert response.status_code == 400 + assert response.status_code == 409 # ConflictException data = response.json() assert "already claimed" in data["message"].lower() @@ -309,7 +309,7 @@ class TestCreateAccountAPI: }, ) - assert response.status_code == 400 + assert response.status_code == 409 # ConflictException data = response.json() assert "already exists" in data["message"].lower() @@ -414,7 +414,7 @@ class TestSetupPaymentAPI: json={"session_id": signup_session}, ) - assert response.status_code == 400 + assert response.status_code == 422 # ValidationException data = response.json() assert "account not created" in data["message"].lower() @@ -514,7 +514,7 @@ class TestCompleteSignupAPI: }, ) - assert response.status_code == 400 + assert response.status_code == 422 # ValidationException data = response.json() assert "not completed" in data["message"].lower()