Add retry utilities, input validation, and comprehensive tests

- Implemented async retry logic with exponential backoff in `src/utils/retry.py`.
- Created input validation utilities for Discord bot in `src/utils/validators.py`.
- Refactored token pricing import in `src/utils/token_counter.py`.
- Added comprehensive test suite in `tests/test_comprehensive.py` covering various modules including pricing, validators, retry logic, and Discord utilities.
This commit is contained in:
2025-11-30 17:45:36 +07:00
parent e2b961e9c0
commit f17081b185
14 changed files with 2986 additions and 99 deletions

View File

@@ -88,3 +88,26 @@ TIMEZONE=UTC
# 168 = 1 week # 168 = 1 week
# -1 = Never expire (permanent storage) # -1 = Never expire (permanent storage)
FILE_EXPIRATION_HOURS=48 FILE_EXPIRATION_HOURS=48
# ============================================
# Monitoring & Observability (Optional)
# ============================================
# Sentry DSN for error tracking
# Get from: https://sentry.io/ (create a project and copy the DSN)
# Leave empty to disable Sentry error tracking
SENTRY_DSN=
# Environment name for Sentry (development, staging, production)
ENVIRONMENT=development
# Sentry sample rate (0.0 to 1.0) - percentage of errors to capture
# 1.0 = 100% of errors, 0.5 = 50% of errors
SENTRY_SAMPLE_RATE=1.0
# Sentry traces sample rate for performance monitoring (0.0 to 1.0)
# 0.1 = 10% of transactions, lower values recommended for high-traffic bots
SENTRY_TRACES_RATE=0.1
# Log level (DEBUG, INFO, WARNING, ERROR)
LOG_LEVEL=INFO

View File

@@ -1,19 +1,49 @@
discord.py # Discord Bot Core
openai discord.py>=2.3.0
motor openai>=1.40.0
pymongo[srv] python-dotenv>=1.0.0
dnspython>=2.0.0
pypdf # Database
beautifulsoup4 motor>=3.3.0
requests pymongo[srv]>=4.6.0
aiohttp dnspython>=2.5.0
# Web & HTTP
aiohttp>=3.9.0
requests>=2.31.0
beautifulsoup4>=4.12.0
# AI & ML
runware>=0.4.33 runware>=0.4.33
python-dotenv tiktoken>=0.7.0
matplotlib
pandas # Data Processing
openpyxl pandas>=2.1.0
seaborn numpy>=1.26.0
tzlocal openpyxl>=3.1.0
numpy
plotly # Visualization
tiktoken matplotlib>=3.8.0
seaborn>=0.13.0
plotly>=5.18.0
# Document Processing
pypdf>=4.0.0
Pillow>=10.0.0
# Scheduling & Time
APScheduler>=3.10.0
tzlocal>=5.2
# Testing
pytest>=8.0.0
pytest-asyncio>=0.23.0
pytest-cov>=4.1.0
pytest-mock>=3.12.0
# Code Quality
ruff>=0.3.0
# Monitoring & Logging (Optional)
# sentry-sdk>=1.40.0 # Uncomment for error monitoring
# python-json-logger>=2.0.0 # Uncomment for structured logging

View File

@@ -7,36 +7,67 @@ import asyncio
from typing import Optional, Dict, List, Any, Callable from typing import Optional, Dict, List, Any, Callable
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS, DEFAULT_MODEL from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS, DEFAULT_MODEL
from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost
from src.utils.image_utils import ImageGenerator from src.utils.image_utils import ImageGenerator
from src.utils.web_utils import google_custom_search, scrape_web_content from src.utils.web_utils import google_custom_search, scrape_web_content
from src.utils.pdf_utils import process_pdf, send_response from src.utils.pdf_utils import process_pdf, send_response
from src.utils.openai_utils import prepare_file_from_path from src.utils.openai_utils import prepare_file_from_path
from src.utils.token_counter import token_counter from src.utils.token_counter import token_counter
from src.utils.code_interpreter import delete_all_user_files from src.utils.code_interpreter import delete_all_user_files
from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed
# Model pricing per 1M tokens (in USD)
MODEL_PRICING = {
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
"openai/gpt-5": {"input": 1.25, "output": 10.00},
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
"openai/o1-preview": {"input": 15.00, "output": 60.00},
"openai/o1-mini": {"input": 1.10, "output": 4.40},
"openai/o1": {"input": 15.00, "output": 60.00},
"openai/o3-mini": {"input": 1.10, "output": 4.40},
"openai/o3": {"input": 2.00, "output": 8.00},
"openai/o4-mini": {"input": 2.00, "output": 8.00}
}
# Dictionary to keep track of user requests and their cooldowns # Dictionary to keep track of user requests and their cooldowns
user_requests = {} user_requests: Dict[int, Dict[str, Any]] = {}
# Dictionary to store user tasks # Dictionary to store user tasks
user_tasks = {} user_tasks: Dict[int, List] = {}
# ============================================================
# Autocomplete Functions
# ============================================================
async def model_autocomplete(
interaction: discord.Interaction,
current: str,
) -> List[app_commands.Choice[str]]:
"""
Autocomplete function for model selection.
Provides filtered model suggestions based on user input.
"""
# Filter models based on current input
matches = [
model for model in MODEL_OPTIONS
if current.lower() in model.lower()
]
# If no matches, show all models
if not matches:
matches = MODEL_OPTIONS
# Return up to 25 choices (Discord limit)
return [
app_commands.Choice(name=model, value=model)
for model in matches[:25]
]
async def image_model_autocomplete(
interaction: discord.Interaction,
current: str,
) -> List[app_commands.Choice[str]]:
"""
Autocomplete function for image generation model selection.
"""
image_models = ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"]
matches = [m for m in image_models if current.lower() in m.lower()]
if not matches:
matches = image_models
return [
app_commands.Choice(name=model, value=model)
for model in matches[:25]
]
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator): def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
""" """
@@ -112,7 +143,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
@tree.command(name="choose_model", description="Select the AI model to use for responses.") @tree.command(name="choose_model", description="Select the AI model to use for responses.")
@check_blacklist() @check_blacklist()
async def choose_model(interaction: discord.Interaction): async def choose_model(interaction: discord.Interaction):
"""Lets users choose an AI model and saves it to the database.""" """Lets users choose an AI model using a dropdown menu."""
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS] options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
select_menu = discord.ui.Select(placeholder="Choose a model", options=options) select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
@@ -131,6 +162,43 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
view.add_item(select_menu) view.add_item(select_menu)
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True) await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
@tree.command(name="set_model", description="Set AI model directly with autocomplete suggestions.")
@app_commands.describe(model="The AI model to use (type to search)")
@app_commands.autocomplete(model=model_autocomplete)
@check_blacklist()
async def set_model(interaction: discord.Interaction, model: str):
"""Sets the AI model directly using autocomplete."""
user_id = interaction.user.id
# Validate the model is in the allowed list
if model not in MODEL_OPTIONS:
# Find close matches for suggestions
close_matches = [m for m in MODEL_OPTIONS if model.lower() in m.lower()]
if close_matches:
suggestions = ", ".join(f"`{m}`" for m in close_matches[:5])
await interaction.response.send_message(
f"❌ Invalid model `{model}`. Did you mean: {suggestions}?",
ephemeral=True
)
else:
await interaction.response.send_message(
f"❌ Invalid model `{model}`. Use `/choose_model` to see available options.",
ephemeral=True
)
return
# Save the model selection
await db_handler.save_user_model(user_id, model)
# Get pricing info for the selected model
pricing = MODEL_PRICING.get(model, {"input": 0, "output": 0})
await interaction.response.send_message(
f"✅ Model set to `{model}`\n"
f"💰 Pricing: ${pricing['input']:.2f}/1M input, ${pricing['output']:.2f}/1M output",
ephemeral=True
)
@tree.command(name="search", description="Search on Google and send results to AI model.") @tree.command(name="search", description="Search on Google and send results to AI model.")
@app_commands.describe(query="The search query") @app_commands.describe(query="The search query")
@check_blacklist() @check_blacklist()
@@ -494,16 +562,22 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
async def help_command(interaction: discord.Interaction): async def help_command(interaction: discord.Interaction):
"""Sends a list of available commands to the user.""" """Sends a list of available commands to the user."""
help_message = ( help_message = (
"**Available commands:**\n" "**🤖 Available Commands:**\n\n"
"/choose_model - Select which AI model to use for responses (openai/gpt-4o, openai/gpt-4o-mini, openai/gpt-5, openai/gpt-5-nano, openai/gpt-5-mini, openai/gpt-5-chat, openai/o1-preview, openai/o1-mini).\n" "**Model Selection:**\n"
"/search `<query>` - Search Google and send results to the AI model.\n" "• `/choose_model` - Select AI model from a dropdown menu\n"
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n" "• `/set_model <model>` - Set model directly with autocomplete\n\n"
"/generate `<prompt>` - Generate an image from a text prompt.\n" "**Search & Web:**\n"
"/toggle_tools - Toggle display of tool execution details (code, input, output).\n" "• `/search <query>` - Search Google and analyze results with AI\n"
"/reset - Reset your chat history and token usage statistics.\n" "• `/web <url>` - Scrape and analyze a webpage\n\n"
"/user_stat - Get information about your token usage, costs, and current model.\n" "**Image Generation:**\n"
"/prices - Display pricing information for all available AI models.\n" "• `/generate <prompt>` - Generate images from text\n\n"
"/help - Display this help message.\n" "**Settings & Stats:**\n"
"• `/toggle_tools` - Toggle tool execution details display\n"
"• `/user_stat` - View your token usage and costs\n"
"• `/prices` - Display model pricing information\n"
"• `/reset` - Clear your chat history and statistics\n\n"
"**Help:**\n"
"• `/help` - Display this help message\n"
) )
await interaction.response.send_message(help_message, ephemeral=True) await interaction.response.send_message(help_message, ephemeral=True)

100
src/config/pricing.py Normal file
View File

