feat: enhance messaging system with improved API and tests

- Refactor messaging API endpoints for admin, shop, and vendor
- Add message-specific exceptions (ConversationNotFoundException, etc.)
- Enhance messaging service with additional helper methods
- Add comprehensive test fixtures for messaging
- Add integration tests for admin and vendor messaging APIs
- Add unit tests for messaging and attachment services

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-21 21:01:14 +01:00
parent 3bfe0ad3f8
commit 0098093287
11 changed files with 2229 additions and 136 deletions

View File

@@ -12,12 +12,19 @@ Provides endpoints for:
import logging
from typing import Any
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.exceptions import (
ConversationClosedException,
ConversationNotFoundException,
InvalidConversationTypeException,
InvalidRecipientTypeException,
MessageAttachmentException,
)
from app.services.message_attachment_service import message_attachment_service
from app.services.messaging_service import messaging_service
from models.database.message import ConversationType, ParticipantType
@@ -231,73 +238,45 @@ def get_recipients(
current_admin: User = Depends(get_current_admin_api),
) -> RecipientListResponse:
"""Get list of available recipients for compose modal."""
from models.database.customer import Customer
from models.database.vendor import VendorUser
recipients = []
if recipient_type == ParticipantType.VENDOR:
# List vendor users (for admin_vendor conversations)
query = (
db.query(User, VendorUser)
.join(VendorUser, User.id == VendorUser.user_id)
.filter(User.is_active == True) # noqa: E712
recipient_data, total = messaging_service.get_vendor_recipients(
db=db,
vendor_id=vendor_id,
search=search,
skip=skip,
limit=limit,
)
if vendor_id:
query = query.filter(VendorUser.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(User.username.ilike(search_pattern))
| (User.email.ilike(search_pattern))
| (User.first_name.ilike(search_pattern))
| (User.last_name.ilike(search_pattern))
recipients = [
RecipientOption(
id=r["id"],
type=r["type"],
name=r["name"],
email=r["email"],
vendor_id=r["vendor_id"],
vendor_name=r.get("vendor_name"),
)
total = query.count()
results = query.offset(skip).limit(limit).all()
for user, vendor_user in results:
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username
recipients.append(
RecipientOption(
id=user.id,
type=ParticipantType.VENDOR,
name=name,
email=user.email,
vendor_id=vendor_user.vendor_id,
vendor_name=vendor_user.vendor.name if vendor_user.vendor else None,
)
)
for r in recipient_data
]
elif recipient_type == ParticipantType.CUSTOMER:
# List customers (for admin_customer conversations)
query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_pattern))
| (Customer.first_name.ilike(search_pattern))
| (Customer.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
for customer in results:
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
recipients.append(
RecipientOption(
id=customer.id,
type=ParticipantType.CUSTOMER,
name=name or customer.email,
email=customer.email,
vendor_id=customer.vendor_id,
)
recipient_data, total = messaging_service.get_customer_recipients(
db=db,
vendor_id=vendor_id,
search=search,
skip=skip,
limit=limit,
)
recipients = [
RecipientOption(
id=r["id"],
type=r["type"],
name=r["name"],
email=r["email"],
vendor_id=r["vendor_id"],
)
for r in recipient_data
]
else:
recipients = []
total = 0
return RecipientListResponse(recipients=recipients, total=total)
@@ -320,9 +299,9 @@ def create_conversation(
ConversationType.ADMIN_VENDOR,
ConversationType.ADMIN_CUSTOMER,
]:
raise HTTPException(
status_code=400,
detail="Admin can only create admin_vendor or admin_customer conversations",
raise InvalidConversationTypeException(
message="Admin can only create admin_vendor or admin_customer conversations",
allowed_types=["admin_vendor", "admin_customer"],
)
# Validate recipient type matches conversation type
@@ -330,17 +309,17 @@ def create_conversation(
data.conversation_type == ConversationType.ADMIN_VENDOR
and data.recipient_type != ParticipantType.VENDOR
):
raise HTTPException(
status_code=400,
detail="admin_vendor conversations require a vendor recipient",
raise InvalidRecipientTypeException(
conversation_type="admin_vendor",
expected_recipient_type="vendor",
)
if (
data.conversation_type == ConversationType.ADMIN_CUSTOMER
and data.recipient_type != ParticipantType.CUSTOMER
):
raise HTTPException(
status_code=400,
detail="admin_customer conversations require a customer recipient",
raise InvalidRecipientTypeException(
conversation_type="admin_customer",
expected_recipient_type="customer",
)
# Create conversation
@@ -460,7 +439,7 @@ def get_conversation(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
# Mark as read if requested
if mark_read:
@@ -498,12 +477,10 @@ async def send_message(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
if conversation.is_closed:
raise HTTPException(
status_code=400, detail="Cannot send messages to a closed conversation"
)
raise ConversationClosedException(conversation_id)
# Process attachments
attachments = []
@@ -514,7 +491,7 @@ async def send_message(
)
attachments.append(att_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise MessageAttachmentException(str(e))
# Send message
message = messaging_service.send_message(
@@ -556,7 +533,7 @@ def close_conversation(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
db.commit()
logger.info(
@@ -585,7 +562,7 @@ def reopen_conversation(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
db.commit()
logger.info(

View File

@@ -21,7 +21,12 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_customer_api
from app.core.database import get_db
from app.exceptions import ConversationNotFoundException, VendorNotFoundException
from app.exceptions import (
AttachmentNotFoundException,
ConversationClosedException,
ConversationNotFoundException,
VendorNotFoundException,
)
from app.services.message_attachment_service import message_attachment_service
from app.services.messaging_service import messaging_service
from models.database.customer import Customer
@@ -292,12 +297,7 @@ async def send_message(
# Check if conversation is closed
if conversation.is_closed:
from fastapi import HTTPException
raise HTTPException(
status_code=400,
detail="Cannot send messages to a closed conversation",
)
raise ConversationClosedException(conversation_id)
# Process attachments
attachment_data = []
@@ -405,7 +405,6 @@ async def download_attachment(
Validates that customer has access to the conversation.
"""
from fastapi import HTTPException
from fastapi.responses import FileResponse
vendor = getattr(request.state, "vendor", None)
@@ -433,7 +432,7 @@ async def download_attachment(
)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
raise AttachmentNotFoundException(attachment_id)
return FileResponse(
path=attachment.file_path,
@@ -455,7 +454,6 @@ async def get_attachment_thumbnail(
Validates that customer has access to the conversation.
"""
from fastapi import HTTPException
from fastapi.responses import FileResponse
vendor = getattr(request.state, "vendor", None)
@@ -483,7 +481,7 @@ async def get_attachment_thumbnail(
)
if not attachment or not attachment.thumbnail_path:
raise HTTPException(status_code=404, detail="Thumbnail not found")
raise AttachmentNotFoundException(f"{attachment_id}/thumbnail")
return FileResponse(
path=attachment.thumbnail_path,

View File

@@ -14,12 +14,19 @@ Uses get_current_vendor_api dependency which guarantees token_vendor_id is prese
import logging
from typing import Any
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from app.exceptions import (
ConversationClosedException,
ConversationNotFoundException,
InvalidConversationTypeException,
InvalidRecipientTypeException,
MessageAttachmentException,
)
from app.services.message_attachment_service import message_attachment_service
from app.services.messaging_service import messaging_service
from models.database.message import ConversationType, ParticipantType
@@ -230,41 +237,30 @@ def get_recipients(
current_user: User = Depends(get_current_vendor_api),
) -> RecipientListResponse:
"""Get list of available recipients for compose modal."""
from models.database.customer import Customer
vendor_id = current_user.token_vendor_id
recipients = []
if recipient_type == ParticipantType.CUSTOMER:
# List customers for this vendor (for vendor_customer conversations)
query = db.query(Customer).filter(
Customer.vendor_id == vendor_id,
Customer.is_active == True, # noqa: E712
recipient_data, total = messaging_service.get_customer_recipients(
db=db,
vendor_id=vendor_id,
search=search,
skip=skip,
limit=limit,
)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_pattern))
| (Customer.first_name.ilike(search_pattern))
| (Customer.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
for customer in results:
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
recipients.append(
RecipientOption(
id=customer.id,
type=ParticipantType.CUSTOMER,
name=name or customer.email,
email=customer.email,
vendor_id=customer.vendor_id,
)
recipients = [
RecipientOption(
id=r["id"],
type=r["type"],
name=r["name"],
email=r["email"],
vendor_id=r["vendor_id"],
)
for r in recipient_data
]
else:
# Vendors can't start conversations with admins - admins initiate those
recipients = []
total = 0
return RecipientListResponse(recipients=recipients, total=total)
@@ -286,15 +282,15 @@ def create_conversation(
# Vendors can only create vendor_customer conversations
if data.conversation_type != ConversationType.VENDOR_CUSTOMER:
raise HTTPException(
status_code=400,
detail="Vendors can only create vendor_customer conversations",
raise InvalidConversationTypeException(
message="Vendors can only create vendor_customer conversations",
allowed_types=["vendor_customer"],
)
if data.recipient_type != ParticipantType.CUSTOMER:
raise HTTPException(
status_code=400,
detail="vendor_customer conversations require a customer recipient",
raise InvalidRecipientTypeException(
conversation_type="vendor_customer",
expected_recipient_type="customer",
)
# Create conversation
@@ -416,11 +412,11 @@ def get_conversation(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
# Verify vendor context
if conversation.vendor_id and conversation.vendor_id != vendor_id:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
# Mark as read if requested
if mark_read:
@@ -460,16 +456,14 @@ async def send_message(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
# Verify vendor context
if conversation.vendor_id and conversation.vendor_id != vendor_id:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
if conversation.is_closed:
raise HTTPException(
status_code=400, detail="Cannot send messages to a closed conversation"
)
raise ConversationClosedException(conversation_id)
# Process attachments
attachments = []
@@ -480,7 +474,7 @@ async def send_message(
)
attachments.append(att_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise MessageAttachmentException(str(e))
# Send message
message = messaging_service.send_message(
@@ -525,10 +519,10 @@ def close_conversation(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
if conversation.vendor_id and conversation.vendor_id != vendor_id:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
conversation = messaging_service.close_conversation(
db=db,
@@ -567,10 +561,10 @@ def reopen_conversation(
)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
if conversation.vendor_id and conversation.vendor_id != vendor_id:
raise HTTPException(status_code=404, detail="Conversation not found")
raise ConversationNotFoundException(str(conversation_id))
conversation = messaging_service.reopen_conversation(
db=db,

View File

@@ -103,8 +103,11 @@ from .inventory import (
# Message exceptions
from .message import (
AttachmentNotFoundException,
ConversationClosedException,
ConversationNotFoundException,
InvalidConversationTypeException,
InvalidRecipientTypeException,
MessageAttachmentException,
MessageNotFoundException,
UnauthorizedConversationAccessException,
@@ -389,4 +392,7 @@ __all__ = [
"ConversationClosedException",
"MessageAttachmentException",
"UnauthorizedConversationAccessException",
"InvalidConversationTypeException",
"InvalidRecipientTypeException",
"AttachmentNotFoundException",
]

View File

@@ -61,3 +61,40 @@ class UnauthorizedConversationAccessException(BusinessLogicException):
error_code="CONVERSATION_ACCESS_DENIED",
details={"conversation_id": conversation_id},
)
class InvalidConversationTypeException(ValidationException):
"""Raised when conversation type is not valid for the operation."""
def __init__(self, message: str, allowed_types: list[str] | None = None):
super().__init__(
message=message,
error_code="INVALID_CONVERSATION_TYPE",
details={"allowed_types": allowed_types} if allowed_types else None,
)
class InvalidRecipientTypeException(ValidationException):
"""Raised when recipient type doesn't match conversation type."""
def __init__(self, conversation_type: str, expected_recipient_type: str):
super().__init__(
message=f"{conversation_type} conversations require a {expected_recipient_type} recipient",
error_code="INVALID_RECIPIENT_TYPE",
details={
"conversation_type": conversation_type,
"expected_recipient_type": expected_recipient_type,
},
)
class AttachmentNotFoundException(ResourceNotFoundException):
"""Raised when an attachment is not found."""
def __init__(self, attachment_id: int | str):
super().__init__(
resource_type="Attachment",
identifier=str(attachment_id),
message=f"Attachment '{attachment_id}' not found",
error_code="ATTACHMENT_NOT_FOUND",
)

View File

@@ -567,6 +567,118 @@ class MessagingService:
db.flush()
return result > 0
# =========================================================================
# RECIPIENT QUERIES
# =========================================================================
def get_vendor_recipients(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[dict], int]:
"""
Get list of vendor users as potential recipients.
Args:
db: Database session
vendor_id: Optional vendor ID filter
search: Search term for name/email
skip: Pagination offset
limit: Max results
Returns:
Tuple of (recipients list, total count)
"""
from models.database.vendor import VendorUser
query = (
db.query(User, VendorUser)
.join(VendorUser, User.id == VendorUser.user_id)
.filter(User.is_active == True) # noqa: E712
)
if vendor_id:
query = query.filter(VendorUser.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(User.username.ilike(search_pattern))
| (User.email.ilike(search_pattern))
| (User.first_name.ilike(search_pattern))
| (User.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
recipients = []
for user, vendor_user in results:
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username
recipients.append({
"id": user.id,
"type": ParticipantType.VENDOR,
"name": name,
"email": user.email,
"vendor_id": vendor_user.vendor_id,
"vendor_name": vendor_user.vendor.name if vendor_user.vendor else None,
})
return recipients, total
def get_customer_recipients(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[dict], int]:
"""
Get list of customers as potential recipients.
Args:
db: Database session
vendor_id: Optional vendor ID filter (required for vendor users)
search: Search term for name/email
skip: Pagination offset
limit: Max results
Returns:
Tuple of (recipients list, total count)
"""
query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_pattern))
| (Customer.first_name.ilike(search_pattern))
| (Customer.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
recipients = []
for customer in results:
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
recipients.append({
"id": customer.id,
"type": ParticipantType.CUSTOMER,
"name": name or customer.email,
"email": customer.email,
"vendor_id": customer.vendor_id,
})
return recipients, total
# Singleton instance
messaging_service = MessagingService()

350
tests/fixtures/message_fixtures.py vendored Normal file
View File

@@ -0,0 +1,350 @@
# tests/fixtures/message_fixtures.py
"""
Messaging-related test fixtures.
Note: Fixtures should NOT use db.expunge() as it breaks lazy loading.
See tests/conftest.py for details on fixture best practices.
"""
import pytest
from models.database.message import (
Conversation,
ConversationParticipant,
ConversationType,
Message,
MessageAttachment,
ParticipantType,
)
@pytest.fixture
def test_conversation_admin_vendor(db, test_admin, test_vendor_user, test_vendor):
"""Create a test conversation between admin and vendor user."""
conversation = Conversation(
conversation_type=ConversationType.ADMIN_VENDOR,
subject="Test Admin-Vendor Conversation",
vendor_id=test_vendor.id,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
# Add admin participant
admin_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
db.add(admin_participant)
# Add vendor participant
vendor_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor.id,
)
db.add(vendor_participant)
db.commit()
db.refresh(conversation)
return conversation
@pytest.fixture
def test_conversation_vendor_customer(db, test_vendor_user, test_customer, test_vendor):
"""Create a test conversation between vendor and customer."""
conversation = Conversation(
conversation_type=ConversationType.VENDOR_CUSTOMER,
subject="Test Vendor-Customer Conversation",
vendor_id=test_vendor.id,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
# Add vendor participant
vendor_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor.id,
)
db.add(vendor_participant)
# Add customer participant
customer_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.CUSTOMER,
participant_id=test_customer.id,
)
db.add(customer_participant)
db.commit()
db.refresh(conversation)
return conversation
@pytest.fixture
def test_conversation_admin_customer(db, test_admin, test_customer, test_vendor):
"""Create a test conversation between admin and customer."""
conversation = Conversation(
conversation_type=ConversationType.ADMIN_CUSTOMER,
subject="Test Admin-Customer Conversation",
vendor_id=test_vendor.id,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
# Add admin participant
admin_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
db.add(admin_participant)
# Add customer participant
customer_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.CUSTOMER,
participant_id=test_customer.id,
)
db.add(customer_participant)
db.commit()
db.refresh(conversation)
return conversation
@pytest.fixture
def test_message(db, test_conversation_admin_vendor, test_admin):
"""Create a test message in a conversation."""
message = Message(
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="This is a test message from admin.",
)
db.add(message)
# Update conversation stats
test_conversation_admin_vendor.message_count = 1
test_conversation_admin_vendor.last_message_at = message.created_at
db.commit()
db.refresh(message)
return message
@pytest.fixture
def test_message_with_attachment(db, test_conversation_admin_vendor, test_admin):
"""Create a test message with an attachment."""
message = Message(
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="This message has an attachment.",
)
db.add(message)
db.commit()
db.refresh(message)
attachment = MessageAttachment(
message_id=message.id,
filename="test_file_abc123.pdf",
original_filename="test_document.pdf",
file_path="/uploads/messages/2025/01/1/test_file_abc123.pdf",
file_size=12345,
mime_type="application/pdf",
is_image=False,
)
db.add(attachment)
db.commit()
db.refresh(message)
return message
@pytest.fixture
def closed_conversation(db, test_admin, test_vendor_user, test_vendor):
"""Create a closed conversation."""
from datetime import datetime, timezone
conversation = Conversation(
conversation_type=ConversationType.ADMIN_VENDOR,
subject="Closed Conversation",
vendor_id=test_vendor.id,
is_closed=True,
closed_at=datetime.now(timezone.utc),
closed_by_type=ParticipantType.ADMIN,
closed_by_id=test_admin.id,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
# Add participants
admin_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
db.add(admin_participant)
vendor_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor.id,
)
db.add(vendor_participant)
db.commit()
db.refresh(conversation)
return conversation
@pytest.fixture
def multiple_conversations(db, test_admin, test_vendor_user, test_customer, test_vendor):
"""Create multiple conversations of different types."""
conversations = []
# Create 3 admin-vendor conversations
for i in range(3):
conv = Conversation(
conversation_type=ConversationType.ADMIN_VENDOR,
subject=f"Admin-Vendor Conversation {i+1}",
vendor_id=test_vendor.id,
)
db.add(conv)
db.commit()
db.refresh(conv)
db.add(
ConversationParticipant(
conversation_id=conv.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
)
db.add(
ConversationParticipant(
conversation_id=conv.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor.id,
)
)
conversations.append(conv)
# Create 2 vendor-customer conversations
for i in range(2):
conv = Conversation(
conversation_type=ConversationType.VENDOR_CUSTOMER,
subject=f"Vendor-Customer Conversation {i+1}",
vendor_id=test_vendor.id,
)
db.add(conv)
db.commit()
db.refresh(conv)
db.add(
ConversationParticipant(
conversation_id=conv.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor.id,
)
)
db.add(
ConversationParticipant(
conversation_id=conv.id,
participant_type=ParticipantType.CUSTOMER,
participant_id=test_customer.id,
)
)
conversations.append(conv)
db.commit()
# Refresh all
for conv in conversations:
db.refresh(conv)
return conversations
@pytest.fixture
def vendor_api_conversation(db, test_admin, test_vendor_user, test_vendor_with_vendor_user):
"""Create a conversation for vendor API tests (uses vendor from vendor_user_headers)."""
conversation = Conversation(
conversation_type=ConversationType.ADMIN_VENDOR,
subject="Vendor API Test Conversation",
vendor_id=test_vendor_with_vendor_user.id,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
# Add admin participant
admin_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
db.add(admin_participant)
# Add vendor participant (uses test_vendor_user which vendor_user_headers uses)
vendor_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor_with_vendor_user.id,
)
db.add(vendor_participant)
db.commit()
db.refresh(conversation)
return conversation
@pytest.fixture
def vendor_api_closed_conversation(db, test_admin, test_vendor_user, test_vendor_with_vendor_user):
"""Create a closed conversation for vendor API tests."""
from datetime import datetime, timezone
conversation = Conversation(
conversation_type=ConversationType.ADMIN_VENDOR,
subject="Vendor API Closed Conversation",
vendor_id=test_vendor_with_vendor_user.id,
is_closed=True,
closed_at=datetime.now(timezone.utc),
closed_by_type=ParticipantType.ADMIN,
closed_by_id=test_admin.id,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
# Add participants
admin_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
db.add(admin_participant)
vendor_participant = ConversationParticipant(
conversation_id=conversation.id,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor_with_vendor_user.id,
)
db.add(vendor_participant)
db.commit()
db.refresh(conversation)
return conversation

View File

@@ -0,0 +1,389 @@
# tests/integration/api/v1/admin/test_messages.py
"""
Integration tests for admin messaging endpoints.
Tests the /api/v1/admin/messages/* endpoints.
"""
import pytest
from models.database.message import ConversationType, ParticipantType
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesListAPI:
"""Tests for admin message list endpoints."""
def test_list_conversations_empty(self, client, admin_headers):
"""Test listing conversations when none exist."""
response = client.get("/api/v1/admin/messages", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert "conversations" in data
assert "total" in data
assert "total_unread" in data
assert data["conversations"] == []
assert data["total"] == 0
def test_list_conversations_requires_auth(self, client):
"""Test that listing requires authentication."""
response = client.get("/api/v1/admin/messages")
assert response.status_code == 401
def test_list_conversations_requires_admin(self, client, auth_headers):
"""Test that listing requires admin role."""
response = client.get("/api/v1/admin/messages", headers=auth_headers)
assert response.status_code == 403
def test_list_conversations_with_data(
self, client, admin_headers, test_conversation_admin_vendor
):
"""Test listing conversations with existing data."""
response = client.get("/api/v1/admin/messages", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
assert len(data["conversations"]) >= 1
# Check conversation structure
conv = data["conversations"][0]
assert "id" in conv
assert "conversation_type" in conv
assert "subject" in conv
assert "is_closed" in conv
def test_list_conversations_filter_by_type(
self, client, admin_headers, test_conversation_admin_vendor
):
"""Test filtering conversations by type."""
response = client.get(
"/api/v1/admin/messages",
params={"conversation_type": "admin_vendor"},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
for conv in data["conversations"]:
assert conv["conversation_type"] == "admin_vendor"
def test_list_conversations_filter_closed(
self, client, admin_headers, closed_conversation
):
"""Test filtering closed conversations."""
response = client.get(
"/api/v1/admin/messages",
params={"is_closed": True},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
for conv in data["conversations"]:
assert conv["is_closed"] is True
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesUnreadCountAPI:
"""Tests for unread count endpoint."""
def test_get_unread_count(self, client, admin_headers):
"""Test getting unread count."""
response = client.get("/api/v1/admin/messages/unread-count", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert "total_unread" in data
assert isinstance(data["total_unread"], int)
def test_get_unread_count_with_unread(
self, client, admin_headers, test_message
):
"""Test unread count with unread messages."""
response = client.get("/api/v1/admin/messages/unread-count", headers=admin_headers)
assert response.status_code == 200
data = response.json()
# The test_message is sent by admin, so no unread count for admin
assert data["total_unread"] >= 0
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesRecipientsAPI:
"""Tests for recipients endpoint."""
def test_get_vendor_recipients(
self, client, admin_headers, test_vendor_with_vendor_user
):
"""Test getting vendor recipients."""
response = client.get(
"/api/v1/admin/messages/recipients",
params={"recipient_type": "vendor"},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert "recipients" in data
assert "total" in data
def test_get_customer_recipients(self, client, admin_headers, test_customer):
"""Test getting customer recipients."""
response = client.get(
"/api/v1/admin/messages/recipients",
params={"recipient_type": "customer"},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert "recipients" in data
assert "total" in data
def test_get_recipients_requires_type(self, client, admin_headers):
"""Test that recipient_type is required."""
response = client.get("/api/v1/admin/messages/recipients", headers=admin_headers)
assert response.status_code == 422 # Validation error
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesCreateAPI:
"""Tests for conversation creation."""
def test_create_conversation_admin_vendor(
self, client, admin_headers, test_vendor_user, test_vendor_with_vendor_user
):
"""Test creating admin-vendor conversation."""
response = client.post(
"/api/v1/admin/messages",
json={
"conversation_type": "admin_vendor",
"subject": "Test Conversation",
"recipient_type": "vendor",
"recipient_id": test_vendor_user.id,
"vendor_id": test_vendor_with_vendor_user.id,
"initial_message": "Hello vendor!",
},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["subject"] == "Test Conversation"
assert data["conversation_type"] == "admin_vendor"
assert len(data["messages"]) == 1
assert data["messages"][0]["content"] == "Hello vendor!"
def test_create_conversation_admin_customer(
self, client, admin_headers, test_customer, test_vendor
):
"""Test creating admin-customer conversation."""
response = client.post(
"/api/v1/admin/messages",
json={
"conversation_type": "admin_customer",
"subject": "Customer Support",
"recipient_type": "customer",
"recipient_id": test_customer.id,
"vendor_id": test_vendor.id,
"initial_message": "How can I help you?",
},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["conversation_type"] == "admin_customer"
def test_create_conversation_wrong_recipient_type(
self, client, admin_headers, test_vendor_user, test_vendor
):
"""Test error when recipient type doesn't match conversation type."""
response = client.post(
"/api/v1/admin/messages",
json={
"conversation_type": "admin_vendor",
"subject": "Test",
"recipient_type": "customer", # Wrong type
"recipient_id": 1,
"vendor_id": test_vendor.id,
},
headers=admin_headers,
)
assert response.status_code == 400
def test_create_conversation_invalid_type(
self, client, admin_headers, test_vendor_user, test_vendor
):
"""Test error when admin tries to create vendor_customer conversation."""
response = client.post(
"/api/v1/admin/messages",
json={
"conversation_type": "vendor_customer",
"subject": "Test",
"recipient_type": "customer",
"recipient_id": 1,
"vendor_id": test_vendor.id,
},
headers=admin_headers,
)
assert response.status_code == 400
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesDetailAPI:
"""Tests for conversation detail."""
def test_get_conversation_detail(
self, client, admin_headers, test_conversation_admin_vendor
):
"""Test getting conversation detail."""
response = client.get(
f"/api/v1/admin/messages/{test_conversation_admin_vendor.id}",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == test_conversation_admin_vendor.id
assert "participants" in data
assert "messages" in data
def test_get_conversation_not_found(self, client, admin_headers):
"""Test getting nonexistent conversation."""
response = client.get("/api/v1/admin/messages/99999", headers=admin_headers)
assert response.status_code == 404
def test_get_conversation_marks_read(
self, client, admin_headers, test_conversation_admin_vendor
):
"""Test that getting detail marks as read by default."""
response = client.get(
f"/api/v1/admin/messages/{test_conversation_admin_vendor.id}",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["unread_count"] == 0
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesSendAPI:
"""Tests for sending messages."""
def test_send_message(self, client, admin_headers, test_conversation_admin_vendor):
"""Test sending a message."""
response = client.post(
f"/api/v1/admin/messages/{test_conversation_admin_vendor.id}/messages",
data={"content": "Test message content"},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["content"] == "Test message content"
assert data["sender_type"] == "admin"
def test_send_message_to_closed(self, client, admin_headers, closed_conversation):
"""Test cannot send to closed conversation."""
response = client.post(
f"/api/v1/admin/messages/{closed_conversation.id}/messages",
data={"content": "Test message"},
headers=admin_headers,
)
assert response.status_code == 400
def test_send_message_not_found(self, client, admin_headers):
"""Test sending to nonexistent conversation."""
response = client.post(
"/api/v1/admin/messages/99999/messages",
data={"content": "Test message"},
headers=admin_headers,
)
assert response.status_code == 404
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.admin
class TestAdminMessagesActionsAPI:
"""Tests for conversation actions."""
def test_close_conversation(
self, client, admin_headers, test_conversation_admin_vendor
):
"""Test closing a conversation."""
response = client.post(
f"/api/v1/admin/messages/{test_conversation_admin_vendor.id}/close",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "closed" in data["message"].lower()
def test_close_conversation_not_found(self, client, admin_headers):
"""Test closing nonexistent conversation."""
response = client.post(
"/api/v1/admin/messages/99999/close",
headers=admin_headers,
)
assert response.status_code == 404
def test_reopen_conversation(self, client, admin_headers, closed_conversation):
"""Test reopening a closed conversation."""
response = client.post(
f"/api/v1/admin/messages/{closed_conversation.id}/reopen",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "reopen" in data["message"].lower()
def test_mark_read(self, client, admin_headers, test_conversation_admin_vendor):
"""Test marking conversation as read."""
response = client.put(
f"/api/v1/admin/messages/{test_conversation_admin_vendor.id}/read",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["unread_count"] == 0
def test_update_preferences(
self, client, admin_headers, test_conversation_admin_vendor
):
"""Test updating notification preferences."""
response = client.put(
f"/api/v1/admin/messages/{test_conversation_admin_vendor.id}/preferences",
json={"email_notifications": False, "muted": True},
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True

View File

@@ -0,0 +1,256 @@
# tests/integration/api/v1/vendor/test_messages.py
"""
Integration tests for vendor messaging endpoints.
Tests the /api/v1/vendor/messages/* endpoints.
"""
import pytest
from models.database.message import ConversationType, ParticipantType
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesListAPI:
"""Tests for vendor message list endpoints."""
def test_list_conversations_empty(self, client, vendor_user_headers):
"""Test listing conversations when none exist."""
response = client.get("/api/v1/vendor/messages", headers=vendor_user_headers)
assert response.status_code == 200
data = response.json()
assert "conversations" in data
assert "total" in data
assert "total_unread" in data
assert data["total"] == 0
def test_list_conversations_requires_auth(self, client):
"""Test that listing requires authentication."""
response = client.get("/api/v1/vendor/messages")
assert response.status_code == 401
def test_list_conversations_requires_vendor(self, client, admin_headers):
"""Test that admin cannot use vendor endpoint."""
response = client.get("/api/v1/vendor/messages", headers=admin_headers)
# Admin doesn't have vendor context
assert response.status_code == 403
def test_list_conversations_with_data(
self, client, vendor_user_headers, vendor_api_conversation
):
"""Test listing conversations with existing data."""
response = client.get("/api/v1/vendor/messages", headers=vendor_user_headers)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
def test_list_conversations_filter_by_type(
self, client, vendor_user_headers, vendor_api_conversation
):
"""Test filtering conversations by type."""
response = client.get(
"/api/v1/vendor/messages",
params={"conversation_type": "admin_vendor"},
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
for conv in data["conversations"]:
assert conv["conversation_type"] == "admin_vendor"
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesUnreadCountAPI:
"""Tests for unread count endpoint."""
def test_get_unread_count(self, client, vendor_user_headers):
"""Test getting unread count."""
response = client.get(
"/api/v1/vendor/messages/unread-count", headers=vendor_user_headers
)
assert response.status_code == 200
data = response.json()
assert "total_unread" in data
assert isinstance(data["total_unread"], int)
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesRecipientsAPI:
"""Tests for recipients endpoint."""
def test_get_customer_recipients(
self, client, vendor_user_headers, test_customer
):
"""Test getting customer recipients."""
response = client.get(
"/api/v1/vendor/messages/recipients",
params={"recipient_type": "customer"},
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
assert "recipients" in data
assert "total" in data
def test_get_recipients_requires_type(self, client, vendor_user_headers):
"""Test that recipient_type is required."""
response = client.get(
"/api/v1/vendor/messages/recipients", headers=vendor_user_headers
)
assert response.status_code == 422
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesCreateAPI:
"""Tests for conversation creation."""
def test_create_conversation_vendor_customer(
self, client, vendor_user_headers, test_customer, test_vendor
):
"""Test creating vendor-customer conversation."""
response = client.post(
"/api/v1/vendor/messages",
json={
"conversation_type": "vendor_customer",
"subject": "Customer Support",
"recipient_type": "customer",
"recipient_id": test_customer.id,
"vendor_id": test_vendor.id,
"initial_message": "Hello customer!",
},
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
assert data["subject"] == "Customer Support"
assert data["conversation_type"] == "vendor_customer"
def test_create_conversation_admin_vendor_not_allowed(
self, client, vendor_user_headers, test_admin, test_vendor
):
"""Test vendor cannot initiate admin_vendor conversation."""
response = client.post(
"/api/v1/vendor/messages",
json={
"conversation_type": "admin_vendor",
"subject": "Question for Admin",
"recipient_type": "admin",
"recipient_id": test_admin.id,
"vendor_id": test_vendor.id,
},
headers=vendor_user_headers,
)
assert response.status_code == 400
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesDetailAPI:
"""Tests for conversation detail."""
def test_get_conversation_detail(
self, client, vendor_user_headers, vendor_api_conversation
):
"""Test getting conversation detail."""
response = client.get(
f"/api/v1/vendor/messages/{vendor_api_conversation.id}",
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == vendor_api_conversation.id
assert "participants" in data
assert "messages" in data
def test_get_conversation_not_found(self, client, vendor_user_headers):
"""Test getting nonexistent conversation."""
response = client.get(
"/api/v1/vendor/messages/99999", headers=vendor_user_headers
)
assert response.status_code == 404
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesSendAPI:
"""Tests for sending messages."""
def test_send_message(
self, client, vendor_user_headers, vendor_api_conversation
):
"""Test sending a message."""
response = client.post(
f"/api/v1/vendor/messages/{vendor_api_conversation.id}/messages",
data={"content": "Reply from vendor"},
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
assert data["content"] == "Reply from vendor"
assert data["sender_type"] == "vendor"
def test_send_message_to_closed(
self, client, vendor_user_headers, vendor_api_closed_conversation
):
"""Test cannot send to closed conversation."""
response = client.post(
f"/api/v1/vendor/messages/{vendor_api_closed_conversation.id}/messages",
data={"content": "Test message"},
headers=vendor_user_headers,
)
assert response.status_code == 400
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.vendor
class TestVendorMessagesActionsAPI:
"""Tests for conversation actions."""
def test_mark_read(
self, client, vendor_user_headers, vendor_api_conversation
):
"""Test marking conversation as read."""
response = client.put(
f"/api/v1/vendor/messages/{vendor_api_conversation.id}/read",
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["unread_count"] == 0
def test_update_preferences(
self, client, vendor_user_headers, vendor_api_conversation
):
"""Test updating notification preferences."""
response = client.put(
f"/api/v1/vendor/messages/{vendor_api_conversation.id}/preferences",
json={"email_notifications": True, "muted": False},
headers=vendor_user_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True

View File

@@ -0,0 +1,387 @@
# tests/unit/services/test_message_attachment_service.py
"""
Unit tests for MessageAttachmentService.
"""
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import UploadFile
from app.services.message_attachment_service import (
ALLOWED_MIME_TYPES,
DEFAULT_MAX_FILE_SIZE_MB,
IMAGE_MIME_TYPES,
MessageAttachmentService,
)
@pytest.fixture
def attachment_service():
"""Create a MessageAttachmentService instance with temp storage."""
with tempfile.TemporaryDirectory() as tmpdir:
yield MessageAttachmentService(storage_base=tmpdir)
@pytest.fixture
def mock_upload_file():
"""Create a mock UploadFile."""
def _create_upload_file(
content: bytes = b"test content",
filename: str = "test.txt",
content_type: str = "text/plain",
):
file = MagicMock(spec=UploadFile)
file.filename = filename
file.content_type = content_type
file.read = AsyncMock(return_value=content)
return file
return _create_upload_file
@pytest.mark.unit
class TestMessageAttachmentServiceValidation:
"""Tests for file validation methods."""
def test_validate_file_type_allowed_image(self, attachment_service):
"""Test image MIME types are allowed."""
for mime_type in IMAGE_MIME_TYPES:
assert attachment_service.validate_file_type(mime_type) is True
def test_validate_file_type_allowed_documents(self, attachment_service):
"""Test document MIME types are allowed."""
document_types = [
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
]
for mime_type in document_types:
assert attachment_service.validate_file_type(mime_type) is True
def test_validate_file_type_allowed_others(self, attachment_service):
"""Test other allowed MIME types."""
other_types = ["application/zip", "text/plain", "text/csv"]
for mime_type in other_types:
assert attachment_service.validate_file_type(mime_type) is True
def test_validate_file_type_not_allowed(self, attachment_service):
"""Test disallowed MIME types."""
disallowed_types = [
"application/javascript",
"application/x-executable",
"text/html",
"video/mp4",
"audio/mpeg",
]
for mime_type in disallowed_types:
assert attachment_service.validate_file_type(mime_type) is False
def test_is_image_true(self, attachment_service):
"""Test image detection for actual images."""
for mime_type in IMAGE_MIME_TYPES:
assert attachment_service.is_image(mime_type) is True
def test_is_image_false(self, attachment_service):
"""Test image detection for non-images."""
non_images = ["application/pdf", "text/plain", "application/zip"]
for mime_type in non_images:
assert attachment_service.is_image(mime_type) is False
@pytest.mark.unit
class TestMessageAttachmentServiceMaxFileSize:
"""Tests for max file size retrieval."""
def test_get_max_file_size_from_settings(self, db, attachment_service):
"""Test retrieving max file size from platform settings."""
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 15
max_size = attachment_service.get_max_file_size_bytes(db)
assert max_size == 15 * 1024 * 1024 # 15 MB in bytes
def test_get_max_file_size_default(self, db, attachment_service):
"""Test default max file size when setting not found."""
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = DEFAULT_MAX_FILE_SIZE_MB
max_size = attachment_service.get_max_file_size_bytes(db)
assert max_size == DEFAULT_MAX_FILE_SIZE_MB * 1024 * 1024
def test_get_max_file_size_invalid_value(self, db, attachment_service):
"""Test handling of invalid setting value."""
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = "invalid"
max_size = attachment_service.get_max_file_size_bytes(db)
assert max_size == DEFAULT_MAX_FILE_SIZE_MB * 1024 * 1024
@pytest.mark.unit
class TestMessageAttachmentServiceValidateAndStore:
"""Tests for validate_and_store method."""
@pytest.mark.asyncio
async def test_validate_and_store_success(
self, db, attachment_service, mock_upload_file
):
"""Test successful file storage."""
file = mock_upload_file(
content=b"test file content",
filename="document.pdf",
content_type="application/pdf",
)
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 10
result = await attachment_service.validate_and_store(
db=db,
file=file,
conversation_id=1,
)
assert result["original_filename"] == "document.pdf"
assert result["mime_type"] == "application/pdf"
assert result["file_size"] == len(b"test file content")
assert result["is_image"] is False
assert result["filename"].endswith(".pdf")
assert os.path.exists(result["file_path"])
@pytest.mark.asyncio
async def test_validate_and_store_image(
self, db, attachment_service, mock_upload_file
):
"""Test storage of image file."""
# Create a minimal valid PNG
png_header = (
b"\x89PNG\r\n\x1a\n" # PNG signature
+ b"\x00\x00\x00\rIHDR" # IHDR chunk header
+ b"\x00\x00\x00\x01" # width = 1
+ b"\x00\x00\x00\x01" # height = 1
+ b"\x08\x02" # bit depth = 8, color type = RGB
+ b"\x00\x00\x00" # compression, filter, interlace
)
file = mock_upload_file(
content=png_header,
filename="image.png",
content_type="image/png",
)
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 10
result = await attachment_service.validate_and_store(
db=db,
file=file,
conversation_id=1,
)
assert result["original_filename"] == "image.png"
assert result["mime_type"] == "image/png"
assert result["is_image"] is True
assert result["filename"].endswith(".png")
@pytest.mark.asyncio
async def test_validate_and_store_invalid_type(
self, db, attachment_service, mock_upload_file
):
"""Test rejection of invalid file type."""
file = mock_upload_file(
content=b"<script>alert('xss')</script>",
filename="script.js",
content_type="application/javascript",
)
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 10
with pytest.raises(ValueError, match="File type.*not allowed"):
await attachment_service.validate_and_store(
db=db,
file=file,
conversation_id=1,
)
@pytest.mark.asyncio
async def test_validate_and_store_file_too_large(
self, db, attachment_service, mock_upload_file
):
"""Test rejection of oversized file."""
# Create content larger than max size
large_content = b"x" * (11 * 1024 * 1024) # 11 MB
file = mock_upload_file(
content=large_content,
filename="large.pdf",
content_type="application/pdf",
)
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 10 # 10 MB limit
with pytest.raises(ValueError, match="exceeds maximum allowed size"):
await attachment_service.validate_and_store(
db=db,
file=file,
conversation_id=1,
)
@pytest.mark.asyncio
async def test_validate_and_store_no_filename(
self, db, attachment_service, mock_upload_file
):
"""Test handling of file without filename."""
file = mock_upload_file(
content=b"test content",
filename=None,
content_type="text/plain",
)
file.filename = None # Ensure it's None
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 10
result = await attachment_service.validate_and_store(
db=db,
file=file,
conversation_id=1,
)
assert result["original_filename"] == "attachment"
@pytest.mark.asyncio
async def test_validate_and_store_no_content_type(
self, db, attachment_service, mock_upload_file
):
"""Test handling of file without content type (falls back to octet-stream)."""
file = mock_upload_file(
content=b"test content",
filename="file.bin",
content_type=None,
)
file.content_type = None
with patch(
"app.services.message_attachment_service.admin_settings_service"
) as mock_settings:
mock_settings.get_setting_value.return_value = 10
# Should reject application/octet-stream as not allowed
with pytest.raises(ValueError, match="File type.*not allowed"):
await attachment_service.validate_and_store(
db=db,
file=file,
conversation_id=1,
)
@pytest.mark.unit
class TestMessageAttachmentServiceFileOperations:
"""Tests for file operation methods."""
def test_delete_attachment_success(self, attachment_service):
"""Test successful attachment deletion."""
# Create a temp file
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(b"test content")
file_path = f.name
assert os.path.exists(file_path)
result = attachment_service.delete_attachment(file_path)
assert result is True
assert not os.path.exists(file_path)
def test_delete_attachment_with_thumbnail(self, attachment_service):
"""Test deletion of attachment with thumbnail."""
# Create temp files
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
f.write(b"image content")
file_path = f.name
with tempfile.NamedTemporaryFile(delete=False, suffix="_thumb.png") as f:
f.write(b"thumbnail content")
thumb_path = f.name
result = attachment_service.delete_attachment(file_path, thumb_path)
assert result is True
assert not os.path.exists(file_path)
assert not os.path.exists(thumb_path)
def test_delete_attachment_file_not_exists(self, attachment_service):
"""Test deletion when file doesn't exist."""
result = attachment_service.delete_attachment("/nonexistent/file.pdf")
assert result is True # No error, just returns True
def test_get_download_url(self, attachment_service):
"""Test download URL generation."""
url = attachment_service.get_download_url("uploads/messages/2025/01/1/abc.pdf")
assert url == "/static/uploads/messages/2025/01/1/abc.pdf"
def test_get_file_content_success(self, attachment_service):
"""Test reading file content."""
test_content = b"test file content"
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(test_content)
file_path = f.name
try:
result = attachment_service.get_file_content(file_path)
assert result == test_content
finally:
os.unlink(file_path)
def test_get_file_content_not_found(self, attachment_service):
"""Test reading non-existent file."""
result = attachment_service.get_file_content("/nonexistent/file.pdf")
assert result is None
@pytest.mark.unit
class TestMessageAttachmentServiceThumbnail:
"""Tests for thumbnail creation."""
def test_create_thumbnail_pil_not_installed(self, attachment_service):
"""Test graceful handling when PIL is not available."""
with patch.dict("sys.modules", {"PIL": None}):
# This should not raise an error, just return empty dict
result = attachment_service._create_thumbnail(
b"fake image content", "/tmp/test.png"
)
# When PIL import fails, it returns empty dict
assert isinstance(result, dict)
def test_create_thumbnail_invalid_image(self, attachment_service):
"""Test handling of invalid image data."""
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
f.write(b"not an image")
file_path = f.name
try:
result = attachment_service._create_thumbnail(b"not an image", file_path)
# Should return empty dict on error
assert isinstance(result, dict)
finally:
os.unlink(file_path)

View File

@@ -0,0 +1,587 @@
# tests/unit/services/test_messaging_service.py
"""Unit tests for MessagingService."""
import pytest
from app.services.messaging_service import MessagingService
from models.database.message import (
Conversation,
ConversationParticipant,
ConversationType,
Message,
ParticipantType,
)
@pytest.fixture
def messaging_service():
"""Create a MessagingService instance."""
return MessagingService()
@pytest.mark.unit
class TestMessagingServiceCreateConversation:
"""Test conversation creation."""
def test_create_conversation_admin_vendor(
self, db, messaging_service, test_admin, test_vendor_user, test_vendor
):
"""Test creating an admin-vendor conversation."""
conversation = messaging_service.create_conversation(
db=db,
conversation_type=ConversationType.ADMIN_VENDOR,
subject="Test Subject",
initiator_type=ParticipantType.ADMIN,
initiator_id=test_admin.id,
recipient_type=ParticipantType.VENDOR,
recipient_id=test_vendor_user.id,
vendor_id=test_vendor.id,
)
db.commit()
assert conversation.id is not None
assert conversation.conversation_type == ConversationType.ADMIN_VENDOR
assert conversation.subject == "Test Subject"
assert conversation.vendor_id == test_vendor.id
assert conversation.is_closed is False
assert len(conversation.participants) == 2
def test_create_conversation_vendor_customer(
self, db, messaging_service, test_vendor_user, test_customer, test_vendor
):
"""Test creating a vendor-customer conversation."""
conversation = messaging_service.create_conversation(
db=db,
conversation_type=ConversationType.VENDOR_CUSTOMER,
subject="Customer Support",
initiator_type=ParticipantType.VENDOR,
initiator_id=test_vendor_user.id,
recipient_type=ParticipantType.CUSTOMER,
recipient_id=test_customer.id,
vendor_id=test_vendor.id,
)
db.commit()
assert conversation.id is not None
assert conversation.conversation_type == ConversationType.VENDOR_CUSTOMER
assert len(conversation.participants) == 2
# Verify participants
participant_types = [p.participant_type for p in conversation.participants]
assert ParticipantType.VENDOR in participant_types
assert ParticipantType.CUSTOMER in participant_types
def test_create_conversation_admin_customer(
self, db, messaging_service, test_admin, test_customer, test_vendor
):
"""Test creating an admin-customer conversation."""
conversation = messaging_service.create_conversation(
db=db,
conversation_type=ConversationType.ADMIN_CUSTOMER,
subject="Platform Support",
initiator_type=ParticipantType.ADMIN,
initiator_id=test_admin.id,
recipient_type=ParticipantType.CUSTOMER,
recipient_id=test_customer.id,
vendor_id=test_vendor.id,
)
db.commit()
assert conversation.conversation_type == ConversationType.ADMIN_CUSTOMER
assert len(conversation.participants) == 2
def test_create_conversation_with_initial_message(
self, db, messaging_service, test_admin, test_vendor_user, test_vendor
):
"""Test creating a conversation with an initial message."""
conversation = messaging_service.create_conversation(
db=db,
conversation_type=ConversationType.ADMIN_VENDOR,
subject="With Message",
initiator_type=ParticipantType.ADMIN,
initiator_id=test_admin.id,
recipient_type=ParticipantType.VENDOR,
recipient_id=test_vendor_user.id,
vendor_id=test_vendor.id,
initial_message="Hello, this is the first message!",
)
db.commit()
db.refresh(conversation)
assert conversation.message_count == 1
assert len(conversation.messages) == 1
assert conversation.messages[0].content == "Hello, this is the first message!"
def test_create_vendor_customer_without_vendor_id_fails(
self, db, messaging_service, test_vendor_user, test_customer
):
"""Test that vendor_customer conversation requires vendor_id."""
with pytest.raises(ValueError) as exc_info:
messaging_service.create_conversation(
db=db,
conversation_type=ConversationType.VENDOR_CUSTOMER,
subject="No Vendor",
initiator_type=ParticipantType.VENDOR,
initiator_id=test_vendor_user.id,
recipient_type=ParticipantType.CUSTOMER,
recipient_id=test_customer.id,
vendor_id=None,
)
assert "vendor_id required" in str(exc_info.value)
@pytest.mark.unit
class TestMessagingServiceGetConversation:
"""Test conversation retrieval."""
def test_get_conversation_success(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test getting a conversation by ID."""
conversation = messaging_service.get_conversation(
db=db,
conversation_id=test_conversation_admin_vendor.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
assert conversation is not None
assert conversation.id == test_conversation_admin_vendor.id
assert conversation.subject == "Test Admin-Vendor Conversation"
def test_get_conversation_not_found(self, db, messaging_service, test_admin):
"""Test getting a non-existent conversation."""
conversation = messaging_service.get_conversation(
db=db,
conversation_id=99999,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
assert conversation is None
def test_get_conversation_unauthorized(
self, db, messaging_service, test_conversation_admin_vendor, test_customer
):
"""Test getting a conversation without access."""
# Customer is not a participant in admin-vendor conversation
conversation = messaging_service.get_conversation(
db=db,
conversation_id=test_conversation_admin_vendor.id,
participant_type=ParticipantType.CUSTOMER,
participant_id=test_customer.id,
)
assert conversation is None
@pytest.mark.unit
class TestMessagingServiceListConversations:
"""Test conversation listing."""
def test_list_conversations_success(
self, db, messaging_service, multiple_conversations, test_admin
):
"""Test listing conversations for a participant."""
conversations, total, total_unread = messaging_service.list_conversations(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
# Admin should see all admin-vendor conversations (3 of them)
assert total == 3
assert len(conversations) == 3
def test_list_conversations_with_type_filter(
self, db, messaging_service, multiple_conversations, test_vendor_user, test_vendor
):
"""Test filtering conversations by type."""
# Vendor should see admin-vendor (3) + vendor-customer (2) = 5
# Filter to vendor-customer only
conversations, total, _ = messaging_service.list_conversations(
db=db,
participant_type=ParticipantType.VENDOR,
participant_id=test_vendor_user.id,
vendor_id=test_vendor.id,
conversation_type=ConversationType.VENDOR_CUSTOMER,
)
assert total == 2
for conv in conversations:
assert conv.conversation_type == ConversationType.VENDOR_CUSTOMER
def test_list_conversations_pagination(
self, db, messaging_service, multiple_conversations, test_admin
):
"""Test pagination of conversations."""
# First page
conversations, total, _ = messaging_service.list_conversations(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
skip=0,
limit=2,
)
assert total == 3
assert len(conversations) == 2
# Second page
conversations, total, _ = messaging_service.list_conversations(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
skip=2,
limit=2,
)
assert total == 3
assert len(conversations) == 1
def test_list_conversations_with_closed_filter(
self, db, messaging_service, test_conversation_admin_vendor, closed_conversation, test_admin
):
"""Test filtering by open/closed status."""
# Only open
conversations, total, _ = messaging_service.list_conversations(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
is_closed=False,
)
assert total == 1
assert all(not conv.is_closed for conv in conversations)
# Only closed
conversations, total, _ = messaging_service.list_conversations(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
is_closed=True,
)
assert total == 1
assert all(conv.is_closed for conv in conversations)
@pytest.mark.unit
class TestMessagingServiceSendMessage:
"""Test message sending."""
def test_send_message_success(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test sending a message."""
message = messaging_service.send_message(
db=db,
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="Hello, this is a test message!",
)
db.commit()
assert message.id is not None
assert message.content == "Hello, this is a test message!"
assert message.sender_type == ParticipantType.ADMIN
assert message.sender_id == test_admin.id
assert message.conversation_id == test_conversation_admin_vendor.id
# Verify conversation was updated
db.refresh(test_conversation_admin_vendor)
assert test_conversation_admin_vendor.message_count == 1
assert test_conversation_admin_vendor.last_message_at is not None
def test_send_message_with_attachments(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test sending a message with attachments."""
attachments = [
{
"filename": "doc1.pdf",
"original_filename": "document.pdf",
"file_path": "/uploads/messages/2025/01/1/doc1.pdf",
"file_size": 12345,
"mime_type": "application/pdf",
"is_image": False,
}
]
message = messaging_service.send_message(
db=db,
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="See attached document.",
attachments=attachments,
)
db.commit()
db.refresh(message)
assert len(message.attachments) == 1
assert message.attachments[0].original_filename == "document.pdf"
def test_send_message_updates_unread_count(
self, db, messaging_service, test_conversation_admin_vendor, test_admin, test_vendor_user
):
"""Test that sending a message updates unread count for other participants."""
# Send message as admin
messaging_service.send_message(
db=db,
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="Test message",
)
db.commit()
# Check that vendor user has unread count increased
vendor_participant = (
db.query(ConversationParticipant)
.filter(
ConversationParticipant.conversation_id == test_conversation_admin_vendor.id,
ConversationParticipant.participant_type == ParticipantType.VENDOR,
ConversationParticipant.participant_id == test_vendor_user.id,
)
.first()
)
assert vendor_participant.unread_count == 1
# Admin's unread count should be 0
admin_participant = (
db.query(ConversationParticipant)
.filter(
ConversationParticipant.conversation_id == test_conversation_admin_vendor.id,
ConversationParticipant.participant_type == ParticipantType.ADMIN,
ConversationParticipant.participant_id == test_admin.id,
)
.first()
)
assert admin_participant.unread_count == 0
def test_send_system_message(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test sending a system message."""
message = messaging_service.send_message(
db=db,
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="Conversation closed",
is_system_message=True,
)
db.commit()
assert message.is_system_message is True
@pytest.mark.unit
class TestMessagingServiceMarkRead:
"""Test marking conversations as read."""
def test_mark_conversation_read(
self, db, messaging_service, test_conversation_admin_vendor, test_admin, test_vendor_user
):
"""Test marking a conversation as read."""
# Send a message to create unread count
messaging_service.send_message(
db=db,
conversation_id=test_conversation_admin_vendor.id,
sender_type=ParticipantType.ADMIN,
sender_id=test_admin.id,
content="Test message",
)
db.commit()
# Mark as read for vendor
result = messaging_service.mark_conversation_read(
db=db,
conversation_id=test_conversation_admin_vendor.id,
reader_type=ParticipantType.VENDOR,
reader_id=test_vendor_user.id,
)
db.commit()
assert result is True
# Verify unread count is reset
vendor_participant = (
db.query(ConversationParticipant)
.filter(
ConversationParticipant.conversation_id == test_conversation_admin_vendor.id,
ConversationParticipant.participant_type == ParticipantType.VENDOR,
)
.first()
)
assert vendor_participant.unread_count == 0
assert vendor_participant.last_read_at is not None
@pytest.mark.unit
class TestMessagingServiceUnreadCount:
"""Test unread count retrieval."""
def test_get_unread_count(
self, db, messaging_service, multiple_conversations, test_admin, test_vendor_user
):
"""Test getting total unread count for a participant."""
# Send messages in multiple conversations (first 2 are admin-vendor)
for conv in multiple_conversations[:2]:
messaging_service.send_message(
db=db,
conversation_id=conv.id,
sender_type=ParticipantType.VENDOR,
sender_id=test_vendor_user.id,
content="Test message",
)
db.commit()
# Admin should have 2 unread messages
unread_count = messaging_service.get_unread_count(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
assert unread_count == 2
def test_get_unread_count_zero(self, db, messaging_service, test_admin):
"""Test unread count when no messages."""
unread_count = messaging_service.get_unread_count(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
assert unread_count == 0
@pytest.mark.unit
class TestMessagingServiceCloseReopen:
"""Test conversation close/reopen."""
def test_close_conversation(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test closing a conversation."""
conversation = messaging_service.close_conversation(
db=db,
conversation_id=test_conversation_admin_vendor.id,
closer_type=ParticipantType.ADMIN,
closer_id=test_admin.id,
)
db.commit()
assert conversation is not None
assert conversation.is_closed is True
assert conversation.closed_at is not None
assert conversation.closed_by_type == ParticipantType.ADMIN
assert conversation.closed_by_id == test_admin.id
# Should have system message
db.refresh(conversation)
assert any(m.is_system_message and "closed" in m.content for m in conversation.messages)
def test_reopen_conversation(
self, db, messaging_service, closed_conversation, test_admin
):
"""Test reopening a closed conversation."""
conversation = messaging_service.reopen_conversation(
db=db,
conversation_id=closed_conversation.id,
opener_type=ParticipantType.ADMIN,
opener_id=test_admin.id,
)
db.commit()
assert conversation is not None
assert conversation.is_closed is False
assert conversation.closed_at is None
assert conversation.closed_by_type is None
assert conversation.closed_by_id is None
@pytest.mark.unit
class TestMessagingServiceParticipantInfo:
"""Test participant info retrieval."""
def test_get_participant_info_admin(self, db, messaging_service, test_admin):
"""Test getting admin participant info."""
info = messaging_service.get_participant_info(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
assert info is not None
assert info["id"] == test_admin.id
assert info["type"] == "admin"
assert "email" in info
def test_get_participant_info_customer(self, db, messaging_service, test_customer):
"""Test getting customer participant info."""
info = messaging_service.get_participant_info(
db=db,
participant_type=ParticipantType.CUSTOMER,
participant_id=test_customer.id,
)
assert info is not None
assert info["id"] == test_customer.id
assert info["type"] == "customer"
assert info["name"] == "John Doe"
def test_get_participant_info_not_found(self, db, messaging_service):
"""Test getting info for non-existent participant."""
info = messaging_service.get_participant_info(
db=db,
participant_type=ParticipantType.ADMIN,
participant_id=99999,
)
assert info is None
@pytest.mark.unit
class TestMessagingServiceNotificationPreferences:
"""Test notification preference updates."""
def test_update_notification_preferences(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test updating notification preferences."""
result = messaging_service.update_notification_preferences(
db=db,
conversation_id=test_conversation_admin_vendor.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
email_notifications=False,
muted=True,
)
db.commit()
assert result is True
# Verify preferences updated
participant = (
db.query(ConversationParticipant)
.filter(
ConversationParticipant.conversation_id == test_conversation_admin_vendor.id,
ConversationParticipant.participant_type == ParticipantType.ADMIN,
)
.first()
)
assert participant.email_notifications is False
assert participant.muted is True
def test_update_notification_preferences_no_changes(
self, db, messaging_service, test_conversation_admin_vendor, test_admin
):
"""Test updating with no changes."""
result = messaging_service.update_notification_preferences(
db=db,
conversation_id=test_conversation_admin_vendor.id,
participant_type=ParticipantType.ADMIN,
participant_id=test_admin.id,
)
assert result is False