Files
orion/app/services/company_service.py
Samir Boulahtit 14054bfd6d fix: add unique() to company query with joinedload
SQLAlchemy Error Fix:
- Add .unique() when using joinedload(Company.vendors)
- When eagerly loading collection relationships with joinedload, SQLAlchemy can return duplicate rows
- The unique() method deduplicates results and is required for joined collection loads

Error was:
InvalidRequestError: The unique() method must be invoked on this Result, as it contains results that include joined eager loads against collections

This is a standard SQLAlchemy pattern for handling one-to-many relationships with eager loading.
2025-12-01 22:10:08 +01:00

268 lines
8.0 KiB
Python

# app/services/company_service.py
"""
Company service for managing company operations.
This service handles CRUD operations for companies and company-vendor relationships.
"""
import logging
import secrets
import string
from typing import List, Optional
from sqlalchemy import func, select
from sqlalchemy.orm import Session, joinedload
from app.exceptions import CompanyNotFoundException
from models.database.company import Company
from models.database.user import User
from models.schema.company import CompanyCreate, CompanyUpdate
logger = logging.getLogger(__name__)
class CompanyService:
"""Service for managing companies."""
def __init__(self):
"""Initialize company service."""
pass
def create_company_with_owner(
self, db: Session, company_data: CompanyCreate
) -> tuple[Company, User, str]:
"""
Create a new company with an owner user account.
Args:
db: Database session
company_data: Company creation data
Returns:
Tuple of (company, owner_user, temporary_password)
"""
# Import AuthManager for password hashing (same pattern as admin_service)
from middleware.auth import AuthManager
auth_manager = AuthManager()
# Check if owner email already exists
existing_user = db.execute(
select(User).where(User.email == company_data.owner_email)
).scalar_one_or_none()
if existing_user:
# Use existing user as owner
owner_user = existing_user
temp_password = None
logger.info(f"Using existing user {owner_user.email} as company owner")
else:
# Generate temporary password for owner
temp_password = self._generate_temp_password()
# Create new owner user
owner_user = User(
username=company_data.owner_email.split("@")[0],
email=company_data.owner_email,
hashed_password=auth_manager.hash_password(temp_password),
role="user",
is_active=True,
is_email_verified=True,
)
db.add(owner_user)
db.flush() # Get owner_user.id
logger.info(f"Created new owner user: {owner_user.email}")
# Create company
company = Company(
name=company_data.name,
description=company_data.description,
owner_user_id=owner_user.id,
contact_email=company_data.contact_email,
contact_phone=company_data.contact_phone,
website=company_data.website,
business_address=company_data.business_address,
tax_number=company_data.tax_number,
is_active=True,
is_verified=False,
)
db.add(company)
db.flush()
logger.info(f"Created company: {company.name} (ID: {company.id})")
return company, owner_user, temp_password
def get_company_by_id(self, db: Session, company_id: int) -> Company:
"""
Get company by ID.
Args:
db: Database session
company_id: Company ID
Returns:
Company object
Raises:
CompanyNotFoundException: If company not found
"""
company = db.execute(
select(Company)
.where(Company.id == company_id)
.options(joinedload(Company.vendors))
).scalar_one_or_none()
if not company:
raise CompanyNotFoundException(company_id)
return company
def get_companies(
self,
db: Session,
skip: int = 0,
limit: int = 100,
search: str | None = None,
is_active: bool | None = None,
is_verified: bool | None = None,
) -> tuple[List[Company], int]:
"""
Get paginated list of companies with optional filters.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
search: Search term for company name
is_active: Filter by active status
is_verified: Filter by verified status
Returns:
Tuple of (companies list, total count)
"""
query = select(Company).options(joinedload(Company.vendors))
# Apply filters
if search:
query = query.where(Company.name.ilike(f"%{search}%"))
if is_active is not None:
query = query.where(Company.is_active == is_active)
if is_verified is not None:
query = query.where(Company.is_verified == is_verified)
# Get total count
count_query = select(func.count()).select_from(query.subquery())
total = db.execute(count_query).scalar()
# Apply pagination and order
query = query.order_by(Company.name).offset(skip).limit(limit)
# Use unique() when using joinedload with collections to avoid duplicate rows
companies = list(db.execute(query).scalars().unique().all())
return companies, total
def update_company(
self, db: Session, company_id: int, company_data: CompanyUpdate
) -> Company:
"""
Update company information.
Args:
db: Database session
company_id: Company ID
company_data: Updated company data
Returns:
Updated company
Raises:
CompanyNotFoundException: If company not found
"""
company = self.get_company_by_id(db, company_id)
# Update only provided fields
update_data = company_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(company, field, value)
db.flush()
logger.info(f"Updated company ID {company_id}")
return company
def delete_company(self, db: Session, company_id: int) -> None:
"""
Delete a company and all associated vendors.
Args:
db: Database session
company_id: Company ID
Raises:
CompanyNotFoundException: If company not found
"""
company = self.get_company_by_id(db, company_id)
# Due to cascade="all, delete-orphan", associated vendors will be deleted
db.delete(company)
db.flush()
logger.info(f"Deleted company ID {company_id} and associated vendors")
def toggle_verification(self, db: Session, company_id: int, is_verified: bool) -> Company:
"""
Toggle company verification status.
Args:
db: Database session
company_id: Company ID
is_verified: New verification status
Returns:
Updated company
Raises:
CompanyNotFoundException: If company not found
"""
company = self.get_company_by_id(db, company_id)
company.is_verified = is_verified
db.flush()
logger.info(
f"Company ID {company_id} verification set to {is_verified}"
)
return company
def toggle_active(self, db: Session, company_id: int, is_active: bool) -> Company:
"""
Toggle company active status.
Args:
db: Database session
company_id: Company ID
is_active: New active status
Returns:
Updated company
Raises:
CompanyNotFoundException: If company not found
"""
company = self.get_company_by_id(db, company_id)
company.is_active = is_active
db.flush()
logger.info(
f"Company ID {company_id} active status set to {is_active}"
)
return company
def _generate_temp_password(self, length: int = 12) -> str:
"""Generate secure temporary password."""
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return "".join(secrets.choice(alphabet) for _ in range(length))
# Create service instance following the same pattern as other services
company_service = CompanyService()