@@ -0,0 +1,100 @@
"""
Centralized pricing configuration for OpenAI models.
This module provides a single source of truth for model pricing,
eliminating duplication across the codebase.
"""
from typing import Dict, Optional
from dataclasses import dataclass
@dataclass
class ModelPricing:
"""Pricing information for a model (per 1M tokens in USD)."""
input: float
output: float
def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
"""Calculate total cost for given token counts."""
input_cost = (input_tokens / 1_000_000) * self.input
output_cost = (output_tokens / 1_000_000) * self.output
return input_cost + output_cost
# Model pricing per 1M tokens (in USD)
# Centralized location - update prices here only
MODEL_PRICING: Dict[str, ModelPricing] = {
# GPT-4o Family
"openai/gpt-4o": ModelPricing(input=5.00, output=20.00),
"openai/gpt-4o-mini": ModelPricing(input=0.60, output=2.40),
# GPT-4.1 Family
"openai/gpt-4.1": ModelPricing(input=2.00, output=8.00),
"openai/gpt-4.1-mini": ModelPricing(input=0.40, output=1.60),
"openai/gpt-4.1-nano": ModelPricing(input=0.10, output=0.40),
# GPT-5 Family
"openai/gpt-5": ModelPricing(input=1.25, output=10.00),
"openai/gpt-5-mini": ModelPricing(input=0.25, output=2.00),
"openai/gpt-5-nano": ModelPricing(input=0.05, output=0.40),
"openai/gpt-5-chat": ModelPricing(input=1.25, output=10.00),
# o1 Family (Reasoning models)
"openai/o1-preview": ModelPricing(input=15.00, output=60.00),
"openai/o1-mini": ModelPricing(input=1.10, output=4.40),
"openai/o1": ModelPricing(input=15.00, output=60.00),
# o3 Family
"openai/o3-mini": ModelPricing(input=1.10, output=4.40),
"openai/o3": ModelPricing(input=2.00, output=8.00),
# o4 Family
"openai/o4-mini": ModelPricing(input=2.00, output=8.00),
}
def get_model_pricing(model: str) -> Optional[ModelPricing]:
"""
Get pricing for a specific model.
Args:
model: The model name (e.g., "openai/gpt-4o")
Returns:
ModelPricing object or None if model not found
"""
return MODEL_PRICING.get(model)
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
"""
Calculate the cost for a given model and token counts.
Args:
model: The model name
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Total cost in USD, or 0.0 if model not found
"""
pricing = get_model_pricing(model)
if pricing:
return pricing.calculate_cost(input_tokens, output_tokens)
return 0.0
def get_all_models() -> list:
"""Get list of all available models with pricing."""
return list(MODEL_PRICING.keys())
def format_cost(cost: float) -> str:
"""Format cost for display."""
if cost < 0.01:
return f"${cost:.6f}"
elif cost < 1.00:
return f"${cost:.4f}"
else:
return f"${cost:.2f}"

View File

@@ -160,7 +160,12 @@ class DatabaseHandler:
return [] return []
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter out image links that are older than 23 hours""" """
Filter out image links that are older than 23 hours.
Properly handles timezone-aware and timezone-naive datetime comparisons
to prevent issues with ISO string parsing.
"""
current_time = datetime.now() current_time = datetime.now()
expiration_time = current_time - timedelta(hours=23) expiration_time = current_time - timedelta(hours=23)
@@ -183,11 +188,27 @@ class DatabaseHandler:
# Check image items for timestamp # Check image items for timestamp
elif item.get('type') == 'image_url': elif item.get('type') == 'image_url':
# If there's no timestamp or timestamp is newer than expiration time, keep it # If there's no timestamp or timestamp is newer than expiration time, keep it
timestamp = item.get('timestamp') timestamp_str = item.get('timestamp')
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time: if not timestamp_str:
# No timestamp, keep the image
filtered_content.append(item) filtered_content.append(item)
else: else:
logging.info(f"Filtering out expired image URL (added at {timestamp})") try:
# Parse the ISO timestamp, handling both timezone-aware and naive
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
# Make comparison timezone-naive for consistency
if timestamp.tzinfo is not None:
timestamp = timestamp.replace(tzinfo=None)
if timestamp > expiration_time:
filtered_content.append(item)
else:
logging.debug(f"Filtering out expired image URL (added at {timestamp_str})")
except (ValueError, AttributeError) as e:
# If we can't parse the timestamp, keep the image to be safe
logging.warning(f"Could not parse image timestamp '{timestamp_str}': {e}")
filtered_content.append(item)
# Update the message with filtered content # Update the message with filtered content
if filtered_content: if filtered_content:

View File

@@ -5,7 +5,7 @@ import logging
import time import time
import functools import functools
import concurrent.futures import concurrent.futures
from typing import Dict, Any, List from typing import Dict, Any, List, Optional
import io import io
import aiohttp import aiohttp
import os import os
@@ -19,32 +19,16 @@ from src.utils.pdf_utils import process_pdf, send_response
from src.utils.code_utils import extract_code_blocks from src.utils.code_utils import extract_code_blocks
from src.utils.reminder_utils import ReminderManager from src.utils.reminder_utils import ReminderManager
from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT, DEFAULT_MODEL from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT, DEFAULT_MODEL
from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost
from src.utils.validators import validate_message_content, validate_prompt, sanitize_for_logging
from src.utils.discord_utils import send_long_message, create_error_embed, create_progress_embed
# Global task and rate limiting tracking # Global task and rate limiting tracking
user_tasks = {} user_tasks: Dict[int, Dict] = {}
user_last_request = {} user_last_request: Dict[int, List[float]] = {}
RATE_LIMIT_WINDOW = 5 # seconds RATE_LIMIT_WINDOW = 5 # seconds
MAX_REQUESTS = 3 # max requests per window MAX_REQUESTS = 3 # max requests per window
# Model pricing per 1M tokens (in USD)
MODEL_PRICING = {
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
"openai/gpt-5": {"input": 1.25, "output": 10.00},
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
"openai/o1-preview": {"input": 15.00, "output": 60.00},
"openai/o1-mini": {"input": 1.10, "output": 4.40},
"openai/o1": {"input": 15.00, "output": 60.00},
"openai/o3-mini": {"input": 1.10, "output": 4.40},
"openai/o3": {"input": 2.00, "output": 8.00},
"openai/o4-mini": {"input": 2.00, "output": 8.00}
}
# File extensions that should be treated as text files # File extensions that should be treated as text files
TEXT_FILE_EXTENSIONS = [ TEXT_FILE_EXTENSIONS = [
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css', '.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css',
@@ -1598,13 +1582,11 @@ print("\\n=== Correlation Analysis ===")
output_tokens = getattr(response.usage, 'completion_tokens', 0) output_tokens = getattr(response.usage, 'completion_tokens', 0)
# Calculate cost based on model pricing # Calculate cost based on model pricing
if model in MODEL_PRICING: pricing = MODEL_PRICING.get(model)
pricing = MODEL_PRICING[model] if pricing:
input_cost = (input_tokens / 1_000_000) * pricing["input"] total_cost = pricing.calculate_cost(input_tokens, output_tokens)
output_cost = (output_tokens / 1_000_000) * pricing["output"]
total_cost = input_cost + output_cost
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: ${total_cost:.6f}") logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
# Save token usage and cost to database # Save token usage and cost to database
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost) await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
@@ -1704,14 +1686,12 @@ print("\\n=== Correlation Analysis ===")
output_tokens += follow_up_output_tokens output_tokens += follow_up_output_tokens
# Calculate additional cost # Calculate additional cost
if model in MODEL_PRICING: pricing = MODEL_PRICING.get(model)
pricing = MODEL_PRICING[model] if pricing:
additional_input_cost = (follow_up_input_tokens / 1_000_000) * pricing["input"] additional_cost = pricing.calculate_cost(follow_up_input_tokens, follow_up_output_tokens)
additional_output_cost = (follow_up_output_tokens / 1_000_000) * pricing["output"]
additional_cost = additional_input_cost + additional_output_cost
total_cost += additional_cost total_cost += additional_cost
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: ${additional_cost:.6f}") logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: {format_cost(additional_cost)}")
# Save additional token usage and cost to database # Save additional token usage and cost to database
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost) await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
@@ -1791,7 +1771,7 @@ print("\\n=== Correlation Analysis ===")
# Log processing time and cost for performance monitoring # Log processing time and cost for performance monitoring
processing_time = time.time() - start_time processing_time = time.time() - start_time
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: ${total_cost:.6f})") logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: {format_cost(total_cost)})")
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle cancellation cleanly # Handle cancellation cleanly

358
src/utils/cache.py Normal file
View File

@@ -0,0 +1,358 @@
"""
Simple caching utilities for API responses and frequently accessed data.
This module provides an in-memory LRU cache with optional TTL (time-to-live)
support, designed for caching API responses and reducing redundant calls.
"""
import asyncio
import time
import logging
from typing import Any, Dict, Optional, Callable, TypeVar, Generic
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import wraps
logger = logging.getLogger(__name__)
T = TypeVar('T')
@dataclass
class CacheEntry(Generic[T]):
"""A single cache entry with value and expiration time."""
value: T
expires_at: float
created_at: float = field(default_factory=time.time)
hits: int = 0
class LRUCache(Generic[T]):
"""
Thread-safe LRU (Least Recently Used) cache with TTL support.
Features:
- Configurable max size with automatic eviction
- Per-entry TTL (time-to-live)
- Automatic cleanup of expired entries
- Hit/miss statistics tracking
Usage:
cache = LRUCache(max_size=1000, default_ttl=300) # 5 min TTL
cache.set("key", "value")
value = cache.get("key") # Returns value or None if expired
"""
def __init__(
self,
max_size: int = 1000,
default_ttl: float = 300.0, # 5 minutes default
cleanup_interval: float = 60.0
):
"""
Initialize the LRU cache.
Args:
max_size: Maximum number of entries
default_ttl: Default TTL in seconds
cleanup_interval: How often to run cleanup (seconds)
"""
self._cache: OrderedDict[str, CacheEntry[T]] = OrderedDict()
self._max_size = max_size
self._default_ttl = default_ttl
self._cleanup_interval = cleanup_interval
self._lock = asyncio.Lock()
# Statistics
self._hits = 0
self._misses = 0
# Background cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
async def start(self) -> None:
"""Start the background cleanup task."""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.debug("Cache cleanup task started")
async def stop(self) -> None:
"""Stop the background cleanup task."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
logger.debug("Cache cleanup task stopped")
async def _cleanup_loop(self) -> None:
"""Background task to periodically clean up expired entries."""
while True:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired()
async def _cleanup_expired(self) -> int:
"""Remove expired entries. Returns count of removed entries."""
now = time.time()
removed = 0
async with self._lock:
keys_to_remove = [
key for key, entry in self._cache.items()
if entry.expires_at <= now
]
for key in keys_to_remove:
del self._cache[key]
removed += 1
if removed > 0:
logger.debug(f"Cache cleanup: removed {removed} expired entries")
return removed
async def get(self, key: str) -> Optional[T]:
"""
Get a value from the cache.
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
async with self._lock:
if key not in self._cache:
self._misses += 1
return None
entry = self._cache[key]
# Check if expired
if entry.expires_at <= time.time():
del self._cache[key]
self._misses += 1
return None
# Move to end (most recently used)
self._cache.move_to_end(key)
entry.hits += 1
self._hits += 1
return entry.value
async def set(
self,
key: str,
value: T,
ttl: Optional[float] = None
) -> None:
"""
Set a value in the cache.
Args:
key: Cache key
value: Value to cache
ttl: Optional TTL override (uses default if not provided)
"""
ttl = ttl if ttl is not None else self._default_ttl
expires_at = time.time() + ttl
async with self._lock:
# Remove oldest entries if at capacity
while len(self._cache) >= self._max_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
logger.debug(f"Cache evicted oldest entry: {oldest_key}")
self._cache[key] = CacheEntry(
value=value,
expires_at=expires_at
)
self._cache.move_to_end(key)
async def delete(self, key: str) -> bool:
"""
Delete a key from the cache.
Args:
key: Cache key
Returns:
True if key was found and deleted
"""
async with self._lock:
if key in self._cache:
del self._cache[key]
return True
return False
async def clear(self) -> int:
"""
Clear all entries from the cache.
Returns:
Number of entries cleared
"""
async with self._lock:
count = len(self._cache)
self._cache.clear()
return count
async def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
return await self.get(key) is not None
def stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dict with size, hits, misses, hit_rate
"""
total = self._hits + self._misses
hit_rate = (self._hits / total * 100) if total > 0 else 0.0
return {
"size": len(self._cache),
"max_size": self._max_size,
"hits": self._hits,
"misses": self._misses,
"hit_rate": f"{hit_rate:.2f}%",
"default_ttl": self._default_ttl
}
# Global cache instances for different purposes
_api_response_cache: Optional[LRUCache[Dict[str, Any]]] = None
_user_preference_cache: Optional[LRUCache[Dict[str, Any]]] = None
async def get_api_cache() -> LRUCache[Dict[str, Any]]:
"""Get or create the API response cache."""
global _api_response_cache
if _api_response_cache is None:
_api_response_cache = LRUCache(
max_size=500,
default_ttl=300.0 # 5 minutes
)
await _api_response_cache.start()
return _api_response_cache
async def get_user_cache() -> LRUCache[Dict[str, Any]]:
"""Get or create the user preference cache."""
global _user_preference_cache
if _user_preference_cache is None:
_user_preference_cache = LRUCache(
max_size=1000,
default_ttl=600.0 # 10 minutes
)
await _user_preference_cache.start()
return _user_preference_cache
def cached(
cache_key_func: Callable[..., str],
ttl: Optional[float] = None,
cache_getter: Callable = get_api_cache
):
"""
Decorator to cache async function results.
Args:
cache_key_func: Function to generate cache key from args
ttl: Optional TTL override
cache_getter: Function to get the cache instance
Usage:
@cached(
cache_key_func=lambda user_id: f"user:{user_id}",
ttl=300
)
async def get_user_data(user_id: int) -> dict:
# Expensive operation
return await fetch_from_api(user_id)
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
cache = await cache_getter()
key = cache_key_func(*args, **kwargs)
# Try to get from cache
cached_value = await cache.get(key)
if cached_value is not None:
logger.debug(f"Cache hit for key: {key}")
return cached_value
# Execute function and cache result
result = await func(*args, **kwargs)
await cache.set(key, result, ttl=ttl)
logger.debug(f"Cached result for key: {key}")
return result
return wrapper
return decorator
def invalidate_on_update(
cache_key_func: Callable[..., str],
cache_getter: Callable = get_api_cache
):
"""
Decorator to invalidate cache when a function (update operation) is called.
Args:
cache_key_func: Function to generate cache key to invalidate
cache_getter: Function to get the cache instance
Usage:
@invalidate_on_update(
cache_key_func=lambda user_id, **_: f"user:{user_id}"
)
async def update_user_data(user_id: int, data: dict) -> None:
await save_to_db(user_id, data)
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
# Invalidate cache after update
cache = await cache_getter()
key = cache_key_func(*args, **kwargs)
await cache.delete(key)
logger.debug(f"Invalidated cache for key: {key}")
return result
return wrapper
return decorator
# Convenience functions for common caching patterns
async def cache_user_model(user_id: int, model: str) -> None:
"""Cache user's selected model."""
cache = await get_user_cache()
await cache.set(f"user_model:{user_id}", {"model": model})
async def get_cached_user_model(user_id: int) -> Optional[str]:
"""Get user's cached model selection."""
cache = await get_user_cache()
result = await cache.get(f"user_model:{user_id}")
return result["model"] if result else None
async def invalidate_user_cache(user_id: int) -> None:
"""Invalidate all cached data for a user."""
cache = await get_user_cache()
# Clear known user-related keys
await cache.delete(f"user_model:{user_id}")
await cache.delete(f"user_history:{user_id}")
await cache.delete(f"user_stats:{user_id}")

