import hashlib import secrets from datetime import datetime, timedelta from sqlalchemy import Column, DateTime, ForeignKey, Integer, String from sqlalchemy.orm import Session, relationship from app.core.database import Base class PasswordResetToken(Base): """Password reset token for customer accounts. Security: - Tokens are stored as SHA256 hashes, not plaintext - Tokens expire after 1 hour - Only one active token per customer (old tokens invalidated on new request) """ __tablename__ = "password_reset_tokens" # Token expiry in hours TOKEN_EXPIRY_HOURS = 1 id = Column(Integer, primary_key=True, index=True) customer_id = Column(Integer, ForeignKey("customers.id", ondelete="CASCADE"), nullable=False) token_hash = Column(String(64), nullable=False, index=True) expires_at = Column(DateTime, nullable=False) used_at = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.utcnow, nullable=False) # Relationships customer = relationship("Customer") def __repr__(self): return f"" @staticmethod def hash_token(token: str) -> str: """Hash a token using SHA256.""" return hashlib.sha256(token.encode()).hexdigest() @classmethod def create_for_customer(cls, db: Session, customer_id: int) -> str: """Create a new password reset token for a customer. Invalidates any existing tokens for the customer. Returns the plaintext token (to be sent via email). """ # Invalidate existing tokens for this customer db.query(cls).filter( cls.customer_id == customer_id, cls.used_at.is_(None), ).delete() # Generate new token plaintext_token = secrets.token_urlsafe(32) token_hash = cls.hash_token(plaintext_token) # Create token record token = cls( customer_id=customer_id, token_hash=token_hash, expires_at=datetime.utcnow() + timedelta(hours=cls.TOKEN_EXPIRY_HOURS), ) db.add(token) db.flush() return plaintext_token @classmethod def find_valid_token(cls, db: Session, plaintext_token: str) -> "PasswordResetToken | None": """Find a valid (not expired, not used) token.""" token_hash = cls.hash_token(plaintext_token) return db.query(cls).filter( cls.token_hash == token_hash, cls.expires_at > datetime.utcnow(), cls.used_at.is_(None), ).first() def mark_used(self, db: Session) -> None: """Mark this token as used.""" self.used_at = datetime.utcnow() db.flush()