Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
455360bfae | ||
|
|
f36424b3de | ||
|
|
d87dd0072c | ||
|
|
01079f2751 |
@@ -18,6 +18,11 @@ OPENAI_API_KEY=your_openai_api_key_here
|
|||||||
# Use OpenAI directly: https://api.openai.com/v1
|
# Use OpenAI directly: https://api.openai.com/v1
|
||||||
OPENAI_BASE_URL=https://models.github.ai/inference
|
OPENAI_BASE_URL=https://models.github.ai/inference
|
||||||
|
|
||||||
|
# Anthropic API Key (for Claude models)
|
||||||
|
# Get from: https://console.anthropic.com/
|
||||||
|
# Leave empty to disable Claude models
|
||||||
|
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Image Generation (Optional)
|
# Image Generation (Optional)
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -2,8 +2,8 @@ test.py
|
|||||||
.env
|
.env
|
||||||
chat_history.db
|
chat_history.db
|
||||||
bot_copy.py
|
bot_copy.py
|
||||||
__pycache__/bot.cpython-312.pyc
|
__pycache__/
|
||||||
tests/__pycache__/test_bot.cpython-312.pyc
|
*.pyc
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
chatgpt.zip
|
chatgpt.zip
|
||||||
response.txt
|
response.txt
|
||||||
@@ -12,4 +12,5 @@ venv
|
|||||||
temp_charts
|
temp_charts
|
||||||
.idea
|
.idea
|
||||||
temp_data_files
|
temp_data_files
|
||||||
logs/
|
logs/
|
||||||
|
.pytest_cache/
|
||||||
17
README.md
17
README.md
@@ -18,11 +18,11 @@
|
|||||||
|
|
||||||
## 🌟 Overview
|
## 🌟 Overview
|
||||||
|
|
||||||
**ChatGPT Discord Bot** brings the power of AI directly to your Discord server! Powered by OpenAI's latest models, this bot goes beyond simple chat - it's a complete AI assistant with **code interpretation**, **file management**, **data analysis**, and much more.
|
**ChatGPT Discord Bot** brings the power of AI directly to your Discord server! Powered by OpenAI's latest models and Anthropic's Claude, this bot goes beyond simple chat - it's a complete AI assistant with **code interpretation**, **file management**, **data analysis**, and much more.
|
||||||
|
|
||||||
### 🎯 What Makes This Bot Special?
|
### 🎯 What Makes This Bot Special?
|
||||||
|
|
||||||
- 🧠 **Latest AI Models** - GPT-4o, GPT-5, o1, o3-mini, and more
|
- 🧠 **Latest AI Models** - GPT-4o, GPT-5, o1, o3-mini, Claude 4, and more
|
||||||
- 💻 **Code Interpreter** - Execute Python code like ChatGPT (NEW in v2.0!)
|
- 💻 **Code Interpreter** - Execute Python code like ChatGPT (NEW in v2.0!)
|
||||||
- 📁 **Smart File Management** - Handle 200+ file types with automatic cleanup
|
- 📁 **Smart File Management** - Handle 200+ file types with automatic cleanup
|
||||||
- 📊 **Data Analysis** - Upload and analyze CSV, Excel, and scientific data
|
- 📊 **Data Analysis** - Upload and analyze CSV, Excel, and scientific data
|
||||||
@@ -164,6 +164,15 @@ Set reminders naturally:
|
|||||||
- `o1`
|
- `o1`
|
||||||
- `o3-mini`
|
- `o3-mini`
|
||||||
|
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
|
||||||
|
**Claude (Anthropic)**
|
||||||
|
- `claude-sonnet-4-20250514`
|
||||||
|
- `claude-opus-4-20250514`
|
||||||
|
- `claude-3.5-sonnet`
|
||||||
|
- `claude-3.5-haiku`
|
||||||
|
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
@@ -179,6 +188,7 @@ Before you begin, ensure you have:
|
|||||||
- 🐳 **Docker** (recommended) or Python 3.13+
|
- 🐳 **Docker** (recommended) or Python 3.13+
|
||||||
- 🎮 **Discord Bot Token** ([Create one here](https://discord.com/developers/applications))
|
- 🎮 **Discord Bot Token** ([Create one here](https://discord.com/developers/applications))
|
||||||
- 🔑 **OpenAI API Key** ([Get it here](https://platform.openai.com/api-keys))
|
- 🔑 **OpenAI API Key** ([Get it here](https://platform.openai.com/api-keys))
|
||||||
|
- 🧠 **Anthropic API Key** (Optional, for Claude models - [Get it here](https://console.anthropic.com/))
|
||||||
- 🎨 **Runware API Key** ([Sign up here](https://runware.ai/))
|
- 🎨 **Runware API Key** ([Sign up here](https://runware.ai/))
|
||||||
- 🔍 **Google API Key** ([Google Cloud Console](https://console.cloud.google.com/))
|
- 🔍 **Google API Key** ([Google Cloud Console](https://console.cloud.google.com/))
|
||||||
- 🗄️ **MongoDB** ([MongoDB Atlas](https://cloud.mongodb.com/) - Free tier available)
|
- 🗄️ **MongoDB** ([MongoDB Atlas](https://cloud.mongodb.com/) - Free tier available)
|
||||||
@@ -195,6 +205,9 @@ DISCORD_TOKEN=your_discord_bot_token_here
|
|||||||
OPENAI_API_KEY=your_openai_api_key_here
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
|
||||||
|
# Anthropic (Claude) - Optional
|
||||||
|
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||||
|
|
||||||
# Image Generation
|
# Image Generation
|
||||||
RUNWARE_API_KEY=your_runware_api_key_here
|
RUNWARE_API_KEY=your_runware_api_key_here
|
||||||
|
|
||||||
|
|||||||
20
bot.py
20
bot.py
@@ -17,7 +17,7 @@ from src.config.config import (
|
|||||||
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
|
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
|
||||||
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
|
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
|
||||||
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
|
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
|
||||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
|
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP, ANTHROPIC_API_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import webhook logger
|
# Import webhook logger
|
||||||
@@ -124,6 +124,20 @@ async def main():
|
|||||||
logging.error(f"Error initializing OpenAI client: {e}")
|
logging.error(f"Error initializing OpenAI client: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Initialize the Anthropic client (for Claude models)
|
||||||
|
anthropic_client = None
|
||||||
|
if ANTHROPIC_API_KEY:
|
||||||
|
try:
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
|
||||||
|
logging.info("Anthropic client initialized successfully")
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("Anthropic package not installed. Claude models will not be available. Install with: pip install anthropic")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Error initializing Anthropic client: {e}. Claude models will not be available.")
|
||||||
|
else:
|
||||||
|
logging.info("ANTHROPIC_API_KEY not set - Claude models will not be available")
|
||||||
|
|
||||||
# Global references to objects that need cleanup
|
# Global references to objects that need cleanup
|
||||||
message_handler = None
|
message_handler = None
|
||||||
db_handler = None
|
db_handler = None
|
||||||
@@ -191,14 +205,14 @@ async def main():
|
|||||||
await ctx.send(f"Error: {error_msg}")
|
await ctx.send(f"Error: {error_msg}")
|
||||||
|
|
||||||
# Initialize message handler
|
# Initialize message handler
|
||||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
|
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator, anthropic_client)
|
||||||
|
|
||||||
# Attach db_handler to bot for cogs
|
# Attach db_handler to bot for cogs
|
||||||
bot.db_handler = db_handler
|
bot.db_handler = db_handler
|
||||||
|
|
||||||
# Set up slash commands
|
# Set up slash commands
|
||||||
from src.commands.commands import setup_commands
|
from src.commands.commands import setup_commands
|
||||||
setup_commands(bot, db_handler, openai_client, image_generator)
|
setup_commands(bot, db_handler, openai_client, image_generator, anthropic_client)
|
||||||
|
|
||||||
# Load file management commands
|
# Load file management commands
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ beautifulsoup4>=4.12.0
|
|||||||
# AI & ML
|
# AI & ML
|
||||||
runware>=0.4.33
|
runware>=0.4.33
|
||||||
tiktoken>=0.7.0
|
tiktoken>=0.7.0
|
||||||
|
anthropic>=0.40.0
|
||||||
|
|
||||||
# Data Processing
|
# Data Processing
|
||||||
pandas>=2.1.0
|
pandas>=2.1.0
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from src.utils.openai_utils import prepare_file_from_path
|
|||||||
from src.utils.token_counter import token_counter
|
from src.utils.token_counter import token_counter
|
||||||
from src.utils.code_interpreter import delete_all_user_files
|
from src.utils.code_interpreter import delete_all_user_files
|
||||||
from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed
|
from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed
|
||||||
|
from src.utils.claude_utils import is_claude_model, call_claude_api
|
||||||
|
|
||||||
# Dictionary to keep track of user requests and their cooldowns
|
# Dictionary to keep track of user requests and their cooldowns
|
||||||
user_requests: Dict[int, Dict[str, Any]] = {}
|
user_requests: Dict[int, Dict[str, Any]] = {}
|
||||||
@@ -69,7 +70,7 @@ async def image_model_autocomplete(
|
|||||||
for model in matches[:25]
|
for model in matches[:25]
|
||||||
]
|
]
|
||||||
|
|
||||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator, anthropic_client=None):
|
||||||
"""
|
"""
|
||||||
Set up all slash commands for the bot.
|
Set up all slash commands for the bot.
|
||||||
|
|
||||||
@@ -78,6 +79,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
|||||||
db_handler: Database handler instance
|
db_handler: Database handler instance
|
||||||
openai_client: OpenAI client instance
|
openai_client: OpenAI client instance
|
||||||
image_generator: Image generator instance
|
image_generator: Image generator instance
|
||||||
|
anthropic_client: Anthropic client instance (optional, for Claude models)
|
||||||
"""
|
"""
|
||||||
tree = bot.tree
|
tree = bot.tree
|
||||||
|
|
||||||
@@ -265,24 +267,46 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
|||||||
f"(text: {input_token_count['text_tokens']}, images: {input_token_count['image_tokens']})"
|
f"(text: {input_token_count['text_tokens']}, images: {input_token_count['image_tokens']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send to the AI model
|
# Check if using Claude model
|
||||||
api_params = {
|
if is_claude_model(model):
|
||||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
if anthropic_client is None:
|
||||||
"messages": messages
|
await interaction.followup.send(
|
||||||
}
|
"❌ Claude model not available. ANTHROPIC_API_KEY is not configured.",
|
||||||
|
ephemeral=True
|
||||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
)
|
||||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
return
|
||||||
api_params["temperature"] = 0.5
|
|
||||||
|
# Use Claude API
|
||||||
response = await openai_client.chat.completions.create(**api_params)
|
claude_response = await call_claude_api(
|
||||||
|
anthropic_client,
|
||||||
|
messages,
|
||||||
|
model,
|
||||||
|
max_tokens=4096,
|
||||||
|
use_tools=False
|
||||||
|
)
|
||||||
|
|
||||||
|
reply = claude_response.get("content", "")
|
||||||
|
actual_input_tokens = claude_response.get("input_tokens", 0)
|
||||||
|
actual_output_tokens = claude_response.get("output_tokens", 0)
|
||||||
|
else:
|
||||||
|
# Send to the OpenAI model
|
||||||
|
api_params = {
|
||||||
|
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||||
|
"messages": messages
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||||
|
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||||
|
api_params["temperature"] = 0.5
|
||||||
|
|
||||||
|
response = await openai_client.chat.completions.create(**api_params)
|
||||||
|
|
||||||
reply = response.choices[0].message.content
|
reply = response.choices[0].message.content
|
||||||
|
|
||||||
# Get actual token usage from API response
|
# Get actual token usage from API response
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
actual_input_tokens = usage.prompt_tokens if usage else input_token_count['total_tokens']
|
actual_input_tokens = usage.prompt_tokens if usage else input_token_count['total_tokens']
|
||||||
actual_output_tokens = usage.completion_tokens if usage else token_counter.count_text_tokens(reply, model)
|
actual_output_tokens = usage.completion_tokens if usage else token_counter.count_text_tokens(reply, model)
|
||||||
|
|
||||||
# Calculate cost
|
# Calculate cost
|
||||||
cost = token_counter.estimate_cost(actual_input_tokens, actual_output_tokens, model)
|
cost = token_counter.estimate_cost(actual_input_tokens, actual_output_tokens, model)
|
||||||
@@ -362,19 +386,38 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
|||||||
{"role": "user", "content": f"Content from {url}:\n{content}"}
|
{"role": "user", "content": f"Content from {url}:\n{content}"}
|
||||||
]
|
]
|
||||||
|
|
||||||
api_params = {
|
# Check if using Claude model
|
||||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
if is_claude_model(model):
|
||||||
"messages": messages
|
if anthropic_client is None:
|
||||||
}
|
await interaction.followup.send(
|
||||||
|
"❌ Claude model not available. ANTHROPIC_API_KEY is not configured.",
|
||||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
ephemeral=True
|
||||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
)
|
||||||
api_params["temperature"] = 0.3
|
return
|
||||||
api_params["top_p"] = 0.7
|
|
||||||
|
# Use Claude API
|
||||||
response = await openai_client.chat.completions.create(**api_params)
|
claude_response = await call_claude_api(
|
||||||
|
anthropic_client,
|
||||||
|
messages,
|
||||||
|
model,
|
||||||
|
max_tokens=4096,
|
||||||
|
use_tools=False
|
||||||
|
)
|
||||||
|
reply = claude_response.get("content", "")
|
||||||
|
else:
|
||||||
|
api_params = {
|
||||||
|
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||||
|
"messages": messages
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||||
|
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||||
|
api_params["temperature"] = 0.3
|
||||||
|
api_params["top_p"] = 0.7
|
||||||
|
|
||||||
|
response = await openai_client.chat.completions.create(**api_params)
|
||||||
|
|
||||||
reply = response.choices[0].message.content
|
reply = response.choices[0].message.content
|
||||||
|
|
||||||
# Add the interaction to history
|
# Add the interaction to history
|
||||||
history.append({"role": "user", "content": f"Scraped content from {url}"})
|
history.append({"role": "user", "content": f"Scraped content from {url}"})
|
||||||
|
|||||||
@@ -101,7 +101,11 @@ MODEL_OPTIONS = [
|
|||||||
"openai/o1",
|
"openai/o1",
|
||||||
"openai/o3-mini",
|
"openai/o3-mini",
|
||||||
"openai/o3",
|
"openai/o3",
|
||||||
"openai/o4-mini"
|
"openai/o4-mini",
|
||||||
|
"anthropic/claude-sonnet-4-20250514",
|
||||||
|
"anthropic/claude-opus-4-20250514",
|
||||||
|
"anthropic/claude-3.5-sonnet",
|
||||||
|
"anthropic/claude-3.5-haiku",
|
||||||
]
|
]
|
||||||
|
|
||||||
# ==================== IMAGE GENERATION MODELS ====================
|
# ==================== IMAGE GENERATION MODELS ====================
|
||||||
@@ -175,7 +179,12 @@ MODEL_TOKEN_LIMITS = {
|
|||||||
"openai/gpt-5": 4000,
|
"openai/gpt-5": 4000,
|
||||||
"openai/gpt-5-nano": 4000,
|
"openai/gpt-5-nano": 4000,
|
||||||
"openai/gpt-5-mini": 4000,
|
"openai/gpt-5-mini": 4000,
|
||||||
"openai/gpt-5-chat": 4000
|
"openai/gpt-5-chat": 4000,
|
||||||
|
# Claude models (200K context window, using conservative limits)
|
||||||
|
"anthropic/claude-sonnet-4-20250514": 16000,
|
||||||
|
"anthropic/claude-opus-4-20250514": 16000,
|
||||||
|
"anthropic/claude-3.5-sonnet": 16000,
|
||||||
|
"anthropic/claude-3.5-haiku": 16000,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Default token limit for unknown models
|
# Default token limit for unknown models
|
||||||
@@ -184,7 +193,7 @@ DEFAULT_TOKEN_LIMIT = 4000
|
|||||||
# Default model for new users
|
# Default model for new users
|
||||||
DEFAULT_MODEL = "openai/gpt-4.1"
|
DEFAULT_MODEL = "openai/gpt-4.1"
|
||||||
|
|
||||||
PDF_ALLOWED_MODELS = ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-4.1","openai/gpt-4.1-nano","openai/gpt-4.1-mini"]
|
PDF_ALLOWED_MODELS = ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-4.1","openai/gpt-4.1-nano","openai/gpt-4.1-mini", "anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514", "anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku"]
|
||||||
PDF_BATCH_SIZE = 3
|
PDF_BATCH_SIZE = 3
|
||||||
|
|
||||||
# Prompt templates
|
# Prompt templates
|
||||||
@@ -403,6 +412,7 @@ RUNWARE_API_KEY = os.getenv("RUNWARE_API_KEY")
|
|||||||
MONGODB_URI = os.getenv("MONGODB_URI")
|
MONGODB_URI = os.getenv("MONGODB_URI")
|
||||||
ADMIN_ID = os.getenv("ADMIN_ID") # Add ADMIN_ID if you're using it
|
ADMIN_ID = os.getenv("ADMIN_ID") # Add ADMIN_ID if you're using it
|
||||||
TIMEZONE = os.getenv("TIMEZONE", "UTC") # Default to UTC if not specified
|
TIMEZONE = os.getenv("TIMEZONE", "UTC") # Default to UTC if not specified
|
||||||
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") # Anthropic API key for Claude models
|
||||||
|
|
||||||
# File management settings
|
# File management settings
|
||||||
FILE_EXPIRATION_HOURS = int(os.getenv("FILE_EXPIRATION_HOURS", "48")) # Hours until files expire (-1 for never)
|
FILE_EXPIRATION_HOURS = int(os.getenv("FILE_EXPIRATION_HOURS", "48")) # Hours until files expire (-1 for never)
|
||||||
@@ -416,5 +426,7 @@ if not MONGODB_URI:
|
|||||||
print("WARNING: MONGODB_URI not found in .env file")
|
print("WARNING: MONGODB_URI not found in .env file")
|
||||||
if not RUNWARE_API_KEY:
|
if not RUNWARE_API_KEY:
|
||||||
print("WARNING: RUNWARE_API_KEY not found in .env file")
|
print("WARNING: RUNWARE_API_KEY not found in .env file")
|
||||||
|
if not ANTHROPIC_API_KEY:
|
||||||
|
print("INFO: ANTHROPIC_API_KEY not found in .env file - Claude models will not be available")
|
||||||
if ENABLE_WEBHOOK_LOGGING and not LOGGING_WEBHOOK_URL:
|
if ENABLE_WEBHOOK_LOGGING and not LOGGING_WEBHOOK_URL:
|
||||||
print("WARNING: Webhook logging enabled but LOGGING_WEBHOOK_URL not found in .env file")
|
print("WARNING: Webhook logging enabled but LOGGING_WEBHOOK_URL not found in .env file")
|
||||||
@@ -51,6 +51,14 @@ MODEL_PRICING: Dict[str, ModelPricing] = {
|
|||||||
|
|
||||||
# o4 Family
|
# o4 Family
|
||||||
"openai/o4-mini": ModelPricing(input=2.00, output=8.00),
|
"openai/o4-mini": ModelPricing(input=2.00, output=8.00),
|
||||||
|
|
||||||
|
# Claude 4 Family (Anthropic - latest models)
|
||||||
|
"anthropic/claude-sonnet-4-20250514": ModelPricing(input=3.00, output=15.00),
|
||||||
|
"anthropic/claude-opus-4-20250514": ModelPricing(input=15.00, output=75.00),
|
||||||
|
|
||||||
|
# Claude 3.5 Family (Anthropic)
|
||||||
|
"anthropic/claude-3.5-sonnet": ModelPricing(input=3.00, output=15.00),
|
||||||
|
"anthropic/claude-3.5-haiku": ModelPricing(input=0.80, output=4.00),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import base64
|
|||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from src.utils.openai_utils import process_tool_calls, prepare_messages_for_api, get_tools_for_model
|
from src.utils.openai_utils import process_tool_calls, prepare_messages_for_api, get_tools_for_model
|
||||||
|
from src.utils.claude_utils import is_claude_model, call_claude_api, convert_claude_tool_calls_to_openai
|
||||||
from src.utils.pdf_utils import process_pdf, send_response
|
from src.utils.pdf_utils import process_pdf, send_response
|
||||||
from src.utils.code_utils import extract_code_blocks
|
from src.utils.code_utils import extract_code_blocks
|
||||||
from src.utils.reminder_utils import ReminderManager
|
from src.utils.reminder_utils import ReminderManager
|
||||||
@@ -95,7 +96,7 @@ except ImportError as e:
|
|||||||
logging.warning(f"Data analysis libraries not available: {str(e)}")
|
logging.warning(f"Data analysis libraries not available: {str(e)}")
|
||||||
|
|
||||||
class MessageHandler:
|
class MessageHandler:
|
||||||
def __init__(self, bot, db_handler, openai_client, image_generator):
|
def __init__(self, bot, db_handler, openai_client, image_generator, anthropic_client=None):
|
||||||
"""
|
"""
|
||||||
Initialize the message handler.
|
Initialize the message handler.
|
||||||
|
|
||||||
@@ -104,10 +105,12 @@ class MessageHandler:
|
|||||||
db_handler: Database handler instance
|
db_handler: Database handler instance
|
||||||
openai_client: OpenAI client instance
|
openai_client: OpenAI client instance
|
||||||
image_generator: Image generator instance
|
image_generator: Image generator instance
|
||||||
|
anthropic_client: Anthropic client instance (optional, for Claude models)
|
||||||
"""
|
"""
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.db = db_handler
|
self.db = db_handler
|
||||||
self.client = openai_client
|
self.client = openai_client
|
||||||
|
self.anthropic_client = anthropic_client
|
||||||
self.image_generator = image_generator
|
self.image_generator = image_generator
|
||||||
self.aiohttp_session = None
|
self.aiohttp_session = None
|
||||||
|
|
||||||
@@ -172,6 +175,26 @@ class MessageHandler:
|
|||||||
logging.warning(f"Failed to initialize tiktoken encoder: {e}")
|
logging.warning(f"Failed to initialize tiktoken encoder: {e}")
|
||||||
self.token_encoder = None
|
self.token_encoder = None
|
||||||
|
|
||||||
|
def _build_claude_tool_result_message(self, tool_call_id: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build a tool result message for Claude API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call_id: The ID of the tool call this result is for
|
||||||
|
content: The result content from the tool execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A properly formatted Claude tool result message
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": tool_call_id,
|
||||||
|
"content": content
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
def _find_user_id_from_current_task(self):
|
def _find_user_id_from_current_task(self):
|
||||||
"""
|
"""
|
||||||
Utility method to find user_id from the current asyncio task.
|
Utility method to find user_id from the current asyncio task.
|
||||||
@@ -1514,7 +1537,14 @@ print("\\n=== Correlation Analysis ===")
|
|||||||
|
|
||||||
# Determine which models should have tools available
|
# Determine which models should have tools available
|
||||||
# openai/o1-mini and openai/o1-preview do not support tools
|
# openai/o1-mini and openai/o1-preview do not support tools
|
||||||
use_tools = model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "openai/o3", "openai/o4-mini"]
|
# Claude models support tools
|
||||||
|
use_tools = model in [
|
||||||
|
"openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano",
|
||||||
|
"openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o3-mini",
|
||||||
|
"openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "openai/o3", "openai/o4-mini",
|
||||||
|
"anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514",
|
||||||
|
"anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku"
|
||||||
|
]
|
||||||
|
|
||||||
# Count tokens being sent to API
|
# Count tokens being sent to API
|
||||||
total_content_length = 0
|
total_content_length = 0
|
||||||
@@ -1535,177 +1565,310 @@ print("\\n=== Correlation Analysis ===")
|
|||||||
logging.info(f"API Request Debug - Model: {model}, Messages: {len(messages_for_api)}, "
|
logging.info(f"API Request Debug - Model: {model}, Messages: {len(messages_for_api)}, "
|
||||||
f"Est. tokens: {estimated_tokens}, Content length: {total_content_length} chars")
|
f"Est. tokens: {estimated_tokens}, Content length: {total_content_length} chars")
|
||||||
|
|
||||||
# Prepare API call parameters
|
|
||||||
api_params = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages_for_api,
|
|
||||||
"timeout": 240 # Increased timeout for better response handling
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
|
||||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
|
||||||
api_params["temperature"] = 0.3
|
|
||||||
api_params["top_p"] = 0.7
|
|
||||||
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
|
||||||
# For other models (not GPT-4o family and not GPT-5 family)
|
|
||||||
api_params["temperature"] = 1
|
|
||||||
api_params["top_p"] = 1
|
|
||||||
|
|
||||||
# Add tools if using a supported model
|
|
||||||
if use_tools:
|
|
||||||
tools = get_tools_for_model()
|
|
||||||
api_params["tools"] = tools
|
|
||||||
|
|
||||||
# Initialize variables to track tool responses
|
# Initialize variables to track tool responses
|
||||||
image_generation_used = False
|
image_generation_used = False
|
||||||
chart_id = None
|
chart_id = None
|
||||||
image_urls = [] # Will store unique image URLs
|
image_urls = [] # Will store unique image URLs
|
||||||
|
|
||||||
# Make the initial API call
|
# Check if this is a Claude model
|
||||||
try:
|
if is_claude_model(model):
|
||||||
response = await self.client.chat.completions.create(**api_params)
|
# Use Claude API
|
||||||
except Exception as e:
|
if self.anthropic_client is None:
|
||||||
# Handle 413 Request Entity Too Large error with a user-friendly message
|
|
||||||
if "413" in str(e) or "tokens_limit_reached" in str(e) or "Request body too large" in str(e):
|
|
||||||
await message.channel.send(
|
await message.channel.send(
|
||||||
f"❌ **Request too large for {model}**\n"
|
f"❌ **Claude model not available**\n"
|
||||||
f"Your conversation history or message is too large for this model.\n"
|
f"The Anthropic API key is not configured. Please set ANTHROPIC_API_KEY in your .env file."
|
||||||
f"Try:\n"
|
|
||||||
f"• Using `/reset` to start fresh\n"
|
|
||||||
f"• Using a model with higher token limits\n"
|
|
||||||
f"• Reducing the size of your current message\n"
|
|
||||||
f"• Breaking up large files into smaller pieces"
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
# Re-raise other errors
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Extract token usage and calculate cost
|
|
||||||
input_tokens = 0
|
|
||||||
output_tokens = 0
|
|
||||||
total_cost = 0.0
|
|
||||||
|
|
||||||
if hasattr(response, 'usage') and response.usage:
|
|
||||||
input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
|
||||||
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
|
||||||
|
|
||||||
# Calculate cost based on model pricing
|
try:
|
||||||
pricing = MODEL_PRICING.get(model)
|
claude_response = await call_claude_api(
|
||||||
if pricing:
|
self.anthropic_client,
|
||||||
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
messages_for_api,
|
||||||
|
model,
|
||||||
|
max_tokens=4096,
|
||||||
|
use_tools=use_tools
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
# Extract token usage and calculate cost for Claude
|
||||||
|
input_tokens = claude_response.get("input_tokens", 0)
|
||||||
|
output_tokens = claude_response.get("output_tokens", 0)
|
||||||
|
total_cost = 0.0
|
||||||
|
|
||||||
# Save token usage and cost to database
|
# Calculate cost based on model pricing
|
||||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
pricing = MODEL_PRICING.get(model)
|
||||||
|
if pricing:
|
||||||
# Process tool calls if any
|
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
||||||
updated_messages = None
|
logging.info(f"Claude API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
||||||
if use_tools and response.choices[0].finish_reason == "tool_calls":
|
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||||
# Process tools
|
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
|
||||||
tool_messages = {}
|
|
||||||
|
|
||||||
# Track which tools are being called
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
if tool_call.function.name in self.tool_mapping:
|
|
||||||
tool_messages[tool_call.function.name] = True
|
|
||||||
if tool_call.function.name == "generate_image":
|
|
||||||
image_generation_used = True
|
|
||||||
elif tool_call.function.name == "edit_image":
|
|
||||||
# Display appropriate message for image editing
|
|
||||||
await message.channel.send("🖌️ Editing image...")
|
|
||||||
|
|
||||||
# Display appropriate messages based on which tools are being called
|
|
||||||
if tool_messages.get("google_search") or tool_messages.get("scrape_webpage"):
|
|
||||||
await message.channel.send("🔍 Researching information...")
|
|
||||||
|
|
||||||
if tool_messages.get("execute_python_code") or tool_messages.get("analyze_data_file"):
|
|
||||||
await message.channel.send("💻 Running code...")
|
|
||||||
|
|
||||||
if tool_messages.get("generate_image"):
|
|
||||||
await message.channel.send("🎨 Generating images...")
|
|
||||||
|
|
||||||
if tool_messages.get("set_reminder") or tool_messages.get("get_reminders"):
|
# Process tool calls if any
|
||||||
await message.channel.send("📅 Processing reminders...")
|
updated_messages = None
|
||||||
|
if use_tools and claude_response.get("tool_calls"):
|
||||||
if not tool_messages:
|
tool_calls = convert_claude_tool_calls_to_openai(claude_response["tool_calls"])
|
||||||
await message.channel.send("🤔 Processing...")
|
tool_messages = {}
|
||||||
|
|
||||||
# Process any tool calls and get the updated messages
|
|
||||||
tool_calls_processed, updated_messages = await process_tool_calls(
|
|
||||||
self.client,
|
|
||||||
response,
|
|
||||||
messages_for_api,
|
|
||||||
self.tool_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process tool responses to extract important data (images, charts)
|
|
||||||
if updated_messages:
|
|
||||||
# Look for image generation and code interpreter tool responses
|
|
||||||
for msg in updated_messages:
|
|
||||||
if msg.get('role') == 'tool' and msg.get('name') == 'generate_image':
|
|
||||||
try:
|
|
||||||
tool_result = json.loads(msg.get('content', '{}'))
|
|
||||||
if tool_result.get('image_urls'):
|
|
||||||
image_urls.extend(tool_result['image_urls'])
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif msg.get('role') == 'tool' and msg.get('name') == 'edit_image':
|
# Track which tools are being called
|
||||||
try:
|
for tool_call in tool_calls:
|
||||||
tool_result = json.loads(msg.get('content', '{}'))
|
if tool_call.function.name in self.tool_mapping:
|
||||||
if tool_result.get('image_url'):
|
tool_messages[tool_call.function.name] = True
|
||||||
image_urls.append(tool_result['image_url'])
|
if tool_call.function.name == "generate_image":
|
||||||
except:
|
image_generation_used = True
|
||||||
pass
|
elif tool_call.function.name == "edit_image":
|
||||||
|
await message.channel.send("🖌️ Editing image...")
|
||||||
|
|
||||||
elif msg.get('role') == 'tool' and msg.get('name') in ['execute_python_code', 'analyze_data_file']:
|
# Display appropriate messages
|
||||||
try:
|
if tool_messages.get("google_search") or tool_messages.get("scrape_webpage"):
|
||||||
tool_result = json.loads(msg.get('content', '{}'))
|
await message.channel.send("🔍 Researching information...")
|
||||||
if tool_result.get('chart_id'):
|
if tool_messages.get("execute_python_code") or tool_messages.get("analyze_data_file"):
|
||||||
chart_id = tool_result['chart_id']
|
await message.channel.send("💻 Running code...")
|
||||||
except:
|
if tool_messages.get("generate_image"):
|
||||||
pass
|
await message.channel.send("🎨 Generating images...")
|
||||||
|
if tool_messages.get("set_reminder") or tool_messages.get("get_reminders"):
|
||||||
# If tool calls were processed, make another API call with the updated messages
|
await message.channel.send("📅 Processing reminders...")
|
||||||
if tool_calls_processed and updated_messages:
|
if not tool_messages:
|
||||||
# Prepare API parameters for follow-up call
|
await message.channel.send("🤔 Processing...")
|
||||||
follow_up_params = {
|
|
||||||
"model": model,
|
|
||||||
"messages": updated_messages,
|
|
||||||
"timeout": 240
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
|
||||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
|
||||||
follow_up_params["temperature"] = 0.3
|
|
||||||
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
|
||||||
follow_up_params["temperature"] = 1
|
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(**follow_up_params)
|
|
||||||
|
|
||||||
# Extract token usage and calculate cost for follow-up call
|
|
||||||
if hasattr(response, 'usage') and response.usage:
|
|
||||||
follow_up_input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
|
||||||
follow_up_output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
|
||||||
|
|
||||||
input_tokens += follow_up_input_tokens
|
# Process tool calls manually for Claude
|
||||||
output_tokens += follow_up_output_tokens
|
tool_results = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
if function_name in self.tool_mapping:
|
||||||
|
try:
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
function_response = await self.tool_mapping[function_name](function_args)
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": str(function_response)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Extract image URLs if generated
|
||||||
|
if function_name == "generate_image":
|
||||||
|
try:
|
||||||
|
tool_result = json.loads(function_response) if isinstance(function_response, str) else function_response
|
||||||
|
if tool_result.get('image_urls'):
|
||||||
|
image_urls.extend(tool_result['image_urls'])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error executing {function_name}: {e}")
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": f"Error: {str(e)}"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build updated messages with tool results for follow-up call
|
||||||
|
updated_messages = messages_for_api.copy()
|
||||||
|
updated_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": claude_response.get("content", "")
|
||||||
|
})
|
||||||
|
for result in tool_results:
|
||||||
|
updated_messages.append(
|
||||||
|
self._build_claude_tool_result_message(result["tool_call_id"], result["content"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make follow-up call
|
||||||
|
follow_up_response = await call_claude_api(
|
||||||
|
self.anthropic_client,
|
||||||
|
updated_messages,
|
||||||
|
model,
|
||||||
|
max_tokens=4096,
|
||||||
|
use_tools=False # Don't need tools for follow-up
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update token usage
|
||||||
|
follow_up_input = follow_up_response.get("input_tokens", 0)
|
||||||
|
follow_up_output = follow_up_response.get("output_tokens", 0)
|
||||||
|
input_tokens += follow_up_input
|
||||||
|
output_tokens += follow_up_output
|
||||||
|
|
||||||
# Calculate additional cost
|
|
||||||
pricing = MODEL_PRICING.get(model)
|
|
||||||
if pricing:
|
if pricing:
|
||||||
additional_cost = pricing.calculate_cost(follow_up_input_tokens, follow_up_output_tokens)
|
additional_cost = pricing.calculate_cost(follow_up_input, follow_up_output)
|
||||||
total_cost += additional_cost
|
total_cost += additional_cost
|
||||||
|
await self.db.save_token_usage(user_id, model, follow_up_input, follow_up_output, additional_cost)
|
||||||
|
|
||||||
|
reply = follow_up_response.get("content", "")
|
||||||
|
else:
|
||||||
|
reply = claude_response.get("content", "")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e)
|
||||||
|
if "overloaded" in error_str.lower():
|
||||||
|
await message.channel.send(
|
||||||
|
f"⚠️ **Claude is currently overloaded**\n"
|
||||||
|
f"Please try again in a moment or switch to an OpenAI model."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
# Use OpenAI API (existing logic)
|
||||||
|
# Prepare API call parameters
|
||||||
|
api_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages_for_api,
|
||||||
|
"timeout": 240 # Increased timeout for better response handling
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||||
|
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
||||||
|
api_params["temperature"] = 0.3
|
||||||
|
api_params["top_p"] = 0.7
|
||||||
|
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||||
|
# For other models (not GPT-4o family and not GPT-5 family)
|
||||||
|
api_params["temperature"] = 1
|
||||||
|
api_params["top_p"] = 1
|
||||||
|
|
||||||
|
# Add tools if using a supported model
|
||||||
|
if use_tools:
|
||||||
|
tools = get_tools_for_model()
|
||||||
|
api_params["tools"] = tools
|
||||||
|
|
||||||
|
# Make the initial API call
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(**api_params)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle 413 Request Entity Too Large error with a user-friendly message
|
||||||
|
if "413" in str(e) or "tokens_limit_reached" in str(e) or "Request body too large" in str(e):
|
||||||
|
await message.channel.send(
|
||||||
|
f"❌ **Request too large for {model}**\n"
|
||||||
|
f"Your conversation history or message is too large for this model.\n"
|
||||||
|
f"Try:\n"
|
||||||
|
f"• Using `/reset` to start fresh\n"
|
||||||
|
f"• Using a model with higher token limits\n"
|
||||||
|
f"• Reducing the size of your current message\n"
|
||||||
|
f"• Breaking up large files into smaller pieces"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Re-raise other errors
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Extract token usage and calculate cost
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
total_cost = 0.0
|
||||||
|
|
||||||
|
if hasattr(response, 'usage') and response.usage:
|
||||||
|
input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||||
|
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||||
|
|
||||||
|
# Calculate cost based on model pricing
|
||||||
|
pricing = MODEL_PRICING.get(model)
|
||||||
|
if pricing:
|
||||||
|
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
||||||
|
|
||||||
|
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
||||||
|
|
||||||
|
# Save token usage and cost to database
|
||||||
|
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||||
|
|
||||||
|
# Process tool calls if any (OpenAI)
|
||||||
|
updated_messages = None
|
||||||
|
if use_tools and response.choices[0].finish_reason == "tool_calls":
|
||||||
|
# Process tools
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
tool_messages = {}
|
||||||
|
|
||||||
|
# Track which tools are being called
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if tool_call.function.name in self.tool_mapping:
|
||||||
|
tool_messages[tool_call.function.name] = True
|
||||||
|
if tool_call.function.name == "generate_image":
|
||||||
|
image_generation_used = True
|
||||||
|
elif tool_call.function.name == "edit_image":
|
||||||
|
# Display appropriate message for image editing
|
||||||
|
await message.channel.send("🖌️ Editing image...")
|
||||||
|
|
||||||
|
# Display appropriate messages based on which tools are being called
|
||||||
|
if tool_messages.get("google_search") or tool_messages.get("scrape_webpage"):
|
||||||
|
await message.channel.send("🔍 Researching information...")
|
||||||
|
|
||||||
|
if tool_messages.get("execute_python_code") or tool_messages.get("analyze_data_file"):
|
||||||
|
await message.channel.send("💻 Running code...")
|
||||||
|
|
||||||
|
if tool_messages.get("generate_image"):
|
||||||
|
await message.channel.send("🎨 Generating images...")
|
||||||
|
|
||||||
|
if tool_messages.get("set_reminder") or tool_messages.get("get_reminders"):
|
||||||
|
await message.channel.send("📅 Processing reminders...")
|
||||||
|
|
||||||
|
if not tool_messages:
|
||||||
|
await message.channel.send("🤔 Processing...")
|
||||||
|
|
||||||
|
# Process any tool calls and get the updated messages
|
||||||
|
tool_calls_processed, updated_messages = await process_tool_calls(
|
||||||
|
self.client,
|
||||||
|
response,
|
||||||
|
messages_for_api,
|
||||||
|
self.tool_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process tool responses to extract important data (images, charts)
|
||||||
|
if updated_messages:
|
||||||
|
# Look for image generation and code interpreter tool responses
|
||||||
|
for msg in updated_messages:
|
||||||
|
if msg.get('role') == 'tool' and msg.get('name') == 'generate_image':
|
||||||
|
try:
|
||||||
|
tool_result = json.loads(msg.get('content', '{}'))
|
||||||
|
if tool_result.get('image_urls'):
|
||||||
|
image_urls.extend(tool_result['image_urls'])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: {format_cost(additional_cost)}")
|
elif msg.get('role') == 'tool' and msg.get('name') == 'edit_image':
|
||||||
|
try:
|
||||||
|
tool_result = json.loads(msg.get('content', '{}'))
|
||||||
|
if tool_result.get('image_url'):
|
||||||
|
image_urls.append(tool_result['image_url'])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
# Save additional token usage and cost to database
|
elif msg.get('role') == 'tool' and msg.get('name') in ['execute_python_code', 'analyze_data_file']:
|
||||||
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
try:
|
||||||
|
tool_result = json.loads(msg.get('content', '{}'))
|
||||||
reply = response.choices[0].message.content
|
if tool_result.get('chart_id'):
|
||||||
|
chart_id = tool_result['chart_id']
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If tool calls were processed, make another API call with the updated messages
|
||||||
|
if tool_calls_processed and updated_messages:
|
||||||
|
# Prepare API parameters for follow-up call
|
||||||
|
follow_up_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": updated_messages,
|
||||||
|
"timeout": 240
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||||
|
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
||||||
|
follow_up_params["temperature"] = 0.3
|
||||||
|
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||||
|
follow_up_params["temperature"] = 1
|
||||||
|
|
||||||
|
response = await self.client.chat.completions.create(**follow_up_params)
|
||||||
|
|
||||||
|
# Extract token usage and calculate cost for follow-up call
|
||||||
|
if hasattr(response, 'usage') and response.usage:
|
||||||
|
follow_up_input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||||
|
follow_up_output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||||
|
|
||||||
|
input_tokens += follow_up_input_tokens
|
||||||
|
output_tokens += follow_up_output_tokens
|
||||||
|
|
||||||
|
# Calculate additional cost
|
||||||
|
pricing = MODEL_PRICING.get(model)
|
||||||
|
if pricing:
|
||||||
|
additional_cost = pricing.calculate_cost(follow_up_input_tokens, follow_up_output_tokens)
|
||||||
|
total_cost += additional_cost
|
||||||
|
|
||||||
|
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: {format_cost(additional_cost)}")
|
||||||
|
|
||||||
|
# Save additional token usage and cost to database
|
||||||
|
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
||||||
|
|
||||||
|
reply = response.choices[0].message.content
|
||||||
|
|
||||||
# Add image URLs to assistant content if any were found
|
# Add image URLs to assistant content if any were found
|
||||||
has_images = len(image_urls) > 0
|
has_images = len(image_urls) > 0
|
||||||
@@ -1724,7 +1887,15 @@ print("\\n=== Correlation Analysis ===")
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Store the response in history for models that support it
|
# Store the response in history for models that support it
|
||||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o1-mini", "openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-nano", "openai/gpt-4.1-mini", "openai/o3", "openai/o4-mini", "openai/o1-preview"]:
|
models_with_history = [
|
||||||
|
"openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano",
|
||||||
|
"openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o1-mini",
|
||||||
|
"openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-nano", "openai/gpt-4.1-mini",
|
||||||
|
"openai/o3", "openai/o4-mini", "openai/o1-preview",
|
||||||
|
"anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514",
|
||||||
|
"anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku"
|
||||||
|
]
|
||||||
|
if model in models_with_history:
|
||||||
if model in ["openai/o1-mini", "openai/o1-preview"]:
|
if model in ["openai/o1-mini", "openai/o1-preview"]:
|
||||||
# For models without system prompt support, keep track separately
|
# For models without system prompt support, keep track separately
|
||||||
if has_images:
|
if has_images:
|
||||||
|
|||||||
531
src/utils/claude_utils.py
Normal file
531
src/utils/claude_utils.py
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
"""
|
||||||
|
Claude (Anthropic) API utility functions.
|
||||||
|
|
||||||
|
This module provides utilities for interacting with Anthropic's Claude models,
|
||||||
|
including message conversion and API calls compatible with the existing bot structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def is_claude_model(model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the model is a Claude/Anthropic model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name (e.g., "anthropic/claude-sonnet-4-20250514")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if it's a Claude model, False otherwise
|
||||||
|
"""
|
||||||
|
return model.startswith("anthropic/")
|
||||||
|
|
||||||
|
|
||||||
|
def get_claude_model_id(model: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract the Claude model ID from the full model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Full model name (e.g., "anthropic/claude-sonnet-4-20250514")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Claude model ID (e.g., "claude-sonnet-4-20250514")
|
||||||
|
"""
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return model[len("anthropic/"):]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_messages_to_claude(messages: List[Dict[str, Any]]) -> Tuple[Optional[str], List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Convert OpenAI message format to Claude message format.
|
||||||
|
|
||||||
|
OpenAI uses:
|
||||||
|
- {"role": "system", "content": "..."}
|
||||||
|
- {"role": "user", "content": "..."}
|
||||||
|
- {"role": "assistant", "content": "..."}
|
||||||
|
|
||||||
|
Claude uses:
|
||||||
|
- system parameter (separate from messages)
|
||||||
|
- {"role": "user", "content": "..."}
|
||||||
|
- {"role": "assistant", "content": "..."}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages in OpenAI format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (system_prompt, claude_messages)
|
||||||
|
"""
|
||||||
|
system_prompt = None
|
||||||
|
claude_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
# Skip messages with None content
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
# Claude uses a separate system parameter
|
||||||
|
if isinstance(content, str):
|
||||||
|
system_prompt = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Extract text from list content
|
||||||
|
text_parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
text_parts.append(item.get("text", ""))
|
||||||
|
elif isinstance(item, str):
|
||||||
|
text_parts.append(item)
|
||||||
|
system_prompt = " ".join(text_parts)
|
||||||
|
elif role in ["user", "assistant"]:
|
||||||
|
# Convert content format
|
||||||
|
converted_content = convert_content_to_claude(content)
|
||||||
|
if converted_content:
|
||||||
|
claude_messages.append({
|
||||||
|
"role": role,
|
||||||
|
"content": converted_content
|
||||||
|
})
|
||||||
|
elif role == "tool":
|
||||||
|
# Claude handles tool results differently - add as user message with tool result
|
||||||
|
tool_call_id = msg.get("tool_call_id", "")
|
||||||
|
tool_name = msg.get("name", "unknown")
|
||||||
|
claude_messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": tool_call_id,
|
||||||
|
"content": str(content)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Claude requires alternating user/assistant messages
|
||||||
|
# Merge consecutive messages of the same role
|
||||||
|
merged_messages = merge_consecutive_messages(claude_messages)
|
||||||
|
|
||||||
|
return system_prompt, merged_messages
|
||||||
|
|
||||||
|
|
||||||
|
def convert_content_to_claude(content: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Convert content from OpenAI format to Claude format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content in OpenAI format (string or list)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content in Claude format
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
claude_content = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
item_type = item.get("type")
|
||||||
|
|
||||||
|
if item_type == "text":
|
||||||
|
claude_content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": item.get("text", "")
|
||||||
|
})
|
||||||
|
elif item_type == "image_url":
|
||||||
|
# Convert image_url format to Claude format
|
||||||
|
image_url_data = item.get("image_url", {})
|
||||||
|
if isinstance(image_url_data, dict):
|
||||||
|
url = image_url_data.get("url", "")
|
||||||
|
else:
|
||||||
|
url = str(image_url_data)
|
||||||
|
|
||||||
|
if url:
|
||||||
|
# Claude requires base64 data or URLs
|
||||||
|
if url.startswith("data:"):
|
||||||
|
# Parse base64 data URL
|
||||||
|
try:
|
||||||
|
media_type, base64_data = parse_data_url(url)
|
||||||
|
claude_content.append({
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to parse data URL: {e}")
|
||||||
|
else:
|
||||||
|
# Regular URL - Claude supports URLs directly
|
||||||
|
claude_content.append({
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "url",
|
||||||
|
"url": url
|
||||||
|
}
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Handle other types as text
|
||||||
|
if "text" in item:
|
||||||
|
claude_content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": str(item.get("text", ""))
|
||||||
|
})
|
||||||
|
elif isinstance(item, str):
|
||||||
|
claude_content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": item
|
||||||
|
})
|
||||||
|
|
||||||
|
return claude_content if claude_content else None
|
||||||
|
|
||||||
|
return str(content) if content else None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_data_url(data_url: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse a data URL into media type and base64 data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_url: Data URL (e.g., "data:image/png;base64,...")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (media_type, base64_data)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the data URL format is invalid
|
||||||
|
"""
|
||||||
|
if not data_url.startswith("data:"):
|
||||||
|
raise ValueError(f"Not a data URL: expected 'data:' prefix, got '{data_url[:20]}...'")
|
||||||
|
|
||||||
|
# Remove "data:" prefix
|
||||||
|
content = data_url[5:]
|
||||||
|
|
||||||
|
# Split by semicolon and comma
|
||||||
|
parts = content.split(";base64,")
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Invalid data URL format: expected ';base64,' separator, got '{content[:50]}...'")
|
||||||
|
|
||||||
|
media_type = parts[0]
|
||||||
|
base64_data = parts[1]
|
||||||
|
|
||||||
|
return media_type, base64_data
|
||||||
|
|
||||||
|
|
||||||
|
def merge_consecutive_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Merge consecutive messages with the same role.
|
||||||
|
Claude requires alternating user/assistant messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of merged messages
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged = []
|
||||||
|
current_role = None
|
||||||
|
current_content = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
if role == current_role:
|
||||||
|
# Same role, merge content
|
||||||
|
if isinstance(content, str):
|
||||||
|
if current_content and isinstance(current_content[-1], dict) and current_content[-1].get("type") == "text":
|
||||||
|
current_content[-1]["text"] += "\n" + content
|
||||||
|
else:
|
||||||
|
current_content.append({"type": "text", "text": content})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
current_content.extend(content)
|
||||||
|
else:
|
||||||
|
# Different role, save previous and start new
|
||||||
|
if current_role is not None and current_content:
|
||||||
|
merged.append({
|
||||||
|
"role": current_role,
|
||||||
|
"content": simplify_content(current_content)
|
||||||
|
})
|
||||||
|
|
||||||
|
current_role = role
|
||||||
|
if isinstance(content, str):
|
||||||
|
current_content = [{"type": "text", "text": content}]
|
||||||
|
elif isinstance(content, list):
|
||||||
|
current_content = content.copy()
|
||||||
|
else:
|
||||||
|
current_content = []
|
||||||
|
|
||||||
|
# Don't forget the last message
|
||||||
|
if current_role is not None and current_content:
|
||||||
|
merged.append({
|
||||||
|
"role": current_role,
|
||||||
|
"content": simplify_content(current_content)
|
||||||
|
})
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def simplify_content(content: List[Dict[str, Any]]) -> Any:
|
||||||
|
"""
|
||||||
|
Simplify content list to string if it only contains text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: List of content items
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Simplified content (string or list)
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# If only one text item, return as string
|
||||||
|
if len(content) == 1 and content[0].get("type") == "text":
|
||||||
|
return content[0].get("text", "")
|
||||||
|
|
||||||
|
# If all items are text, merge them
|
||||||
|
if all(item.get("type") == "text" for item in content):
|
||||||
|
texts = [item.get("text", "") for item in content]
|
||||||
|
return "\n".join(texts)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def get_claude_tools() -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get tool definitions for Claude API.
|
||||||
|
Claude uses a slightly different tool format than OpenAI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool definitions in Claude format
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "google_search",
|
||||||
|
"description": "Search the web for current information",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "The search query"},
|
||||||
|
"num_results": {"type": "integer", "description": "Number of results (max 10)", "maximum": 10}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "scrape_webpage",
|
||||||
|
"description": "Extract and read content from a webpage URL",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {"type": "string", "description": "The webpage URL to scrape"}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "execute_python_code",
|
||||||
|
"description": "Run Python code. Packages auto-install. Use load_file('file_id') for user files. Output files auto-sent to user.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {"type": "string", "description": "Python code to execute"},
|
||||||
|
"timeout": {"type": "integer", "description": "Timeout in seconds", "maximum": 300}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "generate_image",
|
||||||
|
"description": "Create/generate images from text. Models: flux (best), flux-dev, sdxl, realistic (photos), anime, dreamshaper.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {"type": "string", "description": "Detailed description of the image to create"},
|
||||||
|
"model": {"type": "string", "description": "Model to use", "enum": ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"]},
|
||||||
|
"num_images": {"type": "integer", "description": "Number of images (1-4)", "maximum": 4},
|
||||||
|
"aspect_ratio": {"type": "string", "description": "Aspect ratio preset", "enum": ["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "21:9"]}
|
||||||
|
},
|
||||||
|
"required": ["prompt"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "set_reminder",
|
||||||
|
"description": "Set a reminder",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {"type": "string", "description": "Reminder content"},
|
||||||
|
"time": {"type": "string", "description": "Reminder time"}
|
||||||
|
},
|
||||||
|
"required": ["content", "time"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "get_reminders",
|
||||||
|
"description": "List all reminders",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "upscale_image",
|
||||||
|
"description": "Enlarge/upscale an image to higher resolution. Pass 'latest_image' to use the user's most recently uploaded image.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"},
|
||||||
|
"scale_factor": {"type": "integer", "description": "Scale factor (2 or 4)", "enum": [2, 4]},
|
||||||
|
"model": {"type": "string", "description": "Upscale model", "enum": ["clarity", "ccsr", "sd-latent", "swinir"]}
|
||||||
|
},
|
||||||
|
"required": ["image_url"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "remove_background",
|
||||||
|
"description": "Remove background from an image. Pass 'latest_image' to use the user's most recently uploaded image.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"},
|
||||||
|
"model": {"type": "string", "description": "Background removal model", "enum": ["bria", "rembg", "birefnet-base", "birefnet-general", "birefnet-portrait"]}
|
||||||
|
},
|
||||||
|
"required": ["image_url"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "image_to_text",
|
||||||
|
"description": "Generate a text description/caption of an image. Pass 'latest_image' to use the user's most recently uploaded image.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"}
|
||||||
|
},
|
||||||
|
"required": ["image_url"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def call_claude_api(
|
||||||
|
anthropic_client,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
use_tools: bool = True
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Call the Claude API with the given messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
anthropic_client: Anthropic client instance
|
||||||
|
messages: List of messages in OpenAI format
|
||||||
|
model: Model name (e.g., "anthropic/claude-sonnet-4-20250514")
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
use_tools: Whether to include tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with response data including:
|
||||||
|
- content: Response text
|
||||||
|
- input_tokens: Number of input tokens
|
||||||
|
- output_tokens: Number of output tokens
|
||||||
|
- tool_calls: Any tool calls made
|
||||||
|
- stop_reason: Why the response stopped
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert messages
|
||||||
|
system_prompt, claude_messages = convert_openai_messages_to_claude(messages)
|
||||||
|
|
||||||
|
# Get Claude model ID
|
||||||
|
model_id = get_claude_model_id(model)
|
||||||
|
|
||||||
|
# Build API parameters
|
||||||
|
api_params = {
|
||||||
|
"model": model_id,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": claude_messages
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
api_params["system"] = system_prompt
|
||||||
|
|
||||||
|
if use_tools:
|
||||||
|
api_params["tools"] = get_claude_tools()
|
||||||
|
|
||||||
|
# Make API call
|
||||||
|
response = await anthropic_client.messages.create(**api_params)
|
||||||
|
|
||||||
|
# Extract response data
|
||||||
|
result = {
|
||||||
|
"content": "",
|
||||||
|
"input_tokens": response.usage.input_tokens if response.usage else 0,
|
||||||
|
"output_tokens": response.usage.output_tokens if response.usage else 0,
|
||||||
|
"tool_calls": [],
|
||||||
|
"stop_reason": response.stop_reason
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process content blocks
|
||||||
|
for block in response.content:
|
||||||
|
if block.type == "text":
|
||||||
|
result["content"] += block.text
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
result["tool_calls"].append({
|
||||||
|
"id": block.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": json.dumps(block.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error calling Claude API: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def convert_claude_tool_calls_to_openai(tool_calls: List[Dict[str, Any]]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Convert Claude tool calls to OpenAI format for compatibility with existing code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: Tool calls from Claude API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool calls in OpenAI format
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionCall:
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall:
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
result.append(ToolCall(
|
||||||
|
id=tc["id"],
|
||||||
|
type=tc["type"],
|
||||||
|
function=FunctionCall(
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
arguments=tc["function"]["arguments"]
|
||||||
|
)
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -89,6 +89,10 @@ class TestPricingModule:
|
|||||||
"openai/gpt-4.1",
|
"openai/gpt-4.1",
|
||||||
"openai/gpt-5",
|
"openai/gpt-5",
|
||||||
"openai/o1",
|
"openai/o1",
|
||||||
|
"anthropic/claude-sonnet-4-20250514",
|
||||||
|
"anthropic/claude-opus-4-20250514",
|
||||||
|
"anthropic/claude-3.5-sonnet",
|
||||||
|
"anthropic/claude-3.5-haiku",
|
||||||
]
|
]
|
||||||
|
|
||||||
for model in expected_models:
|
for model in expected_models:
|
||||||
@@ -105,6 +109,10 @@ class TestPricingModule:
|
|||||||
# Test smaller amounts
|
# Test smaller amounts
|
||||||
cost = calculate_cost("openai/gpt-4o", 1000, 1000)
|
cost = calculate_cost("openai/gpt-4o", 1000, 1000)
|
||||||
assert cost == pytest.approx(0.025, rel=1e-6) # $0.005 + $0.020
|
assert cost == pytest.approx(0.025, rel=1e-6) # $0.005 + $0.020
|
||||||
|
|
||||||
|
# Test Claude model
|
||||||
|
cost = calculate_cost("anthropic/claude-3.5-sonnet", 1_000_000, 1_000_000)
|
||||||
|
assert cost == 18.00 # $3 + $15
|
||||||
|
|
||||||
def test_calculate_cost_unknown_model(self):
|
def test_calculate_cost_unknown_model(self):
|
||||||
"""Test that unknown models return 0 cost."""
|
"""Test that unknown models return 0 cost."""
|
||||||
@@ -404,6 +412,92 @@ class TestCodeInterpreterSecurity:
|
|||||||
assert fm._detect_file_type("unknown.xyz") == "binary"
|
assert fm._detect_file_type("unknown.xyz") == "binary"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Claude Utils Tests
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class TestClaudeUtils:
|
||||||
|
"""Tests for Claude utility functions."""
|
||||||
|
|
||||||
|
def test_is_claude_model(self):
|
||||||
|
"""Test Claude model detection."""
|
||||||
|
from src.utils.claude_utils import is_claude_model
|
||||||
|
|
||||||
|
# Claude models
|
||||||
|
assert is_claude_model("anthropic/claude-sonnet-4-20250514") == True
|
||||||
|
assert is_claude_model("anthropic/claude-opus-4-20250514") == True
|
||||||
|
assert is_claude_model("anthropic/claude-3.5-sonnet") == True
|
||||||
|
assert is_claude_model("anthropic/claude-3.5-haiku") == True
|
||||||
|
|
||||||
|
# Non-Claude models
|
||||||
|
assert is_claude_model("openai/gpt-4o") == False
|
||||||
|
assert is_claude_model("openai/gpt-4o-mini") == False
|
||||||
|
assert is_claude_model("gpt-4") == False
|
||||||
|
|
||||||
|
def test_get_claude_model_id(self):
|
||||||
|
"""Test Claude model ID extraction."""
|
||||||
|
from src.utils.claude_utils import get_claude_model_id
|
||||||
|
|
||||||
|
assert get_claude_model_id("anthropic/claude-sonnet-4-20250514") == "claude-sonnet-4-20250514"
|
||||||
|
assert get_claude_model_id("anthropic/claude-3.5-sonnet") == "claude-3.5-sonnet"
|
||||||
|
assert get_claude_model_id("claude-3.5-sonnet") == "claude-3.5-sonnet"
|
||||||
|
|
||||||
|
def test_convert_openai_messages_to_claude(self):
|
||||||
|
"""Test message conversion from OpenAI to Claude format."""
|
||||||
|
from src.utils.claude_utils import convert_openai_messages_to_claude
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
system, claude_messages = convert_openai_messages_to_claude(messages)
|
||||||
|
|
||||||
|
# System should be extracted
|
||||||
|
assert system == "You are a helpful assistant."
|
||||||
|
|
||||||
|
# Messages should not contain system
|
||||||
|
assert all(m.get("role") != "system" for m in claude_messages)
|
||||||
|
|
||||||
|
# Should have user and assistant messages
|
||||||
|
assert len(claude_messages) >= 2
|
||||||
|
|
||||||
|
def test_convert_content_to_claude(self):
|
||||||
|
"""Test content conversion."""
|
||||||
|
from src.utils.claude_utils import convert_content_to_claude
|
||||||
|
|
||||||
|
# String content
|
||||||
|
assert convert_content_to_claude("Hello") == "Hello"
|
||||||
|
|
||||||
|
# List content with text
|
||||||
|
list_content = [
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "text", "text": "World"}
|
||||||
|
]
|
||||||
|
result = convert_content_to_claude(list_content)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_merge_consecutive_messages(self):
|
||||||
|
"""Test merging consecutive messages with same role."""
|
||||||
|
from src.utils.claude_utils import merge_consecutive_messages
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "user", "content": "How are you?"},
|
||||||
|
{"role": "assistant", "content": "Hi!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
merged = merge_consecutive_messages(messages)
|
||||||
|
|
||||||
|
# Should merge two user messages into one
|
||||||
|
assert len(merged) == 2
|
||||||
|
assert merged[0]["role"] == "user"
|
||||||
|
assert merged[1]["role"] == "assistant"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# OpenAI Utils Tests
|
# OpenAI Utils Tests
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user