View File

@@ -71,19 +71,40 @@ APPROVED_PACKAGES = {
'more-itertools', 'toolz', 'cytoolz', 'funcy' 'more-itertools', 'toolz', 'cytoolz', 'funcy'
} }
# Blocked patterns # Blocked patterns - Comprehensive security checks
# Note: We allow open() for writing to enable saving plots and outputs # Note: We allow open() for writing to enable saving plots and outputs
# The sandboxed environment restricts file access to safe directories # The sandboxed environment restricts file access to safe directories
BLOCKED_PATTERNS = [ BLOCKED_PATTERNS = [
# Dangerous system modules # ==================== DANGEROUS SYSTEM MODULES ====================
# OS module (except path)
r'import\s+os\b(?!\s*\.path)', r'import\s+os\b(?!\s*\.path)',
r'from\s+os\s+import\s+(?!path)', r'from\s+os\s+import\s+(?!path)',
# File system modules
r'import\s+shutil\b', r'import\s+shutil\b',
r'from\s+shutil\s+import', r'from\s+shutil\s+import',
r'import\s+pathlib\b(?!\s*\.)', # Allow pathlib usage but monitor
# Subprocess and execution modules
r'import\s+subprocess\b', r'import\s+subprocess\b',
r'from\s+subprocess\s+import', r'from\s+subprocess\s+import',
r'import\s+sys\b(?!\s*\.(?:path|version|platform))', r'import\s+multiprocessing\b',
r'from\s+sys\s+import', r'from\s+multiprocessing\s+import',
r'import\s+threading\b',
r'from\s+threading\s+import',
r'import\s+concurrent\b',
r'from\s+concurrent\s+import',
# System access modules
r'import\s+sys\b(?!\s*\.(?:path|version|platform|stdout|stderr))',
r'from\s+sys\s+import\s+(?!path|version|platform|stdout|stderr)',
r'import\s+platform\b',
r'from\s+platform\s+import',
r'import\s+ctypes\b',
r'from\s+ctypes\s+import',
r'import\s+_[a-z]+', # Block private C modules
# ==================== NETWORK MODULES ====================
r'import\s+socket\b', r'import\s+socket\b',
r'from\s+socket\s+import', r'from\s+socket\s+import',
r'import\s+urllib\b', r'import\s+urllib\b',
@@ -92,19 +113,98 @@ BLOCKED_PATTERNS = [
r'from\s+requests\s+import', r'from\s+requests\s+import',
r'import\s+aiohttp\b', r'import\s+aiohttp\b',
r'from\s+aiohttp\s+import', r'from\s+aiohttp\s+import',
# Dangerous code execution r'import\s+httpx\b',
r'from\s+httpx\s+import',
r'import\s+http\.client\b',
r'from\s+http\.client\s+import',
r'import\s+ftplib\b',
r'from\s+ftplib\s+import',
r'import\s+smtplib\b',
r'from\s+smtplib\s+import',
r'import\s+telnetlib\b',
r'from\s+telnetlib\s+import',
r'import\s+ssl\b',
r'from\s+ssl\s+import',
r'import\s+paramiko\b',
r'from\s+paramiko\s+import',
# ==================== DANGEROUS CODE EXECUTION ====================
r'__import__\s*\(', r'__import__\s*\(',
r'\beval\s*\(', r'\beval\s*\(',
r'\bexec\s*\(', r'\bexec\s*\(',
r'\bcompile\s*\(', r'\bcompile\s*\(',
r'\bglobals\s*\(', r'\bglobals\s*\(',
r'\blocals\s*\(', r'\blocals\s*\(',
# File system operations (dangerous) r'\bgetattr\s*\([^,]+,\s*[\'"]__', # Block getattr for dunder methods
r'\bsetattr\s*\([^,]+,\s*[\'"]__', # Block setattr for dunder methods
r'\bdelattr\s*\([^,]+,\s*[\'"]__', # Block delattr for dunder methods
r'\.\_\_\w+\_\_', # Block dunder method access
# ==================== FILE SYSTEM OPERATIONS ====================
r'\.unlink\s*\(', r'\.unlink\s*\(',
r'\.rmdir\s*\(', r'\.rmdir\s*\(',
r'\.remove\s*\(', r'\.remove\s*\(',
r'\.chmod\s*\(', r'\.chmod\s*\(',
r'\.chown\s*\(', r'\.chown\s*\(',
r'\.rmtree\s*\(',
r'\.rename\s*\(',
r'\.replace\s*\(',
r'\.makedirs\s*\(', # Allow mkdir but block makedirs outside sandbox
r'Path\s*\(\s*[\'"]\/(?!tmp)', # Block absolute paths outside /tmp
r'open\s*\(\s*[\'"]\/(?!tmp)', # Block file access outside /tmp
# ==================== PICKLE AND SERIALIZATION ====================
r'pickle\.loads?\s*\(',
r'cPickle\.loads?\s*\(',
r'marshal\.loads?\s*\(',
r'shelve\.open\s*\(',
# ==================== PROCESS MANIPULATION ====================
r'os\.system\s*\(',
r'os\.popen\s*\(',
r'os\.spawn',
r'os\.exec',
r'os\.fork\s*\(',
r'os\.kill\s*\(',
r'os\.killpg\s*\(',
# ==================== ENVIRONMENT ACCESS ====================
r'os\.environ',
r'os\.getenv\s*\(',
r'os\.putenv\s*\(',
# ==================== DANGEROUS BUILTINS ====================
r'__builtins__',
r'__loader__',
r'__spec__',
# ==================== CODE OBJECT MANIPULATION ====================
r'\.f_code',
r'\.f_globals',
r'\.f_locals',
r'\.gi_frame',
r'\.co_code',
r'types\.CodeType',
r'types\.FunctionType',
# ==================== IMPORT SYSTEM MANIPULATION ====================
r'import\s+importlib\b',
r'from\s+importlib\s+import',
r'sys\.modules',
r'sys\.path\.(?:append|insert|extend)',
# ==================== MEMORY OPERATIONS ====================
r'gc\.',
r'sys\.getsizeof',
r'sys\.getrefcount',
r'id\s*\(', # Block id() which can leak memory addresses
]
# Additional patterns that log warnings but don't block
WARNING_PATTERNS = [
(r'while\s+True', "Infinite loop detected - ensure break condition exists"),
(r'for\s+\w+\s+in\s+range\s*\(\s*\d{6,}', "Very large loop detected"),
(r'recursion', "Recursion detected - ensure base case exists"),
] ]
@@ -772,10 +872,54 @@ class CodeExecutor:
logger.warning(f"Cleanup failed: {e}") logger.warning(f"Cleanup failed: {e}")
def validate_code_security(self, code: str) -> Tuple[bool, str]: def validate_code_security(self, code: str) -> Tuple[bool, str]:
"""Validate code for security threats.""" """
Validate code for security threats.
Performs comprehensive security checks including:
- Blocked patterns (dangerous imports, code execution, file ops)
- Warning patterns (potential issues that are logged)
- Code structure validation
Args:
code: The Python code to validate
Returns:
Tuple of (is_safe, message)
"""
# Check for blocked patterns
for pattern in BLOCKED_PATTERNS: for pattern in BLOCKED_PATTERNS:
if re.search(pattern, code, re.IGNORECASE): if re.search(pattern, code, re.IGNORECASE):
return False, f"Blocked unsafe operation: {pattern}" logger.warning(f"Blocked code pattern detected: {pattern[:50]}...")
return False, f"Security violation: Unsafe operation detected"
# Check for warning patterns (log but don't block)
for pattern, warning_msg in WARNING_PATTERNS:
if re.search(pattern, code, re.IGNORECASE):
logger.warning(f"Code warning: {warning_msg}")
# Additional structural checks
try:
# Parse the AST to check for suspicious constructs
tree = ast.parse(code)
for node in ast.walk(tree):
# Check for suspicious attribute access
if isinstance(node, ast.Attribute):
if node.attr.startswith('_') and node.attr.startswith('__'):
logger.warning(f"Dunder attribute access detected: {node.attr}")
return False, "Security violation: Private attribute access not allowed"
# Check for suspicious function calls
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ['eval', 'exec', 'compile', '__import__']:
return False, f"Security violation: {node.func.id}() is not allowed"
except SyntaxError:
# Syntax errors will be caught during execution
pass
except Exception as e:
logger.warning(f"Error during AST validation: {e}")
return True, "Code passed security validation" return True, "Code passed security validation"
def _extract_imports_from_code(self, code: str) -> List[str]: def _extract_imports_from_code(self, code: str) -> List[str]:

417
src/utils/discord_utils.py Normal file
View File

@@ -0,0 +1,417 @@
"""
Discord response utilities for sending messages with proper handling.
This module provides utilities for sending messages to Discord with
proper length handling, error recovery, and formatting.
"""
import discord
import asyncio
import logging
import io
from typing import Optional, List, Union
from dataclasses import dataclass
# Discord message limits
MAX_MESSAGE_LENGTH = 2000
MAX_EMBED_DESCRIPTION = 4096
MAX_EMBED_FIELD_VALUE = 1024
MAX_EMBED_FIELDS = 25
MAX_FILE_SIZE = 8 * 1024 * 1024 # 8MB for non-nitro
@dataclass
class MessageChunk:
"""A chunk of a message that fits within Discord limits."""
content: str
is_code_block: bool = False
language: Optional[str] = None
def split_message(
content: str,
max_length: int = MAX_MESSAGE_LENGTH,
split_on: List[str] = None
) -> List[str]:
"""
Split a long message into chunks that fit within Discord limits.
Args:
content: The message content to split
max_length: Maximum length per chunk
split_on: Preferred split points (default: newlines, spaces)
Returns:
List of message chunks
"""
if len(content) <= max_length:
return [content]
if split_on is None:
split_on = ['\n\n', '\n', '. ', ' ']
chunks = []
remaining = content
while remaining:
if len(remaining) <= max_length:
chunks.append(remaining)
break
# Find the best split point
split_index = max_length
for delimiter in split_on:
# Look for delimiter before max_length
last_index = remaining.rfind(delimiter, 0, max_length)
if last_index > max_length // 2: # Don't split too early
split_index = last_index + len(delimiter)
break
# If no good split point, hard cut at max_length
if split_index >= max_length:
split_index = max_length
chunks.append(remaining[:split_index])
remaining = remaining[split_index:]
return chunks
def split_code_block(
code: str,
language: str = "",
max_length: int = MAX_MESSAGE_LENGTH
) -> List[str]:
"""
Split code into properly formatted code block chunks.
Args:
code: The code content
language: The language for syntax highlighting
max_length: Maximum length per chunk
Returns:
List of formatted code block strings
"""
# Account for code block markers
marker_length = len(f"```{language}\n") + len("```")
effective_max = max_length - marker_length - 20 # Extra buffer
lines = code.split('\n')
chunks = []
current_chunk = []
current_length = 0
for line in lines:
line_length = len(line) + 1 # +1 for newline
if current_length + line_length > effective_max and current_chunk:
# Finish current chunk
chunk_code = '\n'.join(current_chunk)
chunks.append(f"```{language}\n{chunk_code}\n```")
current_chunk = [line]
current_length = line_length
else:
current_chunk.append(line)
current_length += line_length
# Add remaining chunk
if current_chunk:
chunk_code = '\n'.join(current_chunk)
chunks.append(f"```{language}\n{chunk_code}\n```")
return chunks
async def send_long_message(
channel: discord.abc.Messageable,
content: str,
max_length: int = MAX_MESSAGE_LENGTH,
delay: float = 0.5
) -> List[discord.Message]:
"""
Send a long message split across multiple Discord messages.
Args:
channel: The channel to send to
content: The message content
max_length: Maximum length per message
delay: Delay between messages to avoid rate limiting
Returns:
List of sent messages
"""
chunks = split_message(content, max_length)
messages = []
for i, chunk in enumerate(chunks):
try:
msg = await channel.send(chunk)
messages.append(msg)
# Add delay between messages (except for the last one)
if i < len(chunks) - 1:
await asyncio.sleep(delay)
except discord.HTTPException as e:
logging.error(f"Failed to send message chunk {i+1}: {e}")
# Try sending as file if message still too long
if "too long" in str(e).lower():
file = discord.File(
io.StringIO(chunk),
filename=f"message_part_{i+1}.txt"
)
msg = await channel.send(file=file)
messages.append(msg)
return messages
async def send_code_response(
channel: discord.abc.Messageable,
code: str,
language: str = "python",
title: Optional[str] = None
) -> List[discord.Message]:
"""
Send code with proper formatting, handling long code.
Args:
channel: The channel to send to
code: The code content
language: Programming language for highlighting
title: Optional title to display before code
Returns:
List of sent messages
"""
messages = []
if title:
msg = await channel.send(title)
messages.append(msg)
# If code is too long for code blocks, send as file
if len(code) > MAX_MESSAGE_LENGTH - 100:
file = discord.File(
io.StringIO(code),
filename=f"code.{language}" if language else "code.txt"
)
msg = await channel.send("📎 Code attached as file:", file=file)
messages.append(msg)
else:
chunks = split_code_block(code, language)
for chunk in chunks:
msg = await channel.send(chunk)
messages.append(msg)
await asyncio.sleep(0.3)
return messages
def create_error_embed(
title: str,
description: str,
error_type: str = "Error"
) -> discord.Embed:
"""
Create a standardized error embed.
Args:
title: Error title
description: Error description
error_type: Type of error for categorization
Returns:
Discord Embed object
"""
embed = discord.Embed(
title=f"{title}",
description=description[:MAX_EMBED_DESCRIPTION],
color=discord.Color.red()
)
embed.set_footer(text=f"Error Type: {error_type}")
return embed
def create_success_embed(
title: str,
description: str = ""
) -> discord.Embed:
"""
Create a standardized success embed.
Args:
title: Success title
description: Success description
Returns:
Discord Embed object
"""
embed = discord.Embed(
title=f"{title}",
description=description[:MAX_EMBED_DESCRIPTION] if description else None,
color=discord.Color.green()
)
return embed
def create_info_embed(
title: str,
description: str = "",
fields: List[tuple] = None
) -> discord.Embed:
"""
Create a standardized info embed with optional fields.
Args:
title: Info title
description: Info description
fields: List of (name, value, inline) tuples
Returns:
Discord Embed object
"""
embed = discord.Embed(
title=f" {title}",
description=description[:MAX_EMBED_DESCRIPTION] if description else None,
color=discord.Color.blue()
)
if fields:
for name, value, inline in fields[:MAX_EMBED_FIELDS]:
embed.add_field(
name=name[:256],
value=str(value)[:MAX_EMBED_FIELD_VALUE],
inline=inline
)
return embed
def create_progress_embed(
title: str,
description: str,
progress: float = 0.0
) -> discord.Embed:
"""
Create a progress indicator embed.
Args:
title: Progress title
description: Progress description
progress: Progress value 0.0 to 1.0
Returns:
Discord Embed object
"""
# Create progress bar
bar_length = 20
filled = int(bar_length * progress)
bar = "" * filled + "" * (bar_length - filled)
percentage = int(progress * 100)
embed = discord.Embed(
title=f"{title}",
description=f"{description}\n\n`{bar}` {percentage}%",
color=discord.Color.orange()
)
return embed
async def edit_or_send(
message: Optional[discord.Message],
channel: discord.abc.Messageable,
content: str = None,
embed: discord.Embed = None
) -> discord.Message:
"""
Edit an existing message or send a new one if editing fails.
Args:
message: Message to edit (or None to send new)
channel: Channel to send to if message is None
content: Message content
embed: Message embed
Returns:
The edited or new message
"""
try:
if message:
await message.edit(content=content, embed=embed)
return message
else:
return await channel.send(content=content, embed=embed)
except discord.HTTPException:
return await channel.send(content=content, embed=embed)
class ProgressMessage:
"""
A message that can be updated to show progress.
Usage:
async with ProgressMessage(channel, "Processing") as progress:
for i in range(100):
await progress.update(i / 100, f"Step {i}")
"""
def __init__(
self,
channel: discord.abc.Messageable,
title: str,
description: str = "Starting..."
):
self.channel = channel
self.title = title
self.description = description
self.message: Optional[discord.Message] = None
self._last_update = 0.0
self._update_interval = 2.0 # Minimum seconds between updates
async def __aenter__(self):
embed = create_progress_embed(self.title, self.description, 0.0)
self.message = await self.channel.send(embed=embed)
return self
async def __aexit__(self, *args):
# Clean up or finalize
pass
async def update(self, progress: float, description: str = None):
"""Update the progress message."""
import time
now = time.monotonic()
if now - self._last_update < self._update_interval:
return
self._last_update = now
if description:
self.description = description
try:
embed = create_progress_embed(self.title, self.description, progress)
await self.message.edit(embed=embed)
except discord.HTTPException:
pass # Ignore edit failures
async def complete(self, message: str = "Complete!"):
"""Mark the progress as complete."""
try:
embed = create_success_embed(self.title, message)
await self.message.edit(embed=embed)
except discord.HTTPException:
pass
async def error(self, message: str):
"""Mark the progress as failed."""
try:
embed = create_error_embed(self.title, message)
await self.message.edit(embed=embed)
except discord.HTTPException:
pass

446
src/utils/monitoring.py Normal file
View File

@@ -0,0 +1,446 @@
"""
Monitoring and observability utilities.
This module provides structured logging, error tracking with Sentry,
and performance monitoring for the Discord bot.
"""
import os
import logging
import time
import asyncio
from typing import Any, Dict, Optional, Callable
from functools import wraps
from contextlib import contextmanager, asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
# Try to import Sentry
try:
import sentry_sdk
from sentry_sdk.integrations.asyncio import AsyncioIntegration
SENTRY_AVAILABLE = True
except ImportError:
SENTRY_AVAILABLE = False
sentry_sdk = None
logger = logging.getLogger(__name__)
# ============================================================
# Configuration
# ============================================================
@dataclass
class MonitoringConfig:
"""Configuration for monitoring features."""
sentry_dsn: Optional[str] = None
environment: str = "development"
sample_rate: float = 1.0 # 100% of events
traces_sample_rate: float = 0.1 # 10% of transactions
log_level: str = "INFO"
structured_logging: bool = True
def setup_monitoring(config: Optional[MonitoringConfig] = None) -> None:
"""
Initialize monitoring with optional Sentry integration.
Args:
config: Monitoring configuration, uses env vars if not provided
"""
if config is None:
config = MonitoringConfig(
sentry_dsn=os.environ.get("SENTRY_DSN"),
environment=os.environ.get("ENVIRONMENT", "development"),
sample_rate=float(os.environ.get("SENTRY_SAMPLE_RATE", "1.0")),
traces_sample_rate=float(os.environ.get("SENTRY_TRACES_RATE", "0.1")),
log_level=os.environ.get("LOG_LEVEL", "INFO"),
)
# Setup logging
setup_structured_logging(
level=config.log_level,
structured=config.structured_logging
)
# Setup Sentry if available and configured
if SENTRY_AVAILABLE and config.sentry_dsn:
sentry_sdk.init(
dsn=config.sentry_dsn,
environment=config.environment,
sample_rate=config.sample_rate,
traces_sample_rate=config.traces_sample_rate,
integrations=[AsyncioIntegration()],
before_send=before_send_filter,
)
logger.info(f"Sentry initialized for environment: {config.environment}")
else:
if config.sentry_dsn and not SENTRY_AVAILABLE:
logger.warning("Sentry DSN provided but sentry_sdk not installed")
logger.info("Running without Sentry error tracking")
def before_send_filter(event: Dict, hint: Dict) -> Optional[Dict]:
"""Filter events before sending to Sentry."""
# Don't send events for expected/handled errors
if "exc_info" in hint:
exc_type, exc_value, _ = hint["exc_info"]
# Skip common non-critical errors
if exc_type.__name__ in [
"NotFound", # Discord 404
"Forbidden", # Discord 403
"RateLimited", # Discord rate limit
]:
return None
return event
# ============================================================
# Structured Logging
# ============================================================
class StructuredFormatter(logging.Formatter):
"""JSON-like structured log formatter."""
def format(self, record: logging.LogRecord) -> str:
"""Format log record as structured message."""
log_entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
# Add extra fields
if hasattr(record, "user_id"):
log_entry["user_id"] = record.user_id
if hasattr(record, "guild_id"):
log_entry["guild_id"] = record.guild_id
if hasattr(record, "command"):
log_entry["command"] = record.command
if hasattr(record, "duration_ms"):
log_entry["duration_ms"] = record.duration_ms
if hasattr(record, "model"):
log_entry["model"] = record.model
# Add exception info if present
if record.exc_info:
log_entry["exception"] = self.formatException(record.exc_info)
# Format as key=value pairs for easy parsing
parts = [f"{k}={v!r}" for k, v in log_entry.items()]
return " ".join(parts)
def setup_structured_logging(
level: str = "INFO",
structured: bool = True
) -> None:
"""
Setup logging configuration.
Args:
level: Log level (DEBUG, INFO, WARNING, ERROR)
structured: Use structured formatting
"""
log_level = getattr(logging, level.upper(), logging.INFO)
# Create handler
handler = logging.StreamHandler()
handler.setLevel(log_level)
if structured:
handler.setFormatter(StructuredFormatter())
else:
handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
root_logger.handlers = [handler]
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the given name."""
return logging.getLogger(name)
# ============================================================
# Error Tracking
# ============================================================
def capture_exception(
exception: Exception,
context: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""
Capture and report an exception.
Args:
exception: The exception to capture
context: Additional context to attach
Returns:
Event ID if sent to Sentry, None otherwise
"""
logger.exception(f"Captured exception: {exception}")
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
with sentry_sdk.push_scope() as scope:
if context:
for key, value in context.items():
scope.set_extra(key, value)
return sentry_sdk.capture_exception(exception)
return None
def capture_message(
message: str,
level: str = "info",
context: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""
Capture and report a message.
Args:
message: The message to capture
level: Severity level (debug, info, warning, error, fatal)
context: Additional context to attach
Returns:
Event ID if sent to Sentry, None otherwise
"""
log_method = getattr(logger, level, logger.info)
log_method(message)
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
with sentry_sdk.push_scope() as scope:
if context:
for key, value in context.items():
scope.set_extra(key, value)
return sentry_sdk.capture_message(message, level=level)
return None
def set_user_context(
user_id: int,
username: Optional[str] = None,
guild_id: Optional[int] = None
) -> None:
"""
Set user context for error tracking.
Args:
user_id: Discord user ID
username: Discord username
guild_id: Discord guild ID
"""
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
sentry_sdk.set_user({
"id": str(user_id),
"username": username,
})
if guild_id:
sentry_sdk.set_tag("guild_id", str(guild_id))
# ============================================================
# Performance Monitoring
# ============================================================
@dataclass
class PerformanceMetrics:
"""Container for performance metrics."""
name: str
start_time: float = field(default_factory=time.perf_counter)
end_time: Optional[float] = None
success: bool = True
error: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def duration_ms(self) -> float:
"""Get duration in milliseconds."""
end = self.end_time or time.perf_counter()
return (end - self.start_time) * 1000
def finish(self, success: bool = True, error: Optional[str] = None) -> None:
"""Mark the operation as finished."""
self.end_time = time.perf_counter()
self.success = success
self.error = error
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for logging."""
return {
"name": self.name,
"duration_ms": round(self.duration_ms, 2),
"success": self.success,
"error": self.error,
**self.metadata
}
@contextmanager
def measure_sync(name: str, **metadata):
"""
Context manager to measure synchronous operation performance.
Usage:
with measure_sync("database_query", table="users"):
result = db.query(...)
"""
metrics = PerformanceMetrics(name=name, metadata=metadata)
try:
yield metrics
metrics.finish(success=True)
except Exception as e:
metrics.finish(success=False, error=str(e))
raise
finally:
logger.info(
f"Performance: {metrics.name}",
extra={"duration_ms": metrics.duration_ms, **metrics.metadata}
)
@asynccontextmanager
async def measure_async(name: str, **metadata):
"""
Async context manager to measure async operation performance.
Usage:
async with measure_async("api_call", endpoint="chat"):
result = await api.call(...)
"""
metrics = PerformanceMetrics(name=name, metadata=metadata)
# Start Sentry transaction if available
transaction = None
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
transaction = sentry_sdk.start_transaction(
op="task",
name=name
)
try:
yield metrics
metrics.finish(success=True)
except Exception as e:
metrics.finish(success=False, error=str(e))
raise
finally:
if transaction:
transaction.set_status("ok" if metrics.success else "internal_error")
transaction.finish()
logger.info(
f"Performance: {metrics.name}",
extra={"duration_ms": metrics.duration_ms, **metrics.metadata}
)
def track_performance(name: Optional[str] = None):
"""
Decorator to track async function performance.
Args:
name: Operation name (defaults to function name)
Usage:
@track_performance("process_message")
async def handle_message(message):
...
"""
def decorator(func: Callable):
op_name = name or func.__name__
@wraps(func)
async def wrapper(*args, **kwargs):
async with measure_async(op_name):
return await func(*args, **kwargs)
return wrapper
return decorator
# ============================================================
# Health Check
# ============================================================
@dataclass
class HealthStatus:
"""Health check status."""
healthy: bool
checks: Dict[str, Dict[str, Any]] = field(default_factory=dict)
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
def add_check(
self,
name: str,
healthy: bool,
message: str = "",
details: Optional[Dict] = None
) -> None:
"""Add a health check result."""
self.checks[name] = {
"healthy": healthy,
"message": message,
**(details or {})
}
if not healthy:
self.healthy = False
async def check_health(
db_handler=None,
openai_client=None
) -> HealthStatus:
"""
Perform health checks on bot dependencies.
Args:
db_handler: Database handler to check
openai_client: OpenAI client to check
Returns:
HealthStatus with check results
"""
status = HealthStatus(healthy=True)
# Check database
if db_handler:
try:
# Simple ping or list operation
await asyncio.wait_for(
db_handler.client.admin.command('ping'),
timeout=5.0
)
status.add_check("database", True, "MongoDB connected")
except Exception as e:
status.add_check("database", False, f"MongoDB error: {e}")
# Check OpenAI
if openai_client:
try:
# List models as a simple check
await asyncio.wait_for(
openai_client.models.list(),
timeout=10.0
)
status.add_check("openai", True, "OpenAI API accessible")
except Exception as e:
status.add_check("openai", False, f"OpenAI error: {e}")
return status

280
src/utils/retry.py Normal file
View File

@@ -0,0 +1,280 @@
"""
Retry utilities with exponential backoff for API calls.
This module provides robust retry logic for external API calls
to handle transient failures gracefully.
"""
import asyncio
import logging
import random
from typing import TypeVar, Callable, Optional, Any, Type, Tuple
from functools import wraps
T = TypeVar('T')
# Default configuration
DEFAULT_MAX_RETRIES = 3
DEFAULT_BASE_DELAY = 1.0 # seconds
DEFAULT_MAX_DELAY = 60.0 # seconds
DEFAULT_EXPONENTIAL_BASE = 2
class RetryError(Exception):
"""Raised when all retry attempts have been exhausted."""
def __init__(self, message: str, last_exception: Optional[Exception] = None):
super().__init__(message)
self.last_exception = last_exception
async def async_retry_with_backoff(
func: Callable,
*args,
max_retries: int = DEFAULT_MAX_RETRIES,
base_delay: float = DEFAULT_BASE_DELAY,
max_delay: float = DEFAULT_MAX_DELAY,
exponential_base: float = DEFAULT_EXPONENTIAL_BASE,
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
jitter: bool = True,
on_retry: Optional[Callable[[int, Exception], None]] = None,
**kwargs
) -> Any:
"""
Execute an async function with exponential backoff retry.
Args:
func: The async function to execute
*args: Positional arguments for the function
max_retries: Maximum number of retry attempts
base_delay: Initial delay between retries in seconds
max_delay: Maximum delay between retries
exponential_base: Base for exponential backoff calculation
retryable_exceptions: Tuple of exception types that should trigger retry
jitter: Whether to add randomness to delay
on_retry: Optional callback called on each retry with (attempt, exception)
**kwargs: Keyword arguments for the function
Returns:
The return value of the function
Raises:
RetryError: When all retries are exhausted
"""
last_exception = None
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exception = e
if attempt == max_retries:
logging.error(f"All {max_retries} retries exhausted for {func.__name__}: {e}")
raise RetryError(
f"Failed after {max_retries} retries: {str(e)}",
last_exception=e
)
# Calculate delay with exponential backoff
delay = min(base_delay * (exponential_base ** attempt), max_delay)
# Add jitter to prevent thundering herd
if jitter:
delay = delay * (0.5 + random.random())
logging.warning(
f"Retry {attempt + 1}/{max_retries} for {func.__name__} "
f"after {delay:.2f}s delay. Error: {e}"
)
if on_retry:
try:
on_retry(attempt + 1, e)
except Exception as callback_error:
logging.warning(f"on_retry callback failed: {callback_error}")
await asyncio.sleep(delay)
# Should not reach here, but just in case
raise RetryError("Unexpected retry loop exit", last_exception=last_exception)
def retry_decorator(
max_retries: int = DEFAULT_MAX_RETRIES,
base_delay: float = DEFAULT_BASE_DELAY,
max_delay: float = DEFAULT_MAX_DELAY,
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
jitter: bool = True
):
"""
Decorator for adding retry logic to async functions.
Usage:
@retry_decorator(max_retries=3, base_delay=1.0)
async def my_api_call():
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
return await async_retry_with_backoff(
func,
*args,
max_retries=max_retries,
base_delay=base_delay,
max_delay=max_delay,
retryable_exceptions=retryable_exceptions,
jitter=jitter,
**kwargs
)
return wrapper
return decorator
# Common exception sets for different APIs
OPENAI_RETRYABLE_EXCEPTIONS = (
# Add specific OpenAI exceptions as needed
TimeoutError,
ConnectionError,
)
DISCORD_RETRYABLE_EXCEPTIONS = (
# Add specific Discord exceptions as needed
TimeoutError,
ConnectionError,
)
HTTP_RETRYABLE_EXCEPTIONS = (
TimeoutError,
ConnectionError,
ConnectionResetError,
)
class RateLimiter:
"""
Simple rate limiter for API calls.
Usage:
limiter = RateLimiter(calls_per_second=1)
async with limiter:
await make_api_call()
"""
def __init__(self, calls_per_second: float = 1.0):
self.min_interval = 1.0 / calls_per_second
self.last_call = 0.0
self._lock = asyncio.Lock()
async def __aenter__(self):
async with self._lock:
import time
now = time.monotonic()
time_since_last = now - self.last_call
if time_since_last < self.min_interval:
await asyncio.sleep(self.min_interval - time_since_last)
self.last_call = time.monotonic()
return self
async def __aexit__(self, *args):
pass
class CircuitBreaker:
"""
Circuit breaker pattern for preventing cascade failures.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Too many failures, requests are rejected immediately
- HALF_OPEN: Testing if service recovered
"""
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
half_open_requests: int = 3
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.half_open_requests = half_open_requests
self.state = self.CLOSED
self.failure_count = 0
self.last_failure_time = 0.0
self.half_open_successes = 0
self._lock = asyncio.Lock()
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute a function through the circuit breaker.
Args:
func: The async function to execute
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
The function result
Raises:
Exception: If circuit is open or function fails
"""
async with self._lock:
await self._check_state()
if self.state == self.OPEN:
raise Exception("Circuit breaker is OPEN - service unavailable")
try:
result = await func(*args, **kwargs)
await self._on_success()
return result
except Exception as e:
await self._on_failure()
raise
async def _check_state(self):
"""Check and potentially update circuit state."""
import time
if self.state == self.OPEN:
if time.monotonic() - self.last_failure_time >= self.recovery_timeout:
logging.info("Circuit breaker transitioning to HALF_OPEN")
self.state = self.HALF_OPEN
self.half_open_successes = 0
async def _on_success(self):
"""Handle successful call."""
async with self._lock:
if self.state == self.HALF_OPEN:
self.half_open_successes += 1
if self.half_open_successes >= self.half_open_requests:
logging.info("Circuit breaker transitioning to CLOSED")
self.state = self.CLOSED
self.failure_count = 0
elif self.state == self.CLOSED:
self.failure_count = 0
async def _on_failure(self):
"""Handle failed call."""
import time
async with self._lock:
self.failure_count += 1
self.last_failure_time = time.monotonic()
if self.state == self.HALF_OPEN:
logging.warning("Circuit breaker transitioning to OPEN (half-open failure)")
self.state = self.OPEN
elif self.failure_count >= self.failure_threshold:
logging.warning(f"Circuit breaker transitioning to OPEN ({self.failure_count} failures)")
self.state = self.OPEN

View File

@@ -304,8 +304,8 @@ class TokenCounter:
Returns: Returns:
Estimated cost in USD Estimated cost in USD
""" """
# Import here to avoid circular dependency # Import from centralized pricing module
from src.commands.commands import MODEL_PRICING from src.config.pricing import MODEL_PRICING
if model not in MODEL_PRICING: if model not in MODEL_PRICING:
model = "openai/gpt-4o" # Default fallback model = "openai/gpt-4o" # Default fallback

287
src/utils/validators.py Normal file
View File

@@ -0,0 +1,287 @@
"""
Input validation utilities for the Discord bot.
This module provides centralized validation for user inputs,
enhancing security and reducing code duplication.
"""
import re
import logging
from typing import Optional, Tuple, List
from dataclasses import dataclass
# Maximum allowed lengths for various inputs
MAX_MESSAGE_LENGTH = 4000 # Discord's limit is 2000, but we process longer
MAX_PROMPT_LENGTH = 32000 # Reasonable limit for AI prompts
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
MAX_FILENAME_LENGTH = 255
MAX_URL_LENGTH = 2048
MAX_CODE_LENGTH = 100000 # 100KB of code
@dataclass
class ValidationResult:
"""Result of a validation check."""
is_valid: bool
error_message: Optional[str] = None
sanitized_value: Optional[str] = None
def validate_message_content(content: str) -> ValidationResult:
"""
Validate and sanitize message content.
Args:
content: The message content to validate
Returns:
ValidationResult with validation status and sanitized content
"""
if not content:
return ValidationResult(is_valid=True, sanitized_value="")
if len(content) > MAX_MESSAGE_LENGTH:
return ValidationResult(
is_valid=False,
error_message=f"Message too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed."
)
# Remove null bytes and other control characters (except newlines/tabs)
sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', content)
return ValidationResult(is_valid=True, sanitized_value=sanitized)
def validate_prompt(prompt: str) -> ValidationResult:
"""
Validate AI prompt content.
Args:
prompt: The prompt to validate
Returns:
ValidationResult with validation status
"""
if not prompt or not prompt.strip():
return ValidationResult(
is_valid=False,
error_message="Prompt cannot be empty."
)
if len(prompt) > MAX_PROMPT_LENGTH:
return ValidationResult(
is_valid=False,
error_message=f"Prompt too long. Maximum {MAX_PROMPT_LENGTH} characters allowed."
)
# Remove null bytes
sanitized = prompt.replace('\x00', '')
return ValidationResult(is_valid=True, sanitized_value=sanitized)
def validate_url(url: str) -> ValidationResult:
"""
Validate and sanitize a URL.
Args:
url: The URL to validate
Returns:
ValidationResult with validation status
"""
if not url:
return ValidationResult(
is_valid=False,
error_message="URL cannot be empty."
)
if len(url) > MAX_URL_LENGTH:
return ValidationResult(
is_valid=False,
error_message=f"URL too long. Maximum {MAX_URL_LENGTH} characters allowed."
)
# Basic URL pattern check
url_pattern = re.compile(
r'^https?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain
r'localhost|' # localhost
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # or IP
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE
)
if not url_pattern.match(url):
return ValidationResult(
is_valid=False,
error_message="Invalid URL format."
)
# Check for potentially dangerous URL schemes
dangerous_schemes = ['javascript:', 'data:', 'file:', 'vbscript:']
url_lower = url.lower()
for scheme in dangerous_schemes:
if scheme in url_lower:
return ValidationResult(
is_valid=False,
error_message="URL contains potentially dangerous content."
)
return ValidationResult(is_valid=True, sanitized_value=url)
def validate_filename(filename: str) -> ValidationResult:
"""
Validate and sanitize a filename.
Args:
filename: The filename to validate
Returns:
ValidationResult with validation status and sanitized filename
"""
if not filename:
return ValidationResult(
is_valid=False,
error_message="Filename cannot be empty."
)
if len(filename) > MAX_FILENAME_LENGTH:
return ValidationResult(
is_valid=False,
error_message=f"Filename too long. Maximum {MAX_FILENAME_LENGTH} characters allowed."
)
# Remove path traversal attempts
sanitized = filename.replace('..', '').replace('/', '').replace('\\', '')
# Remove dangerous characters
sanitized = re.sub(r'[<>:"|?*\x00-\x1f]', '', sanitized)
# Ensure it's not empty after sanitization
if not sanitized:
return ValidationResult(
is_valid=False,
error_message="Filename contains only invalid characters."
)
return ValidationResult(is_valid=True, sanitized_value=sanitized)
def validate_file_size(size: int) -> ValidationResult:
"""
Validate file size.
Args:
size: The file size in bytes
Returns:
ValidationResult with validation status
"""
if size <= 0:
return ValidationResult(
is_valid=False,
error_message="File size must be greater than 0."
)
if size > MAX_FILE_SIZE:
max_mb = MAX_FILE_SIZE / (1024 * 1024)
return ValidationResult(
is_valid=False,
error_message=f"File too large. Maximum {max_mb:.0f}MB allowed."
)
return ValidationResult(is_valid=True)
def validate_code(code: str) -> ValidationResult:
"""
Validate code for execution.
Args:
code: The code to validate
Returns:
ValidationResult with validation status
"""
if not code or not code.strip():
return ValidationResult(
is_valid=False,
error_message="Code cannot be empty."
)
if len(code) > MAX_CODE_LENGTH:
return ValidationResult(
is_valid=False,
error_message=f"Code too long. Maximum {MAX_CODE_LENGTH} characters allowed."
)
return ValidationResult(is_valid=True, sanitized_value=code)
def validate_user_id(user_id) -> ValidationResult:
"""
Validate a Discord user ID.
Args:
user_id: The user ID to validate
Returns:
ValidationResult with validation status
"""
try:
uid = int(user_id)
if uid <= 0:
return ValidationResult(
is_valid=False,
error_message="Invalid user ID."
)
# Discord IDs are 17-19 digits
if len(str(uid)) < 17 or len(str(uid)) > 19:
return ValidationResult(
is_valid=False,
error_message="Invalid user ID format."
)
return ValidationResult(is_valid=True)
except (ValueError, TypeError):
return ValidationResult(
is_valid=False,
error_message="User ID must be a valid integer."
)
def sanitize_for_logging(text: str, max_length: int = 200) -> str:
"""
Sanitize text for safe logging (remove sensitive data, truncate).
Args:
text: The text to sanitize
max_length: Maximum length for logged text
Returns:
Sanitized text safe for logging
"""
if not text:
return ""
# Remove potential secrets/tokens (common patterns)
patterns = [
(r'(sk-[a-zA-Z0-9]{20,})', '[OPENAI_KEY]'),
(r'(xoxb-[a-zA-Z0-9-]+)', '[SLACK_TOKEN]'),
(r'([A-Za-z0-9_-]{24}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27})', '[DISCORD_TOKEN]'),
(r'(mongodb\+srv://[^@]+@)', 'mongodb+srv://[REDACTED]@'),
(r'(Bearer\s+[A-Za-z0-9_-]+)', 'Bearer [TOKEN]'),
(r'(password["\']?\s*[:=]\s*["\']?)[^"\'\s]+', r'\1[REDACTED]'),
]
sanitized = text
for pattern, replacement in patterns:
sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE)
# Truncate if needed
if len(sanitized) > max_length:
sanitized = sanitized[:max_length] + '...[truncated]'
return sanitized

727
tests/test_comprehensive.py Normal file
View File

@@ -0,0 +1,727 @@
"""
Comprehensive test suite for the ChatGPT Discord Bot.
This module contains unit tests and integration tests for all major components.
Uses pytest with pytest-asyncio for async test support.
"""
import asyncio
import pytest
import os
import sys
import json
from unittest.mock import MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
from typing import Dict, Any
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# ============================================================
# Test Fixtures
# ============================================================
@pytest.fixture
def mock_db_handler():
"""Create a mock database handler."""
mock = MagicMock()
mock.get_history = AsyncMock(return_value=[])
mock.save_history = AsyncMock()
mock.get_user_model = AsyncMock(return_value="openai/gpt-4o")
mock.save_user_model = AsyncMock()
mock.is_admin = AsyncMock(return_value=False)
mock.is_user_whitelisted = AsyncMock(return_value=True)
mock.is_user_blacklisted = AsyncMock(return_value=False)
mock.get_user_tool_display = AsyncMock(return_value=False)
mock.get_user_files = AsyncMock(return_value=[])
mock.save_token_usage = AsyncMock()
return mock
@pytest.fixture
def mock_openai_client():
"""Create a mock OpenAI client."""
mock = MagicMock()
# Mock response structure
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 100
mock_response.usage.completion_tokens = 50
mock.chat.completions.create = AsyncMock(return_value=mock_response)
return mock
@pytest.fixture
def mock_discord_message():
"""Create a mock Discord message."""
mock = MagicMock()
mock.author.id = 123456789
mock.author.name = "TestUser"
mock.content = "Hello, bot!"
mock.channel.send = AsyncMock()
mock.channel.typing = MagicMock(return_value=AsyncMock().__aenter__())
mock.attachments = []
mock.reference = None
mock.guild = MagicMock()
return mock
# ============================================================
# Pricing Module Tests
# ============================================================
class TestPricingModule:
"""Tests for the pricing configuration module."""
def test_model_pricing_exists(self):
"""Test that all expected models have pricing defined."""
from src.config.pricing import MODEL_PRICING
expected_models = [
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/gpt-4.1",
"openai/gpt-5",
"openai/o1",
]
for model in expected_models:
assert model in MODEL_PRICING, f"Missing pricing for {model}"
def test_calculate_cost(self):
"""Test cost calculation for known models."""
from src.config.pricing import calculate_cost
# GPT-4o: $5.00 input, $20.00 output per 1M tokens
cost = calculate_cost("openai/gpt-4o", 1_000_000, 1_000_000)
assert cost == 25.00 # $5 + $20
# Test smaller amounts
cost = calculate_cost("openai/gpt-4o", 1000, 1000)
assert cost == pytest.approx(0.025, rel=1e-6) # $0.005 + $0.020
def test_calculate_cost_unknown_model(self):
"""Test that unknown models return 0 cost."""
from src.config.pricing import calculate_cost
cost = calculate_cost("unknown/model", 1000, 1000)
assert cost == 0.0
def test_format_cost(self):
"""Test cost formatting for display."""
from src.config.pricing import format_cost
assert format_cost(0.000001) == "$0.000001"
assert format_cost(0.005) == "$0.005000" # 6 decimal places for small amounts
assert format_cost(1.50) == "$1.50"
assert format_cost(100.00) == "$100.00"
# ============================================================
# Validator Module Tests
# ============================================================
class TestValidators:
"""Tests for input validation utilities."""
def test_validate_message_content(self):
"""Test message content validation."""
from src.utils.validators import validate_message_content
# Valid content
result = validate_message_content("Hello, world!")
assert result.is_valid
assert result.sanitized_value == "Hello, world!"
# Empty content is valid
result = validate_message_content("")
assert result.is_valid
# Content with null bytes should be sanitized
result = validate_message_content("Hello\x00World")
assert result.is_valid
assert "\x00" not in result.sanitized_value
def test_validate_message_too_long(self):
"""Test that overly long messages are rejected."""
from src.utils.validators import validate_message_content, MAX_MESSAGE_LENGTH
long_message = "x" * (MAX_MESSAGE_LENGTH + 1)
result = validate_message_content(long_message)
assert not result.is_valid
assert "too long" in result.error_message.lower()
def test_validate_url(self):
"""Test URL validation."""
from src.utils.validators import validate_url
# Valid URLs
assert validate_url("https://example.com").is_valid
assert validate_url("http://localhost:8080/path").is_valid
assert validate_url("https://api.example.com/v1/data?q=test").is_valid
# Invalid URLs
assert not validate_url("").is_valid
assert not validate_url("not-a-url").is_valid
assert not validate_url("javascript:alert(1)").is_valid
assert not validate_url("file:///etc/passwd").is_valid
def test_validate_filename(self):
"""Test filename validation and sanitization."""
from src.utils.validators import validate_filename
# Valid filename
result = validate_filename("test_file.txt")
assert result.is_valid
assert result.sanitized_value == "test_file.txt"
# Path traversal attempt
result = validate_filename("../../../etc/passwd")
assert result.is_valid # Sanitized, not rejected
assert ".." not in result.sanitized_value
assert "/" not in result.sanitized_value
# Empty filename
result = validate_filename("")
assert not result.is_valid
def test_sanitize_for_logging(self):
"""Test that secrets are properly redacted for logging."""
from src.utils.validators import sanitize_for_logging
# Test OpenAI key redaction
text = "API key is sk-abcdefghijklmnopqrstuvwxyz123456"
sanitized = sanitize_for_logging(text)
assert "sk-" not in sanitized
assert "[OPENAI_KEY]" in sanitized
# Test MongoDB URI redaction
text = "mongodb+srv://user:password@cluster.mongodb.net/db"
sanitized = sanitize_for_logging(text)
assert "password" not in sanitized
assert "[REDACTED]" in sanitized
# Test truncation
long_text = "x" * 500
sanitized = sanitize_for_logging(long_text, max_length=100)
assert len(sanitized) < 150 # Account for truncation marker
# ============================================================
# Retry Module Tests
# ============================================================
class TestRetryModule:
"""Tests for retry utilities."""
@pytest.mark.asyncio
async def test_retry_success_first_try(self):
"""Test that successful functions don't retry."""
from src.utils.retry import async_retry_with_backoff
call_count = 0
async def success_func():
nonlocal call_count
call_count += 1
return "success"
result = await async_retry_with_backoff(success_func, max_retries=3)
assert result == "success"
assert call_count == 1
@pytest.mark.asyncio
async def test_retry_eventual_success(self):
"""Test that functions eventually succeed after retries."""
from src.utils.retry import async_retry_with_backoff
call_count = 0
async def eventual_success():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ConnectionError("Temporary failure")
return "success"
result = await async_retry_with_backoff(
eventual_success,
max_retries=5,
base_delay=0.01, # Fast for testing
retryable_exceptions=(ConnectionError,)
)
assert result == "success"
assert call_count == 3
@pytest.mark.asyncio
async def test_retry_exhausted(self):
"""Test that RetryError is raised when retries are exhausted."""
from src.utils.retry import async_retry_with_backoff, RetryError
async def always_fail():
raise ConnectionError("Always fails")
with pytest.raises(RetryError):
await async_retry_with_backoff(
always_fail,
max_retries=2,
base_delay=0.01,
retryable_exceptions=(ConnectionError,)
)
# ============================================================
# Discord Utils Tests
# ============================================================
class TestDiscordUtils:
"""Tests for Discord utility functions."""
def test_split_message_short(self):
"""Test that short messages aren't split."""
from src.utils.discord_utils import split_message
short = "This is a short message."
chunks = split_message(short)
assert len(chunks) == 1
assert chunks[0] == short
def test_split_message_long(self):
"""Test that long messages are properly split."""
from src.utils.discord_utils import split_message
# Create a message longer than 2000 characters
long = "Hello world. " * 200
chunks = split_message(long, max_length=2000)
assert len(chunks) > 1
for chunk in chunks:
assert len(chunk) <= 2000
def test_split_code_block(self):
"""Test code block splitting."""
from src.utils.discord_utils import split_code_block
code = "\n".join([f"line {i}" for i in range(100)])
chunks = split_code_block(code, "python", max_length=500)
assert len(chunks) > 1
for chunk in chunks:
assert chunk.startswith("```python\n")
assert chunk.endswith("\n```")
assert len(chunk) <= 500
def test_create_error_embed(self):
"""Test error embed creation."""
from src.utils.discord_utils import create_error_embed
import discord
embed = create_error_embed("Test Error", "Something went wrong", "ValidationError")
assert isinstance(embed, discord.Embed)
assert "Test Error" in embed.title
assert embed.color == discord.Color.red()
def test_create_success_embed(self):
"""Test success embed creation."""
from src.utils.discord_utils import create_success_embed
import discord
embed = create_success_embed("Success!", "Operation completed")
assert isinstance(embed, discord.Embed)
assert "Success!" in embed.title
assert embed.color == discord.Color.green()
# ============================================================
# Code Interpreter Security Tests
# ============================================================
class TestCodeInterpreterSecurity:
"""Tests for code interpreter security features."""
def test_blocked_imports(self):
"""Test that dangerous imports are blocked."""
from src.utils.code_interpreter import BLOCKED_PATTERNS
import re
dangerous_code = [
"import os",
"import subprocess",
"from os import system",
"import socket",
"import requests",
"__import__('os')",
"eval('print(1)')",
"exec('import os')",
]
for code in dangerous_code:
blocked = any(
re.search(pattern, code, re.IGNORECASE)
for pattern in BLOCKED_PATTERNS
)
assert blocked, f"Should block: {code}"
def test_allowed_imports(self):
"""Test that safe imports are allowed."""
from src.utils.code_interpreter import BLOCKED_PATTERNS
import re
safe_code = [
"import pandas as pd",
"import numpy as np",
"import matplotlib.pyplot as plt",
"from sklearn.model_selection import train_test_split",
"import os.path", # os.path is allowed
]
for code in safe_code:
blocked = any(
re.search(pattern, code, re.IGNORECASE)
for pattern in BLOCKED_PATTERNS
)
assert not blocked, f"Should allow: {code}"
def test_file_type_detection(self):
"""Test file type detection for various extensions."""
from src.utils.code_interpreter import FileManager
fm = FileManager()
assert fm._detect_file_type("data.csv") == "csv"
assert fm._detect_file_type("data.xlsx") == "excel"
assert fm._detect_file_type("config.json") == "json"
assert fm._detect_file_type("image.png") == "image"
assert fm._detect_file_type("script.py") == "python"
assert fm._detect_file_type("unknown.xyz") == "binary"
# ============================================================
# OpenAI Utils Tests
# ============================================================
class TestOpenAIUtils:
"""Tests for OpenAI utility functions."""
def test_count_tokens(self):
"""Test token counting function."""
from src.utils.openai_utils import count_tokens
text = "Hello, world!"
tokens = count_tokens(text)
assert tokens > 0
assert isinstance(tokens, int)
def test_trim_content_to_token_limit(self):
"""Test content trimming."""
from src.utils.openai_utils import trim_content_to_token_limit
# Short content should not be trimmed
short = "Hello, world!"
trimmed = trim_content_to_token_limit(short, max_tokens=100)
assert trimmed == short
# Long content should be trimmed
long = "Hello " * 10000
trimmed = trim_content_to_token_limit(long, max_tokens=100)
assert len(trimmed) < len(long)
def test_prepare_messages_for_api(self):
"""Test message preparation for API."""
from src.utils.openai_utils import prepare_messages_for_api
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
prepared = prepare_messages_for_api(messages)
assert len(prepared) == 3
assert all(m.get("role") in ["user", "assistant", "system"] for m in prepared)
def test_prepare_messages_filters_none_content(self):
"""Test that messages with None content are filtered."""
from src.utils.openai_utils import prepare_messages_for_api
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": None},
{"role": "user", "content": "World"},
]
prepared = prepare_messages_for_api(messages)
assert len(prepared) == 2
# ============================================================
# Database Handler Tests (with mocking)
# ============================================================
class TestDatabaseHandlerMocked:
"""Tests for database handler using mocks."""
def test_filter_expired_images_no_images(self):
"""Test that messages without images pass through unchanged."""
from src.database.db_handler import DatabaseHandler
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
handler = DatabaseHandler("mongodb://localhost")
history = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
filtered = handler._filter_expired_images(history)
assert len(filtered) == 2
assert filtered[0]["content"] == "Hello"
def test_filter_expired_images_recent_image(self):
"""Test that recent images are kept."""
from src.database.db_handler import DatabaseHandler
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
handler = DatabaseHandler("mongodb://localhost")
recent_timestamp = datetime.now().isoformat()
history = [
{"role": "user", "content": [
{"type": "text", "text": "Check this image"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": recent_timestamp}
]}
]
filtered = handler._filter_expired_images(history)
assert len(filtered) == 1
assert len(filtered[0]["content"]) == 2 # Both items kept
def test_filter_expired_images_old_image(self):
"""Test that old images are filtered out."""
from src.database.db_handler import DatabaseHandler
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
handler = DatabaseHandler("mongodb://localhost")
old_timestamp = (datetime.now() - timedelta(hours=24)).isoformat()
history = [
{"role": "user", "content": [
{"type": "text", "text": "Check this image"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": old_timestamp}
]}
]
filtered = handler._filter_expired_images(history)
assert len(filtered) == 1
assert len(filtered[0]["content"]) == 1 # Only text kept
# ============================================================
# ============================================================
# Cache Module Tests
# ============================================================
class TestLRUCache:
"""Tests for the LRU cache implementation."""
@pytest.mark.asyncio
async def test_cache_set_and_get(self):
"""Test basic cache set and get operations."""
from src.utils.cache import LRUCache
cache = LRUCache(max_size=100, default_ttl=60.0)
await cache.set("key1", "value1")
result = await cache.get("key1")
assert result == "value1"
@pytest.mark.asyncio
async def test_cache_expiration(self):
"""Test that cache entries expire after TTL."""
from src.utils.cache import LRUCache
cache = LRUCache(max_size=100, default_ttl=0.1) # 100ms TTL
await cache.set("key1", "value1")
# Should exist immediately
assert await cache.get("key1") == "value1"
# Wait for expiration
await asyncio.sleep(0.15)
# Should be expired now
assert await cache.get("key1") is None
@pytest.mark.asyncio
async def test_cache_lru_eviction(self):
"""Test that LRU eviction works correctly."""
from src.utils.cache import LRUCache
cache = LRUCache(max_size=3, default_ttl=60.0)
await cache.set("key1", "value1")
await cache.set("key2", "value2")
await cache.set("key3", "value3")
# Access key1 to make it recently used
await cache.get("key1")
# Add new key, should evict key2 (least recently used)
await cache.set("key4", "value4")
assert await cache.get("key1") == "value1" # Should exist
assert await cache.get("key2") is None # Should be evicted
assert await cache.get("key3") == "value3" # Should exist
assert await cache.get("key4") == "value4" # Should exist
@pytest.mark.asyncio
async def test_cache_stats(self):
"""Test cache statistics tracking."""
from src.utils.cache import LRUCache
cache = LRUCache(max_size=100, default_ttl=60.0)
await cache.set("key1", "value1")
await cache.get("key1") # Hit
await cache.get("key2") # Miss
await cache.get("key1") # Hit
stats = cache.stats()
assert stats["hits"] == 2
assert stats["misses"] == 1
assert stats["size"] == 1
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing."""
from src.utils.cache import LRUCache
cache = LRUCache(max_size=100, default_ttl=60.0)
await cache.set("key1", "value1")
await cache.set("key2", "value2")
cleared = await cache.clear()
assert cleared == 2
assert await cache.get("key1") is None
assert await cache.get("key2") is None
# ============================================================
# Monitoring Module Tests
# ============================================================
class TestMonitoring:
"""Tests for the monitoring utilities."""
def test_performance_metrics(self):
"""Test performance metrics tracking."""
from src.utils.monitoring import PerformanceMetrics
import time
metrics = PerformanceMetrics(name="test_operation")
time.sleep(0.01) # Small delay
metrics.finish(success=True)
assert metrics.success
assert metrics.duration_ms > 0
assert metrics.duration_ms < 1000 # Should be fast
def test_measure_sync_context_manager(self):
"""Test synchronous measurement context manager."""
from src.utils.monitoring import measure_sync
import time
with measure_sync("test_op", custom_field="value") as metrics:
time.sleep(0.01)
assert metrics.duration_ms > 0
assert metrics.metadata["custom_field"] == "value"
@pytest.mark.asyncio
async def test_measure_async_context_manager(self):
"""Test async measurement context manager."""
from src.utils.monitoring import measure_async
async with measure_async("async_op") as metrics:
await asyncio.sleep(0.01)
assert metrics.duration_ms > 0
assert metrics.success
@pytest.mark.asyncio
async def test_track_performance_decorator(self):
"""Test performance tracking decorator."""
from src.utils.monitoring import track_performance
call_count = 0
@track_performance("tracked_function")
async def tracked_func():
nonlocal call_count
call_count += 1
return "result"
result = await tracked_func()
assert result == "result"
assert call_count == 1
def test_health_status(self):
"""Test health status structure."""
from src.utils.monitoring import HealthStatus
status = HealthStatus(healthy=True)
status.add_check("database", True, "Connected")
status.add_check("api", False, "Timeout")
assert not status.healthy # Should be unhealthy due to API check
assert status.checks["database"]["healthy"]
assert not status.checks["api"]["healthy"]
# ============================================================
# Integration Tests (require environment setup)
# ============================================================
@pytest.mark.integration
class TestIntegration:
"""Integration tests that require actual services."""
@pytest.mark.asyncio
async def test_database_connection(self):
"""Test actual database connection (skip if no MongoDB)."""
from dotenv import load_dotenv
load_dotenv()
mongodb_uri = os.getenv("MONGODB_URI")
if not mongodb_uri:
pytest.skip("MONGODB_URI not set")
from src.database.db_handler import DatabaseHandler
handler = DatabaseHandler(mongodb_uri)
connected = await handler.ensure_connected()
assert connected
await handler.close()
# ============================================================
# Run tests
# ============================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])