# app/utils/vat.py """ VAT calculation utilities for the OMS. Provides centralized VAT logic used by both order_service and invoice_service to ensure consistency between order tax calculation and invoice VAT. VAT Logic: - Luxembourg domestic: 17% (standard), 8% (reduced), 3% (super-reduced), 14% (intermediate) - EU cross-border B2C with OSS: Use destination country VAT rate - EU cross-border B2C without OSS: Use Luxembourg VAT rate (origin principle) - EU B2B with valid VAT number: Reverse charge (0% VAT) - Non-EU: VAT exempt (0%) """ from dataclasses import dataclass from decimal import Decimal from enum import Enum class VATRegime(str, Enum): """VAT regime for order/invoice calculation.""" DOMESTIC = "domestic" # Same country as seller OSS = "oss" # EU cross-border with OSS registration REVERSE_CHARGE = "reverse_charge" # B2B with valid VAT number ORIGIN = "origin" # Cross-border without OSS (use origin VAT) EXEMPT = "exempt" # VAT exempt (non-EU) @dataclass class VATResult: """Result of VAT determination.""" regime: VATRegime rate: Decimal destination_country: str | None label: str | None # EU VAT rates by country code (2024 standard rates) EU_VAT_RATES: dict[str, Decimal] = { "AT": Decimal("20.00"), # Austria "BE": Decimal("21.00"), # Belgium "BG": Decimal("20.00"), # Bulgaria "HR": Decimal("25.00"), # Croatia "CY": Decimal("19.00"), # Cyprus "CZ": Decimal("21.00"), # Czech Republic "DK": Decimal("25.00"), # Denmark "EE": Decimal("22.00"), # Estonia "FI": Decimal("24.00"), # Finland "FR": Decimal("20.00"), # France "DE": Decimal("19.00"), # Germany "GR": Decimal("24.00"), # Greece "HU": Decimal("27.00"), # Hungary "IE": Decimal("23.00"), # Ireland "IT": Decimal("22.00"), # Italy "LV": Decimal("21.00"), # Latvia "LT": Decimal("21.00"), # Lithuania "LU": Decimal("17.00"), # Luxembourg (standard) "MT": Decimal("18.00"), # Malta "NL": Decimal("21.00"), # Netherlands "PL": Decimal("23.00"), # Poland "PT": Decimal("23.00"), # Portugal "RO": Decimal("19.00"), # Romania "SK": Decimal("20.00"), # Slovakia "SI": Decimal("22.00"), # Slovenia "ES": Decimal("21.00"), # Spain "SE": Decimal("25.00"), # Sweden } # Luxembourg specific VAT rates LU_VAT_RATES = { "standard": Decimal("17.00"), "intermediate": Decimal("14.00"), "reduced": Decimal("8.00"), "super_reduced": Decimal("3.00"), } # Country names for labels COUNTRY_NAMES: dict[str, str] = { "AT": "Austria", "BE": "Belgium", "BG": "Bulgaria", "HR": "Croatia", "CY": "Cyprus", "CZ": "Czech Republic", "DK": "Denmark", "EE": "Estonia", "FI": "Finland", "FR": "France", "DE": "Germany", "GR": "Greece", "HU": "Hungary", "IE": "Ireland", "IT": "Italy", "LV": "Latvia", "LT": "Lithuania", "LU": "Luxembourg", "MT": "Malta", "NL": "Netherlands", "PL": "Poland", "PT": "Portugal", "RO": "Romania", "SK": "Slovakia", "SI": "Slovenia", "ES": "Spain", "SE": "Sweden", } def get_vat_rate_for_country(country_iso: str) -> Decimal: """ Get standard VAT rate for EU country. Args: country_iso: ISO 2-letter country code Returns: VAT rate as Decimal (0.00 for non-EU countries) """ return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00")) def get_vat_rate_label(country_iso: str, vat_rate: Decimal) -> str: """ Get human-readable VAT rate label. Args: country_iso: ISO 2-letter country code vat_rate: VAT rate as Decimal Returns: Human-readable label (e.g., "Luxembourg VAT 17%") """ country_name = COUNTRY_NAMES.get(country_iso.upper(), country_iso) return f"{country_name} VAT {vat_rate}%" def is_eu_country(country_iso: str) -> bool: """Check if country is in the EU.""" return country_iso.upper() in EU_VAT_RATES def determine_vat_regime( seller_country: str, buyer_country: str, buyer_vat_number: str | None = None, seller_oss_registered: bool = False, ) -> VATResult: """ Determine VAT regime and rate for an order/invoice. VAT Decision Logic: 1. Same country = domestic VAT 2. B2B with valid VAT number = reverse charge (0%) 3. Cross-border + OSS registered = destination country VAT 4. Cross-border + no OSS = origin country VAT 5. Non-EU = VAT exempt (0%) Args: seller_country: Seller's country (ISO 2-letter code) buyer_country: Buyer's country (ISO 2-letter code) buyer_vat_number: Buyer's VAT number (for B2B detection) seller_oss_registered: Whether seller is registered for OSS Returns: VATResult with regime, rate, destination country, and label """ seller_country = seller_country.upper() if seller_country else "LU" buyer_country = buyer_country.upper() if buyer_country else "LU" # Same country = domestic VAT if seller_country == buyer_country: vat_rate = get_vat_rate_for_country(seller_country) label = get_vat_rate_label(seller_country, vat_rate) if vat_rate > 0 else None return VATResult( regime=VATRegime.DOMESTIC, rate=vat_rate, destination_country=None, label=label, ) # Different EU countries if is_eu_country(buyer_country): # B2B with valid VAT number = reverse charge if buyer_vat_number: return VATResult( regime=VATRegime.REVERSE_CHARGE, rate=Decimal("0.00"), destination_country=buyer_country, label="Reverse charge", ) # B2C cross-border if seller_oss_registered: # OSS: use destination country VAT vat_rate = get_vat_rate_for_country(buyer_country) label = get_vat_rate_label(buyer_country, vat_rate) return VATResult( regime=VATRegime.OSS, rate=vat_rate, destination_country=buyer_country, label=label, ) # No OSS: use origin country VAT vat_rate = get_vat_rate_for_country(seller_country) label = get_vat_rate_label(seller_country, vat_rate) return VATResult( regime=VATRegime.ORIGIN, rate=vat_rate, destination_country=buyer_country, label=label, ) # Non-EU = VAT exempt return VATResult( regime=VATRegime.EXEMPT, rate=Decimal("0.00"), destination_country=buyer_country, label="VAT exempt", ) def calculate_vat_amount(subtotal_cents: int, vat_rate: Decimal) -> int: """ Calculate VAT amount from subtotal. Args: subtotal_cents: Subtotal in cents vat_rate: VAT rate as percentage (e.g., 17.00 for 17%) Returns: VAT amount in cents """ if vat_rate <= 0: return 0 # Calculate: tax = subtotal * (rate / 100) subtotal_decimal = Decimal(str(subtotal_cents)) tax_decimal = subtotal_decimal * (vat_rate / Decimal("100")) # Round to nearest cent return int(round(tax_decimal))