Add Claude API integration alongside existing GPT API
Co-authored-by: cauvang32 <113093128+cauvang32@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,11 @@ OPENAI_API_KEY=your_openai_api_key_here
|
||||
# Use OpenAI directly: https://api.openai.com/v1
|
||||
OPENAI_BASE_URL=https://models.github.ai/inference
|
||||
|
||||
# Claude API Key (Anthropic)
|
||||
# Get from: https://console.anthropic.com/
|
||||
# Leave empty to disable Claude models
|
||||
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||
|
||||
# ============================================
|
||||
# Image Generation (Optional)
|
||||
# ============================================
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -2,8 +2,10 @@ test.py
|
||||
.env
|
||||
chat_history.db
|
||||
bot_copy.py
|
||||
__pycache__/bot.cpython-312.pyc
|
||||
tests/__pycache__/test_bot.cpython-312.pyc
|
||||
__pycache__/
|
||||
**/__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
.vscode/settings.json
|
||||
chatgpt.zip
|
||||
response.txt
|
||||
|
||||
20
bot.py
20
bot.py
@@ -17,7 +17,7 @@ from src.config.config import (
|
||||
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
|
||||
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
|
||||
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
|
||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
|
||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP, ANTHROPIC_API_KEY
|
||||
)
|
||||
|
||||
# Import webhook logger
|
||||
@@ -124,6 +124,20 @@ async def main():
|
||||
logging.error(f"Error initializing OpenAI client: {e}")
|
||||
return
|
||||
|
||||
# Initialize the Claude (Anthropic) client if API key is available
|
||||
claude_client = None
|
||||
if ANTHROPIC_API_KEY:
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
claude_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
|
||||
logging.info("Claude (Anthropic) client initialized successfully")
|
||||
except ImportError:
|
||||
logging.warning("Failed to import Anthropic. Make sure it's installed: pip install anthropic")
|
||||
except Exception as e:
|
||||
logging.warning(f"Error initializing Claude client: {e}")
|
||||
else:
|
||||
logging.info("ANTHROPIC_API_KEY not set - Claude models will not be available")
|
||||
|
||||
# Global references to objects that need cleanup
|
||||
message_handler = None
|
||||
db_handler = None
|
||||
@@ -191,14 +205,14 @@ async def main():
|
||||
await ctx.send(f"Error: {error_msg}")
|
||||
|
||||
# Initialize message handler
|
||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
|
||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator, claude_client)
|
||||
|
||||
# Attach db_handler to bot for cogs
|
||||
bot.db_handler = db_handler
|
||||
|
||||
# Set up slash commands
|
||||
from src.commands.commands import setup_commands
|
||||
setup_commands(bot, db_handler, openai_client, image_generator)
|
||||
setup_commands(bot, db_handler, openai_client, image_generator, claude_client)
|
||||
|
||||
# Load file management commands
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Discord Bot Core
|
||||
discord.py>=2.3.0
|
||||
openai>=1.40.0
|
||||
anthropic>=0.39.0
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Database
|
||||
|
||||
@@ -12,6 +12,7 @@ from src.utils.image_utils import ImageGenerator
|
||||
from src.utils.web_utils import google_custom_search, scrape_web_content
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
from src.utils.openai_utils import prepare_file_from_path
|
||||
from src.utils.claude_utils import is_claude_model, call_claude_api
|
||||
from src.utils.token_counter import token_counter
|
||||
from src.utils.code_interpreter import delete_all_user_files
|
||||
from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed
|
||||
@@ -69,7 +70,7 @@ async def image_model_autocomplete(
|
||||
for model in matches[:25]
|
||||
]
|
||||
|
||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator, claude_client=None):
|
||||
"""
|
||||
Set up all slash commands for the bot.
|
||||
|
||||
@@ -78,6 +79,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
db_handler: Database handler instance
|
||||
openai_client: OpenAI client instance
|
||||
image_generator: Image generator instance
|
||||
claude_client: Claude (Anthropic) client instance (optional)
|
||||
"""
|
||||
tree = bot.tree
|
||||
|
||||
@@ -265,24 +267,53 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
f"(text: {input_token_count['text_tokens']}, images: {input_token_count['image_tokens']})"
|
||||
)
|
||||
|
||||
# Send to the AI model
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.5
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
# Check if using Claude model
|
||||
if is_claude_model(model):
|
||||
if not claude_client:
|
||||
await interaction.followup.send(
|
||||
"❌ Claude API not configured. Please set ANTHROPIC_API_KEY.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
# Call Claude API
|
||||
claude_response = await call_claude_api(
|
||||
claude_client,
|
||||
messages,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
if not claude_response.get("success"):
|
||||
await interaction.followup.send(
|
||||
f"❌ Claude API Error: {claude_response.get('error', 'Unknown error')}",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
reply = claude_response.get("content", "")
|
||||
actual_input_tokens = claude_response.get("input_tokens", 0)
|
||||
actual_output_tokens = claude_response.get("output_tokens", 0)
|
||||
else:
|
||||
# Send to the OpenAI model
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.5
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Get actual token usage from API response
|
||||
usage = response.usage
|
||||
actual_input_tokens = usage.prompt_tokens if usage else input_token_count['total_tokens']
|
||||
actual_output_tokens = usage.completion_tokens if usage else token_counter.count_text_tokens(reply, model)
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Get actual token usage from API response
|
||||
usage = response.usage
|
||||
actual_input_tokens = usage.prompt_tokens if usage else input_token_count['total_tokens']
|
||||
actual_output_tokens = usage.completion_tokens if usage else token_counter.count_text_tokens(reply, model)
|
||||
|
||||
# Calculate cost
|
||||
cost = token_counter.estimate_cost(actual_input_tokens, actual_output_tokens, model)
|
||||
@@ -362,19 +393,47 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
{"role": "user", "content": f"Content from {url}:\n{content}"}
|
||||
]
|
||||
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.3
|
||||
api_params["top_p"] = 0.7
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
# Check if using Claude model
|
||||
if is_claude_model(model):
|
||||
if not claude_client:
|
||||
await interaction.followup.send(
|
||||
"❌ Claude API not configured. Please set ANTHROPIC_API_KEY.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
# Call Claude API
|
||||
claude_response = await call_claude_api(
|
||||
claude_client,
|
||||
messages,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
if not claude_response.get("success"):
|
||||
await interaction.followup.send(
|
||||
f"❌ Claude API Error: {claude_response.get('error', 'Unknown error')}",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
reply = claude_response.get("content", "")
|
||||
else:
|
||||
# Send to the OpenAI model
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.3
|
||||
api_params["top_p"] = 0.7
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add the interaction to history
|
||||
history.append({"role": "user", "content": f"Scraped content from {url}"})
|
||||
|
||||
Binary file not shown.
@@ -101,7 +101,10 @@ MODEL_OPTIONS = [
|
||||
"openai/o1",
|
||||
"openai/o3-mini",
|
||||
"openai/o3",
|
||||
"openai/o4-mini"
|
||||
"openai/o4-mini",
|
||||
"claude/claude-3-5-sonnet",
|
||||
"claude/claude-3-5-haiku",
|
||||
"claude/claude-3-opus",
|
||||
]
|
||||
|
||||
# ==================== IMAGE GENERATION MODELS ====================
|
||||
@@ -175,7 +178,10 @@ MODEL_TOKEN_LIMITS = {
|
||||
"openai/gpt-5": 4000,
|
||||
"openai/gpt-5-nano": 4000,
|
||||
"openai/gpt-5-mini": 4000,
|
||||
"openai/gpt-5-chat": 4000
|
||||
"openai/gpt-5-chat": 4000,
|
||||
"claude/claude-3-5-sonnet": 8000,
|
||||
"claude/claude-3-5-haiku": 8000,
|
||||
"claude/claude-3-opus": 8000,
|
||||
}
|
||||
|
||||
# Default token limit for unknown models
|
||||
@@ -403,6 +409,7 @@ RUNWARE_API_KEY = os.getenv("RUNWARE_API_KEY")
|
||||
MONGODB_URI = os.getenv("MONGODB_URI")
|
||||
ADMIN_ID = os.getenv("ADMIN_ID") # Add ADMIN_ID if you're using it
|
||||
TIMEZONE = os.getenv("TIMEZONE", "UTC") # Default to UTC if not specified
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") # Claude API key
|
||||
|
||||
# File management settings
|
||||
FILE_EXPIRATION_HOURS = int(os.getenv("FILE_EXPIRATION_HOURS", "48")) # Hours until files expire (-1 for never)
|
||||
@@ -416,5 +423,7 @@ if not MONGODB_URI:
|
||||
print("WARNING: MONGODB_URI not found in .env file")
|
||||
if not RUNWARE_API_KEY:
|
||||
print("WARNING: RUNWARE_API_KEY not found in .env file")
|
||||
if not ANTHROPIC_API_KEY:
|
||||
print("INFO: ANTHROPIC_API_KEY not found in .env file - Claude models will not be available")
|
||||
if ENABLE_WEBHOOK_LOGGING and not LOGGING_WEBHOOK_URL:
|
||||
print("WARNING: Webhook logging enabled but LOGGING_WEBHOOK_URL not found in .env file")
|
||||
@@ -51,6 +51,11 @@ MODEL_PRICING: Dict[str, ModelPricing] = {
|
||||
|
||||
# o4 Family
|
||||
"openai/o4-mini": ModelPricing(input=2.00, output=8.00),
|
||||
|
||||
# Claude Family (Anthropic)
|
||||
"claude/claude-3-5-sonnet": ModelPricing(input=3.00, output=15.00),
|
||||
"claude/claude-3-5-haiku": ModelPricing(input=0.80, output=4.00),
|
||||
"claude/claude-3-opus": ModelPricing(input=15.00, output=75.00),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
@@ -15,6 +15,7 @@ import base64
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from src.utils.openai_utils import process_tool_calls, prepare_messages_for_api, get_tools_for_model
|
||||
from src.utils.claude_utils import is_claude_model, call_claude_api, convert_messages_for_claude
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
from src.utils.code_utils import extract_code_blocks
|
||||
from src.utils.reminder_utils import ReminderManager
|
||||
@@ -95,7 +96,7 @@ except ImportError as e:
|
||||
logging.warning(f"Data analysis libraries not available: {str(e)}")
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self, bot, db_handler, openai_client, image_generator):
|
||||
def __init__(self, bot, db_handler, openai_client, image_generator, claude_client=None):
|
||||
"""
|
||||
Initialize the message handler.
|
||||
|
||||
@@ -104,10 +105,12 @@ class MessageHandler:
|
||||
db_handler: Database handler instance
|
||||
openai_client: OpenAI client instance
|
||||
image_generator: Image generator instance
|
||||
claude_client: Claude (Anthropic) client instance (optional)
|
||||
"""
|
||||
self.bot = bot
|
||||
self.db = db_handler
|
||||
self.client = openai_client
|
||||
self.claude_client = claude_client
|
||||
self.image_generator = image_generator
|
||||
self.aiohttp_session = None
|
||||
|
||||
@@ -1514,6 +1517,7 @@ print("\\n=== Correlation Analysis ===")
|
||||
|
||||
# Determine which models should have tools available
|
||||
# openai/o1-mini and openai/o1-preview do not support tools
|
||||
# Claude models also don't support OpenAI-style tools (yet)
|
||||
use_tools = model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "openai/o3", "openai/o4-mini"]
|
||||
|
||||
# Count tokens being sent to API
|
||||
@@ -1535,6 +1539,79 @@ print("\\n=== Correlation Analysis ===")
|
||||
logging.info(f"API Request Debug - Model: {model}, Messages: {len(messages_for_api)}, "
|
||||
f"Est. tokens: {estimated_tokens}, Content length: {total_content_length} chars")
|
||||
|
||||
# Initialize variables to track tool responses
|
||||
image_generation_used = False
|
||||
chart_id = None
|
||||
image_urls = [] # Will store unique image URLs
|
||||
|
||||
# Check if this is a Claude model
|
||||
if is_claude_model(model):
|
||||
# Handle Claude API call
|
||||
if not self.claude_client:
|
||||
await message.channel.send(
|
||||
f"❌ **Claude API not configured**\n"
|
||||
f"The Claude model `{model}` requires an Anthropic API key.\n"
|
||||
f"Please set `ANTHROPIC_API_KEY` in your environment variables."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Call Claude API
|
||||
claude_response = await call_claude_api(
|
||||
self.claude_client,
|
||||
messages_for_api,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
if not claude_response.get("success"):
|
||||
error_msg = claude_response.get("error", "Unknown error")
|
||||
await message.channel.send(f"❌ **Claude API Error:** {error_msg}")
|
||||
return
|
||||
|
||||
reply = claude_response.get("content", "")
|
||||
input_tokens = claude_response.get("input_tokens", 0)
|
||||
output_tokens = claude_response.get("output_tokens", 0)
|
||||
|
||||
# Calculate cost
|
||||
pricing = MODEL_PRICING.get(model)
|
||||
if pricing:
|
||||
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
||||
logging.info(f"Claude API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||
else:
|
||||
total_cost = 0.0
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if "rate_limit" in error_str.lower():
|
||||
await message.channel.send(
|
||||
f"❌ **Rate limit exceeded**\n"
|
||||
f"Please wait a moment before trying again."
|
||||
)
|
||||
else:
|
||||
await message.channel.send(f"❌ **Claude API Error:** {error_str}")
|
||||
return
|
||||
|
||||
# Store response in history for Claude models
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
|
||||
# Only keep a reasonable amount of history
|
||||
if len(history) > 15:
|
||||
history = history[:1] + history[-14:]
|
||||
|
||||
await self.db.save_history(user_id, history)
|
||||
|
||||
# Send the response text
|
||||
await send_response(message.channel, reply)
|
||||
|
||||
# Log processing time and cost
|
||||
processing_time = time.time() - start_time
|
||||
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: {format_cost(total_cost)})")
|
||||
return
|
||||
|
||||
# Handle OpenAI API call (existing logic)
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": model,
|
||||
@@ -1556,11 +1633,6 @@ print("\\n=== Correlation Analysis ===")
|
||||
tools = get_tools_for_model()
|
||||
api_params["tools"] = tools
|
||||
|
||||
# Initialize variables to track tool responses
|
||||
image_generation_used = False
|
||||
chart_id = None
|
||||
image_urls = [] # Will store unique image URLs
|
||||
|
||||
# Make the initial API call
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
214
src/utils/claude_utils.py
Normal file
214
src/utils/claude_utils.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Claude API utilities for handling Anthropic Claude model interactions.
|
||||
This module provides similar functionality to openai_utils.py but for Claude models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
|
||||
|
||||
# Map from internal model names to Anthropic API model names
|
||||
CLAUDE_MODEL_MAP = {
|
||||
"claude/claude-3-5-sonnet": "claude-3-5-sonnet-20241022",
|
||||
"claude/claude-3-5-haiku": "claude-3-5-haiku-20241022",
|
||||
"claude/claude-3-opus": "claude-3-opus-20240229",
|
||||
}
|
||||
|
||||
|
||||
def get_anthropic_model_name(model: str) -> str:
|
||||
"""Convert internal model name to Anthropic API model name."""
|
||||
return CLAUDE_MODEL_MAP.get(model, model)
|
||||
|
||||
|
||||
def is_claude_model(model: str) -> bool:
|
||||
"""Check if the model is a Claude model."""
|
||||
return model.startswith("claude/")
|
||||
|
||||
|
||||
def convert_messages_for_claude(messages: List[Dict[str, Any]]) -> Tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Convert OpenAI-style messages to Claude format.
|
||||
|
||||
Claude requires:
|
||||
- System message as a separate parameter (not in messages array)
|
||||
- Messages array without system messages
|
||||
- Different image format
|
||||
|
||||
Args:
|
||||
messages: List of OpenAI-style messages
|
||||
|
||||
Returns:
|
||||
Tuple of (system_prompt, converted_messages)
|
||||
"""
|
||||
system_prompt = None
|
||||
converted_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get('role', '')
|
||||
content = msg.get('content', '')
|
||||
|
||||
# Extract system message
|
||||
if role == 'system':
|
||||
system_prompt = content if isinstance(content, str) else str(content)
|
||||
continue
|
||||
|
||||
# Skip tool and tool_call messages for now (Claude handles tools differently)
|
||||
if role in ['tool', 'function']:
|
||||
continue
|
||||
|
||||
# Convert content based on type
|
||||
if isinstance(content, str):
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": content
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Handle mixed content (text + images)
|
||||
claude_content = []
|
||||
for item in content:
|
||||
item_type = item.get('type', '')
|
||||
|
||||
if item_type == 'text':
|
||||
claude_content.append({
|
||||
"type": "text",
|
||||
"text": item.get('text', '')
|
||||
})
|
||||
elif item_type == 'image_url':
|
||||
# Convert image_url format to Claude's format
|
||||
image_url_data = item.get('image_url', {})
|
||||
url = image_url_data.get('url') if isinstance(image_url_data, dict) else str(image_url_data)
|
||||
|
||||
if url:
|
||||
# Claude expects base64 data or URLs in a specific format
|
||||
if url.startswith('data:'):
|
||||
# Handle base64 encoded images
|
||||
# Format: data:image/png;base64,<base64data>
|
||||
try:
|
||||
media_type = url.split(';')[0].split(':')[1]
|
||||
base64_data = url.split(',')[1]
|
||||
claude_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
})
|
||||
except (IndexError, ValueError) as e:
|
||||
logging.warning(f"Failed to parse base64 image: {e}")
|
||||
else:
|
||||
# For URLs, Claude requires downloading the image
|
||||
# We'll include it as a URL reference in text for now
|
||||
claude_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": url
|
||||
}
|
||||
})
|
||||
|
||||
if claude_content:
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": claude_content
|
||||
})
|
||||
elif content is not None:
|
||||
# Handle any other content types by converting to string
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": str(content)
|
||||
})
|
||||
|
||||
return system_prompt, converted_messages
|
||||
|
||||
|
||||
async def call_claude_api(
|
||||
client,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call the Claude API with the given messages.
|
||||
|
||||
Args:
|
||||
client: Anthropic client instance
|
||||
messages: List of messages in OpenAI format
|
||||
model: Model name (internal format like "claude/claude-3-5-sonnet")
|
||||
max_tokens: Maximum tokens in response
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Dict containing response content and usage info
|
||||
"""
|
||||
try:
|
||||
# Convert model name to Anthropic format
|
||||
anthropic_model = get_anthropic_model_name(model)
|
||||
|
||||
# Convert messages to Claude format
|
||||
system_prompt, claude_messages = convert_messages_for_claude(messages)
|
||||
|
||||
# Prepare API parameters
|
||||
api_params = {
|
||||
"model": anthropic_model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": claude_messages,
|
||||
}
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
api_params["system"] = system_prompt
|
||||
|
||||
# Add temperature (Claude supports 0-1 range)
|
||||
api_params["temperature"] = min(max(temperature, 0), 1)
|
||||
|
||||
# Make the API call
|
||||
response = await client.messages.create(**api_params)
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if hasattr(block, 'text'):
|
||||
content += block.text
|
||||
|
||||
# Extract usage information
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
if hasattr(response, 'usage'):
|
||||
input_tokens = getattr(response.usage, 'input_tokens', 0)
|
||||
output_tokens = getattr(response.usage, 'output_tokens', 0)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"stop_reason": getattr(response, 'stop_reason', None),
|
||||
"model": anthropic_model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Claude API call failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"content": None,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0
|
||||
}
|
||||
|
||||
|
||||
def get_claude_tools() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions in Claude format.
|
||||
Claude uses a different tool format than OpenAI.
|
||||
|
||||
Note: Tool support for Claude is simplified for now.
|
||||
Full tool support would require more extensive integration.
|
||||
"""
|
||||
# For now, we return an empty list as tool support is complex
|
||||
# and would require significant changes to the tool handling logic
|
||||
# Users can use Claude models for text-only interactions
|
||||
return []
|
||||
Binary file not shown.
Binary file not shown.
@@ -375,5 +375,64 @@ class TestPDFUtils(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIn('file', kwargs)
|
||||
|
||||
|
||||
class TestClaudeUtils(unittest.TestCase):
|
||||
"""Test Claude utility functions"""
|
||||
|
||||
def test_is_claude_model(self):
|
||||
from src.utils.claude_utils import is_claude_model
|
||||
|
||||
# Test Claude models
|
||||
self.assertTrue(is_claude_model("claude/claude-3-5-sonnet"))
|
||||
self.assertTrue(is_claude_model("claude/claude-3-5-haiku"))
|
||||
self.assertTrue(is_claude_model("claude/claude-3-opus"))
|
||||
|
||||
# Test non-Claude models
|
||||
self.assertFalse(is_claude_model("openai/gpt-4o"))
|
||||
self.assertFalse(is_claude_model("openai/gpt-4o-mini"))
|
||||
self.assertFalse(is_claude_model("gpt-4"))
|
||||
|
||||
def test_get_anthropic_model_name(self):
|
||||
from src.utils.claude_utils import get_anthropic_model_name
|
||||
|
||||
# Test model name mapping
|
||||
self.assertEqual(get_anthropic_model_name("claude/claude-3-5-sonnet"), "claude-3-5-sonnet-20241022")
|
||||
self.assertEqual(get_anthropic_model_name("claude/claude-3-5-haiku"), "claude-3-5-haiku-20241022")
|
||||
self.assertEqual(get_anthropic_model_name("claude/claude-3-opus"), "claude-3-opus-20240229")
|
||||
|
||||
# Test unknown model (returns as-is)
|
||||
self.assertEqual(get_anthropic_model_name("unknown-model"), "unknown-model")
|
||||
|
||||
def test_convert_messages_for_claude(self):
|
||||
from src.utils.claude_utils import convert_messages_for_claude
|
||||
|
||||
# Test with system message
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
system_prompt, converted = convert_messages_for_claude(messages)
|
||||
|
||||
self.assertEqual(system_prompt, "You are a helpful assistant.")
|
||||
self.assertEqual(len(converted), 2) # System message should be extracted
|
||||
self.assertEqual(converted[0]["role"], "user")
|
||||
self.assertEqual(converted[0]["content"], "Hello!")
|
||||
self.assertEqual(converted[1]["role"], "assistant")
|
||||
self.assertEqual(converted[1]["content"], "Hi there!")
|
||||
|
||||
def test_convert_messages_without_system(self):
|
||||
from src.utils.claude_utils import convert_messages_for_claude
|
||||
|
||||
# Test without system message
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"}
|
||||
]
|
||||
system_prompt, converted = convert_messages_for_claude(messages)
|
||||
|
||||
self.assertIsNone(system_prompt)
|
||||
self.assertEqual(len(converted), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user