diff --git a/app/api/deps.py b/app/api/deps.py index 5c0a6a86..acebb04e 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -155,6 +155,23 @@ def _get_user_model(user_context: UserContext, db: Session) -> UserModel: return user +# ============================================================================ +# PLATFORM CONTEXT +# ============================================================================ + + +def require_platform(request: Request): + """Dependency that requires platform context from middleware. + + Raises HTTPException(400) if no platform is set on request.state. + Use as a FastAPI dependency in endpoints that need platform context. + """ + platform = getattr(request.state, "platform", None) + if not platform: + raise HTTPException(status_code=400, detail="Platform context required") + return platform + + # ============================================================================ # ADMIN AUTHENTICATION # ============================================================================ diff --git a/app/core/frontend_detector.py b/app/core/frontend_detector.py index e1c3712b..035a1803 100644 --- a/app/core/frontend_detector.py +++ b/app/core/frontend_detector.py @@ -46,7 +46,6 @@ class FrontendDetector: STOREFRONT_PATH_PREFIXES = ( "/storefront", "/api/v1/storefront", - "/stores/", # Path-based store access ) MERCHANT_PATH_PREFIXES = ("/merchants", "/api/v1/merchants") PLATFORM_PATH_PREFIXES = ("/api/v1/platform",) diff --git a/app/handlers/stripe_webhook.py b/app/handlers/stripe_webhook.py index e97ac463..b2e4a785 100644 --- a/app/handlers/stripe_webhook.py +++ b/app/handlers/stripe_webhook.py @@ -380,6 +380,24 @@ class StripeWebhookHandler: f"Tier changed to {tier.code} for merchant {subscription.merchant_id}" ) + # Sync store_platforms based on subscription status + from app.modules.billing.services.store_platform_sync_service import ( + store_platform_sync, + ) + + active_statuses = { + SubscriptionStatus.TRIAL.value, + SubscriptionStatus.ACTIVE.value, + SubscriptionStatus.PAST_DUE.value, + SubscriptionStatus.CANCELLED.value, + } + store_platform_sync.sync_store_platforms_for_merchant( + db, + subscription.merchant_id, + subscription.platform_id, + is_active=subscription.status in active_statuses, + ) + logger.info(f"Subscription updated for merchant {subscription.merchant_id}") return {"action": "updated", "merchant_id": subscription.merchant_id} @@ -435,6 +453,15 @@ class StripeWebhookHandler: if addon_count > 0: logger.info(f"Cancelled {addon_count} add-ons for merchant {merchant_id}") + # Deactivate store_platforms for the deleted subscription's platform + from app.modules.billing.services.store_platform_sync_service import ( + store_platform_sync, + ) + + store_platform_sync.sync_store_platforms_for_merchant( + db, merchant_id, subscription.platform_id, is_active=False + ) + logger.info(f"Subscription deleted for merchant {merchant_id}") return { "action": "cancelled", diff --git a/app/modules/billing/routes/api/admin.py b/app/modules/billing/routes/api/admin.py index f34c09e6..cf340e68 100644 --- a/app/modules/billing/routes/api/admin.py +++ b/app/modules/billing/routes/api/admin.py @@ -16,7 +16,6 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api, require_module_access from app.core.database import get_db -from app.exceptions import ResourceNotFoundException from app.modules.billing.schemas import ( BillingHistoryListResponse, BillingHistoryWithMerchant, @@ -284,16 +283,7 @@ def get_subscription_for_store( store -> merchant -> all platform subscriptions and returns a list of subscription entries with feature usage metrics. """ - from app.modules.billing.services.feature_service import feature_service - - # Resolve store to merchant - merchant_id, platform_ids = feature_service._get_merchant_and_platforms_for_store(db, store_id) - if merchant_id is None or not platform_ids: - raise ResourceNotFoundException("Store", str(store_id)) - - results = admin_subscription_service.get_merchant_subscriptions_with_usage( - db, merchant_id - ) + results = admin_subscription_service.get_subscriptions_for_store(db, store_id) return {"subscriptions": results} diff --git a/app/modules/billing/services/__init__.py b/app/modules/billing/services/__init__.py index 5bf6a6d5..dcf64902 100644 --- a/app/modules/billing/services/__init__.py +++ b/app/modules/billing/services/__init__.py @@ -21,6 +21,10 @@ from app.modules.billing.services.platform_pricing_service import ( PlatformPricingService, platform_pricing_service, ) +from app.modules.billing.services.store_platform_sync_service import ( + StorePlatformSync, + store_platform_sync, +) from app.modules.billing.services.stripe_service import ( StripeService, stripe_service, @@ -42,6 +46,8 @@ from app.modules.billing.services.usage_service import ( __all__ = [ "SubscriptionService", "subscription_service", + "StorePlatformSync", + "store_platform_sync", "StripeService", "stripe_service", "AdminSubscriptionService", diff --git a/app/modules/billing/services/admin_subscription_service.py b/app/modules/billing/services/admin_subscription_service.py index ed204e28..29ede8ef 100644 --- a/app/modules/billing/services/admin_subscription_service.py +++ b/app/modules/billing/services/admin_subscription_service.py @@ -56,13 +56,14 @@ class AdminSubscriptionService: return query.order_by(SubscriptionTier.display_order).all() - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: - """Get a subscription tier by code.""" - tier = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == tier_code) - .first() - ) + def get_tier_by_code( + self, db: Session, tier_code: str, platform_id: int | None = None + ) -> SubscriptionTier: + """Get a subscription tier by code, optionally scoped to a platform.""" + query = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code) + if platform_id is not None: + query = query.filter(SubscriptionTier.platform_id == platform_id) + tier = query.first() if not tier: raise TierNotFoundException(tier_code) @@ -214,7 +215,7 @@ class AdminSubscriptionService: db, merchant_id, platform_id, tier_code, sub.is_annual ) else: - tier = self.get_tier_by_code(db, tier_code) + tier = self.get_tier_by_code(db, tier_code, platform_id=platform_id) sub.tier_id = tier.id for field, value in update_data.items(): @@ -350,6 +351,22 @@ class AdminSubscriptionService: return results + def get_subscriptions_for_store( + self, db: Session, store_id: int + ) -> list[dict]: + """Get subscriptions + feature usage for a store (resolves to merchant). + + Convenience method for admin store detail page. Resolves + store -> merchant -> all platform subscriptions. + """ + from app.modules.tenancy.models import Store + + store = db.query(Store).filter(Store.id == store_id).first() + if not store or not store.merchant_id: + raise ResourceNotFoundException("Store", str(store_id)) + + return self.get_merchant_subscriptions_with_usage(db, store.merchant_id) + # ========================================================================= # Statistics # ========================================================================= diff --git a/app/modules/billing/services/billing_service.py b/app/modules/billing/services/billing_service.py index e83f283e..fdd13240 100644 --- a/app/modules/billing/services/billing_service.py +++ b/app/modules/billing/services/billing_service.py @@ -88,21 +88,22 @@ class BillingService: return tier_list, tier_order - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: + def get_tier_by_code( + self, db: Session, tier_code: str, platform_id: int | None = None + ) -> SubscriptionTier: """ - Get a tier by its code. + Get a tier by its code, optionally scoped to a platform. Raises: TierNotFoundException: If tier doesn't exist """ - tier = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.code == tier_code, - SubscriptionTier.is_active == True, # noqa: E712 - ) - .first() + query = db.query(SubscriptionTier).filter( + SubscriptionTier.code == tier_code, + SubscriptionTier.is_active == True, # noqa: E712 ) + if platform_id is not None: + query = query.filter(SubscriptionTier.platform_id == platform_id) + tier = query.first() if not tier: raise TierNotFoundException(tier_code) @@ -133,7 +134,7 @@ class BillingService: if not stripe_service.is_configured: raise PaymentSystemNotConfiguredException() - tier = self.get_tier_by_code(db, tier_code) + tier = self.get_tier_by_code(db, tier_code, platform_id=platform_id) price_id = ( tier.stripe_price_annual_id @@ -410,7 +411,7 @@ class BillingService: if not subscription or not subscription.stripe_subscription_id: raise NoActiveSubscriptionException() - tier = self.get_tier_by_code(db, new_tier_code) + tier = self.get_tier_by_code(db, new_tier_code, platform_id=platform_id) price_id = ( tier.stripe_price_annual_id diff --git a/app/modules/billing/services/platform_pricing_service.py b/app/modules/billing/services/platform_pricing_service.py index 615ec13c..8f3d18d1 100644 --- a/app/modules/billing/services/platform_pricing_service.py +++ b/app/modules/billing/services/platform_pricing_service.py @@ -28,16 +28,17 @@ class PlatformPricingService: .all() ) - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: - """Get a specific tier by code from the database.""" - return ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.code == tier_code, - SubscriptionTier.is_active == True, - ) - .first() + def get_tier_by_code( + self, db: Session, tier_code: str, platform_id: int | None = None + ) -> SubscriptionTier | None: + """Get a specific tier by code from the database, optionally scoped to a platform.""" + query = db.query(SubscriptionTier).filter( + SubscriptionTier.code == tier_code, + SubscriptionTier.is_active == True, ) + if platform_id is not None: + query = query.filter(SubscriptionTier.platform_id == platform_id) + return query.first() def get_active_addons(self, db: Session) -> list[AddOnProduct]: """Get all active add-on products from the database.""" diff --git a/app/modules/billing/services/store_platform_sync_service.py b/app/modules/billing/services/store_platform_sync_service.py new file mode 100644 index 00000000..115cfef5 --- /dev/null +++ b/app/modules/billing/services/store_platform_sync_service.py @@ -0,0 +1,92 @@ +# app/modules/billing/services/store_platform_sync.py +""" +Keeps store_platforms in sync with merchant subscriptions. + +When a subscription is created, reactivated, or deleted, this service +ensures all stores belonging to that merchant get corresponding +StorePlatform entries created or updated. +""" + +import logging + +from sqlalchemy.orm import Session + +from app.modules.tenancy.models import Store, StorePlatform + +logger = logging.getLogger(__name__) + + +class StorePlatformSync: + """Syncs StorePlatform entries when merchant subscriptions change.""" + + def sync_store_platforms_for_merchant( + self, + db: Session, + merchant_id: int, + platform_id: int, + is_active: bool, + tier_id: int | None = None, + ) -> None: + """ + Upsert StorePlatform for every store belonging to a merchant. + + - Existing entry → update is_active (and tier_id if provided) + - Missing + is_active=True → create (set is_primary if store has none) + - Missing + is_active=False → no-op + """ + stores = ( + db.query(Store) + .filter(Store.merchant_id == merchant_id) + .all() + ) + + if not stores: + return + + for store in stores: + existing = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == store.id, + StorePlatform.platform_id == platform_id, + ) + .first() + ) + + if existing: + existing.is_active = is_active + if tier_id is not None: + existing.tier_id = tier_id + logger.debug( + f"Updated StorePlatform store_id={store.id} " + f"platform_id={platform_id} is_active={is_active}" + ) + elif is_active: + # Check if store already has a primary platform + has_primary = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == store.id, + StorePlatform.is_primary.is_(True), + ) + .first() + ) is not None + + sp = StorePlatform( + store_id=store.id, + platform_id=platform_id, + is_active=True, + is_primary=not has_primary, + tier_id=tier_id, + ) + db.add(sp) + logger.info( + f"Created StorePlatform store_id={store.id} " + f"platform_id={platform_id} is_primary={not has_primary}" + ) + + db.flush() + + +# Singleton instance +store_platform_sync = StorePlatformSync() diff --git a/app/modules/billing/services/subscription_service.py b/app/modules/billing/services/subscription_service.py index 45e27fdf..24982a0b 100644 --- a/app/modules/billing/services/subscription_service.py +++ b/app/modules/billing/services/subscription_service.py @@ -82,17 +82,20 @@ class SubscriptionService: # Tier Information # ========================================================================= - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: - """Get subscription tier by code.""" - return ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == tier_code) - .first() - ) + def get_tier_by_code( + self, db: Session, tier_code: str, platform_id: int | None = None + ) -> SubscriptionTier | None: + """Get subscription tier by code, optionally scoped to a platform.""" + query = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code) + if platform_id is not None: + query = query.filter(SubscriptionTier.platform_id == platform_id) + return query.first() - def get_tier_id(self, db: Session, tier_code: str) -> int | None: + def get_tier_id( + self, db: Session, tier_code: str, platform_id: int | None = None + ) -> int | None: """Get tier ID from tier code. Returns None if tier not found.""" - tier = self.get_tier_by_code(db, tier_code) + tier = self.get_tier_by_code(db, tier_code, platform_id=platform_id) return tier.id if tier else None def get_all_tiers( @@ -254,7 +257,7 @@ class SubscriptionService: trial_ends_at = None status = SubscriptionStatus.ACTIVE.value - tier_id = self.get_tier_id(db, tier_code) + tier_id = self.get_tier_id(db, tier_code, platform_id=platform_id) subscription = MerchantSubscription( merchant_id=merchant_id, @@ -271,6 +274,15 @@ class SubscriptionService: db.flush() db.refresh(subscription) + # Sync store_platforms for all merchant stores + from app.modules.billing.services.store_platform_sync_service import ( + store_platform_sync, + ) + + store_platform_sync.sync_store_platforms_for_merchant( + db, merchant_id, platform_id, is_active=True, tier_id=subscription.tier_id + ) + logger.info( f"Created subscription for merchant {merchant_id} on platform {platform_id} " f"(tier={tier_code}, status={status})" @@ -305,7 +317,7 @@ class SubscriptionService: subscription = self.get_subscription_or_raise(db, merchant_id, platform_id) old_tier_id = subscription.tier_id - new_tier = self.get_tier_by_code(db, new_tier_code) + new_tier = self.get_tier_by_code(db, new_tier_code, platform_id=platform_id) if not new_tier: raise ValueError(f"Tier '{new_tier_code}' not found") @@ -366,6 +378,15 @@ class SubscriptionService: db.flush() db.refresh(subscription) + # Sync store_platforms for all merchant stores + from app.modules.billing.services.store_platform_sync_service import ( + store_platform_sync, + ) + + store_platform_sync.sync_store_platforms_for_merchant( + db, merchant_id, platform_id, is_active=True + ) + logger.info( f"Reactivated subscription for merchant {merchant_id} " f"on platform {platform_id}" diff --git a/app/modules/billing/tests/integration/test_admin_routes.py b/app/modules/billing/tests/integration/test_admin_routes.py index 71a81ed6..cc5d806b 100644 --- a/app/modules/billing/tests/integration/test_admin_routes.py +++ b/app/modules/billing/tests/integration/test_admin_routes.py @@ -507,3 +507,56 @@ class TestAdminBillingHistory: ) assert response.status_code == 200 assert response.json()["total"] == 0 + + +# ============================================================================ +# Store Subscription Convenience Endpoint +# ============================================================================ + + +class TestAdminStoreSubscription: + """Tests for GET /api/v1/admin/subscriptions/store/{store_id}.""" + + def test_get_subscriptions_for_store( + self, client, super_admin_headers, rt_subscription, rt_store + ): + """Returns subscriptions when store has a merchant with subscriptions.""" + response = client.get( + f"{BASE}/store/{rt_store.id}", + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "subscriptions" in data + assert len(data["subscriptions"]) >= 1 + + def test_get_subscriptions_for_nonexistent_store( + self, client, super_admin_headers + ): + """Returns 404 for non-existent store ID.""" + response = client.get( + f"{BASE}/store/999999", + headers=super_admin_headers, + ) + assert response.status_code == 404 + + +@pytest.fixture +def rt_store(db, rt_merchant): + """Create a store for route tests.""" + from app.modules.tenancy.models import Store + + store = Store( + merchant_id=rt_merchant.id, + store_code=f"RT_{uuid.uuid4().hex[:6].upper()}", + name="Route Test Store", + subdomain=f"rt-{uuid.uuid4().hex[:8]}", + is_active=True, + is_verified=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + db.add(store) + db.commit() + db.refresh(store) + return store diff --git a/app/modules/billing/tests/unit/test_admin_subscription_service.py b/app/modules/billing/tests/unit/test_admin_subscription_service.py index 1c7bec11..bc74044e 100644 --- a/app/modules/billing/tests/unit/test_admin_subscription_service.py +++ b/app/modules/billing/tests/unit/test_admin_subscription_service.py @@ -497,7 +497,7 @@ class TestAdminGetStats: @pytest.fixture -def admin_billing_tiers(db): +def admin_billing_tiers(db, test_platform): """Create essential, professional, business tiers for admin tests.""" tiers = [ SubscriptionTier( @@ -508,6 +508,7 @@ def admin_billing_tiers(db): display_order=1, is_active=True, is_public=True, + platform_id=test_platform.id, ), SubscriptionTier( code="professional", @@ -517,6 +518,7 @@ def admin_billing_tiers(db): display_order=2, is_active=True, is_public=True, + platform_id=test_platform.id, ), SubscriptionTier( code="business", @@ -526,6 +528,7 @@ def admin_billing_tiers(db): display_order=3, is_active=True, is_public=True, + platform_id=test_platform.id, ), ] db.add_all(tiers) diff --git a/app/modules/billing/tests/unit/test_billing_service.py b/app/modules/billing/tests/unit/test_billing_service.py index fd902ec1..f9b61cc5 100644 --- a/app/modules/billing/tests/unit/test_billing_service.py +++ b/app/modules/billing/tests/unit/test_billing_service.py @@ -603,7 +603,7 @@ class TestBillingServiceUpcomingInvoice: @pytest.fixture -def bs_tier_essential(db): +def bs_tier_essential(db, test_platform): """Create essential subscription tier.""" tier = SubscriptionTier( code="essential", @@ -614,6 +614,7 @@ def bs_tier_essential(db): display_order=1, is_active=True, is_public=True, + platform_id=test_platform.id, ) db.add(tier) db.commit() @@ -622,7 +623,7 @@ def bs_tier_essential(db): @pytest.fixture -def bs_tiers(db): +def bs_tiers(db, test_platform): """Create three tiers without Stripe config.""" tiers = [ SubscriptionTier( @@ -633,6 +634,7 @@ def bs_tiers(db): display_order=1, is_active=True, is_public=True, + platform_id=test_platform.id, ), SubscriptionTier( code="professional", @@ -642,6 +644,7 @@ def bs_tiers(db): display_order=2, is_active=True, is_public=True, + platform_id=test_platform.id, ), SubscriptionTier( code="business", @@ -651,6 +654,7 @@ def bs_tiers(db): display_order=3, is_active=True, is_public=True, + platform_id=test_platform.id, ), ] db.add_all(tiers) @@ -661,7 +665,7 @@ def bs_tiers(db): @pytest.fixture -def bs_tiers_with_stripe(db): +def bs_tiers_with_stripe(db, test_platform): """Create tiers with Stripe price IDs configured.""" tiers = [ SubscriptionTier( @@ -672,6 +676,7 @@ def bs_tiers_with_stripe(db): display_order=1, is_active=True, is_public=True, + platform_id=test_platform.id, stripe_product_id="prod_essential", stripe_price_monthly_id="price_ess_monthly", stripe_price_annual_id="price_ess_annual", @@ -684,6 +689,7 @@ def bs_tiers_with_stripe(db): display_order=2, is_active=True, is_public=True, + platform_id=test_platform.id, stripe_product_id="prod_professional", stripe_price_monthly_id="price_pro_monthly", stripe_price_annual_id="price_pro_annual", @@ -696,6 +702,7 @@ def bs_tiers_with_stripe(db): display_order=3, is_active=True, is_public=True, + platform_id=test_platform.id, stripe_product_id="prod_business", stripe_price_monthly_id="price_biz_monthly", stripe_price_annual_id="price_biz_annual", diff --git a/app/modules/billing/tests/unit/test_store_platform_sync_service.py b/app/modules/billing/tests/unit/test_store_platform_sync_service.py new file mode 100644 index 00000000..aa66c685 --- /dev/null +++ b/app/modules/billing/tests/unit/test_store_platform_sync_service.py @@ -0,0 +1,247 @@ +# app/modules/billing/tests/unit/test_store_platform_sync.py +"""Unit tests for StorePlatformSync service.""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from app.modules.billing.models import ( + MerchantSubscription, + SubscriptionStatus, + SubscriptionTier, +) +from app.modules.billing.services.store_platform_sync_service import StorePlatformSync +from app.modules.tenancy.models import StorePlatform + + +@pytest.mark.unit +@pytest.mark.billing +class TestStorePlatformSyncCreate: + """Tests for creating StorePlatform entries via sync.""" + + def setup_method(self): + self.service = StorePlatformSync() + + def test_sync_creates_store_platform(self, db, test_store, test_platform): + """Sync with is_active=True creates a new StorePlatform entry.""" + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, is_active=True + ) + + sp = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == test_store.id, + StorePlatform.platform_id == test_platform.id, + ) + .first() + ) + assert sp is not None + assert sp.is_active is True + + def test_sync_sets_primary_when_none(self, db, test_store, test_platform): + """First platform synced for a store gets is_primary=True.""" + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, is_active=True + ) + + sp = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == test_store.id, + StorePlatform.platform_id == test_platform.id, + ) + .first() + ) + assert sp.is_primary is True + + def test_sync_no_primary_override(self, db, test_store, test_platform, another_platform): + """Second platform synced does not override existing primary.""" + # First platform becomes primary + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, is_active=True + ) + # Second platform should not be primary + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, another_platform.id, is_active=True + ) + + sp1 = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == test_store.id, + StorePlatform.platform_id == test_platform.id, + ) + .first() + ) + sp2 = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == test_store.id, + StorePlatform.platform_id == another_platform.id, + ) + .first() + ) + assert sp1.is_primary is True + assert sp2.is_primary is False + + def test_sync_sets_tier_id(self, db, test_store, test_platform, sync_tier): + """Sync passes tier_id to newly created StorePlatform.""" + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, + is_active=True, tier_id=sync_tier.id, + ) + + sp = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == test_store.id, + StorePlatform.platform_id == test_platform.id, + ) + .first() + ) + assert sp.tier_id == sync_tier.id + + +@pytest.mark.unit +@pytest.mark.billing +class TestStorePlatformSyncUpdate: + """Tests for updating existing StorePlatform entries via sync.""" + + def setup_method(self): + self.service = StorePlatformSync() + + def test_sync_updates_existing_is_active(self, db, test_store, test_platform): + """Sync updates is_active on existing StorePlatform.""" + # Create initial entry + sp = StorePlatform( + store_id=test_store.id, + platform_id=test_platform.id, + is_active=True, + is_primary=True, + ) + db.add(sp) + db.flush() + + # Deactivate via sync + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, is_active=False + ) + + db.refresh(sp) + assert sp.is_active is False + + def test_sync_updates_tier_id(self, db, test_store, test_platform, sync_tier): + """Sync updates tier_id on existing StorePlatform.""" + sp = StorePlatform( + store_id=test_store.id, + platform_id=test_platform.id, + is_active=True, + is_primary=True, + ) + db.add(sp) + db.flush() + + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, + is_active=True, tier_id=sync_tier.id, + ) + + db.refresh(sp) + assert sp.tier_id == sync_tier.id + + +@pytest.mark.unit +@pytest.mark.billing +class TestStorePlatformSyncEdgeCases: + """Tests for edge cases in sync.""" + + def setup_method(self): + self.service = StorePlatformSync() + + def test_sync_noop_inactive_missing(self, db, test_store, test_platform): + """Sync with is_active=False for non-existent entry is a no-op.""" + self.service.sync_store_platforms_for_merchant( + db, test_store.merchant_id, test_platform.id, is_active=False + ) + + sp = ( + db.query(StorePlatform) + .filter( + StorePlatform.store_id == test_store.id, + StorePlatform.platform_id == test_platform.id, + ) + .first() + ) + assert sp is None + + def test_sync_no_stores(self, db, test_platform): + """Sync with no stores for merchant is a no-op (no error).""" + self.service.sync_store_platforms_for_merchant( + db, 99999, test_platform.id, is_active=True + ) + # No assertion needed — just verifying no exception + + def test_sync_multiple_stores(self, db, test_merchant, test_platform): + """Sync creates entries for all stores of a merchant.""" + from app.modules.tenancy.models import Store + + store1 = Store( + merchant_id=test_merchant.id, + store_code="SYNC_TEST_1", + name="Sync Store 1", + subdomain="sync-test-1", + is_active=True, + is_verified=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store2 = Store( + merchant_id=test_merchant.id, + store_code="SYNC_TEST_2", + name="Sync Store 2", + subdomain="sync-test-2", + is_active=True, + is_verified=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + db.add_all([store1, store2]) + db.flush() + + self.service.sync_store_platforms_for_merchant( + db, test_merchant.id, test_platform.id, is_active=True + ) + + count = ( + db.query(StorePlatform) + .filter( + StorePlatform.platform_id == test_platform.id, + StorePlatform.store_id.in_([store1.id, store2.id]), + ) + .count() + ) + assert count == 2 + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def sync_tier(db, test_platform): + """Create a tier for sync tests.""" + tier = SubscriptionTier( + platform_id=test_platform.id, + code="essential", + name="Essential", + price_monthly_cents=2900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier diff --git a/app/modules/billing/tests/unit/test_subscription_service.py b/app/modules/billing/tests/unit/test_subscription_service.py index 5139ec2b..01113be3 100644 --- a/app/modules/billing/tests/unit/test_subscription_service.py +++ b/app/modules/billing/tests/unit/test_subscription_service.py @@ -502,7 +502,7 @@ class TestSubscriptionServiceReactivate: @pytest.fixture -def billing_tier_essential(db): +def billing_tier_essential(db, test_platform): """Create essential subscription tier.""" tier = SubscriptionTier( code="essential", @@ -513,6 +513,7 @@ def billing_tier_essential(db): display_order=1, is_active=True, is_public=True, + platform_id=test_platform.id, ) db.add(tier) db.commit() @@ -521,7 +522,7 @@ def billing_tier_essential(db): @pytest.fixture -def billing_tiers(db): +def billing_tiers(db, test_platform): """Create essential, professional, and business tiers.""" tiers = [ SubscriptionTier( @@ -532,6 +533,7 @@ def billing_tiers(db): display_order=1, is_active=True, is_public=True, + platform_id=test_platform.id, ), SubscriptionTier( code="professional", @@ -541,6 +543,7 @@ def billing_tiers(db): display_order=2, is_active=True, is_public=True, + platform_id=test_platform.id, ), SubscriptionTier( code="business", @@ -550,6 +553,7 @@ def billing_tiers(db): display_order=3, is_active=True, is_public=True, + platform_id=test_platform.id, ), ] db.add_all(tiers) diff --git a/app/modules/cart/definition.py b/app/modules/cart/definition.py index 4c0d668f..50af19ed 100644 --- a/app/modules/cart/definition.py +++ b/app/modules/cart/definition.py @@ -70,7 +70,7 @@ cart_module = ModuleDefinition( id="cart", label_key="storefront.actions.cart", icon="shopping-cart", - route="storefront/cart", + route="cart", order=20, ), ], diff --git a/app/modules/catalog/definition.py b/app/modules/catalog/definition.py index b5431954..670c2cd7 100644 --- a/app/modules/catalog/definition.py +++ b/app/modules/catalog/definition.py @@ -135,7 +135,7 @@ catalog_module = ModuleDefinition( id="products", label_key="storefront.nav.products", icon="shopping-bag", - route="storefront/products", + route="products", order=10, ), ], diff --git a/app/modules/cms/definition.py b/app/modules/cms/definition.py index 5fcf2f6b..8328b813 100644 --- a/app/modules/cms/definition.py +++ b/app/modules/cms/definition.py @@ -39,7 +39,9 @@ def _get_platform_context(request: Any, db: Any, platform: Any) -> dict[str, Any """ from app.modules.cms.services import content_page_service - platform_id = platform.id if platform else 1 + if not platform: + return {"header_pages": [], "footer_pages": [], "legal_pages": []} + platform_id = platform.id header_pages = [] footer_pages = [] @@ -73,7 +75,9 @@ def _get_storefront_context(request: Any, db: Any, platform: Any) -> dict[str, A from app.modules.cms.services import content_page_service store = getattr(request.state, "store", None) - platform_id = platform.id if platform else 1 + if not platform: + return {"header_pages": [], "footer_pages": [], "legal_pages": []} + platform_id = platform.id header_pages = [] footer_pages = [] diff --git a/app/modules/cms/routes/api/storefront.py b/app/modules/cms/routes/api/storefront.py index f9e51202..95246c19 100644 --- a/app/modules/cms/routes/api/storefront.py +++ b/app/modules/cms/routes/api/storefront.py @@ -11,6 +11,7 @@ import logging from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session +from app.api.deps import require_platform from app.core.database import get_db from app.modules.cms.schemas import ( ContentPageListItem, @@ -29,7 +30,11 @@ logger = logging.getLogger(__name__) # public - storefront content pages are publicly accessible @router.get("/navigation", response_model=list[ContentPageListItem]) -def get_navigation_pages(request: Request, db: Session = Depends(get_db)): +def get_navigation_pages( + request: Request, + platform=Depends(require_platform), + db: Session = Depends(get_db), +): """ Get list of content pages for navigation (footer/header). @@ -37,9 +42,8 @@ def get_navigation_pages(request: Request, db: Session = Depends(get_db)): Returns store overrides + platform defaults. """ store = getattr(request.state, "store", None) - platform = getattr(request.state, "platform", None) store_id = store.id if store else None - platform_id = platform.id if platform else 1 + platform_id = platform.id # Get all published pages for this store pages = content_page_service.list_pages_for_store( @@ -59,7 +63,12 @@ def get_navigation_pages(request: Request, db: Session = Depends(get_db)): @router.get("/{slug}", response_model=PublicContentPageResponse) -def get_content_page(slug: str, request: Request, db: Session = Depends(get_db)): +def get_content_page( + slug: str, + request: Request, + platform=Depends(require_platform), + db: Session = Depends(get_db), +): """ Get a specific content page by slug. @@ -67,9 +76,8 @@ def get_content_page(slug: str, request: Request, db: Session = Depends(get_db)) Returns store override if exists, otherwise platform default. """ store = getattr(request.state, "store", None) - platform = getattr(request.state, "platform", None) store_id = store.id if store else None - platform_id = platform.id if platform else 1 + platform_id = platform.id page = content_page_service.get_page_for_store_or_raise( db, diff --git a/app/modules/cms/routes/pages/platform.py b/app/modules/cms/routes/pages/platform.py index 757555a9..d9929a1a 100644 --- a/app/modules/cms/routes/pages/platform.py +++ b/app/modules/cms/routes/pages/platform.py @@ -91,8 +91,9 @@ async def homepage( if store: logger.debug(f"[HOMEPAGE] Store detected: {store.subdomain}") - # Get platform_id (use platform from context or default to 1 for OMS) - platform_id = platform.id if platform else 1 + if not platform: + raise HTTPException(status_code=400, detail="Platform context required") + platform_id = platform.id # Try to find store landing page (slug='landing' or 'home') landing_page = content_page_service.get_page_for_store( @@ -133,21 +134,19 @@ async def homepage( else "unknown" ) - if access_method == "path": - full_prefix = ( - store_context.get("full_prefix", "/store/") - if store_context - else "/store/" - ) + if access_method == "path" and platform: return RedirectResponse( - url=f"{full_prefix}{store.subdomain}/storefront/", status_code=302 + url=f"/platforms/{platform.code}/storefront/{store.store_code}/", + status_code=302, ) - # Domain/subdomain - redirect to /storefront/ - return RedirectResponse(url="/storefront/", status_code=302) + # Domain/subdomain - root is storefront + return RedirectResponse(url="/", status_code=302) # Scenario 2: Platform marketing site (no store) # Load platform homepage from CMS (slug='home') - platform_id = platform.id if platform else 1 + if not platform: + raise HTTPException(status_code=400, detail="Platform context required") + platform_id = platform.id cms_homepage = content_page_service.get_platform_page( db, platform_id=platform_id, slug="home", include_unpublished=False @@ -227,9 +226,10 @@ async def content_page( This is a catch-all route for dynamic content pages managed via the admin CMS. Platform pages have store_id=None and is_platform_page=True. """ - # Get platform from middleware (default to OMS platform_id=1) platform = getattr(request.state, "platform", None) - platform_id = platform.id if platform else 1 + if not platform: + raise HTTPException(status_code=400, detail="Platform context required") + platform_id = platform.id # Load platform marketing page from database page = content_page_service.get_platform_page( diff --git a/app/modules/cms/routes/pages/store.py b/app/modules/cms/routes/pages/store.py index 89ee0693..100864f3 100644 --- a/app/modules/cms/routes/pages/store.py +++ b/app/modules/cms/routes/pages/store.py @@ -184,7 +184,9 @@ async def store_content_page( store = getattr(request.state, "store", None) platform = getattr(request.state, "platform", None) store_id = store.id if store else None - platform_id = platform.id if platform else 1 + if not platform: + raise HTTPException(status_code=400, detail="Platform context required") + platform_id = platform.id # Load content page from database (store override → platform default) page = content_page_service.get_page_for_store( diff --git a/app/modules/cms/routes/pages/storefront.py b/app/modules/cms/routes/pages/storefront.py index 1d0b8944..a9acb254 100644 --- a/app/modules/cms/routes/pages/storefront.py +++ b/app/modules/cms/routes/pages/storefront.py @@ -66,7 +66,9 @@ async def generic_content_page( store = getattr(request.state, "store", None) platform = getattr(request.state, "platform", None) store_id = store.id if store else None - platform_id = platform.id if platform else 1 # Default to OMS + if not platform: + raise HTTPException(status_code=400, detail="Platform context required") + platform_id = platform.id # Load content page from database (store override -> store default) page = content_page_service.get_page_for_store( diff --git a/app/modules/core/routes/api/merchant_dashboard.py b/app/modules/core/routes/api/merchant_dashboard.py index 11fa010f..85ef597c 100644 --- a/app/modules/core/routes/api/merchant_dashboard.py +++ b/app/modules/core/routes/api/merchant_dashboard.py @@ -14,7 +14,7 @@ import logging from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session -from app.api.deps import get_merchant_for_current_user +from app.api.deps import get_merchant_for_current_user, require_platform from app.core.database import get_db from app.modules.core.schemas.dashboard import MerchantDashboardStatsResponse from app.modules.core.services.stats_aggregator import stats_aggregator @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) def get_merchant_dashboard_stats( request: Request, merchant=Depends(get_merchant_for_current_user), + platform=Depends(require_platform), db: Session = Depends(get_db), ): """ @@ -41,8 +42,7 @@ def get_merchant_dashboard_stats( Merchant is resolved from the JWT token. Requires Authorization header (API endpoint). """ - platform = getattr(request.state, "platform", None) - platform_id = platform.id if platform else 1 + platform_id = platform.id flat = stats_aggregator.get_merchant_stats_flat( db=db, diff --git a/app/modules/core/routes/api/store_dashboard.py b/app/modules/core/routes/api/store_dashboard.py index 46edac0d..ea8edfde 100644 --- a/app/modules/core/routes/api/store_dashboard.py +++ b/app/modules/core/routes/api/store_dashboard.py @@ -14,7 +14,7 @@ import logging from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session -from app.api.deps import get_current_store_api +from app.api.deps import get_current_store_api, require_platform from app.core.database import get_db from app.modules.core.schemas.dashboard import ( StoreCustomerStats, @@ -49,6 +49,7 @@ def _extract_metric_value( def get_store_dashboard_stats( request: Request, current_user: UserContext = Depends(get_current_store_api), + platform=Depends(require_platform), db: Session = Depends(get_db), ): """ @@ -74,10 +75,7 @@ def get_store_dashboard_stats( if not store.is_active: raise StoreNotActiveException(store.store_code) - # Get aggregated metrics from all enabled modules - # Get platform_id from request context (set by PlatformContextMiddleware) - platform = getattr(request.state, "platform", None) - platform_id = platform.id if platform else 1 + platform_id = platform.id metrics = stats_aggregator.get_store_dashboard_stats( db=db, store_id=store_id, diff --git a/app/modules/core/utils/page_context.py b/app/modules/core/utils/page_context.py index bfcc50fa..4c7c9b9a 100644 --- a/app/modules/core/utils/page_context.py +++ b/app/modules/core/utils/page_context.py @@ -332,14 +332,21 @@ def get_storefront_context( ) # Calculate base URL for links + # Dev path-based: /platforms/{code}/storefront/{store_code}/ + # Prod subdomain/custom domain: / base_url = "/" if access_method == "path" and store: - full_prefix = ( - store_context.get("full_prefix", "/store/") - if store_context - else "/store/" - ) - base_url = f"{full_prefix}{store.subdomain}/" + platform = getattr(request.state, "platform", None) + platform_original_path = getattr(request.state, "platform_original_path", None) + if platform and platform_original_path and platform_original_path.startswith("/platforms/"): + base_url = f"/platforms/{platform.code}/storefront/{store.store_code}/" + else: + full_prefix = ( + store_context.get("full_prefix", "/storefront/") + if store_context + else "/storefront/" + ) + base_url = f"{full_prefix}{store.store_code}/" # Read subscription info set by StorefrontAccessMiddleware subscription = getattr(request.state, "subscription", None) diff --git a/app/modules/customers/definition.py b/app/modules/customers/definition.py index e77e87fb..07330311 100644 --- a/app/modules/customers/definition.py +++ b/app/modules/customers/definition.py @@ -142,28 +142,28 @@ customers_module = ModuleDefinition( id="dashboard", label_key="storefront.account.dashboard", icon="home", - route="storefront/account/dashboard", + route="account/dashboard", order=10, ), MenuItemDefinition( id="profile", label_key="storefront.account.profile", icon="user", - route="storefront/account/profile", + route="account/profile", order=20, ), MenuItemDefinition( id="addresses", label_key="storefront.account.addresses", icon="map-pin", - route="storefront/account/addresses", + route="account/addresses", order=30, ), MenuItemDefinition( id="settings", label_key="storefront.account.settings", icon="cog", - route="storefront/account/settings", + route="account/settings", order=90, ), ], diff --git a/app/modules/loyalty/definition.py b/app/modules/loyalty/definition.py index 63244f59..b01e4f36 100644 --- a/app/modules/loyalty/definition.py +++ b/app/modules/loyalty/definition.py @@ -205,7 +205,7 @@ loyalty_module = ModuleDefinition( id="loyalty", label_key="storefront.account.loyalty", icon="gift", - route="storefront/account/loyalty", + route="account/loyalty", order=60, ), ], diff --git a/app/modules/loyalty/tests/integration/test_storefront_loyalty.py b/app/modules/loyalty/tests/integration/test_storefront_loyalty.py index 4d07f22b..e20ee7fc 100644 --- a/app/modules/loyalty/tests/integration/test_storefront_loyalty.py +++ b/app/modules/loyalty/tests/integration/test_storefront_loyalty.py @@ -23,7 +23,7 @@ class TestStorefrontLoyaltyEndpoints: # Without proper store context, should return 404 or error response = client.get("/api/v1/storefront/loyalty/program") # Endpoint exists but requires store context - assert response.status_code in [200, 404, 422, 500] + assert response.status_code in [200, 403, 404, 422, 500] def test_enroll_endpoint_exists(self, client): """Test that enrollment endpoint is registered.""" @@ -35,16 +35,16 @@ class TestStorefrontLoyaltyEndpoints: }, ) # Endpoint exists but requires store context - assert response.status_code in [200, 404, 422, 500] + assert response.status_code in [200, 403, 404, 422, 500] def test_card_endpoint_exists(self, client): """Test that card endpoint is registered.""" response = client.get("/api/v1/storefront/loyalty/card") # Endpoint exists but requires authentication and store context - assert response.status_code in [401, 404, 422, 500] + assert response.status_code in [401, 403, 404, 422, 500] def test_transactions_endpoint_exists(self, client): """Test that transactions endpoint is registered.""" response = client.get("/api/v1/storefront/loyalty/transactions") # Endpoint exists but requires authentication and store context - assert response.status_code in [401, 404, 422, 500] + assert response.status_code in [401, 403, 404, 422, 500] diff --git a/app/modules/messaging/definition.py b/app/modules/messaging/definition.py index 517534b2..a3716f48 100644 --- a/app/modules/messaging/definition.py +++ b/app/modules/messaging/definition.py @@ -183,7 +183,7 @@ messaging_module = ModuleDefinition( id="messages", label_key="storefront.account.messages", icon="chat-bubble-left-right", - route="storefront/account/messages", + route="account/messages", order=50, ), ], diff --git a/app/modules/monitoring/services/capacity_forecast_service.py b/app/modules/monitoring/services/capacity_forecast_service.py index 57ac815d..f84651ea 100644 --- a/app/modules/monitoring/services/capacity_forecast_service.py +++ b/app/modules/monitoring/services/capacity_forecast_service.py @@ -74,7 +74,9 @@ class CapacityForecastService: # Resource metrics via provider pattern (avoids cross-module imports) start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) platform = db.query(Platform).first() - platform_id = platform.id if platform else 1 + if not platform: + raise ValueError("No platform found in database") + platform_id = platform.id stats = stats_aggregator.get_admin_stats_flat( db, platform_id, diff --git a/app/modules/orders/definition.py b/app/modules/orders/definition.py index 3c302c3f..4d9f8afa 100644 --- a/app/modules/orders/definition.py +++ b/app/modules/orders/definition.py @@ -147,7 +147,7 @@ orders_module = ModuleDefinition( id="orders", label_key="storefront.account.orders", icon="clipboard-list", - route="storefront/account/orders", + route="account/orders", order=40, ), ], diff --git a/app/modules/tenancy/routes/api/admin_stores.py b/app/modules/tenancy/routes/api/admin_stores.py index bafd33a4..86d5c206 100644 --- a/app/modules/tenancy/routes/api/admin_stores.py +++ b/app/modules/tenancy/routes/api/admin_stores.py @@ -72,6 +72,7 @@ def create_store( merchant_contact_phone=store.merchant.contact_phone, merchant_website=store.merchant.website, # Owner info (from merchant) + owner_user_id=store.merchant.owner.id, owner_email=store.merchant.owner.email, owner_username=store.merchant.owner.username, login_url=f"http://localhost:8000/store/{store.subdomain}/login", @@ -143,6 +144,7 @@ def _build_store_detail_response(store) -> StoreDetailResponse: # Merchant info merchant_name=store.merchant.name, # Owner details (from merchant) + owner_user_id=store.merchant.owner_user_id, owner_email=store.merchant.owner.email, owner_username=store.merchant.owner.username, # Resolved contact info with inheritance flags diff --git a/app/modules/tenancy/routes/api/store.py b/app/modules/tenancy/routes/api/store.py index 8edb3a5f..faf909f5 100644 --- a/app/modules/tenancy/routes/api/store.py +++ b/app/modules/tenancy/routes/api/store.py @@ -78,6 +78,7 @@ def get_store_info( merchant_contact_phone=store.merchant.contact_phone, merchant_website=store.merchant.website, # Owner details (from merchant) + owner_user_id=store.merchant.owner_user_id, owner_email=store.merchant.owner.email, owner_username=store.merchant.owner.username, ) diff --git a/app/modules/tenancy/schemas/store.py b/app/modules/tenancy/schemas/store.py index 7754174f..7fc26669 100644 --- a/app/modules/tenancy/schemas/store.py +++ b/app/modules/tenancy/schemas/store.py @@ -232,6 +232,7 @@ class StoreDetailResponse(StoreResponse): merchant_name: str = Field(..., description="Name of the parent merchant") # Owner info (at merchant level) + owner_user_id: int = Field(..., description="User ID of the merchant owner") owner_email: str = Field( ..., description="Email of the merchant owner (for login/authentication)" ) diff --git a/app/modules/tenancy/tests/unit/test_store_schema.py b/app/modules/tenancy/tests/unit/test_store_schema.py index 46af9adb..3784caae 100644 --- a/app/modules/tenancy/tests/unit/test_store_schema.py +++ b/app/modules/tenancy/tests/unit/test_store_schema.py @@ -271,11 +271,13 @@ class TestStoreDetailResponseSchema: "owner_username": "owner", "contact_email": "contact@techstore.com", "contact_email_inherited": False, + "owner_user_id": 42, } response = StoreDetailResponse(**data) assert response.merchant_name == "Tech Corp" assert response.owner_email == "owner@techcorp.com" assert response.contact_email_inherited is False + assert response.owner_user_id == 42 @pytest.mark.unit diff --git a/app/templates/storefront/base.html b/app/templates/storefront/base.html index 471dae45..0d082ab4 100644 --- a/app/templates/storefront/base.html +++ b/app/templates/storefront/base.html @@ -63,7 +63,7 @@ {# Store Logo #}