From f17081b185657a7ac650fa5107dc125ad2c6bf28 Mon Sep 17 00:00:00 2001 From: cauvang32 Date: Sun, 30 Nov 2025 17:45:36 +0700 Subject: [PATCH] Add retry utilities, input validation, and comprehensive tests - Implemented async retry logic with exponential backoff in `src/utils/retry.py`. - Created input validation utilities for Discord bot in `src/utils/validators.py`. - Refactored token pricing import in `src/utils/token_counter.py`. - Added comprehensive test suite in `tests/test_comprehensive.py` covering various modules including pricing, validators, retry logic, and Discord utilities. --- .env.example | 23 ++ requirements.txt | 66 ++- src/commands/commands.py | 138 +++++-- src/config/pricing.py | 100 +++++ src/database/db_handler.py | 29 +- src/module/message_handler.py | 50 +-- src/utils/cache.py | 358 +++++++++++++++++ src/utils/code_interpreter.py | 160 +++++++- src/utils/discord_utils.py | 417 +++++++++++++++++++ src/utils/monitoring.py | 446 +++++++++++++++++++++ src/utils/retry.py | 280 +++++++++++++ src/utils/token_counter.py | 4 +- src/utils/validators.py | 287 ++++++++++++++ tests/test_comprehensive.py | 727 ++++++++++++++++++++++++++++++++++ 14 files changed, 2986 insertions(+), 99 deletions(-) create mode 100644 src/config/pricing.py create mode 100644 src/utils/cache.py create mode 100644 src/utils/discord_utils.py create mode 100644 src/utils/monitoring.py create mode 100644 src/utils/retry.py create mode 100644 src/utils/validators.py create mode 100644 tests/test_comprehensive.py diff --git a/.env.example b/.env.example index 7dcaf20..ef6baa2 100644 --- a/.env.example +++ b/.env.example @@ -88,3 +88,26 @@ TIMEZONE=UTC # 168 = 1 week # -1 = Never expire (permanent storage) FILE_EXPIRATION_HOURS=48 + +# ============================================ +# Monitoring & Observability (Optional) +# ============================================ + +# Sentry DSN for error tracking +# Get from: https://sentry.io/ (create a project and copy the DSN) +# Leave empty to disable Sentry error tracking +SENTRY_DSN= + +# Environment name for Sentry (development, staging, production) +ENVIRONMENT=development + +# Sentry sample rate (0.0 to 1.0) - percentage of errors to capture +# 1.0 = 100% of errors, 0.5 = 50% of errors +SENTRY_SAMPLE_RATE=1.0 + +# Sentry traces sample rate for performance monitoring (0.0 to 1.0) +# 0.1 = 10% of transactions, lower values recommended for high-traffic bots +SENTRY_TRACES_RATE=0.1 + +# Log level (DEBUG, INFO, WARNING, ERROR) +LOG_LEVEL=INFO diff --git a/requirements.txt b/requirements.txt index f687597..832b380 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,49 @@ -discord.py -openai -motor -pymongo[srv] -dnspython>=2.0.0 -pypdf -beautifulsoup4 -requests -aiohttp +# Discord Bot Core +discord.py>=2.3.0 +openai>=1.40.0 +python-dotenv>=1.0.0 + +# Database +motor>=3.3.0 +pymongo[srv]>=4.6.0 +dnspython>=2.5.0 + +# Web & HTTP +aiohttp>=3.9.0 +requests>=2.31.0 +beautifulsoup4>=4.12.0 + +# AI & ML runware>=0.4.33 -python-dotenv -matplotlib -pandas -openpyxl -seaborn -tzlocal -numpy -plotly -tiktoken \ No newline at end of file +tiktoken>=0.7.0 + +# Data Processing +pandas>=2.1.0 +numpy>=1.26.0 +openpyxl>=3.1.0 + +# Visualization +matplotlib>=3.8.0 +seaborn>=0.13.0 +plotly>=5.18.0 + +# Document Processing +pypdf>=4.0.0 +Pillow>=10.0.0 + +# Scheduling & Time +APScheduler>=3.10.0 +tzlocal>=5.2 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +pytest-cov>=4.1.0 +pytest-mock>=3.12.0 + +# Code Quality +ruff>=0.3.0 + +# Monitoring & Logging (Optional) +# sentry-sdk>=1.40.0 # Uncomment for error monitoring +# python-json-logger>=2.0.0 # Uncomment for structured logging \ No newline at end of file diff --git a/src/commands/commands.py b/src/commands/commands.py index 77da99d..48eb410 100644 --- a/src/commands/commands.py +++ b/src/commands/commands.py @@ -7,36 +7,67 @@ import asyncio from typing import Optional, Dict, List, Any, Callable from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS, DEFAULT_MODEL +from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost from src.utils.image_utils import ImageGenerator from src.utils.web_utils import google_custom_search, scrape_web_content from src.utils.pdf_utils import process_pdf, send_response from src.utils.openai_utils import prepare_file_from_path from src.utils.token_counter import token_counter from src.utils.code_interpreter import delete_all_user_files - -# Model pricing per 1M tokens (in USD) -MODEL_PRICING = { - "openai/gpt-4o": {"input": 5.00, "output": 20.00}, - "openai/gpt-4o-mini": {"input": 0.60, "output": 2.40}, - "openai/gpt-4.1": {"input": 2.00, "output": 8.00}, - "openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60}, - "openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40}, - "openai/gpt-5": {"input": 1.25, "output": 10.00}, - "openai/gpt-5-mini": {"input": 0.25, "output": 2.00}, - "openai/gpt-5-nano": {"input": 0.05, "output": 0.40}, - "openai/gpt-5-chat": {"input": 1.25, "output": 10.00}, - "openai/o1-preview": {"input": 15.00, "output": 60.00}, - "openai/o1-mini": {"input": 1.10, "output": 4.40}, - "openai/o1": {"input": 15.00, "output": 60.00}, - "openai/o3-mini": {"input": 1.10, "output": 4.40}, - "openai/o3": {"input": 2.00, "output": 8.00}, - "openai/o4-mini": {"input": 2.00, "output": 8.00} -} +from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed # Dictionary to keep track of user requests and their cooldowns -user_requests = {} +user_requests: Dict[int, Dict[str, Any]] = {} # Dictionary to store user tasks -user_tasks = {} +user_tasks: Dict[int, List] = {} + + +# ============================================================ +# Autocomplete Functions +# ============================================================ + +async def model_autocomplete( + interaction: discord.Interaction, + current: str, +) -> List[app_commands.Choice[str]]: + """ + Autocomplete function for model selection. + Provides filtered model suggestions based on user input. + """ + # Filter models based on current input + matches = [ + model for model in MODEL_OPTIONS + if current.lower() in model.lower() + ] + + # If no matches, show all models + if not matches: + matches = MODEL_OPTIONS + + # Return up to 25 choices (Discord limit) + return [ + app_commands.Choice(name=model, value=model) + for model in matches[:25] + ] + + +async def image_model_autocomplete( + interaction: discord.Interaction, + current: str, +) -> List[app_commands.Choice[str]]: + """ + Autocomplete function for image generation model selection. + """ + image_models = ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"] + matches = [m for m in image_models if current.lower() in m.lower()] + + if not matches: + matches = image_models + + return [ + app_commands.Choice(name=model, value=model) + for model in matches[:25] + ] def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator): """ @@ -112,7 +143,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator @tree.command(name="choose_model", description="Select the AI model to use for responses.") @check_blacklist() async def choose_model(interaction: discord.Interaction): - """Lets users choose an AI model and saves it to the database.""" + """Lets users choose an AI model using a dropdown menu.""" options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS] select_menu = discord.ui.Select(placeholder="Choose a model", options=options) @@ -131,6 +162,43 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator view.add_item(select_menu) await interaction.response.send_message("Choose a model:", view=view, ephemeral=True) + @tree.command(name="set_model", description="Set AI model directly with autocomplete suggestions.") + @app_commands.describe(model="The AI model to use (type to search)") + @app_commands.autocomplete(model=model_autocomplete) + @check_blacklist() + async def set_model(interaction: discord.Interaction, model: str): + """Sets the AI model directly using autocomplete.""" + user_id = interaction.user.id + + # Validate the model is in the allowed list + if model not in MODEL_OPTIONS: + # Find close matches for suggestions + close_matches = [m for m in MODEL_OPTIONS if model.lower() in m.lower()] + if close_matches: + suggestions = ", ".join(f"`{m}`" for m in close_matches[:5]) + await interaction.response.send_message( + f"❌ Invalid model `{model}`. Did you mean: {suggestions}?", + ephemeral=True + ) + else: + await interaction.response.send_message( + f"❌ Invalid model `{model}`. Use `/choose_model` to see available options.", + ephemeral=True + ) + return + + # Save the model selection + await db_handler.save_user_model(user_id, model) + + # Get pricing info for the selected model + pricing = MODEL_PRICING.get(model, {"input": 0, "output": 0}) + + await interaction.response.send_message( + f"✅ Model set to `{model}`\n" + f"💰 Pricing: ${pricing['input']:.2f}/1M input, ${pricing['output']:.2f}/1M output", + ephemeral=True + ) + @tree.command(name="search", description="Search on Google and send results to AI model.") @app_commands.describe(query="The search query") @check_blacklist() @@ -494,16 +562,22 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator async def help_command(interaction: discord.Interaction): """Sends a list of available commands to the user.""" help_message = ( - "**Available commands:**\n" - "/choose_model - Select which AI model to use for responses (openai/gpt-4o, openai/gpt-4o-mini, openai/gpt-5, openai/gpt-5-nano, openai/gpt-5-mini, openai/gpt-5-chat, openai/o1-preview, openai/o1-mini).\n" - "/search `` - Search Google and send results to the AI model.\n" - "/web `` - Scrape a webpage and send the data to the AI model.\n" - "/generate `` - Generate an image from a text prompt.\n" - "/toggle_tools - Toggle display of tool execution details (code, input, output).\n" - "/reset - Reset your chat history and token usage statistics.\n" - "/user_stat - Get information about your token usage, costs, and current model.\n" - "/prices - Display pricing information for all available AI models.\n" - "/help - Display this help message.\n" + "**🤖 Available Commands:**\n\n" + "**Model Selection:**\n" + "• `/choose_model` - Select AI model from a dropdown menu\n" + "• `/set_model ` - Set model directly with autocomplete\n\n" + "**Search & Web:**\n" + "• `/search ` - Search Google and analyze results with AI\n" + "• `/web ` - Scrape and analyze a webpage\n\n" + "**Image Generation:**\n" + "• `/generate ` - Generate images from text\n\n" + "**Settings & Stats:**\n" + "• `/toggle_tools` - Toggle tool execution details display\n" + "• `/user_stat` - View your token usage and costs\n" + "• `/prices` - Display model pricing information\n" + "• `/reset` - Clear your chat history and statistics\n\n" + "**Help:**\n" + "• `/help` - Display this help message\n" ) await interaction.response.send_message(help_message, ephemeral=True) diff --git a/src/config/pricing.py b/src/config/pricing.py new file mode 100644 index 0000000..879b3c9 --- /dev/null +++ b/src/config/pricing.py @@ -0,0 +1,100 @@ +""" +Centralized pricing configuration for OpenAI models. + +This module provides a single source of truth for model pricing, +eliminating duplication across the codebase. +""" + +from typing import Dict, Optional +from dataclasses import dataclass + + +@dataclass +class ModelPricing: + """Pricing information for a model (per 1M tokens in USD).""" + input: float + output: float + + def calculate_cost(self, input_tokens: int, output_tokens: int) -> float: + """Calculate total cost for given token counts.""" + input_cost = (input_tokens / 1_000_000) * self.input + output_cost = (output_tokens / 1_000_000) * self.output + return input_cost + output_cost + + +# Model pricing per 1M tokens (in USD) +# Centralized location - update prices here only +MODEL_PRICING: Dict[str, ModelPricing] = { + # GPT-4o Family + "openai/gpt-4o": ModelPricing(input=5.00, output=20.00), + "openai/gpt-4o-mini": ModelPricing(input=0.60, output=2.40), + + # GPT-4.1 Family + "openai/gpt-4.1": ModelPricing(input=2.00, output=8.00), + "openai/gpt-4.1-mini": ModelPricing(input=0.40, output=1.60), + "openai/gpt-4.1-nano": ModelPricing(input=0.10, output=0.40), + + # GPT-5 Family + "openai/gpt-5": ModelPricing(input=1.25, output=10.00), + "openai/gpt-5-mini": ModelPricing(input=0.25, output=2.00), + "openai/gpt-5-nano": ModelPricing(input=0.05, output=0.40), + "openai/gpt-5-chat": ModelPricing(input=1.25, output=10.00), + + # o1 Family (Reasoning models) + "openai/o1-preview": ModelPricing(input=15.00, output=60.00), + "openai/o1-mini": ModelPricing(input=1.10, output=4.40), + "openai/o1": ModelPricing(input=15.00, output=60.00), + + # o3 Family + "openai/o3-mini": ModelPricing(input=1.10, output=4.40), + "openai/o3": ModelPricing(input=2.00, output=8.00), + + # o4 Family + "openai/o4-mini": ModelPricing(input=2.00, output=8.00), +} + + +def get_model_pricing(model: str) -> Optional[ModelPricing]: + """ + Get pricing for a specific model. + + Args: + model: The model name (e.g., "openai/gpt-4o") + + Returns: + ModelPricing object or None if model not found + """ + return MODEL_PRICING.get(model) + + +def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float: + """ + Calculate the cost for a given model and token counts. + + Args: + model: The model name + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Total cost in USD, or 0.0 if model not found + """ + pricing = get_model_pricing(model) + if pricing: + return pricing.calculate_cost(input_tokens, output_tokens) + return 0.0 + + +def get_all_models() -> list: + """Get list of all available models with pricing.""" + return list(MODEL_PRICING.keys()) + + +def format_cost(cost: float) -> str: + """Format cost for display.""" + if cost < 0.01: + return f"${cost:.6f}" + elif cost < 1.00: + return f"${cost:.4f}" + else: + return f"${cost:.2f}" diff --git a/src/database/db_handler.py b/src/database/db_handler.py index c829b43..ee44004 100644 --- a/src/database/db_handler.py +++ b/src/database/db_handler.py @@ -160,7 +160,12 @@ class DatabaseHandler: return [] def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Filter out image links that are older than 23 hours""" + """ + Filter out image links that are older than 23 hours. + + Properly handles timezone-aware and timezone-naive datetime comparisons + to prevent issues with ISO string parsing. + """ current_time = datetime.now() expiration_time = current_time - timedelta(hours=23) @@ -183,11 +188,27 @@ class DatabaseHandler: # Check image items for timestamp elif item.get('type') == 'image_url': # If there's no timestamp or timestamp is newer than expiration time, keep it - timestamp = item.get('timestamp') - if not timestamp or datetime.fromisoformat(timestamp) > expiration_time: + timestamp_str = item.get('timestamp') + if not timestamp_str: + # No timestamp, keep the image filtered_content.append(item) else: - logging.info(f"Filtering out expired image URL (added at {timestamp})") + try: + # Parse the ISO timestamp, handling both timezone-aware and naive + timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) + + # Make comparison timezone-naive for consistency + if timestamp.tzinfo is not None: + timestamp = timestamp.replace(tzinfo=None) + + if timestamp > expiration_time: + filtered_content.append(item) + else: + logging.debug(f"Filtering out expired image URL (added at {timestamp_str})") + except (ValueError, AttributeError) as e: + # If we can't parse the timestamp, keep the image to be safe + logging.warning(f"Could not parse image timestamp '{timestamp_str}': {e}") + filtered_content.append(item) # Update the message with filtered content if filtered_content: diff --git a/src/module/message_handler.py b/src/module/message_handler.py index 2d07fee..5b161d8 100644 --- a/src/module/message_handler.py +++ b/src/module/message_handler.py @@ -5,7 +5,7 @@ import logging import time import functools import concurrent.futures -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional import io import aiohttp import os @@ -19,32 +19,16 @@ from src.utils.pdf_utils import process_pdf, send_response from src.utils.code_utils import extract_code_blocks from src.utils.reminder_utils import ReminderManager from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT, DEFAULT_MODEL +from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost +from src.utils.validators import validate_message_content, validate_prompt, sanitize_for_logging +from src.utils.discord_utils import send_long_message, create_error_embed, create_progress_embed # Global task and rate limiting tracking -user_tasks = {} -user_last_request = {} +user_tasks: Dict[int, Dict] = {} +user_last_request: Dict[int, List[float]] = {} RATE_LIMIT_WINDOW = 5 # seconds MAX_REQUESTS = 3 # max requests per window -# Model pricing per 1M tokens (in USD) -MODEL_PRICING = { - "openai/gpt-4o": {"input": 5.00, "output": 20.00}, - "openai/gpt-4o-mini": {"input": 0.60, "output": 2.40}, - "openai/gpt-4.1": {"input": 2.00, "output": 8.00}, - "openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60}, - "openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40}, - "openai/gpt-5": {"input": 1.25, "output": 10.00}, - "openai/gpt-5-mini": {"input": 0.25, "output": 2.00}, - "openai/gpt-5-nano": {"input": 0.05, "output": 0.40}, - "openai/gpt-5-chat": {"input": 1.25, "output": 10.00}, - "openai/o1-preview": {"input": 15.00, "output": 60.00}, - "openai/o1-mini": {"input": 1.10, "output": 4.40}, - "openai/o1": {"input": 15.00, "output": 60.00}, - "openai/o3-mini": {"input": 1.10, "output": 4.40}, - "openai/o3": {"input": 2.00, "output": 8.00}, - "openai/o4-mini": {"input": 2.00, "output": 8.00} -} - # File extensions that should be treated as text files TEXT_FILE_EXTENSIONS = [ '.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css', @@ -1598,13 +1582,11 @@ print("\\n=== Correlation Analysis ===") output_tokens = getattr(response.usage, 'completion_tokens', 0) # Calculate cost based on model pricing - if model in MODEL_PRICING: - pricing = MODEL_PRICING[model] - input_cost = (input_tokens / 1_000_000) * pricing["input"] - output_cost = (output_tokens / 1_000_000) * pricing["output"] - total_cost = input_cost + output_cost + pricing = MODEL_PRICING.get(model) + if pricing: + total_cost = pricing.calculate_cost(input_tokens, output_tokens) - logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: ${total_cost:.6f}") + logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}") # Save token usage and cost to database await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost) @@ -1704,14 +1686,12 @@ print("\\n=== Correlation Analysis ===") output_tokens += follow_up_output_tokens # Calculate additional cost - if model in MODEL_PRICING: - pricing = MODEL_PRICING[model] - additional_input_cost = (follow_up_input_tokens / 1_000_000) * pricing["input"] - additional_output_cost = (follow_up_output_tokens / 1_000_000) * pricing["output"] - additional_cost = additional_input_cost + additional_output_cost + pricing = MODEL_PRICING.get(model) + if pricing: + additional_cost = pricing.calculate_cost(follow_up_input_tokens, follow_up_output_tokens) total_cost += additional_cost - logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: ${additional_cost:.6f}") + logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: {format_cost(additional_cost)}") # Save additional token usage and cost to database await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost) @@ -1791,7 +1771,7 @@ print("\\n=== Correlation Analysis ===") # Log processing time and cost for performance monitoring processing_time = time.time() - start_time - logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: ${total_cost:.6f})") + logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: {format_cost(total_cost)})") except asyncio.CancelledError: # Handle cancellation cleanly diff --git a/src/utils/cache.py b/src/utils/cache.py new file mode 100644 index 0000000..9d8beae --- /dev/null +++ b/src/utils/cache.py @@ -0,0 +1,358 @@ +""" +Simple caching utilities for API responses and frequently accessed data. + +This module provides an in-memory LRU cache with optional TTL (time-to-live) +support, designed for caching API responses and reducing redundant calls. +""" + +import asyncio +import time +import logging +from typing import Any, Dict, Optional, Callable, TypeVar, Generic +from collections import OrderedDict +from dataclasses import dataclass, field +from functools import wraps + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + + +@dataclass +class CacheEntry(Generic[T]): + """A single cache entry with value and expiration time.""" + value: T + expires_at: float + created_at: float = field(default_factory=time.time) + hits: int = 0 + + +class LRUCache(Generic[T]): + """ + Thread-safe LRU (Least Recently Used) cache with TTL support. + + Features: + - Configurable max size with automatic eviction + - Per-entry TTL (time-to-live) + - Automatic cleanup of expired entries + - Hit/miss statistics tracking + + Usage: + cache = LRUCache(max_size=1000, default_ttl=300) # 5 min TTL + cache.set("key", "value") + value = cache.get("key") # Returns value or None if expired + """ + + def __init__( + self, + max_size: int = 1000, + default_ttl: float = 300.0, # 5 minutes default + cleanup_interval: float = 60.0 + ): + """ + Initialize the LRU cache. + + Args: + max_size: Maximum number of entries + default_ttl: Default TTL in seconds + cleanup_interval: How often to run cleanup (seconds) + """ + self._cache: OrderedDict[str, CacheEntry[T]] = OrderedDict() + self._max_size = max_size + self._default_ttl = default_ttl + self._cleanup_interval = cleanup_interval + self._lock = asyncio.Lock() + + # Statistics + self._hits = 0 + self._misses = 0 + + # Background cleanup task + self._cleanup_task: Optional[asyncio.Task] = None + + async def start(self) -> None: + """Start the background cleanup task.""" + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.debug("Cache cleanup task started") + + async def stop(self) -> None: + """Stop the background cleanup task.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + logger.debug("Cache cleanup task stopped") + + async def _cleanup_loop(self) -> None: + """Background task to periodically clean up expired entries.""" + while True: + await asyncio.sleep(self._cleanup_interval) + await self._cleanup_expired() + + async def _cleanup_expired(self) -> int: + """Remove expired entries. Returns count of removed entries.""" + now = time.time() + removed = 0 + + async with self._lock: + keys_to_remove = [ + key for key, entry in self._cache.items() + if entry.expires_at <= now + ] + + for key in keys_to_remove: + del self._cache[key] + removed += 1 + + if removed > 0: + logger.debug(f"Cache cleanup: removed {removed} expired entries") + + return removed + + async def get(self, key: str) -> Optional[T]: + """ + Get a value from the cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found/expired + """ + async with self._lock: + if key not in self._cache: + self._misses += 1 + return None + + entry = self._cache[key] + + # Check if expired + if entry.expires_at <= time.time(): + del self._cache[key] + self._misses += 1 + return None + + # Move to end (most recently used) + self._cache.move_to_end(key) + entry.hits += 1 + self._hits += 1 + + return entry.value + + async def set( + self, + key: str, + value: T, + ttl: Optional[float] = None + ) -> None: + """ + Set a value in the cache. + + Args: + key: Cache key + value: Value to cache + ttl: Optional TTL override (uses default if not provided) + """ + ttl = ttl if ttl is not None else self._default_ttl + expires_at = time.time() + ttl + + async with self._lock: + # Remove oldest entries if at capacity + while len(self._cache) >= self._max_size: + oldest_key = next(iter(self._cache)) + del self._cache[oldest_key] + logger.debug(f"Cache evicted oldest entry: {oldest_key}") + + self._cache[key] = CacheEntry( + value=value, + expires_at=expires_at + ) + self._cache.move_to_end(key) + + async def delete(self, key: str) -> bool: + """ + Delete a key from the cache. + + Args: + key: Cache key + + Returns: + True if key was found and deleted + """ + async with self._lock: + if key in self._cache: + del self._cache[key] + return True + return False + + async def clear(self) -> int: + """ + Clear all entries from the cache. + + Returns: + Number of entries cleared + """ + async with self._lock: + count = len(self._cache) + self._cache.clear() + return count + + async def has(self, key: str) -> bool: + """Check if a key exists and is not expired.""" + return await self.get(key) is not None + + def stats(self) -> Dict[str, Any]: + """ + Get cache statistics. + + Returns: + Dict with size, hits, misses, hit_rate + """ + total = self._hits + self._misses + hit_rate = (self._hits / total * 100) if total > 0 else 0.0 + + return { + "size": len(self._cache), + "max_size": self._max_size, + "hits": self._hits, + "misses": self._misses, + "hit_rate": f"{hit_rate:.2f}%", + "default_ttl": self._default_ttl + } + + +# Global cache instances for different purposes +_api_response_cache: Optional[LRUCache[Dict[str, Any]]] = None +_user_preference_cache: Optional[LRUCache[Dict[str, Any]]] = None + + +async def get_api_cache() -> LRUCache[Dict[str, Any]]: + """Get or create the API response cache.""" + global _api_response_cache + if _api_response_cache is None: + _api_response_cache = LRUCache( + max_size=500, + default_ttl=300.0 # 5 minutes + ) + await _api_response_cache.start() + return _api_response_cache + + +async def get_user_cache() -> LRUCache[Dict[str, Any]]: + """Get or create the user preference cache.""" + global _user_preference_cache + if _user_preference_cache is None: + _user_preference_cache = LRUCache( + max_size=1000, + default_ttl=600.0 # 10 minutes + ) + await _user_preference_cache.start() + return _user_preference_cache + + +def cached( + cache_key_func: Callable[..., str], + ttl: Optional[float] = None, + cache_getter: Callable = get_api_cache +): + """ + Decorator to cache async function results. + + Args: + cache_key_func: Function to generate cache key from args + ttl: Optional TTL override + cache_getter: Function to get the cache instance + + Usage: + @cached( + cache_key_func=lambda user_id: f"user:{user_id}", + ttl=300 + ) + async def get_user_data(user_id: int) -> dict: + # Expensive operation + return await fetch_from_api(user_id) + """ + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + cache = await cache_getter() + key = cache_key_func(*args, **kwargs) + + # Try to get from cache + cached_value = await cache.get(key) + if cached_value is not None: + logger.debug(f"Cache hit for key: {key}") + return cached_value + + # Execute function and cache result + result = await func(*args, **kwargs) + await cache.set(key, result, ttl=ttl) + logger.debug(f"Cached result for key: {key}") + + return result + + return wrapper + return decorator + + +def invalidate_on_update( + cache_key_func: Callable[..., str], + cache_getter: Callable = get_api_cache +): + """ + Decorator to invalidate cache when a function (update operation) is called. + + Args: + cache_key_func: Function to generate cache key to invalidate + cache_getter: Function to get the cache instance + + Usage: + @invalidate_on_update( + cache_key_func=lambda user_id, **_: f"user:{user_id}" + ) + async def update_user_data(user_id: int, data: dict) -> None: + await save_to_db(user_id, data) + """ + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + result = await func(*args, **kwargs) + + # Invalidate cache after update + cache = await cache_getter() + key = cache_key_func(*args, **kwargs) + await cache.delete(key) + logger.debug(f"Invalidated cache for key: {key}") + + return result + + return wrapper + return decorator + + +# Convenience functions for common caching patterns + +async def cache_user_model(user_id: int, model: str) -> None: + """Cache user's selected model.""" + cache = await get_user_cache() + await cache.set(f"user_model:{user_id}", {"model": model}) + + +async def get_cached_user_model(user_id: int) -> Optional[str]: + """Get user's cached model selection.""" + cache = await get_user_cache() + result = await cache.get(f"user_model:{user_id}") + return result["model"] if result else None + + +async def invalidate_user_cache(user_id: int) -> None: + """Invalidate all cached data for a user.""" + cache = await get_user_cache() + # Clear known user-related keys + await cache.delete(f"user_model:{user_id}") + await cache.delete(f"user_history:{user_id}") + await cache.delete(f"user_stats:{user_id}") diff --git a/src/utils/code_interpreter.py b/src/utils/code_interpreter.py index b316c21..49e11d1 100644 --- a/src/utils/code_interpreter.py +++ b/src/utils/code_interpreter.py @@ -71,19 +71,40 @@ APPROVED_PACKAGES = { 'more-itertools', 'toolz', 'cytoolz', 'funcy' } -# Blocked patterns +# Blocked patterns - Comprehensive security checks # Note: We allow open() for writing to enable saving plots and outputs # The sandboxed environment restricts file access to safe directories BLOCKED_PATTERNS = [ - # Dangerous system modules + # ==================== DANGEROUS SYSTEM MODULES ==================== + # OS module (except path) r'import\s+os\b(?!\s*\.path)', r'from\s+os\s+import\s+(?!path)', + + # File system modules r'import\s+shutil\b', r'from\s+shutil\s+import', + r'import\s+pathlib\b(?!\s*\.)', # Allow pathlib usage but monitor + + # Subprocess and execution modules r'import\s+subprocess\b', r'from\s+subprocess\s+import', - r'import\s+sys\b(?!\s*\.(?:path|version|platform))', - r'from\s+sys\s+import', + r'import\s+multiprocessing\b', + r'from\s+multiprocessing\s+import', + r'import\s+threading\b', + r'from\s+threading\s+import', + r'import\s+concurrent\b', + r'from\s+concurrent\s+import', + + # System access modules + r'import\s+sys\b(?!\s*\.(?:path|version|platform|stdout|stderr))', + r'from\s+sys\s+import\s+(?!path|version|platform|stdout|stderr)', + r'import\s+platform\b', + r'from\s+platform\s+import', + r'import\s+ctypes\b', + r'from\s+ctypes\s+import', + r'import\s+_[a-z]+', # Block private C modules + + # ==================== NETWORK MODULES ==================== r'import\s+socket\b', r'from\s+socket\s+import', r'import\s+urllib\b', @@ -92,19 +113,98 @@ BLOCKED_PATTERNS = [ r'from\s+requests\s+import', r'import\s+aiohttp\b', r'from\s+aiohttp\s+import', - # Dangerous code execution + r'import\s+httpx\b', + r'from\s+httpx\s+import', + r'import\s+http\.client\b', + r'from\s+http\.client\s+import', + r'import\s+ftplib\b', + r'from\s+ftplib\s+import', + r'import\s+smtplib\b', + r'from\s+smtplib\s+import', + r'import\s+telnetlib\b', + r'from\s+telnetlib\s+import', + r'import\s+ssl\b', + r'from\s+ssl\s+import', + r'import\s+paramiko\b', + r'from\s+paramiko\s+import', + + # ==================== DANGEROUS CODE EXECUTION ==================== r'__import__\s*\(', r'\beval\s*\(', r'\bexec\s*\(', r'\bcompile\s*\(', r'\bglobals\s*\(', r'\blocals\s*\(', - # File system operations (dangerous) + r'\bgetattr\s*\([^,]+,\s*[\'"]__', # Block getattr for dunder methods + r'\bsetattr\s*\([^,]+,\s*[\'"]__', # Block setattr for dunder methods + r'\bdelattr\s*\([^,]+,\s*[\'"]__', # Block delattr for dunder methods + r'\.\_\_\w+\_\_', # Block dunder method access + + # ==================== FILE SYSTEM OPERATIONS ==================== r'\.unlink\s*\(', r'\.rmdir\s*\(', r'\.remove\s*\(', r'\.chmod\s*\(', r'\.chown\s*\(', + r'\.rmtree\s*\(', + r'\.rename\s*\(', + r'\.replace\s*\(', + r'\.makedirs\s*\(', # Allow mkdir but block makedirs outside sandbox + r'Path\s*\(\s*[\'"]\/(?!tmp)', # Block absolute paths outside /tmp + r'open\s*\(\s*[\'"]\/(?!tmp)', # Block file access outside /tmp + + # ==================== PICKLE AND SERIALIZATION ==================== + r'pickle\.loads?\s*\(', + r'cPickle\.loads?\s*\(', + r'marshal\.loads?\s*\(', + r'shelve\.open\s*\(', + + # ==================== PROCESS MANIPULATION ==================== + r'os\.system\s*\(', + r'os\.popen\s*\(', + r'os\.spawn', + r'os\.exec', + r'os\.fork\s*\(', + r'os\.kill\s*\(', + r'os\.killpg\s*\(', + + # ==================== ENVIRONMENT ACCESS ==================== + r'os\.environ', + r'os\.getenv\s*\(', + r'os\.putenv\s*\(', + + # ==================== DANGEROUS BUILTINS ==================== + r'__builtins__', + r'__loader__', + r'__spec__', + + # ==================== CODE OBJECT MANIPULATION ==================== + r'\.f_code', + r'\.f_globals', + r'\.f_locals', + r'\.gi_frame', + r'\.co_code', + r'types\.CodeType', + r'types\.FunctionType', + + # ==================== IMPORT SYSTEM MANIPULATION ==================== + r'import\s+importlib\b', + r'from\s+importlib\s+import', + r'sys\.modules', + r'sys\.path\.(?:append|insert|extend)', + + # ==================== MEMORY OPERATIONS ==================== + r'gc\.', + r'sys\.getsizeof', + r'sys\.getrefcount', + r'id\s*\(', # Block id() which can leak memory addresses +] + +# Additional patterns that log warnings but don't block +WARNING_PATTERNS = [ + (r'while\s+True', "Infinite loop detected - ensure break condition exists"), + (r'for\s+\w+\s+in\s+range\s*\(\s*\d{6,}', "Very large loop detected"), + (r'recursion', "Recursion detected - ensure base case exists"), ] @@ -772,10 +872,54 @@ class CodeExecutor: logger.warning(f"Cleanup failed: {e}") def validate_code_security(self, code: str) -> Tuple[bool, str]: - """Validate code for security threats.""" + """ + Validate code for security threats. + + Performs comprehensive security checks including: + - Blocked patterns (dangerous imports, code execution, file ops) + - Warning patterns (potential issues that are logged) + - Code structure validation + + Args: + code: The Python code to validate + + Returns: + Tuple of (is_safe, message) + """ + # Check for blocked patterns for pattern in BLOCKED_PATTERNS: if re.search(pattern, code, re.IGNORECASE): - return False, f"Blocked unsafe operation: {pattern}" + logger.warning(f"Blocked code pattern detected: {pattern[:50]}...") + return False, f"Security violation: Unsafe operation detected" + + # Check for warning patterns (log but don't block) + for pattern, warning_msg in WARNING_PATTERNS: + if re.search(pattern, code, re.IGNORECASE): + logger.warning(f"Code warning: {warning_msg}") + + # Additional structural checks + try: + # Parse the AST to check for suspicious constructs + tree = ast.parse(code) + for node in ast.walk(tree): + # Check for suspicious attribute access + if isinstance(node, ast.Attribute): + if node.attr.startswith('_') and node.attr.startswith('__'): + logger.warning(f"Dunder attribute access detected: {node.attr}") + return False, "Security violation: Private attribute access not allowed" + + # Check for suspicious function calls + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id in ['eval', 'exec', 'compile', '__import__']: + return False, f"Security violation: {node.func.id}() is not allowed" + + except SyntaxError: + # Syntax errors will be caught during execution + pass + except Exception as e: + logger.warning(f"Error during AST validation: {e}") + return True, "Code passed security validation" def _extract_imports_from_code(self, code: str) -> List[str]: diff --git a/src/utils/discord_utils.py b/src/utils/discord_utils.py new file mode 100644 index 0000000..52db680 --- /dev/null +++ b/src/utils/discord_utils.py @@ -0,0 +1,417 @@ +""" +Discord response utilities for sending messages with proper handling. + +This module provides utilities for sending messages to Discord with +proper length handling, error recovery, and formatting. +""" + +import discord +import asyncio +import logging +import io +from typing import Optional, List, Union +from dataclasses import dataclass + + +# Discord message limits +MAX_MESSAGE_LENGTH = 2000 +MAX_EMBED_DESCRIPTION = 4096 +MAX_EMBED_FIELD_VALUE = 1024 +MAX_EMBED_FIELDS = 25 +MAX_FILE_SIZE = 8 * 1024 * 1024 # 8MB for non-nitro + + +@dataclass +class MessageChunk: + """A chunk of a message that fits within Discord limits.""" + content: str + is_code_block: bool = False + language: Optional[str] = None + + +def split_message( + content: str, + max_length: int = MAX_MESSAGE_LENGTH, + split_on: List[str] = None +) -> List[str]: + """ + Split a long message into chunks that fit within Discord limits. + + Args: + content: The message content to split + max_length: Maximum length per chunk + split_on: Preferred split points (default: newlines, spaces) + + Returns: + List of message chunks + """ + if len(content) <= max_length: + return [content] + + if split_on is None: + split_on = ['\n\n', '\n', '. ', ' '] + + chunks = [] + remaining = content + + while remaining: + if len(remaining) <= max_length: + chunks.append(remaining) + break + + # Find the best split point + split_index = max_length + + for delimiter in split_on: + # Look for delimiter before max_length + last_index = remaining.rfind(delimiter, 0, max_length) + if last_index > max_length // 2: # Don't split too early + split_index = last_index + len(delimiter) + break + + # If no good split point, hard cut at max_length + if split_index >= max_length: + split_index = max_length + + chunks.append(remaining[:split_index]) + remaining = remaining[split_index:] + + return chunks + + +def split_code_block( + code: str, + language: str = "", + max_length: int = MAX_MESSAGE_LENGTH +) -> List[str]: + """ + Split code into properly formatted code block chunks. + + Args: + code: The code content + language: The language for syntax highlighting + max_length: Maximum length per chunk + + Returns: + List of formatted code block strings + """ + # Account for code block markers + marker_length = len(f"```{language}\n") + len("```") + effective_max = max_length - marker_length - 20 # Extra buffer + + lines = code.split('\n') + chunks = [] + current_chunk = [] + current_length = 0 + + for line in lines: + line_length = len(line) + 1 # +1 for newline + + if current_length + line_length > effective_max and current_chunk: + # Finish current chunk + chunk_code = '\n'.join(current_chunk) + chunks.append(f"```{language}\n{chunk_code}\n```") + current_chunk = [line] + current_length = line_length + else: + current_chunk.append(line) + current_length += line_length + + # Add remaining chunk + if current_chunk: + chunk_code = '\n'.join(current_chunk) + chunks.append(f"```{language}\n{chunk_code}\n```") + + return chunks + + +async def send_long_message( + channel: discord.abc.Messageable, + content: str, + max_length: int = MAX_MESSAGE_LENGTH, + delay: float = 0.5 +) -> List[discord.Message]: + """ + Send a long message split across multiple Discord messages. + + Args: + channel: The channel to send to + content: The message content + max_length: Maximum length per message + delay: Delay between messages to avoid rate limiting + + Returns: + List of sent messages + """ + chunks = split_message(content, max_length) + messages = [] + + for i, chunk in enumerate(chunks): + try: + msg = await channel.send(chunk) + messages.append(msg) + + # Add delay between messages (except for the last one) + if i < len(chunks) - 1: + await asyncio.sleep(delay) + + except discord.HTTPException as e: + logging.error(f"Failed to send message chunk {i+1}: {e}") + # Try sending as file if message still too long + if "too long" in str(e).lower(): + file = discord.File( + io.StringIO(chunk), + filename=f"message_part_{i+1}.txt" + ) + msg = await channel.send(file=file) + messages.append(msg) + + return messages + + +async def send_code_response( + channel: discord.abc.Messageable, + code: str, + language: str = "python", + title: Optional[str] = None +) -> List[discord.Message]: + """ + Send code with proper formatting, handling long code. + + Args: + channel: The channel to send to + code: The code content + language: Programming language for highlighting + title: Optional title to display before code + + Returns: + List of sent messages + """ + messages = [] + + if title: + msg = await channel.send(title) + messages.append(msg) + + # If code is too long for code blocks, send as file + if len(code) > MAX_MESSAGE_LENGTH - 100: + file = discord.File( + io.StringIO(code), + filename=f"code.{language}" if language else "code.txt" + ) + msg = await channel.send("📎 Code attached as file:", file=file) + messages.append(msg) + else: + chunks = split_code_block(code, language) + for chunk in chunks: + msg = await channel.send(chunk) + messages.append(msg) + await asyncio.sleep(0.3) + + return messages + + +def create_error_embed( + title: str, + description: str, + error_type: str = "Error" +) -> discord.Embed: + """ + Create a standardized error embed. + + Args: + title: Error title + description: Error description + error_type: Type of error for categorization + + Returns: + Discord Embed object + """ + embed = discord.Embed( + title=f"❌ {title}", + description=description[:MAX_EMBED_DESCRIPTION], + color=discord.Color.red() + ) + embed.set_footer(text=f"Error Type: {error_type}") + return embed + + +def create_success_embed( + title: str, + description: str = "" +) -> discord.Embed: + """ + Create a standardized success embed. + + Args: + title: Success title + description: Success description + + Returns: + Discord Embed object + """ + embed = discord.Embed( + title=f"✅ {title}", + description=description[:MAX_EMBED_DESCRIPTION] if description else None, + color=discord.Color.green() + ) + return embed + + +def create_info_embed( + title: str, + description: str = "", + fields: List[tuple] = None +) -> discord.Embed: + """ + Create a standardized info embed with optional fields. + + Args: + title: Info title + description: Info description + fields: List of (name, value, inline) tuples + + Returns: + Discord Embed object + """ + embed = discord.Embed( + title=f"ℹ️ {title}", + description=description[:MAX_EMBED_DESCRIPTION] if description else None, + color=discord.Color.blue() + ) + + if fields: + for name, value, inline in fields[:MAX_EMBED_FIELDS]: + embed.add_field( + name=name[:256], + value=str(value)[:MAX_EMBED_FIELD_VALUE], + inline=inline + ) + + return embed + + +def create_progress_embed( + title: str, + description: str, + progress: float = 0.0 +) -> discord.Embed: + """ + Create a progress indicator embed. + + Args: + title: Progress title + description: Progress description + progress: Progress value 0.0 to 1.0 + + Returns: + Discord Embed object + """ + # Create progress bar + bar_length = 20 + filled = int(bar_length * progress) + bar = "█" * filled + "░" * (bar_length - filled) + percentage = int(progress * 100) + + embed = discord.Embed( + title=f"⏳ {title}", + description=f"{description}\n\n`{bar}` {percentage}%", + color=discord.Color.orange() + ) + return embed + + +async def edit_or_send( + message: Optional[discord.Message], + channel: discord.abc.Messageable, + content: str = None, + embed: discord.Embed = None +) -> discord.Message: + """ + Edit an existing message or send a new one if editing fails. + + Args: + message: Message to edit (or None to send new) + channel: Channel to send to if message is None + content: Message content + embed: Message embed + + Returns: + The edited or new message + """ + try: + if message: + await message.edit(content=content, embed=embed) + return message + else: + return await channel.send(content=content, embed=embed) + except discord.HTTPException: + return await channel.send(content=content, embed=embed) + + +class ProgressMessage: + """ + A message that can be updated to show progress. + + Usage: + async with ProgressMessage(channel, "Processing") as progress: + for i in range(100): + await progress.update(i / 100, f"Step {i}") + """ + + def __init__( + self, + channel: discord.abc.Messageable, + title: str, + description: str = "Starting..." + ): + self.channel = channel + self.title = title + self.description = description + self.message: Optional[discord.Message] = None + self._last_update = 0.0 + self._update_interval = 2.0 # Minimum seconds between updates + + async def __aenter__(self): + embed = create_progress_embed(self.title, self.description, 0.0) + self.message = await self.channel.send(embed=embed) + return self + + async def __aexit__(self, *args): + # Clean up or finalize + pass + + async def update(self, progress: float, description: str = None): + """Update the progress message.""" + import time + + now = time.monotonic() + if now - self._last_update < self._update_interval: + return + + self._last_update = now + + if description: + self.description = description + + try: + embed = create_progress_embed(self.title, self.description, progress) + await self.message.edit(embed=embed) + except discord.HTTPException: + pass # Ignore edit failures + + async def complete(self, message: str = "Complete!"): + """Mark the progress as complete.""" + try: + embed = create_success_embed(self.title, message) + await self.message.edit(embed=embed) + except discord.HTTPException: + pass + + async def error(self, message: str): + """Mark the progress as failed.""" + try: + embed = create_error_embed(self.title, message) + await self.message.edit(embed=embed) + except discord.HTTPException: + pass diff --git a/src/utils/monitoring.py b/src/utils/monitoring.py new file mode 100644 index 0000000..a2ffd21 --- /dev/null +++ b/src/utils/monitoring.py @@ -0,0 +1,446 @@ +""" +Monitoring and observability utilities. + +This module provides structured logging, error tracking with Sentry, +and performance monitoring for the Discord bot. +""" + +import os +import logging +import time +import asyncio +from typing import Any, Dict, Optional, Callable +from functools import wraps +from contextlib import contextmanager, asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone + +# Try to import Sentry +try: + import sentry_sdk + from sentry_sdk.integrations.asyncio import AsyncioIntegration + SENTRY_AVAILABLE = True +except ImportError: + SENTRY_AVAILABLE = False + sentry_sdk = None + +logger = logging.getLogger(__name__) + + +# ============================================================ +# Configuration +# ============================================================ + +@dataclass +class MonitoringConfig: + """Configuration for monitoring features.""" + sentry_dsn: Optional[str] = None + environment: str = "development" + sample_rate: float = 1.0 # 100% of events + traces_sample_rate: float = 0.1 # 10% of transactions + log_level: str = "INFO" + structured_logging: bool = True + + +def setup_monitoring(config: Optional[MonitoringConfig] = None) -> None: + """ + Initialize monitoring with optional Sentry integration. + + Args: + config: Monitoring configuration, uses env vars if not provided + """ + if config is None: + config = MonitoringConfig( + sentry_dsn=os.environ.get("SENTRY_DSN"), + environment=os.environ.get("ENVIRONMENT", "development"), + sample_rate=float(os.environ.get("SENTRY_SAMPLE_RATE", "1.0")), + traces_sample_rate=float(os.environ.get("SENTRY_TRACES_RATE", "0.1")), + log_level=os.environ.get("LOG_LEVEL", "INFO"), + ) + + # Setup logging + setup_structured_logging( + level=config.log_level, + structured=config.structured_logging + ) + + # Setup Sentry if available and configured + if SENTRY_AVAILABLE and config.sentry_dsn: + sentry_sdk.init( + dsn=config.sentry_dsn, + environment=config.environment, + sample_rate=config.sample_rate, + traces_sample_rate=config.traces_sample_rate, + integrations=[AsyncioIntegration()], + before_send=before_send_filter, + ) + logger.info(f"Sentry initialized for environment: {config.environment}") + else: + if config.sentry_dsn and not SENTRY_AVAILABLE: + logger.warning("Sentry DSN provided but sentry_sdk not installed") + logger.info("Running without Sentry error tracking") + + +def before_send_filter(event: Dict, hint: Dict) -> Optional[Dict]: + """Filter events before sending to Sentry.""" + # Don't send events for expected/handled errors + if "exc_info" in hint: + exc_type, exc_value, _ = hint["exc_info"] + + # Skip common non-critical errors + if exc_type.__name__ in [ + "NotFound", # Discord 404 + "Forbidden", # Discord 403 + "RateLimited", # Discord rate limit + ]: + return None + + return event + + +# ============================================================ +# Structured Logging +# ============================================================ + +class StructuredFormatter(logging.Formatter): + """JSON-like structured log formatter.""" + + def format(self, record: logging.LogRecord) -> str: + """Format log record as structured message.""" + log_entry = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + # Add extra fields + if hasattr(record, "user_id"): + log_entry["user_id"] = record.user_id + if hasattr(record, "guild_id"): + log_entry["guild_id"] = record.guild_id + if hasattr(record, "command"): + log_entry["command"] = record.command + if hasattr(record, "duration_ms"): + log_entry["duration_ms"] = record.duration_ms + if hasattr(record, "model"): + log_entry["model"] = record.model + + # Add exception info if present + if record.exc_info: + log_entry["exception"] = self.formatException(record.exc_info) + + # Format as key=value pairs for easy parsing + parts = [f"{k}={v!r}" for k, v in log_entry.items()] + return " ".join(parts) + + +def setup_structured_logging( + level: str = "INFO", + structured: bool = True +) -> None: + """ + Setup logging configuration. + + Args: + level: Log level (DEBUG, INFO, WARNING, ERROR) + structured: Use structured formatting + """ + log_level = getattr(logging, level.upper(), logging.INFO) + + # Create handler + handler = logging.StreamHandler() + handler.setLevel(log_level) + + if structured: + handler.setFormatter(StructuredFormatter()) + else: + handler.setFormatter(logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + )) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + root_logger.handlers = [handler] + + +def get_logger(name: str) -> logging.Logger: + """Get a logger with the given name.""" + return logging.getLogger(name) + + +# ============================================================ +# Error Tracking +# ============================================================ + +def capture_exception( + exception: Exception, + context: Optional[Dict[str, Any]] = None +) -> Optional[str]: + """ + Capture and report an exception. + + Args: + exception: The exception to capture + context: Additional context to attach + + Returns: + Event ID if sent to Sentry, None otherwise + """ + logger.exception(f"Captured exception: {exception}") + + if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client: + with sentry_sdk.push_scope() as scope: + if context: + for key, value in context.items(): + scope.set_extra(key, value) + return sentry_sdk.capture_exception(exception) + + return None + + +def capture_message( + message: str, + level: str = "info", + context: Optional[Dict[str, Any]] = None +) -> Optional[str]: + """ + Capture and report a message. + + Args: + message: The message to capture + level: Severity level (debug, info, warning, error, fatal) + context: Additional context to attach + + Returns: + Event ID if sent to Sentry, None otherwise + """ + log_method = getattr(logger, level, logger.info) + log_method(message) + + if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client: + with sentry_sdk.push_scope() as scope: + if context: + for key, value in context.items(): + scope.set_extra(key, value) + return sentry_sdk.capture_message(message, level=level) + + return None + + +def set_user_context( + user_id: int, + username: Optional[str] = None, + guild_id: Optional[int] = None +) -> None: + """ + Set user context for error tracking. + + Args: + user_id: Discord user ID + username: Discord username + guild_id: Discord guild ID + """ + if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client: + sentry_sdk.set_user({ + "id": str(user_id), + "username": username, + }) + if guild_id: + sentry_sdk.set_tag("guild_id", str(guild_id)) + + +# ============================================================ +# Performance Monitoring +# ============================================================ + +@dataclass +class PerformanceMetrics: + """Container for performance metrics.""" + name: str + start_time: float = field(default_factory=time.perf_counter) + end_time: Optional[float] = None + success: bool = True + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def duration_ms(self) -> float: + """Get duration in milliseconds.""" + end = self.end_time or time.perf_counter() + return (end - self.start_time) * 1000 + + def finish(self, success: bool = True, error: Optional[str] = None) -> None: + """Mark the operation as finished.""" + self.end_time = time.perf_counter() + self.success = success + self.error = error + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for logging.""" + return { + "name": self.name, + "duration_ms": round(self.duration_ms, 2), + "success": self.success, + "error": self.error, + **self.metadata + } + + +@contextmanager +def measure_sync(name: str, **metadata): + """ + Context manager to measure synchronous operation performance. + + Usage: + with measure_sync("database_query", table="users"): + result = db.query(...) + """ + metrics = PerformanceMetrics(name=name, metadata=metadata) + + try: + yield metrics + metrics.finish(success=True) + except Exception as e: + metrics.finish(success=False, error=str(e)) + raise + finally: + logger.info( + f"Performance: {metrics.name}", + extra={"duration_ms": metrics.duration_ms, **metrics.metadata} + ) + + +@asynccontextmanager +async def measure_async(name: str, **metadata): + """ + Async context manager to measure async operation performance. + + Usage: + async with measure_async("api_call", endpoint="chat"): + result = await api.call(...) + """ + metrics = PerformanceMetrics(name=name, metadata=metadata) + + # Start Sentry transaction if available + transaction = None + if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client: + transaction = sentry_sdk.start_transaction( + op="task", + name=name + ) + + try: + yield metrics + metrics.finish(success=True) + except Exception as e: + metrics.finish(success=False, error=str(e)) + raise + finally: + if transaction: + transaction.set_status("ok" if metrics.success else "internal_error") + transaction.finish() + + logger.info( + f"Performance: {metrics.name}", + extra={"duration_ms": metrics.duration_ms, **metrics.metadata} + ) + + +def track_performance(name: Optional[str] = None): + """ + Decorator to track async function performance. + + Args: + name: Operation name (defaults to function name) + + Usage: + @track_performance("process_message") + async def handle_message(message): + ... + """ + def decorator(func: Callable): + op_name = name or func.__name__ + + @wraps(func) + async def wrapper(*args, **kwargs): + async with measure_async(op_name): + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +# ============================================================ +# Health Check +# ============================================================ + +@dataclass +class HealthStatus: + """Health check status.""" + healthy: bool + checks: Dict[str, Dict[str, Any]] = field(default_factory=dict) + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + + def add_check( + self, + name: str, + healthy: bool, + message: str = "", + details: Optional[Dict] = None + ) -> None: + """Add a health check result.""" + self.checks[name] = { + "healthy": healthy, + "message": message, + **(details or {}) + } + if not healthy: + self.healthy = False + + +async def check_health( + db_handler=None, + openai_client=None +) -> HealthStatus: + """ + Perform health checks on bot dependencies. + + Args: + db_handler: Database handler to check + openai_client: OpenAI client to check + + Returns: + HealthStatus with check results + """ + status = HealthStatus(healthy=True) + + # Check database + if db_handler: + try: + # Simple ping or list operation + await asyncio.wait_for( + db_handler.client.admin.command('ping'), + timeout=5.0 + ) + status.add_check("database", True, "MongoDB connected") + except Exception as e: + status.add_check("database", False, f"MongoDB error: {e}") + + # Check OpenAI + if openai_client: + try: + # List models as a simple check + await asyncio.wait_for( + openai_client.models.list(), + timeout=10.0 + ) + status.add_check("openai", True, "OpenAI API accessible") + except Exception as e: + status.add_check("openai", False, f"OpenAI error: {e}") + + return status diff --git a/src/utils/retry.py b/src/utils/retry.py new file mode 100644 index 0000000..389e179 --- /dev/null +++ b/src/utils/retry.py @@ -0,0 +1,280 @@ +""" +Retry utilities with exponential backoff for API calls. + +This module provides robust retry logic for external API calls +to handle transient failures gracefully. +""" + +import asyncio +import logging +import random +from typing import TypeVar, Callable, Optional, Any, Type, Tuple +from functools import wraps + +T = TypeVar('T') + +# Default configuration +DEFAULT_MAX_RETRIES = 3 +DEFAULT_BASE_DELAY = 1.0 # seconds +DEFAULT_MAX_DELAY = 60.0 # seconds +DEFAULT_EXPONENTIAL_BASE = 2 + + +class RetryError(Exception): + """Raised when all retry attempts have been exhausted.""" + + def __init__(self, message: str, last_exception: Optional[Exception] = None): + super().__init__(message) + self.last_exception = last_exception + + +async def async_retry_with_backoff( + func: Callable, + *args, + max_retries: int = DEFAULT_MAX_RETRIES, + base_delay: float = DEFAULT_BASE_DELAY, + max_delay: float = DEFAULT_MAX_DELAY, + exponential_base: float = DEFAULT_EXPONENTIAL_BASE, + retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,), + jitter: bool = True, + on_retry: Optional[Callable[[int, Exception], None]] = None, + **kwargs +) -> Any: + """ + Execute an async function with exponential backoff retry. + + Args: + func: The async function to execute + *args: Positional arguments for the function + max_retries: Maximum number of retry attempts + base_delay: Initial delay between retries in seconds + max_delay: Maximum delay between retries + exponential_base: Base for exponential backoff calculation + retryable_exceptions: Tuple of exception types that should trigger retry + jitter: Whether to add randomness to delay + on_retry: Optional callback called on each retry with (attempt, exception) + **kwargs: Keyword arguments for the function + + Returns: + The return value of the function + + Raises: + RetryError: When all retries are exhausted + """ + last_exception = None + + for attempt in range(max_retries + 1): + try: + return await func(*args, **kwargs) + except retryable_exceptions as e: + last_exception = e + + if attempt == max_retries: + logging.error(f"All {max_retries} retries exhausted for {func.__name__}: {e}") + raise RetryError( + f"Failed after {max_retries} retries: {str(e)}", + last_exception=e + ) + + # Calculate delay with exponential backoff + delay = min(base_delay * (exponential_base ** attempt), max_delay) + + # Add jitter to prevent thundering herd + if jitter: + delay = delay * (0.5 + random.random()) + + logging.warning( + f"Retry {attempt + 1}/{max_retries} for {func.__name__} " + f"after {delay:.2f}s delay. Error: {e}" + ) + + if on_retry: + try: + on_retry(attempt + 1, e) + except Exception as callback_error: + logging.warning(f"on_retry callback failed: {callback_error}") + + await asyncio.sleep(delay) + + # Should not reach here, but just in case + raise RetryError("Unexpected retry loop exit", last_exception=last_exception) + + +def retry_decorator( + max_retries: int = DEFAULT_MAX_RETRIES, + base_delay: float = DEFAULT_BASE_DELAY, + max_delay: float = DEFAULT_MAX_DELAY, + retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,), + jitter: bool = True +): + """ + Decorator for adding retry logic to async functions. + + Usage: + @retry_decorator(max_retries=3, base_delay=1.0) + async def my_api_call(): + ... + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + return await async_retry_with_backoff( + func, + *args, + max_retries=max_retries, + base_delay=base_delay, + max_delay=max_delay, + retryable_exceptions=retryable_exceptions, + jitter=jitter, + **kwargs + ) + return wrapper + return decorator + + +# Common exception sets for different APIs +OPENAI_RETRYABLE_EXCEPTIONS = ( + # Add specific OpenAI exceptions as needed + TimeoutError, + ConnectionError, +) + +DISCORD_RETRYABLE_EXCEPTIONS = ( + # Add specific Discord exceptions as needed + TimeoutError, + ConnectionError, +) + +HTTP_RETRYABLE_EXCEPTIONS = ( + TimeoutError, + ConnectionError, + ConnectionResetError, +) + + +class RateLimiter: + """ + Simple rate limiter for API calls. + + Usage: + limiter = RateLimiter(calls_per_second=1) + async with limiter: + await make_api_call() + """ + + def __init__(self, calls_per_second: float = 1.0): + self.min_interval = 1.0 / calls_per_second + self.last_call = 0.0 + self._lock = asyncio.Lock() + + async def __aenter__(self): + async with self._lock: + import time + now = time.monotonic() + time_since_last = now - self.last_call + + if time_since_last < self.min_interval: + await asyncio.sleep(self.min_interval - time_since_last) + + self.last_call = time.monotonic() + return self + + async def __aexit__(self, *args): + pass + + +class CircuitBreaker: + """ + Circuit breaker pattern for preventing cascade failures. + + States: + - CLOSED: Normal operation, requests pass through + - OPEN: Too many failures, requests are rejected immediately + - HALF_OPEN: Testing if service recovered + """ + + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout: float = 60.0, + half_open_requests: int = 3 + ): + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.half_open_requests = half_open_requests + + self.state = self.CLOSED + self.failure_count = 0 + self.last_failure_time = 0.0 + self.half_open_successes = 0 + self._lock = asyncio.Lock() + + async def call(self, func: Callable, *args, **kwargs) -> Any: + """ + Execute a function through the circuit breaker. + + Args: + func: The async function to execute + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + The function result + + Raises: + Exception: If circuit is open or function fails + """ + async with self._lock: + await self._check_state() + + if self.state == self.OPEN: + raise Exception("Circuit breaker is OPEN - service unavailable") + + try: + result = await func(*args, **kwargs) + await self._on_success() + return result + except Exception as e: + await self._on_failure() + raise + + async def _check_state(self): + """Check and potentially update circuit state.""" + import time + + if self.state == self.OPEN: + if time.monotonic() - self.last_failure_time >= self.recovery_timeout: + logging.info("Circuit breaker transitioning to HALF_OPEN") + self.state = self.HALF_OPEN + self.half_open_successes = 0 + + async def _on_success(self): + """Handle successful call.""" + async with self._lock: + if self.state == self.HALF_OPEN: + self.half_open_successes += 1 + if self.half_open_successes >= self.half_open_requests: + logging.info("Circuit breaker transitioning to CLOSED") + self.state = self.CLOSED + self.failure_count = 0 + elif self.state == self.CLOSED: + self.failure_count = 0 + + async def _on_failure(self): + """Handle failed call.""" + import time + + async with self._lock: + self.failure_count += 1 + self.last_failure_time = time.monotonic() + + if self.state == self.HALF_OPEN: + logging.warning("Circuit breaker transitioning to OPEN (half-open failure)") + self.state = self.OPEN + elif self.failure_count >= self.failure_threshold: + logging.warning(f"Circuit breaker transitioning to OPEN ({self.failure_count} failures)") + self.state = self.OPEN diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py index 58dbc70..00017e5 100644 --- a/src/utils/token_counter.py +++ b/src/utils/token_counter.py @@ -304,8 +304,8 @@ class TokenCounter: Returns: Estimated cost in USD """ - # Import here to avoid circular dependency - from src.commands.commands import MODEL_PRICING + # Import from centralized pricing module + from src.config.pricing import MODEL_PRICING if model not in MODEL_PRICING: model = "openai/gpt-4o" # Default fallback diff --git a/src/utils/validators.py b/src/utils/validators.py new file mode 100644 index 0000000..4c38f0d --- /dev/null +++ b/src/utils/validators.py @@ -0,0 +1,287 @@ +""" +Input validation utilities for the Discord bot. + +This module provides centralized validation for user inputs, +enhancing security and reducing code duplication. +""" + +import re +import logging +from typing import Optional, Tuple, List +from dataclasses import dataclass + + +# Maximum allowed lengths for various inputs +MAX_MESSAGE_LENGTH = 4000 # Discord's limit is 2000, but we process longer +MAX_PROMPT_LENGTH = 32000 # Reasonable limit for AI prompts +MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB +MAX_FILENAME_LENGTH = 255 +MAX_URL_LENGTH = 2048 +MAX_CODE_LENGTH = 100000 # 100KB of code + + +@dataclass +class ValidationResult: + """Result of a validation check.""" + is_valid: bool + error_message: Optional[str] = None + sanitized_value: Optional[str] = None + + +def validate_message_content(content: str) -> ValidationResult: + """ + Validate and sanitize message content. + + Args: + content: The message content to validate + + Returns: + ValidationResult with validation status and sanitized content + """ + if not content: + return ValidationResult(is_valid=True, sanitized_value="") + + if len(content) > MAX_MESSAGE_LENGTH: + return ValidationResult( + is_valid=False, + error_message=f"Message too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed." + ) + + # Remove null bytes and other control characters (except newlines/tabs) + sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', content) + + return ValidationResult(is_valid=True, sanitized_value=sanitized) + + +def validate_prompt(prompt: str) -> ValidationResult: + """ + Validate AI prompt content. + + Args: + prompt: The prompt to validate + + Returns: + ValidationResult with validation status + """ + if not prompt or not prompt.strip(): + return ValidationResult( + is_valid=False, + error_message="Prompt cannot be empty." + ) + + if len(prompt) > MAX_PROMPT_LENGTH: + return ValidationResult( + is_valid=False, + error_message=f"Prompt too long. Maximum {MAX_PROMPT_LENGTH} characters allowed." + ) + + # Remove null bytes + sanitized = prompt.replace('\x00', '') + + return ValidationResult(is_valid=True, sanitized_value=sanitized) + + +def validate_url(url: str) -> ValidationResult: + """ + Validate and sanitize a URL. + + Args: + url: The URL to validate + + Returns: + ValidationResult with validation status + """ + if not url: + return ValidationResult( + is_valid=False, + error_message="URL cannot be empty." + ) + + if len(url) > MAX_URL_LENGTH: + return ValidationResult( + is_valid=False, + error_message=f"URL too long. Maximum {MAX_URL_LENGTH} characters allowed." + ) + + # Basic URL pattern check + url_pattern = re.compile( + r'^https?://' # http:// or https:// + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain + r'localhost|' # localhost + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # or IP + r'(?::\d+)?' # optional port + r'(?:/?|[/?]\S+)$', re.IGNORECASE + ) + + if not url_pattern.match(url): + return ValidationResult( + is_valid=False, + error_message="Invalid URL format." + ) + + # Check for potentially dangerous URL schemes + dangerous_schemes = ['javascript:', 'data:', 'file:', 'vbscript:'] + url_lower = url.lower() + for scheme in dangerous_schemes: + if scheme in url_lower: + return ValidationResult( + is_valid=False, + error_message="URL contains potentially dangerous content." + ) + + return ValidationResult(is_valid=True, sanitized_value=url) + + +def validate_filename(filename: str) -> ValidationResult: + """ + Validate and sanitize a filename. + + Args: + filename: The filename to validate + + Returns: + ValidationResult with validation status and sanitized filename + """ + if not filename: + return ValidationResult( + is_valid=False, + error_message="Filename cannot be empty." + ) + + if len(filename) > MAX_FILENAME_LENGTH: + return ValidationResult( + is_valid=False, + error_message=f"Filename too long. Maximum {MAX_FILENAME_LENGTH} characters allowed." + ) + + # Remove path traversal attempts + sanitized = filename.replace('..', '').replace('/', '').replace('\\', '') + + # Remove dangerous characters + sanitized = re.sub(r'[<>:"|?*\x00-\x1f]', '', sanitized) + + # Ensure it's not empty after sanitization + if not sanitized: + return ValidationResult( + is_valid=False, + error_message="Filename contains only invalid characters." + ) + + return ValidationResult(is_valid=True, sanitized_value=sanitized) + + +def validate_file_size(size: int) -> ValidationResult: + """ + Validate file size. + + Args: + size: The file size in bytes + + Returns: + ValidationResult with validation status + """ + if size <= 0: + return ValidationResult( + is_valid=False, + error_message="File size must be greater than 0." + ) + + if size > MAX_FILE_SIZE: + max_mb = MAX_FILE_SIZE / (1024 * 1024) + return ValidationResult( + is_valid=False, + error_message=f"File too large. Maximum {max_mb:.0f}MB allowed." + ) + + return ValidationResult(is_valid=True) + + +def validate_code(code: str) -> ValidationResult: + """ + Validate code for execution. + + Args: + code: The code to validate + + Returns: + ValidationResult with validation status + """ + if not code or not code.strip(): + return ValidationResult( + is_valid=False, + error_message="Code cannot be empty." + ) + + if len(code) > MAX_CODE_LENGTH: + return ValidationResult( + is_valid=False, + error_message=f"Code too long. Maximum {MAX_CODE_LENGTH} characters allowed." + ) + + return ValidationResult(is_valid=True, sanitized_value=code) + + +def validate_user_id(user_id) -> ValidationResult: + """ + Validate a Discord user ID. + + Args: + user_id: The user ID to validate + + Returns: + ValidationResult with validation status + """ + try: + uid = int(user_id) + if uid <= 0: + return ValidationResult( + is_valid=False, + error_message="Invalid user ID." + ) + # Discord IDs are 17-19 digits + if len(str(uid)) < 17 or len(str(uid)) > 19: + return ValidationResult( + is_valid=False, + error_message="Invalid user ID format." + ) + return ValidationResult(is_valid=True) + except (ValueError, TypeError): + return ValidationResult( + is_valid=False, + error_message="User ID must be a valid integer." + ) + + +def sanitize_for_logging(text: str, max_length: int = 200) -> str: + """ + Sanitize text for safe logging (remove sensitive data, truncate). + + Args: + text: The text to sanitize + max_length: Maximum length for logged text + + Returns: + Sanitized text safe for logging + """ + if not text: + return "" + + # Remove potential secrets/tokens (common patterns) + patterns = [ + (r'(sk-[a-zA-Z0-9]{20,})', '[OPENAI_KEY]'), + (r'(xoxb-[a-zA-Z0-9-]+)', '[SLACK_TOKEN]'), + (r'([A-Za-z0-9_-]{24}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27})', '[DISCORD_TOKEN]'), + (r'(mongodb\+srv://[^@]+@)', 'mongodb+srv://[REDACTED]@'), + (r'(Bearer\s+[A-Za-z0-9_-]+)', 'Bearer [TOKEN]'), + (r'(password["\']?\s*[:=]\s*["\']?)[^"\'\s]+', r'\1[REDACTED]'), + ] + + sanitized = text + for pattern, replacement in patterns: + sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE) + + # Truncate if needed + if len(sanitized) > max_length: + sanitized = sanitized[:max_length] + '...[truncated]' + + return sanitized diff --git a/tests/test_comprehensive.py b/tests/test_comprehensive.py new file mode 100644 index 0000000..f059640 --- /dev/null +++ b/tests/test_comprehensive.py @@ -0,0 +1,727 @@ +""" +Comprehensive test suite for the ChatGPT Discord Bot. + +This module contains unit tests and integration tests for all major components. +Uses pytest with pytest-asyncio for async test support. +""" + +import asyncio +import pytest +import os +import sys +import json +from unittest.mock import MagicMock, patch, AsyncMock +from datetime import datetime, timedelta +from typing import Dict, Any + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# ============================================================ +# Test Fixtures +# ============================================================ + +@pytest.fixture +def mock_db_handler(): + """Create a mock database handler.""" + mock = MagicMock() + mock.get_history = AsyncMock(return_value=[]) + mock.save_history = AsyncMock() + mock.get_user_model = AsyncMock(return_value="openai/gpt-4o") + mock.save_user_model = AsyncMock() + mock.is_admin = AsyncMock(return_value=False) + mock.is_user_whitelisted = AsyncMock(return_value=True) + mock.is_user_blacklisted = AsyncMock(return_value=False) + mock.get_user_tool_display = AsyncMock(return_value=False) + mock.get_user_files = AsyncMock(return_value=[]) + mock.save_token_usage = AsyncMock() + return mock + + +@pytest.fixture +def mock_openai_client(): + """Create a mock OpenAI client.""" + mock = MagicMock() + + # Mock response structure + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Test response" + mock_response.choices[0].finish_reason = "stop" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 100 + mock_response.usage.completion_tokens = 50 + + mock.chat.completions.create = AsyncMock(return_value=mock_response) + return mock + + +@pytest.fixture +def mock_discord_message(): + """Create a mock Discord message.""" + mock = MagicMock() + mock.author.id = 123456789 + mock.author.name = "TestUser" + mock.content = "Hello, bot!" + mock.channel.send = AsyncMock() + mock.channel.typing = MagicMock(return_value=AsyncMock().__aenter__()) + mock.attachments = [] + mock.reference = None + mock.guild = MagicMock() + return mock + + +# ============================================================ +# Pricing Module Tests +# ============================================================ + +class TestPricingModule: + """Tests for the pricing configuration module.""" + + def test_model_pricing_exists(self): + """Test that all expected models have pricing defined.""" + from src.config.pricing import MODEL_PRICING + + expected_models = [ + "openai/gpt-4o", + "openai/gpt-4o-mini", + "openai/gpt-4.1", + "openai/gpt-5", + "openai/o1", + ] + + for model in expected_models: + assert model in MODEL_PRICING, f"Missing pricing for {model}" + + def test_calculate_cost(self): + """Test cost calculation for known models.""" + from src.config.pricing import calculate_cost + + # GPT-4o: $5.00 input, $20.00 output per 1M tokens + cost = calculate_cost("openai/gpt-4o", 1_000_000, 1_000_000) + assert cost == 25.00 # $5 + $20 + + # Test smaller amounts + cost = calculate_cost("openai/gpt-4o", 1000, 1000) + assert cost == pytest.approx(0.025, rel=1e-6) # $0.005 + $0.020 + + def test_calculate_cost_unknown_model(self): + """Test that unknown models return 0 cost.""" + from src.config.pricing import calculate_cost + + cost = calculate_cost("unknown/model", 1000, 1000) + assert cost == 0.0 + + def test_format_cost(self): + """Test cost formatting for display.""" + from src.config.pricing import format_cost + + assert format_cost(0.000001) == "$0.000001" + assert format_cost(0.005) == "$0.005000" # 6 decimal places for small amounts + assert format_cost(1.50) == "$1.50" + assert format_cost(100.00) == "$100.00" + + +# ============================================================ +# Validator Module Tests +# ============================================================ + +class TestValidators: + """Tests for input validation utilities.""" + + def test_validate_message_content(self): + """Test message content validation.""" + from src.utils.validators import validate_message_content + + # Valid content + result = validate_message_content("Hello, world!") + assert result.is_valid + assert result.sanitized_value == "Hello, world!" + + # Empty content is valid + result = validate_message_content("") + assert result.is_valid + + # Content with null bytes should be sanitized + result = validate_message_content("Hello\x00World") + assert result.is_valid + assert "\x00" not in result.sanitized_value + + def test_validate_message_too_long(self): + """Test that overly long messages are rejected.""" + from src.utils.validators import validate_message_content, MAX_MESSAGE_LENGTH + + long_message = "x" * (MAX_MESSAGE_LENGTH + 1) + result = validate_message_content(long_message) + assert not result.is_valid + assert "too long" in result.error_message.lower() + + def test_validate_url(self): + """Test URL validation.""" + from src.utils.validators import validate_url + + # Valid URLs + assert validate_url("https://example.com").is_valid + assert validate_url("http://localhost:8080/path").is_valid + assert validate_url("https://api.example.com/v1/data?q=test").is_valid + + # Invalid URLs + assert not validate_url("").is_valid + assert not validate_url("not-a-url").is_valid + assert not validate_url("javascript:alert(1)").is_valid + assert not validate_url("file:///etc/passwd").is_valid + + def test_validate_filename(self): + """Test filename validation and sanitization.""" + from src.utils.validators import validate_filename + + # Valid filename + result = validate_filename("test_file.txt") + assert result.is_valid + assert result.sanitized_value == "test_file.txt" + + # Path traversal attempt + result = validate_filename("../../../etc/passwd") + assert result.is_valid # Sanitized, not rejected + assert ".." not in result.sanitized_value + assert "/" not in result.sanitized_value + + # Empty filename + result = validate_filename("") + assert not result.is_valid + + def test_sanitize_for_logging(self): + """Test that secrets are properly redacted for logging.""" + from src.utils.validators import sanitize_for_logging + + # Test OpenAI key redaction + text = "API key is sk-abcdefghijklmnopqrstuvwxyz123456" + sanitized = sanitize_for_logging(text) + assert "sk-" not in sanitized + assert "[OPENAI_KEY]" in sanitized + + # Test MongoDB URI redaction + text = "mongodb+srv://user:password@cluster.mongodb.net/db" + sanitized = sanitize_for_logging(text) + assert "password" not in sanitized + assert "[REDACTED]" in sanitized + + # Test truncation + long_text = "x" * 500 + sanitized = sanitize_for_logging(long_text, max_length=100) + assert len(sanitized) < 150 # Account for truncation marker + + +# ============================================================ +# Retry Module Tests +# ============================================================ + +class TestRetryModule: + """Tests for retry utilities.""" + + @pytest.mark.asyncio + async def test_retry_success_first_try(self): + """Test that successful functions don't retry.""" + from src.utils.retry import async_retry_with_backoff + + call_count = 0 + + async def success_func(): + nonlocal call_count + call_count += 1 + return "success" + + result = await async_retry_with_backoff(success_func, max_retries=3) + assert result == "success" + assert call_count == 1 + + @pytest.mark.asyncio + async def test_retry_eventual_success(self): + """Test that functions eventually succeed after retries.""" + from src.utils.retry import async_retry_with_backoff + + call_count = 0 + + async def eventual_success(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ConnectionError("Temporary failure") + return "success" + + result = await async_retry_with_backoff( + eventual_success, + max_retries=5, + base_delay=0.01, # Fast for testing + retryable_exceptions=(ConnectionError,) + ) + assert result == "success" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_retry_exhausted(self): + """Test that RetryError is raised when retries are exhausted.""" + from src.utils.retry import async_retry_with_backoff, RetryError + + async def always_fail(): + raise ConnectionError("Always fails") + + with pytest.raises(RetryError): + await async_retry_with_backoff( + always_fail, + max_retries=2, + base_delay=0.01, + retryable_exceptions=(ConnectionError,) + ) + + +# ============================================================ +# Discord Utils Tests +# ============================================================ + +class TestDiscordUtils: + """Tests for Discord utility functions.""" + + def test_split_message_short(self): + """Test that short messages aren't split.""" + from src.utils.discord_utils import split_message + + short = "This is a short message." + chunks = split_message(short) + assert len(chunks) == 1 + assert chunks[0] == short + + def test_split_message_long(self): + """Test that long messages are properly split.""" + from src.utils.discord_utils import split_message + + # Create a message longer than 2000 characters + long = "Hello world. " * 200 + chunks = split_message(long, max_length=2000) + + assert len(chunks) > 1 + for chunk in chunks: + assert len(chunk) <= 2000 + + def test_split_code_block(self): + """Test code block splitting.""" + from src.utils.discord_utils import split_code_block + + code = "\n".join([f"line {i}" for i in range(100)]) + chunks = split_code_block(code, "python", max_length=500) + + assert len(chunks) > 1 + for chunk in chunks: + assert chunk.startswith("```python\n") + assert chunk.endswith("\n```") + assert len(chunk) <= 500 + + def test_create_error_embed(self): + """Test error embed creation.""" + from src.utils.discord_utils import create_error_embed + import discord + + embed = create_error_embed("Test Error", "Something went wrong", "ValidationError") + + assert isinstance(embed, discord.Embed) + assert "Test Error" in embed.title + assert embed.color == discord.Color.red() + + def test_create_success_embed(self): + """Test success embed creation.""" + from src.utils.discord_utils import create_success_embed + import discord + + embed = create_success_embed("Success!", "Operation completed") + + assert isinstance(embed, discord.Embed) + assert "Success!" in embed.title + assert embed.color == discord.Color.green() + + +# ============================================================ +# Code Interpreter Security Tests +# ============================================================ + +class TestCodeInterpreterSecurity: + """Tests for code interpreter security features.""" + + def test_blocked_imports(self): + """Test that dangerous imports are blocked.""" + from src.utils.code_interpreter import BLOCKED_PATTERNS + import re + + dangerous_code = [ + "import os", + "import subprocess", + "from os import system", + "import socket", + "import requests", + "__import__('os')", + "eval('print(1)')", + "exec('import os')", + ] + + for code in dangerous_code: + blocked = any( + re.search(pattern, code, re.IGNORECASE) + for pattern in BLOCKED_PATTERNS + ) + assert blocked, f"Should block: {code}" + + def test_allowed_imports(self): + """Test that safe imports are allowed.""" + from src.utils.code_interpreter import BLOCKED_PATTERNS + import re + + safe_code = [ + "import pandas as pd", + "import numpy as np", + "import matplotlib.pyplot as plt", + "from sklearn.model_selection import train_test_split", + "import os.path", # os.path is allowed + ] + + for code in safe_code: + blocked = any( + re.search(pattern, code, re.IGNORECASE) + for pattern in BLOCKED_PATTERNS + ) + assert not blocked, f"Should allow: {code}" + + def test_file_type_detection(self): + """Test file type detection for various extensions.""" + from src.utils.code_interpreter import FileManager + + fm = FileManager() + + assert fm._detect_file_type("data.csv") == "csv" + assert fm._detect_file_type("data.xlsx") == "excel" + assert fm._detect_file_type("config.json") == "json" + assert fm._detect_file_type("image.png") == "image" + assert fm._detect_file_type("script.py") == "python" + assert fm._detect_file_type("unknown.xyz") == "binary" + + +# ============================================================ +# OpenAI Utils Tests +# ============================================================ + +class TestOpenAIUtils: + """Tests for OpenAI utility functions.""" + + def test_count_tokens(self): + """Test token counting function.""" + from src.utils.openai_utils import count_tokens + + text = "Hello, world!" + tokens = count_tokens(text) + assert tokens > 0 + assert isinstance(tokens, int) + + def test_trim_content_to_token_limit(self): + """Test content trimming.""" + from src.utils.openai_utils import trim_content_to_token_limit + + # Short content should not be trimmed + short = "Hello, world!" + trimmed = trim_content_to_token_limit(short, max_tokens=100) + assert trimmed == short + + # Long content should be trimmed + long = "Hello " * 10000 + trimmed = trim_content_to_token_limit(long, max_tokens=100) + assert len(trimmed) < len(long) + + def test_prepare_messages_for_api(self): + """Test message preparation for API.""" + from src.utils.openai_utils import prepare_messages_for_api + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + prepared = prepare_messages_for_api(messages) + + assert len(prepared) == 3 + assert all(m.get("role") in ["user", "assistant", "system"] for m in prepared) + + def test_prepare_messages_filters_none_content(self): + """Test that messages with None content are filtered.""" + from src.utils.openai_utils import prepare_messages_for_api + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None}, + {"role": "user", "content": "World"}, + ] + + prepared = prepare_messages_for_api(messages) + + assert len(prepared) == 2 + + +# ============================================================ +# Database Handler Tests (with mocking) +# ============================================================ + +class TestDatabaseHandlerMocked: + """Tests for database handler using mocks.""" + + def test_filter_expired_images_no_images(self): + """Test that messages without images pass through unchanged.""" + from src.database.db_handler import DatabaseHandler + + with patch('motor.motor_asyncio.AsyncIOMotorClient'): + handler = DatabaseHandler("mongodb://localhost") + + history = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + filtered = handler._filter_expired_images(history) + assert len(filtered) == 2 + assert filtered[0]["content"] == "Hello" + + def test_filter_expired_images_recent_image(self): + """Test that recent images are kept.""" + from src.database.db_handler import DatabaseHandler + + with patch('motor.motor_asyncio.AsyncIOMotorClient'): + handler = DatabaseHandler("mongodb://localhost") + + recent_timestamp = datetime.now().isoformat() + history = [ + {"role": "user", "content": [ + {"type": "text", "text": "Check this image"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": recent_timestamp} + ]} + ] + + filtered = handler._filter_expired_images(history) + assert len(filtered) == 1 + assert len(filtered[0]["content"]) == 2 # Both items kept + + def test_filter_expired_images_old_image(self): + """Test that old images are filtered out.""" + from src.database.db_handler import DatabaseHandler + + with patch('motor.motor_asyncio.AsyncIOMotorClient'): + handler = DatabaseHandler("mongodb://localhost") + + old_timestamp = (datetime.now() - timedelta(hours=24)).isoformat() + history = [ + {"role": "user", "content": [ + {"type": "text", "text": "Check this image"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": old_timestamp} + ]} + ] + + filtered = handler._filter_expired_images(history) + assert len(filtered) == 1 + assert len(filtered[0]["content"]) == 1 # Only text kept + + +# ============================================================ +# ============================================================ +# Cache Module Tests +# ============================================================ + +class TestLRUCache: + """Tests for the LRU cache implementation.""" + + @pytest.mark.asyncio + async def test_cache_set_and_get(self): + """Test basic cache set and get operations.""" + from src.utils.cache import LRUCache + + cache = LRUCache(max_size=100, default_ttl=60.0) + + await cache.set("key1", "value1") + result = await cache.get("key1") + assert result == "value1" + + @pytest.mark.asyncio + async def test_cache_expiration(self): + """Test that cache entries expire after TTL.""" + from src.utils.cache import LRUCache + + cache = LRUCache(max_size=100, default_ttl=0.1) # 100ms TTL + + await cache.set("key1", "value1") + + # Should exist immediately + assert await cache.get("key1") == "value1" + + # Wait for expiration + await asyncio.sleep(0.15) + + # Should be expired now + assert await cache.get("key1") is None + + @pytest.mark.asyncio + async def test_cache_lru_eviction(self): + """Test that LRU eviction works correctly.""" + from src.utils.cache import LRUCache + + cache = LRUCache(max_size=3, default_ttl=60.0) + + await cache.set("key1", "value1") + await cache.set("key2", "value2") + await cache.set("key3", "value3") + + # Access key1 to make it recently used + await cache.get("key1") + + # Add new key, should evict key2 (least recently used) + await cache.set("key4", "value4") + + assert await cache.get("key1") == "value1" # Should exist + assert await cache.get("key2") is None # Should be evicted + assert await cache.get("key3") == "value3" # Should exist + assert await cache.get("key4") == "value4" # Should exist + + @pytest.mark.asyncio + async def test_cache_stats(self): + """Test cache statistics tracking.""" + from src.utils.cache import LRUCache + + cache = LRUCache(max_size=100, default_ttl=60.0) + + await cache.set("key1", "value1") + await cache.get("key1") # Hit + await cache.get("key2") # Miss + await cache.get("key1") # Hit + + stats = cache.stats() + assert stats["hits"] == 2 + assert stats["misses"] == 1 + assert stats["size"] == 1 + + @pytest.mark.asyncio + async def test_cache_clear(self): + """Test cache clearing.""" + from src.utils.cache import LRUCache + + cache = LRUCache(max_size=100, default_ttl=60.0) + + await cache.set("key1", "value1") + await cache.set("key2", "value2") + + cleared = await cache.clear() + assert cleared == 2 + + assert await cache.get("key1") is None + assert await cache.get("key2") is None + + +# ============================================================ +# Monitoring Module Tests +# ============================================================ + +class TestMonitoring: + """Tests for the monitoring utilities.""" + + def test_performance_metrics(self): + """Test performance metrics tracking.""" + from src.utils.monitoring import PerformanceMetrics + import time + + metrics = PerformanceMetrics(name="test_operation") + time.sleep(0.01) # Small delay + metrics.finish(success=True) + + assert metrics.success + assert metrics.duration_ms > 0 + assert metrics.duration_ms < 1000 # Should be fast + + def test_measure_sync_context_manager(self): + """Test synchronous measurement context manager.""" + from src.utils.monitoring import measure_sync + import time + + with measure_sync("test_op", custom_field="value") as metrics: + time.sleep(0.01) + + assert metrics.duration_ms > 0 + assert metrics.metadata["custom_field"] == "value" + + @pytest.mark.asyncio + async def test_measure_async_context_manager(self): + """Test async measurement context manager.""" + from src.utils.monitoring import measure_async + + async with measure_async("async_op") as metrics: + await asyncio.sleep(0.01) + + assert metrics.duration_ms > 0 + assert metrics.success + + @pytest.mark.asyncio + async def test_track_performance_decorator(self): + """Test performance tracking decorator.""" + from src.utils.monitoring import track_performance + + call_count = 0 + + @track_performance("tracked_function") + async def tracked_func(): + nonlocal call_count + call_count += 1 + return "result" + + result = await tracked_func() + assert result == "result" + assert call_count == 1 + + def test_health_status(self): + """Test health status structure.""" + from src.utils.monitoring import HealthStatus + + status = HealthStatus(healthy=True) + + status.add_check("database", True, "Connected") + status.add_check("api", False, "Timeout") + + assert not status.healthy # Should be unhealthy due to API check + assert status.checks["database"]["healthy"] + assert not status.checks["api"]["healthy"] + + +# ============================================================ +# Integration Tests (require environment setup) +# ============================================================ + +@pytest.mark.integration +class TestIntegration: + """Integration tests that require actual services.""" + + @pytest.mark.asyncio + async def test_database_connection(self): + """Test actual database connection (skip if no MongoDB).""" + from dotenv import load_dotenv + load_dotenv() + + mongodb_uri = os.getenv("MONGODB_URI") + if not mongodb_uri: + pytest.skip("MONGODB_URI not set") + + from src.database.db_handler import DatabaseHandler + handler = DatabaseHandler(mongodb_uri) + + connected = await handler.ensure_connected() + assert connected + + await handler.close() + + +# ============================================================ +# Run tests +# ============================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])