Add retry utilities, input validation, and comprehensive tests
- Implemented async retry logic with exponential backoff in `src/utils/retry.py`. - Created input validation utilities for Discord bot in `src/utils/validators.py`. - Refactored token pricing import in `src/utils/token_counter.py`. - Added comprehensive test suite in `tests/test_comprehensive.py` covering various modules including pricing, validators, retry logic, and Discord utilities.
This commit is contained in:
23
.env.example
23
.env.example
@@ -88,3 +88,26 @@ TIMEZONE=UTC
|
|||||||
# 168 = 1 week
|
# 168 = 1 week
|
||||||
# -1 = Never expire (permanent storage)
|
# -1 = Never expire (permanent storage)
|
||||||
FILE_EXPIRATION_HOURS=48
|
FILE_EXPIRATION_HOURS=48
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# Monitoring & Observability (Optional)
|
||||||
|
# ============================================
|
||||||
|
|
||||||
|
# Sentry DSN for error tracking
|
||||||
|
# Get from: https://sentry.io/ (create a project and copy the DSN)
|
||||||
|
# Leave empty to disable Sentry error tracking
|
||||||
|
SENTRY_DSN=
|
||||||
|
|
||||||
|
# Environment name for Sentry (development, staging, production)
|
||||||
|
ENVIRONMENT=development
|
||||||
|
|
||||||
|
# Sentry sample rate (0.0 to 1.0) - percentage of errors to capture
|
||||||
|
# 1.0 = 100% of errors, 0.5 = 50% of errors
|
||||||
|
SENTRY_SAMPLE_RATE=1.0
|
||||||
|
|
||||||
|
# Sentry traces sample rate for performance monitoring (0.0 to 1.0)
|
||||||
|
# 0.1 = 10% of transactions, lower values recommended for high-traffic bots
|
||||||
|
SENTRY_TRACES_RATE=0.1
|
||||||
|
|
||||||
|
# Log level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
|||||||
@@ -1,19 +1,49 @@
|
|||||||
discord.py
|
# Discord Bot Core
|
||||||
openai
|
discord.py>=2.3.0
|
||||||
motor
|
openai>=1.40.0
|
||||||
pymongo[srv]
|
python-dotenv>=1.0.0
|
||||||
dnspython>=2.0.0
|
|
||||||
pypdf
|
# Database
|
||||||
beautifulsoup4
|
motor>=3.3.0
|
||||||
requests
|
pymongo[srv]>=4.6.0
|
||||||
aiohttp
|
dnspython>=2.5.0
|
||||||
|
|
||||||
|
# Web & HTTP
|
||||||
|
aiohttp>=3.9.0
|
||||||
|
requests>=2.31.0
|
||||||
|
beautifulsoup4>=4.12.0
|
||||||
|
|
||||||
|
# AI & ML
|
||||||
runware>=0.4.33
|
runware>=0.4.33
|
||||||
python-dotenv
|
tiktoken>=0.7.0
|
||||||
matplotlib
|
|
||||||
pandas
|
# Data Processing
|
||||||
openpyxl
|
pandas>=2.1.0
|
||||||
seaborn
|
numpy>=1.26.0
|
||||||
tzlocal
|
openpyxl>=3.1.0
|
||||||
numpy
|
|
||||||
plotly
|
# Visualization
|
||||||
tiktoken
|
matplotlib>=3.8.0
|
||||||
|
seaborn>=0.13.0
|
||||||
|
plotly>=5.18.0
|
||||||
|
|
||||||
|
# Document Processing
|
||||||
|
pypdf>=4.0.0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
|
||||||
|
# Scheduling & Time
|
||||||
|
APScheduler>=3.10.0
|
||||||
|
tzlocal>=5.2
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.23.0
|
||||||
|
pytest-cov>=4.1.0
|
||||||
|
pytest-mock>=3.12.0
|
||||||
|
|
||||||
|
# Code Quality
|
||||||
|
ruff>=0.3.0
|
||||||
|
|
||||||
|
# Monitoring & Logging (Optional)
|
||||||
|
# sentry-sdk>=1.40.0 # Uncomment for error monitoring
|
||||||
|
# python-json-logger>=2.0.0 # Uncomment for structured logging
|
||||||
@@ -7,36 +7,67 @@ import asyncio
|
|||||||
from typing import Optional, Dict, List, Any, Callable
|
from typing import Optional, Dict, List, Any, Callable
|
||||||
|
|
||||||
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS, DEFAULT_MODEL
|
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS, DEFAULT_MODEL
|
||||||
|
from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost
|
||||||
from src.utils.image_utils import ImageGenerator
|
from src.utils.image_utils import ImageGenerator
|
||||||
from src.utils.web_utils import google_custom_search, scrape_web_content
|
from src.utils.web_utils import google_custom_search, scrape_web_content
|
||||||
from src.utils.pdf_utils import process_pdf, send_response
|
from src.utils.pdf_utils import process_pdf, send_response
|
||||||
from src.utils.openai_utils import prepare_file_from_path
|
from src.utils.openai_utils import prepare_file_from_path
|
||||||
from src.utils.token_counter import token_counter
|
from src.utils.token_counter import token_counter
|
||||||
from src.utils.code_interpreter import delete_all_user_files
|
from src.utils.code_interpreter import delete_all_user_files
|
||||||
|
from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed
|
||||||
# Model pricing per 1M tokens (in USD)
|
|
||||||
MODEL_PRICING = {
|
|
||||||
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
|
|
||||||
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
|
|
||||||
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
|
|
||||||
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
|
||||||
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
|
||||||
"openai/gpt-5": {"input": 1.25, "output": 10.00},
|
|
||||||
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
|
|
||||||
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
|
|
||||||
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
|
|
||||||
"openai/o1-preview": {"input": 15.00, "output": 60.00},
|
|
||||||
"openai/o1-mini": {"input": 1.10, "output": 4.40},
|
|
||||||
"openai/o1": {"input": 15.00, "output": 60.00},
|
|
||||||
"openai/o3-mini": {"input": 1.10, "output": 4.40},
|
|
||||||
"openai/o3": {"input": 2.00, "output": 8.00},
|
|
||||||
"openai/o4-mini": {"input": 2.00, "output": 8.00}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Dictionary to keep track of user requests and their cooldowns
|
# Dictionary to keep track of user requests and their cooldowns
|
||||||
user_requests = {}
|
user_requests: Dict[int, Dict[str, Any]] = {}
|
||||||
# Dictionary to store user tasks
|
# Dictionary to store user tasks
|
||||||
user_tasks = {}
|
user_tasks: Dict[int, List] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Autocomplete Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
async def model_autocomplete(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
current: str,
|
||||||
|
) -> List[app_commands.Choice[str]]:
|
||||||
|
"""
|
||||||
|
Autocomplete function for model selection.
|
||||||
|
Provides filtered model suggestions based on user input.
|
||||||
|
"""
|
||||||
|
# Filter models based on current input
|
||||||
|
matches = [
|
||||||
|
model for model in MODEL_OPTIONS
|
||||||
|
if current.lower() in model.lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
# If no matches, show all models
|
||||||
|
if not matches:
|
||||||
|
matches = MODEL_OPTIONS
|
||||||
|
|
||||||
|
# Return up to 25 choices (Discord limit)
|
||||||
|
return [
|
||||||
|
app_commands.Choice(name=model, value=model)
|
||||||
|
for model in matches[:25]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def image_model_autocomplete(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
current: str,
|
||||||
|
) -> List[app_commands.Choice[str]]:
|
||||||
|
"""
|
||||||
|
Autocomplete function for image generation model selection.
|
||||||
|
"""
|
||||||
|
image_models = ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"]
|
||||||
|
matches = [m for m in image_models if current.lower() in m.lower()]
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
matches = image_models
|
||||||
|
|
||||||
|
return [
|
||||||
|
app_commands.Choice(name=model, value=model)
|
||||||
|
for model in matches[:25]
|
||||||
|
]
|
||||||
|
|
||||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
||||||
"""
|
"""
|
||||||
@@ -112,7 +143,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
|||||||
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
|
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
|
||||||
@check_blacklist()
|
@check_blacklist()
|
||||||
async def choose_model(interaction: discord.Interaction):
|
async def choose_model(interaction: discord.Interaction):
|
||||||
"""Lets users choose an AI model and saves it to the database."""
|
"""Lets users choose an AI model using a dropdown menu."""
|
||||||
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
|
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
|
||||||
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
|
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
|
||||||
|
|
||||||
@@ -131,6 +162,43 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
|||||||
view.add_item(select_menu)
|
view.add_item(select_menu)
|
||||||
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
|
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
|
||||||
|
|
||||||
|
@tree.command(name="set_model", description="Set AI model directly with autocomplete suggestions.")
|
||||||
|
@app_commands.describe(model="The AI model to use (type to search)")
|
||||||
|
@app_commands.autocomplete(model=model_autocomplete)
|
||||||
|
@check_blacklist()
|
||||||
|
async def set_model(interaction: discord.Interaction, model: str):
|
||||||
|
"""Sets the AI model directly using autocomplete."""
|
||||||
|
user_id = interaction.user.id
|
||||||
|
|
||||||
|
# Validate the model is in the allowed list
|
||||||
|
if model not in MODEL_OPTIONS:
|
||||||
|
# Find close matches for suggestions
|
||||||
|
close_matches = [m for m in MODEL_OPTIONS if model.lower() in m.lower()]
|
||||||
|
if close_matches:
|
||||||
|
suggestions = ", ".join(f"`{m}`" for m in close_matches[:5])
|
||||||
|
await interaction.response.send_message(
|
||||||
|
f"❌ Invalid model `{model}`. Did you mean: {suggestions}?",
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
f"❌ Invalid model `{model}`. Use `/choose_model` to see available options.",
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save the model selection
|
||||||
|
await db_handler.save_user_model(user_id, model)
|
||||||
|
|
||||||
|
# Get pricing info for the selected model
|
||||||
|
pricing = MODEL_PRICING.get(model, {"input": 0, "output": 0})
|
||||||
|
|
||||||
|
await interaction.response.send_message(
|
||||||
|
f"✅ Model set to `{model}`\n"
|
||||||
|
f"💰 Pricing: ${pricing['input']:.2f}/1M input, ${pricing['output']:.2f}/1M output",
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
@tree.command(name="search", description="Search on Google and send results to AI model.")
|
@tree.command(name="search", description="Search on Google and send results to AI model.")
|
||||||
@app_commands.describe(query="The search query")
|
@app_commands.describe(query="The search query")
|
||||||
@check_blacklist()
|
@check_blacklist()
|
||||||
@@ -494,16 +562,22 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
|||||||
async def help_command(interaction: discord.Interaction):
|
async def help_command(interaction: discord.Interaction):
|
||||||
"""Sends a list of available commands to the user."""
|
"""Sends a list of available commands to the user."""
|
||||||
help_message = (
|
help_message = (
|
||||||
"**Available commands:**\n"
|
"**🤖 Available Commands:**\n\n"
|
||||||
"/choose_model - Select which AI model to use for responses (openai/gpt-4o, openai/gpt-4o-mini, openai/gpt-5, openai/gpt-5-nano, openai/gpt-5-mini, openai/gpt-5-chat, openai/o1-preview, openai/o1-mini).\n"
|
"**Model Selection:**\n"
|
||||||
"/search `<query>` - Search Google and send results to the AI model.\n"
|
"• `/choose_model` - Select AI model from a dropdown menu\n"
|
||||||
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
|
"• `/set_model <model>` - Set model directly with autocomplete\n\n"
|
||||||
"/generate `<prompt>` - Generate an image from a text prompt.\n"
|
"**Search & Web:**\n"
|
||||||
"/toggle_tools - Toggle display of tool execution details (code, input, output).\n"
|
"• `/search <query>` - Search Google and analyze results with AI\n"
|
||||||
"/reset - Reset your chat history and token usage statistics.\n"
|
"• `/web <url>` - Scrape and analyze a webpage\n\n"
|
||||||
"/user_stat - Get information about your token usage, costs, and current model.\n"
|
"**Image Generation:**\n"
|
||||||
"/prices - Display pricing information for all available AI models.\n"
|
"• `/generate <prompt>` - Generate images from text\n\n"
|
||||||
"/help - Display this help message.\n"
|
"**Settings & Stats:**\n"
|
||||||
|
"• `/toggle_tools` - Toggle tool execution details display\n"
|
||||||
|
"• `/user_stat` - View your token usage and costs\n"
|
||||||
|
"• `/prices` - Display model pricing information\n"
|
||||||
|
"• `/reset` - Clear your chat history and statistics\n\n"
|
||||||
|
"**Help:**\n"
|
||||||
|
"• `/help` - Display this help message\n"
|
||||||
)
|
)
|
||||||
await interaction.response.send_message(help_message, ephemeral=True)
|
await interaction.response.send_message(help_message, ephemeral=True)
|
||||||
|
|
||||||
|
|||||||
100
src/config/pricing.py
Normal file
100
src/config/pricing.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
Centralized pricing configuration for OpenAI models.
|
||||||
|
|
||||||
|
This module provides a single source of truth for model pricing,
|
||||||
|
eliminating duplication across the codebase.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelPricing:
|
||||||
|
"""Pricing information for a model (per 1M tokens in USD)."""
|
||||||
|
input: float
|
||||||
|
output: float
|
||||||
|
|
||||||
|
def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
|
||||||
|
"""Calculate total cost for given token counts."""
|
||||||
|
input_cost = (input_tokens / 1_000_000) * self.input
|
||||||
|
output_cost = (output_tokens / 1_000_000) * self.output
|
||||||
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
|
||||||
|
# Model pricing per 1M tokens (in USD)
|
||||||
|
# Centralized location - update prices here only
|
||||||
|
MODEL_PRICING: Dict[str, ModelPricing] = {
|
||||||
|
# GPT-4o Family
|
||||||
|
"openai/gpt-4o": ModelPricing(input=5.00, output=20.00),
|
||||||
|
"openai/gpt-4o-mini": ModelPricing(input=0.60, output=2.40),
|
||||||
|
|
||||||
|
# GPT-4.1 Family
|
||||||
|
"openai/gpt-4.1": ModelPricing(input=2.00, output=8.00),
|
||||||
|
"openai/gpt-4.1-mini": ModelPricing(input=0.40, output=1.60),
|
||||||
|
"openai/gpt-4.1-nano": ModelPricing(input=0.10, output=0.40),
|
||||||
|
|
||||||
|
# GPT-5 Family
|
||||||
|
"openai/gpt-5": ModelPricing(input=1.25, output=10.00),
|
||||||
|
"openai/gpt-5-mini": ModelPricing(input=0.25, output=2.00),
|
||||||
|
"openai/gpt-5-nano": ModelPricing(input=0.05, output=0.40),
|
||||||
|
"openai/gpt-5-chat": ModelPricing(input=1.25, output=10.00),
|
||||||
|
|
||||||
|
# o1 Family (Reasoning models)
|
||||||
|
"openai/o1-preview": ModelPricing(input=15.00, output=60.00),
|
||||||
|
"openai/o1-mini": ModelPricing(input=1.10, output=4.40),
|
||||||
|
"openai/o1": ModelPricing(input=15.00, output=60.00),
|
||||||
|
|
||||||
|
# o3 Family
|
||||||
|
"openai/o3-mini": ModelPricing(input=1.10, output=4.40),
|
||||||
|
"openai/o3": ModelPricing(input=2.00, output=8.00),
|
||||||
|
|
||||||
|
# o4 Family
|
||||||
|
"openai/o4-mini": ModelPricing(input=2.00, output=8.00),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_pricing(model: str) -> Optional[ModelPricing]:
|
||||||
|
"""
|
||||||
|
Get pricing for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name (e.g., "openai/gpt-4o")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelPricing object or None if model not found
|
||||||
|
"""
|
||||||
|
return MODEL_PRICING.get(model)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the cost for a given model and token counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total cost in USD, or 0.0 if model not found
|
||||||
|
"""
|
||||||
|
pricing = get_model_pricing(model)
|
||||||
|
if pricing:
|
||||||
|
return pricing.calculate_cost(input_tokens, output_tokens)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_models() -> list:
|
||||||
|
"""Get list of all available models with pricing."""
|
||||||
|
return list(MODEL_PRICING.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def format_cost(cost: float) -> str:
|
||||||
|
"""Format cost for display."""
|
||||||
|
if cost < 0.01:
|
||||||
|
return f"${cost:.6f}"
|
||||||
|
elif cost < 1.00:
|
||||||
|
return f"${cost:.4f}"
|
||||||
|
else:
|
||||||
|
return f"${cost:.2f}"
|
||||||
@@ -160,7 +160,12 @@ class DatabaseHandler:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""Filter out image links that are older than 23 hours"""
|
"""
|
||||||
|
Filter out image links that are older than 23 hours.
|
||||||
|
|
||||||
|
Properly handles timezone-aware and timezone-naive datetime comparisons
|
||||||
|
to prevent issues with ISO string parsing.
|
||||||
|
"""
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
expiration_time = current_time - timedelta(hours=23)
|
expiration_time = current_time - timedelta(hours=23)
|
||||||
|
|
||||||
@@ -183,11 +188,27 @@ class DatabaseHandler:
|
|||||||
# Check image items for timestamp
|
# Check image items for timestamp
|
||||||
elif item.get('type') == 'image_url':
|
elif item.get('type') == 'image_url':
|
||||||
# If there's no timestamp or timestamp is newer than expiration time, keep it
|
# If there's no timestamp or timestamp is newer than expiration time, keep it
|
||||||
timestamp = item.get('timestamp')
|
timestamp_str = item.get('timestamp')
|
||||||
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time:
|
if not timestamp_str:
|
||||||
|
# No timestamp, keep the image
|
||||||
filtered_content.append(item)
|
filtered_content.append(item)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Filtering out expired image URL (added at {timestamp})")
|
try:
|
||||||
|
# Parse the ISO timestamp, handling both timezone-aware and naive
|
||||||
|
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
# Make comparison timezone-naive for consistency
|
||||||
|
if timestamp.tzinfo is not None:
|
||||||
|
timestamp = timestamp.replace(tzinfo=None)
|
||||||
|
|
||||||
|
if timestamp > expiration_time:
|
||||||
|
filtered_content.append(item)
|
||||||
|
else:
|
||||||
|
logging.debug(f"Filtering out expired image URL (added at {timestamp_str})")
|
||||||
|
except (ValueError, AttributeError) as e:
|
||||||
|
# If we can't parse the timestamp, keep the image to be safe
|
||||||
|
logging.warning(f"Could not parse image timestamp '{timestamp_str}': {e}")
|
||||||
|
filtered_content.append(item)
|
||||||
|
|
||||||
# Update the message with filtered content
|
# Update the message with filtered content
|
||||||
if filtered_content:
|
if filtered_content:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import functools
|
import functools
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Optional
|
||||||
import io
|
import io
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
@@ -19,32 +19,16 @@ from src.utils.pdf_utils import process_pdf, send_response
|
|||||||
from src.utils.code_utils import extract_code_blocks
|
from src.utils.code_utils import extract_code_blocks
|
||||||
from src.utils.reminder_utils import ReminderManager
|
from src.utils.reminder_utils import ReminderManager
|
||||||
from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT, DEFAULT_MODEL
|
from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT, DEFAULT_MODEL
|
||||||
|
from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost
|
||||||
|
from src.utils.validators import validate_message_content, validate_prompt, sanitize_for_logging
|
||||||
|
from src.utils.discord_utils import send_long_message, create_error_embed, create_progress_embed
|
||||||
|
|
||||||
# Global task and rate limiting tracking
|
# Global task and rate limiting tracking
|
||||||
user_tasks = {}
|
user_tasks: Dict[int, Dict] = {}
|
||||||
user_last_request = {}
|
user_last_request: Dict[int, List[float]] = {}
|
||||||
RATE_LIMIT_WINDOW = 5 # seconds
|
RATE_LIMIT_WINDOW = 5 # seconds
|
||||||
MAX_REQUESTS = 3 # max requests per window
|
MAX_REQUESTS = 3 # max requests per window
|
||||||
|
|
||||||
# Model pricing per 1M tokens (in USD)
|
|
||||||
MODEL_PRICING = {
|
|
||||||
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
|
|
||||||
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
|
|
||||||
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
|
|
||||||
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
|
||||||
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
|
||||||
"openai/gpt-5": {"input": 1.25, "output": 10.00},
|
|
||||||
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
|
|
||||||
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
|
|
||||||
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
|
|
||||||
"openai/o1-preview": {"input": 15.00, "output": 60.00},
|
|
||||||
"openai/o1-mini": {"input": 1.10, "output": 4.40},
|
|
||||||
"openai/o1": {"input": 15.00, "output": 60.00},
|
|
||||||
"openai/o3-mini": {"input": 1.10, "output": 4.40},
|
|
||||||
"openai/o3": {"input": 2.00, "output": 8.00},
|
|
||||||
"openai/o4-mini": {"input": 2.00, "output": 8.00}
|
|
||||||
}
|
|
||||||
|
|
||||||
# File extensions that should be treated as text files
|
# File extensions that should be treated as text files
|
||||||
TEXT_FILE_EXTENSIONS = [
|
TEXT_FILE_EXTENSIONS = [
|
||||||
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css',
|
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css',
|
||||||
@@ -1598,13 +1582,11 @@ print("\\n=== Correlation Analysis ===")
|
|||||||
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||||
|
|
||||||
# Calculate cost based on model pricing
|
# Calculate cost based on model pricing
|
||||||
if model in MODEL_PRICING:
|
pricing = MODEL_PRICING.get(model)
|
||||||
pricing = MODEL_PRICING[model]
|
if pricing:
|
||||||
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
||||||
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
|
||||||
total_cost = input_cost + output_cost
|
|
||||||
|
|
||||||
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: ${total_cost:.6f}")
|
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
||||||
|
|
||||||
# Save token usage and cost to database
|
# Save token usage and cost to database
|
||||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||||
@@ -1704,14 +1686,12 @@ print("\\n=== Correlation Analysis ===")
|
|||||||
output_tokens += follow_up_output_tokens
|
output_tokens += follow_up_output_tokens
|
||||||
|
|
||||||
# Calculate additional cost
|
# Calculate additional cost
|
||||||
if model in MODEL_PRICING:
|
pricing = MODEL_PRICING.get(model)
|
||||||
pricing = MODEL_PRICING[model]
|
if pricing:
|
||||||
additional_input_cost = (follow_up_input_tokens / 1_000_000) * pricing["input"]
|
additional_cost = pricing.calculate_cost(follow_up_input_tokens, follow_up_output_tokens)
|
||||||
additional_output_cost = (follow_up_output_tokens / 1_000_000) * pricing["output"]
|
|
||||||
additional_cost = additional_input_cost + additional_output_cost
|
|
||||||
total_cost += additional_cost
|
total_cost += additional_cost
|
||||||
|
|
||||||
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: ${additional_cost:.6f}")
|
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: {format_cost(additional_cost)}")
|
||||||
|
|
||||||
# Save additional token usage and cost to database
|
# Save additional token usage and cost to database
|
||||||
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
||||||
@@ -1791,7 +1771,7 @@ print("\\n=== Correlation Analysis ===")
|
|||||||
|
|
||||||
# Log processing time and cost for performance monitoring
|
# Log processing time and cost for performance monitoring
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: ${total_cost:.6f})")
|
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: {format_cost(total_cost)})")
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Handle cancellation cleanly
|
# Handle cancellation cleanly
|
||||||
|
|||||||
358
src/utils/cache.py
Normal file
358
src/utils/cache.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
Simple caching utilities for API responses and frequently accessed data.
|
||||||
|
|
||||||
|
This module provides an in-memory LRU cache with optional TTL (time-to-live)
|
||||||
|
support, designed for caching API responses and reducing redundant calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Callable, TypeVar, Generic
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry(Generic[T]):
|
||||||
|
"""A single cache entry with value and expiration time."""
|
||||||
|
value: T
|
||||||
|
expires_at: float
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
hits: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache(Generic[T]):
|
||||||
|
"""
|
||||||
|
Thread-safe LRU (Least Recently Used) cache with TTL support.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Configurable max size with automatic eviction
|
||||||
|
- Per-entry TTL (time-to-live)
|
||||||
|
- Automatic cleanup of expired entries
|
||||||
|
- Hit/miss statistics tracking
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
cache = LRUCache(max_size=1000, default_ttl=300) # 5 min TTL
|
||||||
|
cache.set("key", "value")
|
||||||
|
value = cache.get("key") # Returns value or None if expired
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int = 1000,
|
||||||
|
default_ttl: float = 300.0, # 5 minutes default
|
||||||
|
cleanup_interval: float = 60.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LRU cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of entries
|
||||||
|
default_ttl: Default TTL in seconds
|
||||||
|
cleanup_interval: How often to run cleanup (seconds)
|
||||||
|
"""
|
||||||
|
self._cache: OrderedDict[str, CacheEntry[T]] = OrderedDict()
|
||||||
|
self._max_size = max_size
|
||||||
|
self._default_ttl = default_ttl
|
||||||
|
self._cleanup_interval = cleanup_interval
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self._hits = 0
|
||||||
|
self._misses = 0
|
||||||
|
|
||||||
|
# Background cleanup task
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the background cleanup task."""
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.debug("Cache cleanup task started")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the background cleanup task."""
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
logger.debug("Cache cleanup task stopped")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""Background task to periodically clean up expired entries."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self._cleanup_interval)
|
||||||
|
await self._cleanup_expired()
|
||||||
|
|
||||||
|
async def _cleanup_expired(self) -> int:
|
||||||
|
"""Remove expired entries. Returns count of removed entries."""
|
||||||
|
now = time.time()
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
keys_to_remove = [
|
||||||
|
key for key, entry in self._cache.items()
|
||||||
|
if entry.expires_at <= now
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
if removed > 0:
|
||||||
|
logger.debug(f"Cache cleanup: removed {removed} expired entries")
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Get a value from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or None if not found/expired
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if key not in self._cache:
|
||||||
|
self._misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
entry = self._cache[key]
|
||||||
|
|
||||||
|
# Check if expired
|
||||||
|
if entry.expires_at <= time.time():
|
||||||
|
del self._cache[key]
|
||||||
|
self._misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Move to end (most recently used)
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
entry.hits += 1
|
||||||
|
self._hits += 1
|
||||||
|
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: T,
|
||||||
|
ttl: Optional[float] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set a value in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache
|
||||||
|
ttl: Optional TTL override (uses default if not provided)
|
||||||
|
"""
|
||||||
|
ttl = ttl if ttl is not None else self._default_ttl
|
||||||
|
expires_at = time.time() + ttl
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Remove oldest entries if at capacity
|
||||||
|
while len(self._cache) >= self._max_size:
|
||||||
|
oldest_key = next(iter(self._cache))
|
||||||
|
del self._cache[oldest_key]
|
||||||
|
logger.debug(f"Cache evicted oldest entry: {oldest_key}")
|
||||||
|
|
||||||
|
self._cache[key] = CacheEntry(
|
||||||
|
value=value,
|
||||||
|
expires_at=expires_at
|
||||||
|
)
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a key from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was found and deleted
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
del self._cache[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all entries from the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def has(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists and is not expired."""
|
||||||
|
return await self.get(key) is not None
|
||||||
|
|
||||||
|
def stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with size, hits, misses, hit_rate
|
||||||
|
"""
|
||||||
|
total = self._hits + self._misses
|
||||||
|
hit_rate = (self._hits / total * 100) if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"size": len(self._cache),
|
||||||
|
"max_size": self._max_size,
|
||||||
|
"hits": self._hits,
|
||||||
|
"misses": self._misses,
|
||||||
|
"hit_rate": f"{hit_rate:.2f}%",
|
||||||
|
"default_ttl": self._default_ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global cache instances for different purposes
|
||||||
|
_api_response_cache: Optional[LRUCache[Dict[str, Any]]] = None
|
||||||
|
_user_preference_cache: Optional[LRUCache[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_api_cache() -> LRUCache[Dict[str, Any]]:
|
||||||
|
"""Get or create the API response cache."""
|
||||||
|
global _api_response_cache
|
||||||
|
if _api_response_cache is None:
|
||||||
|
_api_response_cache = LRUCache(
|
||||||
|
max_size=500,
|
||||||
|
default_ttl=300.0 # 5 minutes
|
||||||
|
)
|
||||||
|
await _api_response_cache.start()
|
||||||
|
return _api_response_cache
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_cache() -> LRUCache[Dict[str, Any]]:
|
||||||
|
"""Get or create the user preference cache."""
|
||||||
|
global _user_preference_cache
|
||||||
|
if _user_preference_cache is None:
|
||||||
|
_user_preference_cache = LRUCache(
|
||||||
|
max_size=1000,
|
||||||
|
default_ttl=600.0 # 10 minutes
|
||||||
|
)
|
||||||
|
await _user_preference_cache.start()
|
||||||
|
return _user_preference_cache
|
||||||
|
|
||||||
|
|
||||||
|
def cached(
|
||||||
|
cache_key_func: Callable[..., str],
|
||||||
|
ttl: Optional[float] = None,
|
||||||
|
cache_getter: Callable = get_api_cache
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator to cache async function results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key_func: Function to generate cache key from args
|
||||||
|
ttl: Optional TTL override
|
||||||
|
cache_getter: Function to get the cache instance
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@cached(
|
||||||
|
cache_key_func=lambda user_id: f"user:{user_id}",
|
||||||
|
ttl=300
|
||||||
|
)
|
||||||
|
async def get_user_data(user_id: int) -> dict:
|
||||||
|
# Expensive operation
|
||||||
|
return await fetch_from_api(user_id)
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
cache = await cache_getter()
|
||||||
|
key = cache_key_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Try to get from cache
|
||||||
|
cached_value = await cache.get(key)
|
||||||
|
if cached_value is not None:
|
||||||
|
logger.debug(f"Cache hit for key: {key}")
|
||||||
|
return cached_value
|
||||||
|
|
||||||
|
# Execute function and cache result
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
await cache.set(key, result, ttl=ttl)
|
||||||
|
logger.debug(f"Cached result for key: {key}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_on_update(
|
||||||
|
cache_key_func: Callable[..., str],
|
||||||
|
cache_getter: Callable = get_api_cache
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator to invalidate cache when a function (update operation) is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key_func: Function to generate cache key to invalidate
|
||||||
|
cache_getter: Function to get the cache instance
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@invalidate_on_update(
|
||||||
|
cache_key_func=lambda user_id, **_: f"user:{user_id}"
|
||||||
|
)
|
||||||
|
async def update_user_data(user_id: int, data: dict) -> None:
|
||||||
|
await save_to_db(user_id, data)
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Invalidate cache after update
|
||||||
|
cache = await cache_getter()
|
||||||
|
key = cache_key_func(*args, **kwargs)
|
||||||
|
await cache.delete(key)
|
||||||
|
logger.debug(f"Invalidated cache for key: {key}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for common caching patterns
|
||||||
|
|
||||||
|
async def cache_user_model(user_id: int, model: str) -> None:
|
||||||
|
"""Cache user's selected model."""
|
||||||
|
cache = await get_user_cache()
|
||||||
|
await cache.set(f"user_model:{user_id}", {"model": model})
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cached_user_model(user_id: int) -> Optional[str]:
|
||||||
|
"""Get user's cached model selection."""
|
||||||
|
cache = await get_user_cache()
|
||||||
|
result = await cache.get(f"user_model:{user_id}")
|
||||||
|
return result["model"] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_user_cache(user_id: int) -> None:
|
||||||
|
"""Invalidate all cached data for a user."""
|
||||||
|
cache = await get_user_cache()
|
||||||
|
# Clear known user-related keys
|
||||||
|
await cache.delete(f"user_model:{user_id}")
|
||||||
|
await cache.delete(f"user_history:{user_id}")
|
||||||
|
await cache.delete(f"user_stats:{user_id}")
|
||||||
@@ -71,19 +71,40 @@ APPROVED_PACKAGES = {
|
|||||||
'more-itertools', 'toolz', 'cytoolz', 'funcy'
|
'more-itertools', 'toolz', 'cytoolz', 'funcy'
|
||||||
}
|
}
|
||||||
|
|
||||||
# Blocked patterns
|
# Blocked patterns - Comprehensive security checks
|
||||||
# Note: We allow open() for writing to enable saving plots and outputs
|
# Note: We allow open() for writing to enable saving plots and outputs
|
||||||
# The sandboxed environment restricts file access to safe directories
|
# The sandboxed environment restricts file access to safe directories
|
||||||
BLOCKED_PATTERNS = [
|
BLOCKED_PATTERNS = [
|
||||||
# Dangerous system modules
|
# ==================== DANGEROUS SYSTEM MODULES ====================
|
||||||
|
# OS module (except path)
|
||||||
r'import\s+os\b(?!\s*\.path)',
|
r'import\s+os\b(?!\s*\.path)',
|
||||||
r'from\s+os\s+import\s+(?!path)',
|
r'from\s+os\s+import\s+(?!path)',
|
||||||
|
|
||||||
|
# File system modules
|
||||||
r'import\s+shutil\b',
|
r'import\s+shutil\b',
|
||||||
r'from\s+shutil\s+import',
|
r'from\s+shutil\s+import',
|
||||||
|
r'import\s+pathlib\b(?!\s*\.)', # Allow pathlib usage but monitor
|
||||||
|
|
||||||
|
# Subprocess and execution modules
|
||||||
r'import\s+subprocess\b',
|
r'import\s+subprocess\b',
|
||||||
r'from\s+subprocess\s+import',
|
r'from\s+subprocess\s+import',
|
||||||
r'import\s+sys\b(?!\s*\.(?:path|version|platform))',
|
r'import\s+multiprocessing\b',
|
||||||
r'from\s+sys\s+import',
|
r'from\s+multiprocessing\s+import',
|
||||||
|
r'import\s+threading\b',
|
||||||
|
r'from\s+threading\s+import',
|
||||||
|
r'import\s+concurrent\b',
|
||||||
|
r'from\s+concurrent\s+import',
|
||||||
|
|
||||||
|
# System access modules
|
||||||
|
r'import\s+sys\b(?!\s*\.(?:path|version|platform|stdout|stderr))',
|
||||||
|
r'from\s+sys\s+import\s+(?!path|version|platform|stdout|stderr)',
|
||||||
|
r'import\s+platform\b',
|
||||||
|
r'from\s+platform\s+import',
|
||||||
|
r'import\s+ctypes\b',
|
||||||
|
r'from\s+ctypes\s+import',
|
||||||
|
r'import\s+_[a-z]+', # Block private C modules
|
||||||
|
|
||||||
|
# ==================== NETWORK MODULES ====================
|
||||||
r'import\s+socket\b',
|
r'import\s+socket\b',
|
||||||
r'from\s+socket\s+import',
|
r'from\s+socket\s+import',
|
||||||
r'import\s+urllib\b',
|
r'import\s+urllib\b',
|
||||||
@@ -92,19 +113,98 @@ BLOCKED_PATTERNS = [
|
|||||||
r'from\s+requests\s+import',
|
r'from\s+requests\s+import',
|
||||||
r'import\s+aiohttp\b',
|
r'import\s+aiohttp\b',
|
||||||
r'from\s+aiohttp\s+import',
|
r'from\s+aiohttp\s+import',
|
||||||
# Dangerous code execution
|
r'import\s+httpx\b',
|
||||||
|
r'from\s+httpx\s+import',
|
||||||
|
r'import\s+http\.client\b',
|
||||||
|
r'from\s+http\.client\s+import',
|
||||||
|
r'import\s+ftplib\b',
|
||||||
|
r'from\s+ftplib\s+import',
|
||||||
|
r'import\s+smtplib\b',
|
||||||
|
r'from\s+smtplib\s+import',
|
||||||
|
r'import\s+telnetlib\b',
|
||||||
|
r'from\s+telnetlib\s+import',
|
||||||
|
r'import\s+ssl\b',
|
||||||
|
r'from\s+ssl\s+import',
|
||||||
|
r'import\s+paramiko\b',
|
||||||
|
r'from\s+paramiko\s+import',
|
||||||
|
|
||||||
|
# ==================== DANGEROUS CODE EXECUTION ====================
|
||||||
r'__import__\s*\(',
|
r'__import__\s*\(',
|
||||||
r'\beval\s*\(',
|
r'\beval\s*\(',
|
||||||
r'\bexec\s*\(',
|
r'\bexec\s*\(',
|
||||||
r'\bcompile\s*\(',
|
r'\bcompile\s*\(',
|
||||||
r'\bglobals\s*\(',
|
r'\bglobals\s*\(',
|
||||||
r'\blocals\s*\(',
|
r'\blocals\s*\(',
|
||||||
# File system operations (dangerous)
|
r'\bgetattr\s*\([^,]+,\s*[\'"]__', # Block getattr for dunder methods
|
||||||
|
r'\bsetattr\s*\([^,]+,\s*[\'"]__', # Block setattr for dunder methods
|
||||||
|
r'\bdelattr\s*\([^,]+,\s*[\'"]__', # Block delattr for dunder methods
|
||||||
|
r'\.\_\_\w+\_\_', # Block dunder method access
|
||||||
|
|
||||||
|
# ==================== FILE SYSTEM OPERATIONS ====================
|
||||||
r'\.unlink\s*\(',
|
r'\.unlink\s*\(',
|
||||||
r'\.rmdir\s*\(',
|
r'\.rmdir\s*\(',
|
||||||
r'\.remove\s*\(',
|
r'\.remove\s*\(',
|
||||||
r'\.chmod\s*\(',
|
r'\.chmod\s*\(',
|
||||||
r'\.chown\s*\(',
|
r'\.chown\s*\(',
|
||||||
|
r'\.rmtree\s*\(',
|
||||||
|
r'\.rename\s*\(',
|
||||||
|
r'\.replace\s*\(',
|
||||||
|
r'\.makedirs\s*\(', # Allow mkdir but block makedirs outside sandbox
|
||||||
|
r'Path\s*\(\s*[\'"]\/(?!tmp)', # Block absolute paths outside /tmp
|
||||||
|
r'open\s*\(\s*[\'"]\/(?!tmp)', # Block file access outside /tmp
|
||||||
|
|
||||||
|
# ==================== PICKLE AND SERIALIZATION ====================
|
||||||
|
r'pickle\.loads?\s*\(',
|
||||||
|
r'cPickle\.loads?\s*\(',
|
||||||
|
r'marshal\.loads?\s*\(',
|
||||||
|
r'shelve\.open\s*\(',
|
||||||
|
|
||||||
|
# ==================== PROCESS MANIPULATION ====================
|
||||||
|
r'os\.system\s*\(',
|
||||||
|
r'os\.popen\s*\(',
|
||||||
|
r'os\.spawn',
|
||||||
|
r'os\.exec',
|
||||||
|
r'os\.fork\s*\(',
|
||||||
|
r'os\.kill\s*\(',
|
||||||
|
r'os\.killpg\s*\(',
|
||||||
|
|
||||||
|
# ==================== ENVIRONMENT ACCESS ====================
|
||||||
|
r'os\.environ',
|
||||||
|
r'os\.getenv\s*\(',
|
||||||
|
r'os\.putenv\s*\(',
|
||||||
|
|
||||||
|
# ==================== DANGEROUS BUILTINS ====================
|
||||||
|
r'__builtins__',
|
||||||
|
r'__loader__',
|
||||||
|
r'__spec__',
|
||||||
|
|
||||||
|
# ==================== CODE OBJECT MANIPULATION ====================
|
||||||
|
r'\.f_code',
|
||||||
|
r'\.f_globals',
|
||||||
|
r'\.f_locals',
|
||||||
|
r'\.gi_frame',
|
||||||
|
r'\.co_code',
|
||||||
|
r'types\.CodeType',
|
||||||
|
r'types\.FunctionType',
|
||||||
|
|
||||||
|
# ==================== IMPORT SYSTEM MANIPULATION ====================
|
||||||
|
r'import\s+importlib\b',
|
||||||
|
r'from\s+importlib\s+import',
|
||||||
|
r'sys\.modules',
|
||||||
|
r'sys\.path\.(?:append|insert|extend)',
|
||||||
|
|
||||||
|
# ==================== MEMORY OPERATIONS ====================
|
||||||
|
r'gc\.',
|
||||||
|
r'sys\.getsizeof',
|
||||||
|
r'sys\.getrefcount',
|
||||||
|
r'id\s*\(', # Block id() which can leak memory addresses
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional patterns that log warnings but don't block
|
||||||
|
WARNING_PATTERNS = [
|
||||||
|
(r'while\s+True', "Infinite loop detected - ensure break condition exists"),
|
||||||
|
(r'for\s+\w+\s+in\s+range\s*\(\s*\d{6,}', "Very large loop detected"),
|
||||||
|
(r'recursion', "Recursion detected - ensure base case exists"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -772,10 +872,54 @@ class CodeExecutor:
|
|||||||
logger.warning(f"Cleanup failed: {e}")
|
logger.warning(f"Cleanup failed: {e}")
|
||||||
|
|
||||||
def validate_code_security(self, code: str) -> Tuple[bool, str]:
|
def validate_code_security(self, code: str) -> Tuple[bool, str]:
|
||||||
"""Validate code for security threats."""
|
"""
|
||||||
|
Validate code for security threats.
|
||||||
|
|
||||||
|
Performs comprehensive security checks including:
|
||||||
|
- Blocked patterns (dangerous imports, code execution, file ops)
|
||||||
|
- Warning patterns (potential issues that are logged)
|
||||||
|
- Code structure validation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The Python code to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, message)
|
||||||
|
"""
|
||||||
|
# Check for blocked patterns
|
||||||
for pattern in BLOCKED_PATTERNS:
|
for pattern in BLOCKED_PATTERNS:
|
||||||
if re.search(pattern, code, re.IGNORECASE):
|
if re.search(pattern, code, re.IGNORECASE):
|
||||||
return False, f"Blocked unsafe operation: {pattern}"
|
logger.warning(f"Blocked code pattern detected: {pattern[:50]}...")
|
||||||
|
return False, f"Security violation: Unsafe operation detected"
|
||||||
|
|
||||||
|
# Check for warning patterns (log but don't block)
|
||||||
|
for pattern, warning_msg in WARNING_PATTERNS:
|
||||||
|
if re.search(pattern, code, re.IGNORECASE):
|
||||||
|
logger.warning(f"Code warning: {warning_msg}")
|
||||||
|
|
||||||
|
# Additional structural checks
|
||||||
|
try:
|
||||||
|
# Parse the AST to check for suspicious constructs
|
||||||
|
tree = ast.parse(code)
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
# Check for suspicious attribute access
|
||||||
|
if isinstance(node, ast.Attribute):
|
||||||
|
if node.attr.startswith('_') and node.attr.startswith('__'):
|
||||||
|
logger.warning(f"Dunder attribute access detected: {node.attr}")
|
||||||
|
return False, "Security violation: Private attribute access not allowed"
|
||||||
|
|
||||||
|
# Check for suspicious function calls
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
if node.func.id in ['eval', 'exec', 'compile', '__import__']:
|
||||||
|
return False, f"Security violation: {node.func.id}() is not allowed"
|
||||||
|
|
||||||
|
except SyntaxError:
|
||||||
|
# Syntax errors will be caught during execution
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during AST validation: {e}")
|
||||||
|
|
||||||
return True, "Code passed security validation"
|
return True, "Code passed security validation"
|
||||||
|
|
||||||
def _extract_imports_from_code(self, code: str) -> List[str]:
|
def _extract_imports_from_code(self, code: str) -> List[str]:
|
||||||
|
|||||||
417
src/utils/discord_utils.py
Normal file
417
src/utils/discord_utils.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
"""
|
||||||
|
Discord response utilities for sending messages with proper handling.
|
||||||
|
|
||||||
|
This module provides utilities for sending messages to Discord with
|
||||||
|
proper length handling, error recovery, and formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import io
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# Discord message limits
|
||||||
|
MAX_MESSAGE_LENGTH = 2000
|
||||||
|
MAX_EMBED_DESCRIPTION = 4096
|
||||||
|
MAX_EMBED_FIELD_VALUE = 1024
|
||||||
|
MAX_EMBED_FIELDS = 25
|
||||||
|
MAX_FILE_SIZE = 8 * 1024 * 1024 # 8MB for non-nitro
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageChunk:
|
||||||
|
"""A chunk of a message that fits within Discord limits."""
|
||||||
|
content: str
|
||||||
|
is_code_block: bool = False
|
||||||
|
language: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def split_message(
|
||||||
|
content: str,
|
||||||
|
max_length: int = MAX_MESSAGE_LENGTH,
|
||||||
|
split_on: List[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split a long message into chunks that fit within Discord limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The message content to split
|
||||||
|
max_length: Maximum length per chunk
|
||||||
|
split_on: Preferred split points (default: newlines, spaces)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message chunks
|
||||||
|
"""
|
||||||
|
if len(content) <= max_length:
|
||||||
|
return [content]
|
||||||
|
|
||||||
|
if split_on is None:
|
||||||
|
split_on = ['\n\n', '\n', '. ', ' ']
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
remaining = content
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
if len(remaining) <= max_length:
|
||||||
|
chunks.append(remaining)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Find the best split point
|
||||||
|
split_index = max_length
|
||||||
|
|
||||||
|
for delimiter in split_on:
|
||||||
|
# Look for delimiter before max_length
|
||||||
|
last_index = remaining.rfind(delimiter, 0, max_length)
|
||||||
|
if last_index > max_length // 2: # Don't split too early
|
||||||
|
split_index = last_index + len(delimiter)
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no good split point, hard cut at max_length
|
||||||
|
if split_index >= max_length:
|
||||||
|
split_index = max_length
|
||||||
|
|
||||||
|
chunks.append(remaining[:split_index])
|
||||||
|
remaining = remaining[split_index:]
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def split_code_block(
|
||||||
|
code: str,
|
||||||
|
language: str = "",
|
||||||
|
max_length: int = MAX_MESSAGE_LENGTH
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split code into properly formatted code block chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The code content
|
||||||
|
language: The language for syntax highlighting
|
||||||
|
max_length: Maximum length per chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of formatted code block strings
|
||||||
|
"""
|
||||||
|
# Account for code block markers
|
||||||
|
marker_length = len(f"```{language}\n") + len("```")
|
||||||
|
effective_max = max_length - marker_length - 20 # Extra buffer
|
||||||
|
|
||||||
|
lines = code.split('\n')
|
||||||
|
chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line_length = len(line) + 1 # +1 for newline
|
||||||
|
|
||||||
|
if current_length + line_length > effective_max and current_chunk:
|
||||||
|
# Finish current chunk
|
||||||
|
chunk_code = '\n'.join(current_chunk)
|
||||||
|
chunks.append(f"```{language}\n{chunk_code}\n```")
|
||||||
|
current_chunk = [line]
|
||||||
|
current_length = line_length
|
||||||
|
else:
|
||||||
|
current_chunk.append(line)
|
||||||
|
current_length += line_length
|
||||||
|
|
||||||
|
# Add remaining chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunk_code = '\n'.join(current_chunk)
|
||||||
|
chunks.append(f"```{language}\n{chunk_code}\n```")
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def send_long_message(
|
||||||
|
channel: discord.abc.Messageable,
|
||||||
|
content: str,
|
||||||
|
max_length: int = MAX_MESSAGE_LENGTH,
|
||||||
|
delay: float = 0.5
|
||||||
|
) -> List[discord.Message]:
|
||||||
|
"""
|
||||||
|
Send a long message split across multiple Discord messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: The channel to send to
|
||||||
|
content: The message content
|
||||||
|
max_length: Maximum length per message
|
||||||
|
delay: Delay between messages to avoid rate limiting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sent messages
|
||||||
|
"""
|
||||||
|
chunks = split_message(content, max_length)
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
try:
|
||||||
|
msg = await channel.send(chunk)
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
# Add delay between messages (except for the last one)
|
||||||
|
if i < len(chunks) - 1:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
except discord.HTTPException as e:
|
||||||
|
logging.error(f"Failed to send message chunk {i+1}: {e}")
|
||||||
|
# Try sending as file if message still too long
|
||||||
|
if "too long" in str(e).lower():
|
||||||
|
file = discord.File(
|
||||||
|
io.StringIO(chunk),
|
||||||
|
filename=f"message_part_{i+1}.txt"
|
||||||
|
)
|
||||||
|
msg = await channel.send(file=file)
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
async def send_code_response(
|
||||||
|
channel: discord.abc.Messageable,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
title: Optional[str] = None
|
||||||
|
) -> List[discord.Message]:
|
||||||
|
"""
|
||||||
|
Send code with proper formatting, handling long code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: The channel to send to
|
||||||
|
code: The code content
|
||||||
|
language: Programming language for highlighting
|
||||||
|
title: Optional title to display before code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sent messages
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if title:
|
||||||
|
msg = await channel.send(title)
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
# If code is too long for code blocks, send as file
|
||||||
|
if len(code) > MAX_MESSAGE_LENGTH - 100:
|
||||||
|
file = discord.File(
|
||||||
|
io.StringIO(code),
|
||||||
|
filename=f"code.{language}" if language else "code.txt"
|
||||||
|
)
|
||||||
|
msg = await channel.send("📎 Code attached as file:", file=file)
|
||||||
|
messages.append(msg)
|
||||||
|
else:
|
||||||
|
chunks = split_code_block(code, language)
|
||||||
|
for chunk in chunks:
|
||||||
|
msg = await channel.send(chunk)
|
||||||
|
messages.append(msg)
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_embed(
|
||||||
|
title: str,
|
||||||
|
description: str,
|
||||||
|
error_type: str = "Error"
|
||||||
|
) -> discord.Embed:
|
||||||
|
"""
|
||||||
|
Create a standardized error embed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Error title
|
||||||
|
description: Error description
|
||||||
|
error_type: Type of error for categorization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Discord Embed object
|
||||||
|
"""
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"❌ {title}",
|
||||||
|
description=description[:MAX_EMBED_DESCRIPTION],
|
||||||
|
color=discord.Color.red()
|
||||||
|
)
|
||||||
|
embed.set_footer(text=f"Error Type: {error_type}")
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
def create_success_embed(
|
||||||
|
title: str,
|
||||||
|
description: str = ""
|
||||||
|
) -> discord.Embed:
|
||||||
|
"""
|
||||||
|
Create a standardized success embed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Success title
|
||||||
|
description: Success description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Discord Embed object
|
||||||
|
"""
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"✅ {title}",
|
||||||
|
description=description[:MAX_EMBED_DESCRIPTION] if description else None,
|
||||||
|
color=discord.Color.green()
|
||||||
|
)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
def create_info_embed(
|
||||||
|
title: str,
|
||||||
|
description: str = "",
|
||||||
|
fields: List[tuple] = None
|
||||||
|
) -> discord.Embed:
|
||||||
|
"""
|
||||||
|
Create a standardized info embed with optional fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Info title
|
||||||
|
description: Info description
|
||||||
|
fields: List of (name, value, inline) tuples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Discord Embed object
|
||||||
|
"""
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"ℹ️ {title}",
|
||||||
|
description=description[:MAX_EMBED_DESCRIPTION] if description else None,
|
||||||
|
color=discord.Color.blue()
|
||||||
|
)
|
||||||
|
|
||||||
|
if fields:
|
||||||
|
for name, value, inline in fields[:MAX_EMBED_FIELDS]:
|
||||||
|
embed.add_field(
|
||||||
|
name=name[:256],
|
||||||
|
value=str(value)[:MAX_EMBED_FIELD_VALUE],
|
||||||
|
inline=inline
|
||||||
|
)
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
def create_progress_embed(
|
||||||
|
title: str,
|
||||||
|
description: str,
|
||||||
|
progress: float = 0.0
|
||||||
|
) -> discord.Embed:
|
||||||
|
"""
|
||||||
|
Create a progress indicator embed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Progress title
|
||||||
|
description: Progress description
|
||||||
|
progress: Progress value 0.0 to 1.0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Discord Embed object
|
||||||
|
"""
|
||||||
|
# Create progress bar
|
||||||
|
bar_length = 20
|
||||||
|
filled = int(bar_length * progress)
|
||||||
|
bar = "█" * filled + "░" * (bar_length - filled)
|
||||||
|
percentage = int(progress * 100)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"⏳ {title}",
|
||||||
|
description=f"{description}\n\n`{bar}` {percentage}%",
|
||||||
|
color=discord.Color.orange()
|
||||||
|
)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
async def edit_or_send(
|
||||||
|
message: Optional[discord.Message],
|
||||||
|
channel: discord.abc.Messageable,
|
||||||
|
content: str = None,
|
||||||
|
embed: discord.Embed = None
|
||||||
|
) -> discord.Message:
|
||||||
|
"""
|
||||||
|
Edit an existing message or send a new one if editing fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to edit (or None to send new)
|
||||||
|
channel: Channel to send to if message is None
|
||||||
|
content: Message content
|
||||||
|
embed: Message embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The edited or new message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if message:
|
||||||
|
await message.edit(content=content, embed=embed)
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
return await channel.send(content=content, embed=embed)
|
||||||
|
except discord.HTTPException:
|
||||||
|
return await channel.send(content=content, embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressMessage:
|
||||||
|
"""
|
||||||
|
A message that can be updated to show progress.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with ProgressMessage(channel, "Processing") as progress:
|
||||||
|
for i in range(100):
|
||||||
|
await progress.update(i / 100, f"Step {i}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel: discord.abc.Messageable,
|
||||||
|
title: str,
|
||||||
|
description: str = "Starting..."
|
||||||
|
):
|
||||||
|
self.channel = channel
|
||||||
|
self.title = title
|
||||||
|
self.description = description
|
||||||
|
self.message: Optional[discord.Message] = None
|
||||||
|
self._last_update = 0.0
|
||||||
|
self._update_interval = 2.0 # Minimum seconds between updates
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
embed = create_progress_embed(self.title, self.description, 0.0)
|
||||||
|
self.message = await self.channel.send(embed=embed)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
# Clean up or finalize
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update(self, progress: float, description: str = None):
|
||||||
|
"""Update the progress message."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - self._last_update < self._update_interval:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._last_update = now
|
||||||
|
|
||||||
|
if description:
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
try:
|
||||||
|
embed = create_progress_embed(self.title, self.description, progress)
|
||||||
|
await self.message.edit(embed=embed)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass # Ignore edit failures
|
||||||
|
|
||||||
|
async def complete(self, message: str = "Complete!"):
|
||||||
|
"""Mark the progress as complete."""
|
||||||
|
try:
|
||||||
|
embed = create_success_embed(self.title, message)
|
||||||
|
await self.message.edit(embed=embed)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def error(self, message: str):
|
||||||
|
"""Mark the progress as failed."""
|
||||||
|
try:
|
||||||
|
embed = create_error_embed(self.title, message)
|
||||||
|
await self.message.edit(embed=embed)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
446
src/utils/monitoring.py
Normal file
446
src/utils/monitoring.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
"""
|
||||||
|
Monitoring and observability utilities.
|
||||||
|
|
||||||
|
This module provides structured logging, error tracking with Sentry,
|
||||||
|
and performance monitoring for the Discord bot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, Optional, Callable
|
||||||
|
from functools import wraps
|
||||||
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
# Try to import Sentry
|
||||||
|
try:
|
||||||
|
import sentry_sdk
|
||||||
|
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||||
|
SENTRY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
SENTRY_AVAILABLE = False
|
||||||
|
sentry_sdk = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MonitoringConfig:
|
||||||
|
"""Configuration for monitoring features."""
|
||||||
|
sentry_dsn: Optional[str] = None
|
||||||
|
environment: str = "development"
|
||||||
|
sample_rate: float = 1.0 # 100% of events
|
||||||
|
traces_sample_rate: float = 0.1 # 10% of transactions
|
||||||
|
log_level: str = "INFO"
|
||||||
|
structured_logging: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def setup_monitoring(config: Optional[MonitoringConfig] = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize monitoring with optional Sentry integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Monitoring configuration, uses env vars if not provided
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = MonitoringConfig(
|
||||||
|
sentry_dsn=os.environ.get("SENTRY_DSN"),
|
||||||
|
environment=os.environ.get("ENVIRONMENT", "development"),
|
||||||
|
sample_rate=float(os.environ.get("SENTRY_SAMPLE_RATE", "1.0")),
|
||||||
|
traces_sample_rate=float(os.environ.get("SENTRY_TRACES_RATE", "0.1")),
|
||||||
|
log_level=os.environ.get("LOG_LEVEL", "INFO"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_structured_logging(
|
||||||
|
level=config.log_level,
|
||||||
|
structured=config.structured_logging
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup Sentry if available and configured
|
||||||
|
if SENTRY_AVAILABLE and config.sentry_dsn:
|
||||||
|
sentry_sdk.init(
|
||||||
|
dsn=config.sentry_dsn,
|
||||||
|
environment=config.environment,
|
||||||
|
sample_rate=config.sample_rate,
|
||||||
|
traces_sample_rate=config.traces_sample_rate,
|
||||||
|
integrations=[AsyncioIntegration()],
|
||||||
|
before_send=before_send_filter,
|
||||||
|
)
|
||||||
|
logger.info(f"Sentry initialized for environment: {config.environment}")
|
||||||
|
else:
|
||||||
|
if config.sentry_dsn and not SENTRY_AVAILABLE:
|
||||||
|
logger.warning("Sentry DSN provided but sentry_sdk not installed")
|
||||||
|
logger.info("Running without Sentry error tracking")
|
||||||
|
|
||||||
|
|
||||||
|
def before_send_filter(event: Dict, hint: Dict) -> Optional[Dict]:
|
||||||
|
"""Filter events before sending to Sentry."""
|
||||||
|
# Don't send events for expected/handled errors
|
||||||
|
if "exc_info" in hint:
|
||||||
|
exc_type, exc_value, _ = hint["exc_info"]
|
||||||
|
|
||||||
|
# Skip common non-critical errors
|
||||||
|
if exc_type.__name__ in [
|
||||||
|
"NotFound", # Discord 404
|
||||||
|
"Forbidden", # Discord 403
|
||||||
|
"RateLimited", # Discord rate limit
|
||||||
|
]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Structured Logging
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class StructuredFormatter(logging.Formatter):
|
||||||
|
"""JSON-like structured log formatter."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
"""Format log record as structured message."""
|
||||||
|
log_entry = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add extra fields
|
||||||
|
if hasattr(record, "user_id"):
|
||||||
|
log_entry["user_id"] = record.user_id
|
||||||
|
if hasattr(record, "guild_id"):
|
||||||
|
log_entry["guild_id"] = record.guild_id
|
||||||
|
if hasattr(record, "command"):
|
||||||
|
log_entry["command"] = record.command
|
||||||
|
if hasattr(record, "duration_ms"):
|
||||||
|
log_entry["duration_ms"] = record.duration_ms
|
||||||
|
if hasattr(record, "model"):
|
||||||
|
log_entry["model"] = record.model
|
||||||
|
|
||||||
|
# Add exception info if present
|
||||||
|
if record.exc_info:
|
||||||
|
log_entry["exception"] = self.formatException(record.exc_info)
|
||||||
|
|
||||||
|
# Format as key=value pairs for easy parsing
|
||||||
|
parts = [f"{k}={v!r}" for k, v in log_entry.items()]
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_structured_logging(
|
||||||
|
level: str = "INFO",
|
||||||
|
structured: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Setup logging configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
structured: Use structured formatting
|
||||||
|
"""
|
||||||
|
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||||
|
|
||||||
|
# Create handler
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(log_level)
|
||||||
|
|
||||||
|
if structured:
|
||||||
|
handler.setFormatter(StructuredFormatter())
|
||||||
|
else:
|
||||||
|
handler.setFormatter(logging.Formatter(
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
))
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(log_level)
|
||||||
|
root_logger.handlers = [handler]
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
"""Get a logger with the given name."""
|
||||||
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Error Tracking
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def capture_exception(
|
||||||
|
exception: Exception,
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Capture and report an exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exception: The exception to capture
|
||||||
|
context: Additional context to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event ID if sent to Sentry, None otherwise
|
||||||
|
"""
|
||||||
|
logger.exception(f"Captured exception: {exception}")
|
||||||
|
|
||||||
|
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
|
||||||
|
with sentry_sdk.push_scope() as scope:
|
||||||
|
if context:
|
||||||
|
for key, value in context.items():
|
||||||
|
scope.set_extra(key, value)
|
||||||
|
return sentry_sdk.capture_exception(exception)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def capture_message(
|
||||||
|
message: str,
|
||||||
|
level: str = "info",
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Capture and report a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to capture
|
||||||
|
level: Severity level (debug, info, warning, error, fatal)
|
||||||
|
context: Additional context to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event ID if sent to Sentry, None otherwise
|
||||||
|
"""
|
||||||
|
log_method = getattr(logger, level, logger.info)
|
||||||
|
log_method(message)
|
||||||
|
|
||||||
|
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
|
||||||
|
with sentry_sdk.push_scope() as scope:
|
||||||
|
if context:
|
||||||
|
for key, value in context.items():
|
||||||
|
scope.set_extra(key, value)
|
||||||
|
return sentry_sdk.capture_message(message, level=level)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def set_user_context(
|
||||||
|
user_id: int,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
guild_id: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set user context for error tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
username: Discord username
|
||||||
|
guild_id: Discord guild ID
|
||||||
|
"""
|
||||||
|
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
|
||||||
|
sentry_sdk.set_user({
|
||||||
|
"id": str(user_id),
|
||||||
|
"username": username,
|
||||||
|
})
|
||||||
|
if guild_id:
|
||||||
|
sentry_sdk.set_tag("guild_id", str(guild_id))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Performance Monitoring
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PerformanceMetrics:
|
||||||
|
"""Container for performance metrics."""
|
||||||
|
name: str
|
||||||
|
start_time: float = field(default_factory=time.perf_counter)
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
success: bool = True
|
||||||
|
error: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_ms(self) -> float:
|
||||||
|
"""Get duration in milliseconds."""
|
||||||
|
end = self.end_time or time.perf_counter()
|
||||||
|
return (end - self.start_time) * 1000
|
||||||
|
|
||||||
|
def finish(self, success: bool = True, error: Optional[str] = None) -> None:
|
||||||
|
"""Mark the operation as finished."""
|
||||||
|
self.end_time = time.perf_counter()
|
||||||
|
self.success = success
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for logging."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"duration_ms": round(self.duration_ms, 2),
|
||||||
|
"success": self.success,
|
||||||
|
"error": self.error,
|
||||||
|
**self.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def measure_sync(name: str, **metadata):
|
||||||
|
"""
|
||||||
|
Context manager to measure synchronous operation performance.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with measure_sync("database_query", table="users"):
|
||||||
|
result = db.query(...)
|
||||||
|
"""
|
||||||
|
metrics = PerformanceMetrics(name=name, metadata=metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield metrics
|
||||||
|
metrics.finish(success=True)
|
||||||
|
except Exception as e:
|
||||||
|
metrics.finish(success=False, error=str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
logger.info(
|
||||||
|
f"Performance: {metrics.name}",
|
||||||
|
extra={"duration_ms": metrics.duration_ms, **metrics.metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def measure_async(name: str, **metadata):
|
||||||
|
"""
|
||||||
|
Async context manager to measure async operation performance.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with measure_async("api_call", endpoint="chat"):
|
||||||
|
result = await api.call(...)
|
||||||
|
"""
|
||||||
|
metrics = PerformanceMetrics(name=name, metadata=metadata)
|
||||||
|
|
||||||
|
# Start Sentry transaction if available
|
||||||
|
transaction = None
|
||||||
|
if SENTRY_AVAILABLE and sentry_sdk.Hub.current.client:
|
||||||
|
transaction = sentry_sdk.start_transaction(
|
||||||
|
op="task",
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield metrics
|
||||||
|
metrics.finish(success=True)
|
||||||
|
except Exception as e:
|
||||||
|
metrics.finish(success=False, error=str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if transaction:
|
||||||
|
transaction.set_status("ok" if metrics.success else "internal_error")
|
||||||
|
transaction.finish()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Performance: {metrics.name}",
|
||||||
|
extra={"duration_ms": metrics.duration_ms, **metrics.metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def track_performance(name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Decorator to track async function performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Operation name (defaults to function name)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@track_performance("process_message")
|
||||||
|
async def handle_message(message):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable):
|
||||||
|
op_name = name or func.__name__
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
async with measure_async(op_name):
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Health Check
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HealthStatus:
|
||||||
|
"""Health check status."""
|
||||||
|
healthy: bool
|
||||||
|
checks: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||||
|
timestamp: str = field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_check(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
healthy: bool,
|
||||||
|
message: str = "",
|
||||||
|
details: Optional[Dict] = None
|
||||||
|
) -> None:
|
||||||
|
"""Add a health check result."""
|
||||||
|
self.checks[name] = {
|
||||||
|
"healthy": healthy,
|
||||||
|
"message": message,
|
||||||
|
**(details or {})
|
||||||
|
}
|
||||||
|
if not healthy:
|
||||||
|
self.healthy = False
|
||||||
|
|
||||||
|
|
||||||
|
async def check_health(
|
||||||
|
db_handler=None,
|
||||||
|
openai_client=None
|
||||||
|
) -> HealthStatus:
|
||||||
|
"""
|
||||||
|
Perform health checks on bot dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_handler: Database handler to check
|
||||||
|
openai_client: OpenAI client to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthStatus with check results
|
||||||
|
"""
|
||||||
|
status = HealthStatus(healthy=True)
|
||||||
|
|
||||||
|
# Check database
|
||||||
|
if db_handler:
|
||||||
|
try:
|
||||||
|
# Simple ping or list operation
|
||||||
|
await asyncio.wait_for(
|
||||||
|
db_handler.client.admin.command('ping'),
|
||||||
|
timeout=5.0
|
||||||
|
)
|
||||||
|
status.add_check("database", True, "MongoDB connected")
|
||||||
|
except Exception as e:
|
||||||
|
status.add_check("database", False, f"MongoDB error: {e}")
|
||||||
|
|
||||||
|
# Check OpenAI
|
||||||
|
if openai_client:
|
||||||
|
try:
|
||||||
|
# List models as a simple check
|
||||||
|
await asyncio.wait_for(
|
||||||
|
openai_client.models.list(),
|
||||||
|
timeout=10.0
|
||||||
|
)
|
||||||
|
status.add_check("openai", True, "OpenAI API accessible")
|
||||||
|
except Exception as e:
|
||||||
|
status.add_check("openai", False, f"OpenAI error: {e}")
|
||||||
|
|
||||||
|
return status
|
||||||
280
src/utils/retry.py
Normal file
280
src/utils/retry.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
"""
|
||||||
|
Retry utilities with exponential backoff for API calls.
|
||||||
|
|
||||||
|
This module provides robust retry logic for external API calls
|
||||||
|
to handle transient failures gracefully.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import TypeVar, Callable, Optional, Any, Type, Tuple
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
# Default configuration
|
||||||
|
DEFAULT_MAX_RETRIES = 3
|
||||||
|
DEFAULT_BASE_DELAY = 1.0 # seconds
|
||||||
|
DEFAULT_MAX_DELAY = 60.0 # seconds
|
||||||
|
DEFAULT_EXPONENTIAL_BASE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class RetryError(Exception):
|
||||||
|
"""Raised when all retry attempts have been exhausted."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, last_exception: Optional[Exception] = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.last_exception = last_exception
|
||||||
|
|
||||||
|
|
||||||
|
async def async_retry_with_backoff(
|
||||||
|
func: Callable,
|
||||||
|
*args,
|
||||||
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||||
|
base_delay: float = DEFAULT_BASE_DELAY,
|
||||||
|
max_delay: float = DEFAULT_MAX_DELAY,
|
||||||
|
exponential_base: float = DEFAULT_EXPONENTIAL_BASE,
|
||||||
|
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||||
|
jitter: bool = True,
|
||||||
|
on_retry: Optional[Callable[[int, Exception], None]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Execute an async function with exponential backoff retry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The async function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
base_delay: Initial delay between retries in seconds
|
||||||
|
max_delay: Maximum delay between retries
|
||||||
|
exponential_base: Base for exponential backoff calculation
|
||||||
|
retryable_exceptions: Tuple of exception types that should trigger retry
|
||||||
|
jitter: Whether to add randomness to delay
|
||||||
|
on_retry: Optional callback called on each retry with (attempt, exception)
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The return value of the function
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RetryError: When all retries are exhausted
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except retryable_exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == max_retries:
|
||||||
|
logging.error(f"All {max_retries} retries exhausted for {func.__name__}: {e}")
|
||||||
|
raise RetryError(
|
||||||
|
f"Failed after {max_retries} retries: {str(e)}",
|
||||||
|
last_exception=e
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate delay with exponential backoff
|
||||||
|
delay = min(base_delay * (exponential_base ** attempt), max_delay)
|
||||||
|
|
||||||
|
# Add jitter to prevent thundering herd
|
||||||
|
if jitter:
|
||||||
|
delay = delay * (0.5 + random.random())
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
f"Retry {attempt + 1}/{max_retries} for {func.__name__} "
|
||||||
|
f"after {delay:.2f}s delay. Error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if on_retry:
|
||||||
|
try:
|
||||||
|
on_retry(attempt + 1, e)
|
||||||
|
except Exception as callback_error:
|
||||||
|
logging.warning(f"on_retry callback failed: {callback_error}")
|
||||||
|
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# Should not reach here, but just in case
|
||||||
|
raise RetryError("Unexpected retry loop exit", last_exception=last_exception)
|
||||||
|
|
||||||
|
|
||||||
|
def retry_decorator(
|
||||||
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||||
|
base_delay: float = DEFAULT_BASE_DELAY,
|
||||||
|
max_delay: float = DEFAULT_MAX_DELAY,
|
||||||
|
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||||
|
jitter: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator for adding retry logic to async functions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@retry_decorator(max_retries=3, base_delay=1.0)
|
||||||
|
async def my_api_call():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
return await async_retry_with_backoff(
|
||||||
|
func,
|
||||||
|
*args,
|
||||||
|
max_retries=max_retries,
|
||||||
|
base_delay=base_delay,
|
||||||
|
max_delay=max_delay,
|
||||||
|
retryable_exceptions=retryable_exceptions,
|
||||||
|
jitter=jitter,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Common exception sets for different APIs
|
||||||
|
OPENAI_RETRYABLE_EXCEPTIONS = (
|
||||||
|
# Add specific OpenAI exceptions as needed
|
||||||
|
TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
DISCORD_RETRYABLE_EXCEPTIONS = (
|
||||||
|
# Add specific Discord exceptions as needed
|
||||||
|
TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
HTTP_RETRYABLE_EXCEPTIONS = (
|
||||||
|
TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
ConnectionResetError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
Simple rate limiter for API calls.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
limiter = RateLimiter(calls_per_second=1)
|
||||||
|
async with limiter:
|
||||||
|
await make_api_call()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, calls_per_second: float = 1.0):
|
||||||
|
self.min_interval = 1.0 / calls_per_second
|
||||||
|
self.last_call = 0.0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
async with self._lock:
|
||||||
|
import time
|
||||||
|
now = time.monotonic()
|
||||||
|
time_since_last = now - self.last_call
|
||||||
|
|
||||||
|
if time_since_last < self.min_interval:
|
||||||
|
await asyncio.sleep(self.min_interval - time_since_last)
|
||||||
|
|
||||||
|
self.last_call = time.monotonic()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreaker:
|
||||||
|
"""
|
||||||
|
Circuit breaker pattern for preventing cascade failures.
|
||||||
|
|
||||||
|
States:
|
||||||
|
- CLOSED: Normal operation, requests pass through
|
||||||
|
- OPEN: Too many failures, requests are rejected immediately
|
||||||
|
- HALF_OPEN: Testing if service recovered
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLOSED = "closed"
|
||||||
|
OPEN = "open"
|
||||||
|
HALF_OPEN = "half_open"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
failure_threshold: int = 5,
|
||||||
|
recovery_timeout: float = 60.0,
|
||||||
|
half_open_requests: int = 3
|
||||||
|
):
|
||||||
|
self.failure_threshold = failure_threshold
|
||||||
|
self.recovery_timeout = recovery_timeout
|
||||||
|
self.half_open_requests = half_open_requests
|
||||||
|
|
||||||
|
self.state = self.CLOSED
|
||||||
|
self.failure_count = 0
|
||||||
|
self.last_failure_time = 0.0
|
||||||
|
self.half_open_successes = 0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a function through the circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The async function to execute
|
||||||
|
*args: Positional arguments
|
||||||
|
**kwargs: Keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The function result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If circuit is open or function fails
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
await self._check_state()
|
||||||
|
|
||||||
|
if self.state == self.OPEN:
|
||||||
|
raise Exception("Circuit breaker is OPEN - service unavailable")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
await self._on_success()
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
await self._on_failure()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _check_state(self):
|
||||||
|
"""Check and potentially update circuit state."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
if self.state == self.OPEN:
|
||||||
|
if time.monotonic() - self.last_failure_time >= self.recovery_timeout:
|
||||||
|
logging.info("Circuit breaker transitioning to HALF_OPEN")
|
||||||
|
self.state = self.HALF_OPEN
|
||||||
|
self.half_open_successes = 0
|
||||||
|
|
||||||
|
async def _on_success(self):
|
||||||
|
"""Handle successful call."""
|
||||||
|
async with self._lock:
|
||||||
|
if self.state == self.HALF_OPEN:
|
||||||
|
self.half_open_successes += 1
|
||||||
|
if self.half_open_successes >= self.half_open_requests:
|
||||||
|
logging.info("Circuit breaker transitioning to CLOSED")
|
||||||
|
self.state = self.CLOSED
|
||||||
|
self.failure_count = 0
|
||||||
|
elif self.state == self.CLOSED:
|
||||||
|
self.failure_count = 0
|
||||||
|
|
||||||
|
async def _on_failure(self):
|
||||||
|
"""Handle failed call."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self.failure_count += 1
|
||||||
|
self.last_failure_time = time.monotonic()
|
||||||
|
|
||||||
|
if self.state == self.HALF_OPEN:
|
||||||
|
logging.warning("Circuit breaker transitioning to OPEN (half-open failure)")
|
||||||
|
self.state = self.OPEN
|
||||||
|
elif self.failure_count >= self.failure_threshold:
|
||||||
|
logging.warning(f"Circuit breaker transitioning to OPEN ({self.failure_count} failures)")
|
||||||
|
self.state = self.OPEN
|
||||||
@@ -304,8 +304,8 @@ class TokenCounter:
|
|||||||
Returns:
|
Returns:
|
||||||
Estimated cost in USD
|
Estimated cost in USD
|
||||||
"""
|
"""
|
||||||
# Import here to avoid circular dependency
|
# Import from centralized pricing module
|
||||||
from src.commands.commands import MODEL_PRICING
|
from src.config.pricing import MODEL_PRICING
|
||||||
|
|
||||||
if model not in MODEL_PRICING:
|
if model not in MODEL_PRICING:
|
||||||
model = "openai/gpt-4o" # Default fallback
|
model = "openai/gpt-4o" # Default fallback
|
||||||
|
|||||||
287
src/utils/validators.py
Normal file
287
src/utils/validators.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
Input validation utilities for the Discord bot.
|
||||||
|
|
||||||
|
This module provides centralized validation for user inputs,
|
||||||
|
enhancing security and reducing code duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# Maximum allowed lengths for various inputs
|
||||||
|
MAX_MESSAGE_LENGTH = 4000 # Discord's limit is 2000, but we process longer
|
||||||
|
MAX_PROMPT_LENGTH = 32000 # Reasonable limit for AI prompts
|
||||||
|
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||||
|
MAX_FILENAME_LENGTH = 255
|
||||||
|
MAX_URL_LENGTH = 2048
|
||||||
|
MAX_CODE_LENGTH = 100000 # 100KB of code
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of a validation check."""
|
||||||
|
is_valid: bool
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
sanitized_value: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_message_content(content: str) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate and sanitize message content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The message content to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status and sanitized content
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ValidationResult(is_valid=True, sanitized_value="")
|
||||||
|
|
||||||
|
if len(content) > MAX_MESSAGE_LENGTH:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"Message too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove null bytes and other control characters (except newlines/tabs)
|
||||||
|
sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', content)
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True, sanitized_value=sanitized)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_prompt(prompt: str) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate AI prompt content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status
|
||||||
|
"""
|
||||||
|
if not prompt or not prompt.strip():
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Prompt cannot be empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(prompt) > MAX_PROMPT_LENGTH:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"Prompt too long. Maximum {MAX_PROMPT_LENGTH} characters allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove null bytes
|
||||||
|
sanitized = prompt.replace('\x00', '')
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True, sanitized_value=sanitized)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate and sanitize a URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="URL cannot be empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(url) > MAX_URL_LENGTH:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"URL too long. Maximum {MAX_URL_LENGTH} characters allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Basic URL pattern check
|
||||||
|
url_pattern = re.compile(
|
||||||
|
r'^https?://' # http:// or https://
|
||||||
|
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain
|
||||||
|
r'localhost|' # localhost
|
||||||
|
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # or IP
|
||||||
|
r'(?::\d+)?' # optional port
|
||||||
|
r'(?:/?|[/?]\S+)$', re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
if not url_pattern.match(url):
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Invalid URL format."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for potentially dangerous URL schemes
|
||||||
|
dangerous_schemes = ['javascript:', 'data:', 'file:', 'vbscript:']
|
||||||
|
url_lower = url.lower()
|
||||||
|
for scheme in dangerous_schemes:
|
||||||
|
if scheme in url_lower:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="URL contains potentially dangerous content."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True, sanitized_value=url)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_filename(filename: str) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate and sanitize a filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: The filename to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status and sanitized filename
|
||||||
|
"""
|
||||||
|
if not filename:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Filename cannot be empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(filename) > MAX_FILENAME_LENGTH:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"Filename too long. Maximum {MAX_FILENAME_LENGTH} characters allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove path traversal attempts
|
||||||
|
sanitized = filename.replace('..', '').replace('/', '').replace('\\', '')
|
||||||
|
|
||||||
|
# Remove dangerous characters
|
||||||
|
sanitized = re.sub(r'[<>:"|?*\x00-\x1f]', '', sanitized)
|
||||||
|
|
||||||
|
# Ensure it's not empty after sanitization
|
||||||
|
if not sanitized:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Filename contains only invalid characters."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True, sanitized_value=sanitized)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_file_size(size: int) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate file size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: The file size in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status
|
||||||
|
"""
|
||||||
|
if size <= 0:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="File size must be greater than 0."
|
||||||
|
)
|
||||||
|
|
||||||
|
if size > MAX_FILE_SIZE:
|
||||||
|
max_mb = MAX_FILE_SIZE / (1024 * 1024)
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"File too large. Maximum {max_mb:.0f}MB allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_code(code: str) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate code for execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The code to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status
|
||||||
|
"""
|
||||||
|
if not code or not code.strip():
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Code cannot be empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(code) > MAX_CODE_LENGTH:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"Code too long. Maximum {MAX_CODE_LENGTH} characters allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidationResult(is_valid=True, sanitized_value=code)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_user_id(user_id) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate a Discord user ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
uid = int(user_id)
|
||||||
|
if uid <= 0:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Invalid user ID."
|
||||||
|
)
|
||||||
|
# Discord IDs are 17-19 digits
|
||||||
|
if len(str(uid)) < 17 or len(str(uid)) > 19:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Invalid user ID format."
|
||||||
|
)
|
||||||
|
return ValidationResult(is_valid=True)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="User ID must be a valid integer."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_for_logging(text: str, max_length: int = 200) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize text for safe logging (remove sensitive data, truncate).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to sanitize
|
||||||
|
max_length: Maximum length for logged text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized text safe for logging
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove potential secrets/tokens (common patterns)
|
||||||
|
patterns = [
|
||||||
|
(r'(sk-[a-zA-Z0-9]{20,})', '[OPENAI_KEY]'),
|
||||||
|
(r'(xoxb-[a-zA-Z0-9-]+)', '[SLACK_TOKEN]'),
|
||||||
|
(r'([A-Za-z0-9_-]{24}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27})', '[DISCORD_TOKEN]'),
|
||||||
|
(r'(mongodb\+srv://[^@]+@)', 'mongodb+srv://[REDACTED]@'),
|
||||||
|
(r'(Bearer\s+[A-Za-z0-9_-]+)', 'Bearer [TOKEN]'),
|
||||||
|
(r'(password["\']?\s*[:=]\s*["\']?)[^"\'\s]+', r'\1[REDACTED]'),
|
||||||
|
]
|
||||||
|
|
||||||
|
sanitized = text
|
||||||
|
for pattern, replacement in patterns:
|
||||||
|
sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
if len(sanitized) > max_length:
|
||||||
|
sanitized = sanitized[:max_length] + '...[truncated]'
|
||||||
|
|
||||||
|
return sanitized
|
||||||
727
tests/test_comprehensive.py
Normal file
727
tests/test_comprehensive.py
Normal file
@@ -0,0 +1,727 @@
|
|||||||
|
"""
|
||||||
|
Comprehensive test suite for the ChatGPT Discord Bot.
|
||||||
|
|
||||||
|
This module contains unit tests and integration tests for all major components.
|
||||||
|
Uses pytest with pytest-asyncio for async test support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test Fixtures
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_handler():
|
||||||
|
"""Create a mock database handler."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.get_history = AsyncMock(return_value=[])
|
||||||
|
mock.save_history = AsyncMock()
|
||||||
|
mock.get_user_model = AsyncMock(return_value="openai/gpt-4o")
|
||||||
|
mock.save_user_model = AsyncMock()
|
||||||
|
mock.is_admin = AsyncMock(return_value=False)
|
||||||
|
mock.is_user_whitelisted = AsyncMock(return_value=True)
|
||||||
|
mock.is_user_blacklisted = AsyncMock(return_value=False)
|
||||||
|
mock.get_user_tool_display = AsyncMock(return_value=False)
|
||||||
|
mock.get_user_files = AsyncMock(return_value=[])
|
||||||
|
mock.save_token_usage = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client():
|
||||||
|
"""Create a mock OpenAI client."""
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
# Mock response structure
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 100
|
||||||
|
mock_response.usage.completion_tokens = 50
|
||||||
|
|
||||||
|
mock.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_message():
|
||||||
|
"""Create a mock Discord message."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.author.id = 123456789
|
||||||
|
mock.author.name = "TestUser"
|
||||||
|
mock.content = "Hello, bot!"
|
||||||
|
mock.channel.send = AsyncMock()
|
||||||
|
mock.channel.typing = MagicMock(return_value=AsyncMock().__aenter__())
|
||||||
|
mock.attachments = []
|
||||||
|
mock.reference = None
|
||||||
|
mock.guild = MagicMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Pricing Module Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestPricingModule:
|
||||||
|
"""Tests for the pricing configuration module."""
|
||||||
|
|
||||||
|
def test_model_pricing_exists(self):
|
||||||
|
"""Test that all expected models have pricing defined."""
|
||||||
|
from src.config.pricing import MODEL_PRICING
|
||||||
|
|
||||||
|
expected_models = [
|
||||||
|
"openai/gpt-4o",
|
||||||
|
"openai/gpt-4o-mini",
|
||||||
|
"openai/gpt-4.1",
|
||||||
|
"openai/gpt-5",
|
||||||
|
"openai/o1",
|
||||||
|
]
|
||||||
|
|
||||||
|
for model in expected_models:
|
||||||
|
assert model in MODEL_PRICING, f"Missing pricing for {model}"
|
||||||
|
|
||||||
|
def test_calculate_cost(self):
|
||||||
|
"""Test cost calculation for known models."""
|
||||||
|
from src.config.pricing import calculate_cost
|
||||||
|
|
||||||
|
# GPT-4o: $5.00 input, $20.00 output per 1M tokens
|
||||||
|
cost = calculate_cost("openai/gpt-4o", 1_000_000, 1_000_000)
|
||||||
|
assert cost == 25.00 # $5 + $20
|
||||||
|
|
||||||
|
# Test smaller amounts
|
||||||
|
cost = calculate_cost("openai/gpt-4o", 1000, 1000)
|
||||||
|
assert cost == pytest.approx(0.025, rel=1e-6) # $0.005 + $0.020
|
||||||
|
|
||||||
|
def test_calculate_cost_unknown_model(self):
|
||||||
|
"""Test that unknown models return 0 cost."""
|
||||||
|
from src.config.pricing import calculate_cost
|
||||||
|
|
||||||
|
cost = calculate_cost("unknown/model", 1000, 1000)
|
||||||
|
assert cost == 0.0
|
||||||
|
|
||||||
|
def test_format_cost(self):
|
||||||
|
"""Test cost formatting for display."""
|
||||||
|
from src.config.pricing import format_cost
|
||||||
|
|
||||||
|
assert format_cost(0.000001) == "$0.000001"
|
||||||
|
assert format_cost(0.005) == "$0.005000" # 6 decimal places for small amounts
|
||||||
|
assert format_cost(1.50) == "$1.50"
|
||||||
|
assert format_cost(100.00) == "$100.00"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Validator Module Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestValidators:
|
||||||
|
"""Tests for input validation utilities."""
|
||||||
|
|
||||||
|
def test_validate_message_content(self):
|
||||||
|
"""Test message content validation."""
|
||||||
|
from src.utils.validators import validate_message_content
|
||||||
|
|
||||||
|
# Valid content
|
||||||
|
result = validate_message_content("Hello, world!")
|
||||||
|
assert result.is_valid
|
||||||
|
assert result.sanitized_value == "Hello, world!"
|
||||||
|
|
||||||
|
# Empty content is valid
|
||||||
|
result = validate_message_content("")
|
||||||
|
assert result.is_valid
|
||||||
|
|
||||||
|
# Content with null bytes should be sanitized
|
||||||
|
result = validate_message_content("Hello\x00World")
|
||||||
|
assert result.is_valid
|
||||||
|
assert "\x00" not in result.sanitized_value
|
||||||
|
|
||||||
|
def test_validate_message_too_long(self):
|
||||||
|
"""Test that overly long messages are rejected."""
|
||||||
|
from src.utils.validators import validate_message_content, MAX_MESSAGE_LENGTH
|
||||||
|
|
||||||
|
long_message = "x" * (MAX_MESSAGE_LENGTH + 1)
|
||||||
|
result = validate_message_content(long_message)
|
||||||
|
assert not result.is_valid
|
||||||
|
assert "too long" in result.error_message.lower()
|
||||||
|
|
||||||
|
def test_validate_url(self):
|
||||||
|
"""Test URL validation."""
|
||||||
|
from src.utils.validators import validate_url
|
||||||
|
|
||||||
|
# Valid URLs
|
||||||
|
assert validate_url("https://example.com").is_valid
|
||||||
|
assert validate_url("http://localhost:8080/path").is_valid
|
||||||
|
assert validate_url("https://api.example.com/v1/data?q=test").is_valid
|
||||||
|
|
||||||
|
# Invalid URLs
|
||||||
|
assert not validate_url("").is_valid
|
||||||
|
assert not validate_url("not-a-url").is_valid
|
||||||
|
assert not validate_url("javascript:alert(1)").is_valid
|
||||||
|
assert not validate_url("file:///etc/passwd").is_valid
|
||||||
|
|
||||||
|
def test_validate_filename(self):
|
||||||
|
"""Test filename validation and sanitization."""
|
||||||
|
from src.utils.validators import validate_filename
|
||||||
|
|
||||||
|
# Valid filename
|
||||||
|
result = validate_filename("test_file.txt")
|
||||||
|
assert result.is_valid
|
||||||
|
assert result.sanitized_value == "test_file.txt"
|
||||||
|
|
||||||
|
# Path traversal attempt
|
||||||
|
result = validate_filename("../../../etc/passwd")
|
||||||
|
assert result.is_valid # Sanitized, not rejected
|
||||||
|
assert ".." not in result.sanitized_value
|
||||||
|
assert "/" not in result.sanitized_value
|
||||||
|
|
||||||
|
# Empty filename
|
||||||
|
result = validate_filename("")
|
||||||
|
assert not result.is_valid
|
||||||
|
|
||||||
|
def test_sanitize_for_logging(self):
|
||||||
|
"""Test that secrets are properly redacted for logging."""
|
||||||
|
from src.utils.validators import sanitize_for_logging
|
||||||
|
|
||||||
|
# Test OpenAI key redaction
|
||||||
|
text = "API key is sk-abcdefghijklmnopqrstuvwxyz123456"
|
||||||
|
sanitized = sanitize_for_logging(text)
|
||||||
|
assert "sk-" not in sanitized
|
||||||
|
assert "[OPENAI_KEY]" in sanitized
|
||||||
|
|
||||||
|
# Test MongoDB URI redaction
|
||||||
|
text = "mongodb+srv://user:password@cluster.mongodb.net/db"
|
||||||
|
sanitized = sanitize_for_logging(text)
|
||||||
|
assert "password" not in sanitized
|
||||||
|
assert "[REDACTED]" in sanitized
|
||||||
|
|
||||||
|
# Test truncation
|
||||||
|
long_text = "x" * 500
|
||||||
|
sanitized = sanitize_for_logging(long_text, max_length=100)
|
||||||
|
assert len(sanitized) < 150 # Account for truncation marker
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Retry Module Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestRetryModule:
|
||||||
|
"""Tests for retry utilities."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_success_first_try(self):
|
||||||
|
"""Test that successful functions don't retry."""
|
||||||
|
from src.utils.retry import async_retry_with_backoff
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def success_func():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = await async_retry_with_backoff(success_func, max_retries=3)
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_eventual_success(self):
|
||||||
|
"""Test that functions eventually succeed after retries."""
|
||||||
|
from src.utils.retry import async_retry_with_backoff
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def eventual_success():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise ConnectionError("Temporary failure")
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = await async_retry_with_backoff(
|
||||||
|
eventual_success,
|
||||||
|
max_retries=5,
|
||||||
|
base_delay=0.01, # Fast for testing
|
||||||
|
retryable_exceptions=(ConnectionError,)
|
||||||
|
)
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_exhausted(self):
|
||||||
|
"""Test that RetryError is raised when retries are exhausted."""
|
||||||
|
from src.utils.retry import async_retry_with_backoff, RetryError
|
||||||
|
|
||||||
|
async def always_fail():
|
||||||
|
raise ConnectionError("Always fails")
|
||||||
|
|
||||||
|
with pytest.raises(RetryError):
|
||||||
|
await async_retry_with_backoff(
|
||||||
|
always_fail,
|
||||||
|
max_retries=2,
|
||||||
|
base_delay=0.01,
|
||||||
|
retryable_exceptions=(ConnectionError,)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Discord Utils Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestDiscordUtils:
|
||||||
|
"""Tests for Discord utility functions."""
|
||||||
|
|
||||||
|
def test_split_message_short(self):
|
||||||
|
"""Test that short messages aren't split."""
|
||||||
|
from src.utils.discord_utils import split_message
|
||||||
|
|
||||||
|
short = "This is a short message."
|
||||||
|
chunks = split_message(short)
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0] == short
|
||||||
|
|
||||||
|
def test_split_message_long(self):
|
||||||
|
"""Test that long messages are properly split."""
|
||||||
|
from src.utils.discord_utils import split_message
|
||||||
|
|
||||||
|
# Create a message longer than 2000 characters
|
||||||
|
long = "Hello world. " * 200
|
||||||
|
chunks = split_message(long, max_length=2000)
|
||||||
|
|
||||||
|
assert len(chunks) > 1
|
||||||
|
for chunk in chunks:
|
||||||
|
assert len(chunk) <= 2000
|
||||||
|
|
||||||
|
def test_split_code_block(self):
|
||||||
|
"""Test code block splitting."""
|
||||||
|
from src.utils.discord_utils import split_code_block
|
||||||
|
|
||||||
|
code = "\n".join([f"line {i}" for i in range(100)])
|
||||||
|
chunks = split_code_block(code, "python", max_length=500)
|
||||||
|
|
||||||
|
assert len(chunks) > 1
|
||||||
|
for chunk in chunks:
|
||||||
|
assert chunk.startswith("```python\n")
|
||||||
|
assert chunk.endswith("\n```")
|
||||||
|
assert len(chunk) <= 500
|
||||||
|
|
||||||
|
def test_create_error_embed(self):
|
||||||
|
"""Test error embed creation."""
|
||||||
|
from src.utils.discord_utils import create_error_embed
|
||||||
|
import discord
|
||||||
|
|
||||||
|
embed = create_error_embed("Test Error", "Something went wrong", "ValidationError")
|
||||||
|
|
||||||
|
assert isinstance(embed, discord.Embed)
|
||||||
|
assert "Test Error" in embed.title
|
||||||
|
assert embed.color == discord.Color.red()
|
||||||
|
|
||||||
|
def test_create_success_embed(self):
|
||||||
|
"""Test success embed creation."""
|
||||||
|
from src.utils.discord_utils import create_success_embed
|
||||||
|
import discord
|
||||||
|
|
||||||
|
embed = create_success_embed("Success!", "Operation completed")
|
||||||
|
|
||||||
|
assert isinstance(embed, discord.Embed)
|
||||||
|
assert "Success!" in embed.title
|
||||||
|
assert embed.color == discord.Color.green()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Code Interpreter Security Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestCodeInterpreterSecurity:
|
||||||
|
"""Tests for code interpreter security features."""
|
||||||
|
|
||||||
|
def test_blocked_imports(self):
|
||||||
|
"""Test that dangerous imports are blocked."""
|
||||||
|
from src.utils.code_interpreter import BLOCKED_PATTERNS
|
||||||
|
import re
|
||||||
|
|
||||||
|
dangerous_code = [
|
||||||
|
"import os",
|
||||||
|
"import subprocess",
|
||||||
|
"from os import system",
|
||||||
|
"import socket",
|
||||||
|
"import requests",
|
||||||
|
"__import__('os')",
|
||||||
|
"eval('print(1)')",
|
||||||
|
"exec('import os')",
|
||||||
|
]
|
||||||
|
|
||||||
|
for code in dangerous_code:
|
||||||
|
blocked = any(
|
||||||
|
re.search(pattern, code, re.IGNORECASE)
|
||||||
|
for pattern in BLOCKED_PATTERNS
|
||||||
|
)
|
||||||
|
assert blocked, f"Should block: {code}"
|
||||||
|
|
||||||
|
def test_allowed_imports(self):
|
||||||
|
"""Test that safe imports are allowed."""
|
||||||
|
from src.utils.code_interpreter import BLOCKED_PATTERNS
|
||||||
|
import re
|
||||||
|
|
||||||
|
safe_code = [
|
||||||
|
"import pandas as pd",
|
||||||
|
"import numpy as np",
|
||||||
|
"import matplotlib.pyplot as plt",
|
||||||
|
"from sklearn.model_selection import train_test_split",
|
||||||
|
"import os.path", # os.path is allowed
|
||||||
|
]
|
||||||
|
|
||||||
|
for code in safe_code:
|
||||||
|
blocked = any(
|
||||||
|
re.search(pattern, code, re.IGNORECASE)
|
||||||
|
for pattern in BLOCKED_PATTERNS
|
||||||
|
)
|
||||||
|
assert not blocked, f"Should allow: {code}"
|
||||||
|
|
||||||
|
def test_file_type_detection(self):
|
||||||
|
"""Test file type detection for various extensions."""
|
||||||
|
from src.utils.code_interpreter import FileManager
|
||||||
|
|
||||||
|
fm = FileManager()
|
||||||
|
|
||||||
|
assert fm._detect_file_type("data.csv") == "csv"
|
||||||
|
assert fm._detect_file_type("data.xlsx") == "excel"
|
||||||
|
assert fm._detect_file_type("config.json") == "json"
|
||||||
|
assert fm._detect_file_type("image.png") == "image"
|
||||||
|
assert fm._detect_file_type("script.py") == "python"
|
||||||
|
assert fm._detect_file_type("unknown.xyz") == "binary"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# OpenAI Utils Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestOpenAIUtils:
|
||||||
|
"""Tests for OpenAI utility functions."""
|
||||||
|
|
||||||
|
def test_count_tokens(self):
|
||||||
|
"""Test token counting function."""
|
||||||
|
from src.utils.openai_utils import count_tokens
|
||||||
|
|
||||||
|
text = "Hello, world!"
|
||||||
|
tokens = count_tokens(text)
|
||||||
|
assert tokens > 0
|
||||||
|
assert isinstance(tokens, int)
|
||||||
|
|
||||||
|
def test_trim_content_to_token_limit(self):
|
||||||
|
"""Test content trimming."""
|
||||||
|
from src.utils.openai_utils import trim_content_to_token_limit
|
||||||
|
|
||||||
|
# Short content should not be trimmed
|
||||||
|
short = "Hello, world!"
|
||||||
|
trimmed = trim_content_to_token_limit(short, max_tokens=100)
|
||||||
|
assert trimmed == short
|
||||||
|
|
||||||
|
# Long content should be trimmed
|
||||||
|
long = "Hello " * 10000
|
||||||
|
trimmed = trim_content_to_token_limit(long, max_tokens=100)
|
||||||
|
assert len(trimmed) < len(long)
|
||||||
|
|
||||||
|
def test_prepare_messages_for_api(self):
|
||||||
|
"""Test message preparation for API."""
|
||||||
|
from src.utils.openai_utils import prepare_messages_for_api
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
prepared = prepare_messages_for_api(messages)
|
||||||
|
|
||||||
|
assert len(prepared) == 3
|
||||||
|
assert all(m.get("role") in ["user", "assistant", "system"] for m in prepared)
|
||||||
|
|
||||||
|
def test_prepare_messages_filters_none_content(self):
|
||||||
|
"""Test that messages with None content are filtered."""
|
||||||
|
from src.utils.openai_utils import prepare_messages_for_api
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": None},
|
||||||
|
{"role": "user", "content": "World"},
|
||||||
|
]
|
||||||
|
|
||||||
|
prepared = prepare_messages_for_api(messages)
|
||||||
|
|
||||||
|
assert len(prepared) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Database Handler Tests (with mocking)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestDatabaseHandlerMocked:
|
||||||
|
"""Tests for database handler using mocks."""
|
||||||
|
|
||||||
|
def test_filter_expired_images_no_images(self):
|
||||||
|
"""Test that messages without images pass through unchanged."""
|
||||||
|
from src.database.db_handler import DatabaseHandler
|
||||||
|
|
||||||
|
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
|
||||||
|
handler = DatabaseHandler("mongodb://localhost")
|
||||||
|
|
||||||
|
history = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = handler._filter_expired_images(history)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
assert filtered[0]["content"] == "Hello"
|
||||||
|
|
||||||
|
def test_filter_expired_images_recent_image(self):
|
||||||
|
"""Test that recent images are kept."""
|
||||||
|
from src.database.db_handler import DatabaseHandler
|
||||||
|
|
||||||
|
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
|
||||||
|
handler = DatabaseHandler("mongodb://localhost")
|
||||||
|
|
||||||
|
recent_timestamp = datetime.now().isoformat()
|
||||||
|
history = [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "Check this image"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": recent_timestamp}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = handler._filter_expired_images(history)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert len(filtered[0]["content"]) == 2 # Both items kept
|
||||||
|
|
||||||
|
def test_filter_expired_images_old_image(self):
|
||||||
|
"""Test that old images are filtered out."""
|
||||||
|
from src.database.db_handler import DatabaseHandler
|
||||||
|
|
||||||
|
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
|
||||||
|
handler = DatabaseHandler("mongodb://localhost")
|
||||||
|
|
||||||
|
old_timestamp = (datetime.now() - timedelta(hours=24)).isoformat()
|
||||||
|
history = [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "Check this image"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": old_timestamp}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = handler._filter_expired_images(history)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert len(filtered[0]["content"]) == 1 # Only text kept
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# ============================================================
|
||||||
|
# Cache Module Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestLRUCache:
|
||||||
|
"""Tests for the LRU cache implementation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_and_get(self):
|
||||||
|
"""Test basic cache set and get operations."""
|
||||||
|
from src.utils.cache import LRUCache
|
||||||
|
|
||||||
|
cache = LRUCache(max_size=100, default_ttl=60.0)
|
||||||
|
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result == "value1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expiration(self):
|
||||||
|
"""Test that cache entries expire after TTL."""
|
||||||
|
from src.utils.cache import LRUCache
|
||||||
|
|
||||||
|
cache = LRUCache(max_size=100, default_ttl=0.1) # 100ms TTL
|
||||||
|
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
|
||||||
|
# Should exist immediately
|
||||||
|
assert await cache.get("key1") == "value1"
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
# Should be expired now
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_lru_eviction(self):
|
||||||
|
"""Test that LRU eviction works correctly."""
|
||||||
|
from src.utils.cache import LRUCache
|
||||||
|
|
||||||
|
cache = LRUCache(max_size=3, default_ttl=60.0)
|
||||||
|
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
await cache.set("key3", "value3")
|
||||||
|
|
||||||
|
# Access key1 to make it recently used
|
||||||
|
await cache.get("key1")
|
||||||
|
|
||||||
|
# Add new key, should evict key2 (least recently used)
|
||||||
|
await cache.set("key4", "value4")
|
||||||
|
|
||||||
|
assert await cache.get("key1") == "value1" # Should exist
|
||||||
|
assert await cache.get("key2") is None # Should be evicted
|
||||||
|
assert await cache.get("key3") == "value3" # Should exist
|
||||||
|
assert await cache.get("key4") == "value4" # Should exist
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_stats(self):
|
||||||
|
"""Test cache statistics tracking."""
|
||||||
|
from src.utils.cache import LRUCache
|
||||||
|
|
||||||
|
cache = LRUCache(max_size=100, default_ttl=60.0)
|
||||||
|
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.get("key1") # Hit
|
||||||
|
await cache.get("key2") # Miss
|
||||||
|
await cache.get("key1") # Hit
|
||||||
|
|
||||||
|
stats = cache.stats()
|
||||||
|
assert stats["hits"] == 2
|
||||||
|
assert stats["misses"] == 1
|
||||||
|
assert stats["size"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_clear(self):
|
||||||
|
"""Test cache clearing."""
|
||||||
|
from src.utils.cache import LRUCache
|
||||||
|
|
||||||
|
cache = LRUCache(max_size=100, default_ttl=60.0)
|
||||||
|
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
|
||||||
|
cleared = await cache.clear()
|
||||||
|
assert cleared == 2
|
||||||
|
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
assert await cache.get("key2") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Monitoring Module Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestMonitoring:
|
||||||
|
"""Tests for the monitoring utilities."""
|
||||||
|
|
||||||
|
def test_performance_metrics(self):
|
||||||
|
"""Test performance metrics tracking."""
|
||||||
|
from src.utils.monitoring import PerformanceMetrics
|
||||||
|
import time
|
||||||
|
|
||||||
|
metrics = PerformanceMetrics(name="test_operation")
|
||||||
|
time.sleep(0.01) # Small delay
|
||||||
|
metrics.finish(success=True)
|
||||||
|
|
||||||
|
assert metrics.success
|
||||||
|
assert metrics.duration_ms > 0
|
||||||
|
assert metrics.duration_ms < 1000 # Should be fast
|
||||||
|
|
||||||
|
def test_measure_sync_context_manager(self):
|
||||||
|
"""Test synchronous measurement context manager."""
|
||||||
|
from src.utils.monitoring import measure_sync
|
||||||
|
import time
|
||||||
|
|
||||||
|
with measure_sync("test_op", custom_field="value") as metrics:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
assert metrics.duration_ms > 0
|
||||||
|
assert metrics.metadata["custom_field"] == "value"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_measure_async_context_manager(self):
|
||||||
|
"""Test async measurement context manager."""
|
||||||
|
from src.utils.monitoring import measure_async
|
||||||
|
|
||||||
|
async with measure_async("async_op") as metrics:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
assert metrics.duration_ms > 0
|
||||||
|
assert metrics.success
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_performance_decorator(self):
|
||||||
|
"""Test performance tracking decorator."""
|
||||||
|
from src.utils.monitoring import track_performance
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@track_performance("tracked_function")
|
||||||
|
async def tracked_func():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
result = await tracked_func()
|
||||||
|
assert result == "result"
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
def test_health_status(self):
|
||||||
|
"""Test health status structure."""
|
||||||
|
from src.utils.monitoring import HealthStatus
|
||||||
|
|
||||||
|
status = HealthStatus(healthy=True)
|
||||||
|
|
||||||
|
status.add_check("database", True, "Connected")
|
||||||
|
status.add_check("api", False, "Timeout")
|
||||||
|
|
||||||
|
assert not status.healthy # Should be unhealthy due to API check
|
||||||
|
assert status.checks["database"]["healthy"]
|
||||||
|
assert not status.checks["api"]["healthy"]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Integration Tests (require environment setup)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestIntegration:
|
||||||
|
"""Integration tests that require actual services."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_database_connection(self):
|
||||||
|
"""Test actual database connection (skip if no MongoDB)."""
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
mongodb_uri = os.getenv("MONGODB_URI")
|
||||||
|
if not mongodb_uri:
|
||||||
|
pytest.skip("MONGODB_URI not set")
|
||||||
|
|
||||||
|
from src.database.db_handler import DatabaseHandler
|
||||||
|
handler = DatabaseHandler(mongodb_uri)
|
||||||
|
|
||||||
|
connected = await handler.ensure_connected()
|
||||||
|
assert connected
|
||||||
|
|
||||||
|
await handler.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Run tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--tb=short"])
|
||||||
Reference in New Issue
Block a user