# app/modules/prospecting/services/prospect_service.py """ Prospect CRUD service. Manages creation, retrieval, update, and deletion of prospects. Supports both digital (domain scan) and offline (manual capture) channels. """ import json import logging from sqlalchemy import func, or_ from sqlalchemy.orm import Session, joinedload from app.modules.prospecting.exceptions import ( DuplicateDomainException, ProspectNotFoundException, ) from app.modules.prospecting.models import ( Prospect, ProspectChannel, ProspectContact, ProspectScore, ProspectStatus, ) logger = logging.getLogger(__name__) class ProspectService: """Service for prospect CRUD operations.""" def get_by_id(self, db: Session, prospect_id: int) -> Prospect: prospect = ( db.query(Prospect) .options( joinedload(Prospect.tech_profile), joinedload(Prospect.performance_profile), joinedload(Prospect.score), joinedload(Prospect.contacts), ) .filter(Prospect.id == prospect_id) .first() ) if not prospect: raise ProspectNotFoundException(str(prospect_id)) return prospect def get_by_domain(self, db: Session, domain_name: str) -> Prospect | None: return db.query(Prospect).filter(Prospect.domain_name == domain_name).first() def get_all( self, db: Session, *, page: int = 1, per_page: int = 20, search: str | None = None, channel: str | None = None, status: str | None = None, tier: str | None = None, city: str | None = None, has_email: bool | None = None, has_phone: bool | None = None, ) -> tuple[list[Prospect], int]: query = db.query(Prospect).options( joinedload(Prospect.score), joinedload(Prospect.contacts), ) if search: query = query.filter( or_( Prospect.domain_name.ilike(f"%{search}%"), Prospect.business_name.ilike(f"%{search}%"), ) ) if channel: query = query.filter(Prospect.channel == channel) if status: query = query.filter(Prospect.status == status) if city: query = query.filter(Prospect.city.ilike(f"%{city}%")) if tier: query = query.join(ProspectScore).filter(ProspectScore.lead_tier == tier) total = query.count() prospects = ( query.order_by(Prospect.created_at.desc()) .offset((page - 1) * per_page) .limit(per_page) .all() ) return prospects, total @staticmethod def _normalize_domain(domain: str) -> str: """Strip protocol, www prefix, and trailing slash from a domain.""" domain = domain.strip() for prefix in ["https://", "http://"]: if domain.lower().startswith(prefix): domain = domain[len(prefix):] if domain.lower().startswith("www."): domain = domain[4:] return domain.rstrip("/") def create(self, db: Session, data: dict, captured_by_user_id: int | None = None) -> Prospect: channel = data.get("channel", "digital") if channel == "digital" and data.get("domain_name"): data["domain_name"] = self._normalize_domain(data["domain_name"]) existing = self.get_by_domain(db, data["domain_name"]) if existing: raise DuplicateDomainException(data["domain_name"]) tags = data.get("tags") if isinstance(tags, list): tags = json.dumps(tags) prospect = Prospect( channel=ProspectChannel(channel), business_name=data.get("business_name"), domain_name=data.get("domain_name"), status=ProspectStatus.PENDING, source=data.get("source", "domain_scan" if channel == "digital" else "manual"), address=data.get("address"), city=data.get("city"), postal_code=data.get("postal_code"), country=data.get("country", "LU"), notes=data.get("notes"), tags=tags, captured_by_user_id=captured_by_user_id, location_lat=data.get("location_lat"), location_lng=data.get("location_lng"), ) db.add(prospect) db.flush() # Create inline contacts if provided contacts = data.get("contacts", []) contact_objects = [] for c in contacts: contact_objects.append(ProspectContact( prospect_id=prospect.id, contact_type=c["contact_type"], value=c["value"], label=c.get("label"), is_primary=c.get("is_primary", False), )) if contact_objects: db.add_all(contact_objects) db.flush() logger.info("Created prospect: %s (channel=%s)", prospect.display_name, channel) return prospect def create_bulk(self, db: Session, domain_names: list[str], source: str = "csv_import") -> tuple[int, int]: created = 0 skipped = 0 new_prospects = [] for name in domain_names: name = self._normalize_domain(name).lower() if not name: continue existing = self.get_by_domain(db, name) if existing: skipped += 1 continue new_prospects.append(Prospect( channel=ProspectChannel.DIGITAL, domain_name=name, source=source, )) created += 1 if new_prospects: db.add_all(new_prospects) db.flush() logger.info("Bulk import: %d created, %d skipped", created, skipped) return created, skipped def update(self, db: Session, prospect_id: int, data: dict) -> Prospect: prospect = self.get_by_id(db, prospect_id) if "domain_name" in data and data["domain_name"] is not None: prospect.domain_name = self._normalize_domain(data["domain_name"]) for field in ["business_name", "status", "source", "address", "city", "postal_code", "notes"]: if field in data and data[field] is not None: setattr(prospect, field, data[field]) if "tags" in data: tags = data["tags"] if isinstance(tags, list): tags = json.dumps(tags) prospect.tags = tags db.flush() return prospect def delete(self, db: Session, prospect_id: int) -> bool: prospect = self.get_by_id(db, prospect_id) db.delete(prospect) db.flush() logger.info("Deleted prospect: %d", prospect_id) return True def get_pending_http_check(self, db: Session, limit: int = 100) -> list[Prospect]: return ( db.query(Prospect) .filter( Prospect.channel == ProspectChannel.DIGITAL, Prospect.domain_name.isnot(None), Prospect.last_http_check_at.is_(None), ) .limit(limit) .all() ) def get_pending_tech_scan(self, db: Session, limit: int = 100) -> list[Prospect]: return ( db.query(Prospect) .filter( Prospect.has_website.is_(True), Prospect.last_tech_scan_at.is_(None), ) .limit(limit) .all() ) def get_pending_performance_scan(self, db: Session, limit: int = 100) -> list[Prospect]: return ( db.query(Prospect) .filter( Prospect.has_website.is_(True), Prospect.last_perf_scan_at.is_(None), ) .limit(limit) .all() ) def get_pending_contact_scrape(self, db: Session, limit: int = 100) -> list[Prospect]: return ( db.query(Prospect) .filter( Prospect.has_website.is_(True), Prospect.last_contact_scrape_at.is_(None), ) .limit(limit) .all() ) def get_pending_content_scrape(self, db: Session, limit: int = 100) -> list[Prospect]: return ( db.query(Prospect) .filter( Prospect.has_website.is_(True), Prospect.last_content_scrape_at.is_(None), ) .limit(limit) .all() ) def get_pending_security_audit(self, db: Session, limit: int = 50) -> list[Prospect]: return ( db.query(Prospect) .filter( Prospect.has_website.is_(True), Prospect.last_security_audit_at.is_(None), ) .limit(limit) .all() ) def count_by_status(self, db: Session) -> dict[str, int]: results = db.query(Prospect.status, func.count(Prospect.id)).group_by(Prospect.status).all() # noqa: SVC-005 - prospecting is platform-scoped, not store-scoped return {status.value if hasattr(status, "value") else str(status): count for status, count in results} def count_by_channel(self, db: Session) -> dict[str, int]: results = db.query(Prospect.channel, func.count(Prospect.id)).group_by(Prospect.channel).all() return {channel.value if hasattr(channel, "value") else str(channel): count for channel, count in results} prospect_service = ProspectService()