feat: Enhance token usage tracking and history management
- Implement proactive history trimming in DatabaseHandler to limit conversation messages to the last 50 while preserving system messages. - Add methods for saving and retrieving user token usage statistics, including breakdowns by model. - Introduce model pricing for token usage calculations in MessageHandler. - Update message handling to count tokens accurately and log costs associated with API calls. - Refactor token counting methods to utilize API response data instead of internal counting. - Improve tool definitions in get_tools_for_model for clarity and conciseness. - Remove deprecated python_executor_new.py file. - Adjust web_utils to use tiktoken for preprocessing content before API calls.
This commit is contained in:
@@ -2,7 +2,6 @@ discord.py
|
||||
openai
|
||||
motor
|
||||
pymongo
|
||||
tiktoken
|
||||
pypdf
|
||||
beautifulsoup4
|
||||
requests
|
||||
@@ -15,6 +14,5 @@ openpyxl
|
||||
seaborn
|
||||
tzlocal
|
||||
numpy
|
||||
scipy
|
||||
plotly
|
||||
nbformat
|
||||
tiktoken
|
||||
@@ -12,6 +12,25 @@ from src.utils.web_utils import google_custom_search, scrape_web_content
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
from src.utils.openai_utils import prepare_file_from_path
|
||||
|
||||
# Model pricing per 1M tokens (in USD)
|
||||
MODEL_PRICING = {
|
||||
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
|
||||
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
|
||||
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"openai/gpt-5": {"input": 1.25, "output": 10.00},
|
||||
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
|
||||
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
|
||||
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
|
||||
"openai/o1-preview": {"input": 15.00, "output": 60.00},
|
||||
"openai/o1-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o1": {"input": 15.00, "output": 60.00},
|
||||
"openai/o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o3": {"input": 2.00, "output": 8.00},
|
||||
"openai/o4-mini": {"input": 2.00, "output": 8.00}
|
||||
}
|
||||
|
||||
# Dictionary to keep track of user requests and their cooldowns
|
||||
user_requests = {}
|
||||
# Dictionary to store user tasks
|
||||
@@ -298,74 +317,91 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
|
||||
await process_request(interaction, process_image_generation, prompt)
|
||||
|
||||
@tree.command(name="reset", description="Reset the bot by clearing user data.")
|
||||
@tree.command(name="reset", description="Reset the bot by clearing user data and token usage statistics.")
|
||||
@check_blacklist()
|
||||
async def reset(interaction: discord.Interaction):
|
||||
"""Resets the bot by clearing user data."""
|
||||
user_id = interaction.user.id
|
||||
await db_handler.save_history(user_id, [])
|
||||
await interaction.response.send_message("Your conversation history has been cleared and reset!", ephemeral=True)
|
||||
await db_handler.reset_user_token_stats(user_id)
|
||||
await interaction.response.send_message("Your conversation history and token usage statistics have been cleared and reset!", ephemeral=True)
|
||||
|
||||
@tree.command(name="user_stat", description="Get your current input token, output token, and model.")
|
||||
@tree.command(name="user_stat", description="Get your current token usage, costs, and model.")
|
||||
@check_blacklist()
|
||||
async def user_stat(interaction: discord.Interaction):
|
||||
"""Fetches and displays the current input token, output token, and model for the user."""
|
||||
"""Fetches and displays the current token usage, costs, and model for the user."""
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
async def process_user_stat(interaction: discord.Interaction):
|
||||
import tiktoken
|
||||
|
||||
user_id = interaction.user.id
|
||||
history = await db_handler.get_history(user_id)
|
||||
model = await db_handler.get_user_model(user_id) or DEFAULT_MODEL # Default model
|
||||
|
||||
# Adjust model for encoding purposes
|
||||
if model in ["openai/gpt-4o", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o1-preview", "openai/o1-mini", "openai/o3-mini"]:
|
||||
encoding_model = "openai/gpt-4o"
|
||||
else:
|
||||
encoding_model = model
|
||||
|
||||
# Retrieve the appropriate encoding for the selected model
|
||||
encoding = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
# Initialize token counts
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# Calculate input and output tokens
|
||||
if history:
|
||||
for item in history:
|
||||
content = item.get('content', '')
|
||||
|
||||
# Handle case where content is a list or other type
|
||||
if isinstance(content, list):
|
||||
content_str = ""
|
||||
for part in content:
|
||||
if isinstance(part, dict) and 'text' in part:
|
||||
content_str += part['text'] + " "
|
||||
content = content_str
|
||||
|
||||
# Ensure content is a string before processing
|
||||
if isinstance(content, str):
|
||||
tokens = len(encoding.encode(content))
|
||||
if item.get('role') == 'user':
|
||||
input_tokens += tokens
|
||||
elif item.get('role') == 'assistant':
|
||||
output_tokens += tokens
|
||||
model = await db_handler.get_user_model(user_id) or DEFAULT_MODEL
|
||||
|
||||
# Get token usage from database
|
||||
token_stats = await db_handler.get_user_token_usage(user_id)
|
||||
|
||||
total_input_tokens = token_stats.get('total_input_tokens', 0)
|
||||
total_output_tokens = token_stats.get('total_output_tokens', 0)
|
||||
total_cost = token_stats.get('total_cost', 0.0)
|
||||
|
||||
# Get usage by model for detailed breakdown
|
||||
model_usage = await db_handler.get_user_token_usage_by_model(user_id)
|
||||
|
||||
# Create the statistics message
|
||||
stat_message = (
|
||||
f"**User Statistics:**\n"
|
||||
f"Model: `{model}`\n"
|
||||
f"Input Tokens: `{input_tokens}`\n"
|
||||
f"Output Tokens: `{output_tokens}`\n"
|
||||
f"**📊 User Statistics**\n"
|
||||
f"Current Model: `{model}`\n"
|
||||
f"Total Input Tokens: `{total_input_tokens:,}`\n"
|
||||
f"Total Output Tokens: `{total_output_tokens:,}`\n"
|
||||
f"**💰 Total Cost: `${total_cost:.6f}`**\n\n"
|
||||
)
|
||||
|
||||
# Add breakdown by model if available
|
||||
if model_usage:
|
||||
stat_message += "**Model Usage Breakdown:**\n"
|
||||
for model_name, usage in model_usage.items():
|
||||
input_tokens = usage.get('input_tokens', 0)
|
||||
output_tokens = usage.get('output_tokens', 0)
|
||||
cost = usage.get('cost', 0.0)
|
||||
stat_message += f"`{model_name.replace('openai/', '')}`: {input_tokens:,} in, {output_tokens:,} out, ${cost:.6f}\n"
|
||||
|
||||
# Send the response
|
||||
await interaction.followup.send(stat_message, ephemeral=True)
|
||||
|
||||
await process_request(interaction, process_user_stat)
|
||||
|
||||
@tree.command(name="prices", description="Display pricing information for all available AI models.")
|
||||
@check_blacklist()
|
||||
async def prices_command(interaction: discord.Interaction):
|
||||
"""Displays pricing information for all available AI models."""
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
async def process_prices(interaction: discord.Interaction):
|
||||
# Create the pricing message
|
||||
pricing_message = (
|
||||
"**💰 Model Pricing (per 1M tokens)**\n"
|
||||
"```\n"
|
||||
f"{'Model':<20} {'Input':<8} {'Output':<8}\n"
|
||||
f"{'-' * 40}\n"
|
||||
)
|
||||
|
||||
for model, pricing in MODEL_PRICING.items():
|
||||
model_short = model.replace("openai/", "")
|
||||
pricing_message += f"{model_short:<20} ${pricing['input']:<7.2f} ${pricing['output']:<7.2f}\n"
|
||||
|
||||
pricing_message += "```\n"
|
||||
pricing_message += (
|
||||
"**💡 Cost Examples:**\n"
|
||||
"• A typical conversation (~1,000 tokens) with `gpt-4o-mini`: ~$0.002\n"
|
||||
"• A typical conversation (~1,000 tokens) with `gpt-4o`: ~$0.025\n"
|
||||
"• A typical conversation (~1,000 tokens) with `o1-preview`: ~$0.075\n\n"
|
||||
"Use `/user_stat` to see your total usage and costs!"
|
||||
)
|
||||
|
||||
# Send the response
|
||||
await interaction.followup.send(pricing_message, ephemeral=True)
|
||||
|
||||
await process_request(interaction, process_prices)
|
||||
|
||||
@tree.command(name="help", description="Display a list of available commands.")
|
||||
@check_blacklist()
|
||||
async def help_command(interaction: discord.Interaction):
|
||||
@@ -377,8 +413,9 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
|
||||
"/generate `<prompt>` - Generate an image from a text prompt.\n"
|
||||
"/toggle_tools - Toggle display of tool execution details (code, input, output).\n"
|
||||
"/reset - Reset your chat history.\n"
|
||||
"/user_stat - Get information about your input tokens, output tokens, and current model.\n"
|
||||
"/reset - Reset your chat history and token usage statistics.\n"
|
||||
"/user_stat - Get information about your token usage, costs, and current model.\n"
|
||||
"/prices - Display pricing information for all available AI models.\n"
|
||||
"/help - Display this help message.\n"
|
||||
)
|
||||
await interaction.response.send_message(help_message, ephemeral=True)
|
||||
|
||||
@@ -41,7 +41,26 @@ class DatabaseHandler:
|
||||
if user_data and 'history' in user_data:
|
||||
# Filter out expired image links
|
||||
filtered_history = self._filter_expired_images(user_data['history'])
|
||||
return filtered_history
|
||||
|
||||
# Proactive history trimming: Keep only the last 50 messages to prevent excessive token usage
|
||||
# Always preserve system messages
|
||||
system_messages = [msg for msg in filtered_history if msg.get('role') == 'system']
|
||||
conversation_messages = [msg for msg in filtered_history if msg.get('role') != 'system']
|
||||
|
||||
# Keep only the last 50 conversation messages
|
||||
if len(conversation_messages) > 50:
|
||||
conversation_messages = conversation_messages[-50:]
|
||||
logging.info(f"Trimmed history for user {user_id}: kept last 50 conversation messages")
|
||||
|
||||
# Combine system messages with trimmed conversation
|
||||
trimmed_history = system_messages + conversation_messages
|
||||
|
||||
# If history was trimmed, save the trimmed version back to DB
|
||||
if len(trimmed_history) < len(filtered_history):
|
||||
await self.save_history(user_id, trimmed_history)
|
||||
logging.info(f"Saved trimmed history for user {user_id}: {len(trimmed_history)} messages")
|
||||
|
||||
return trimmed_history
|
||||
return []
|
||||
|
||||
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
@@ -180,6 +199,8 @@ class DatabaseHandler:
|
||||
await self.db.user_preferences.create_index("user_id")
|
||||
await self.db.whitelist.create_index("user_id")
|
||||
await self.db.blacklist.create_index("user_id")
|
||||
await self.db.token_usage.create_index([("user_id", 1), ("timestamp", -1)])
|
||||
await self.db.user_token_stats.create_index("user_id")
|
||||
|
||||
async def ensure_reminders_collection(self):
|
||||
"""
|
||||
@@ -190,6 +211,90 @@ class DatabaseHandler:
|
||||
await self.reminders_collection.create_index([("remind_at", 1), ("sent", 1)])
|
||||
logging.info("Ensured reminders collection and indexes")
|
||||
|
||||
# Token usage tracking methods
|
||||
async def save_token_usage(self, user_id: int, model: str, input_tokens: int, output_tokens: int, cost: float):
|
||||
"""Save token usage and cost for a user"""
|
||||
try:
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"model": model,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cost": cost,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
# Insert usage record
|
||||
await self.db.token_usage.insert_one(usage_data)
|
||||
|
||||
# Escape model name for MongoDB field names (replace dots and other special chars)
|
||||
escaped_model = model.replace(".", "_DOT_").replace("/", "_SLASH_").replace("$", "_DOLLAR_")
|
||||
|
||||
# Update user's total usage
|
||||
await self.db.user_token_stats.update_one(
|
||||
{"user_id": user_id},
|
||||
{
|
||||
"$inc": {
|
||||
"total_input_tokens": input_tokens,
|
||||
"total_output_tokens": output_tokens,
|
||||
"total_cost": cost,
|
||||
f"models.{escaped_model}.input_tokens": input_tokens,
|
||||
f"models.{escaped_model}.output_tokens": output_tokens,
|
||||
f"models.{escaped_model}.cost": cost
|
||||
},
|
||||
"$set": {"last_updated": datetime.now()}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving token usage: {e}")
|
||||
|
||||
async def get_user_token_usage(self, user_id: int) -> Dict[str, Any]:
|
||||
"""Get total token usage for a user"""
|
||||
try:
|
||||
user_stats = await self.db.user_token_stats.find_one({"user_id": user_id})
|
||||
if user_stats:
|
||||
return {
|
||||
"total_input_tokens": user_stats.get("total_input_tokens", 0),
|
||||
"total_output_tokens": user_stats.get("total_output_tokens", 0),
|
||||
"total_cost": user_stats.get("total_cost", 0.0)
|
||||
}
|
||||
return {"total_input_tokens": 0, "total_output_tokens": 0, "total_cost": 0.0}
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting user token usage: {e}")
|
||||
return {"total_input_tokens": 0, "total_output_tokens": 0, "total_cost": 0.0}
|
||||
|
||||
async def get_user_token_usage_by_model(self, user_id: int) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get token usage breakdown by model for a user"""
|
||||
try:
|
||||
user_stats = await self.db.user_token_stats.find_one({"user_id": user_id})
|
||||
if user_stats and "models" in user_stats:
|
||||
# Unescape model names for display
|
||||
unescaped_models = {}
|
||||
for escaped_model, usage in user_stats["models"].items():
|
||||
# Reverse the escaping
|
||||
original_model = escaped_model.replace("_DOT_", ".").replace("_SLASH_", "/").replace("_DOLLAR_", "$")
|
||||
unescaped_models[original_model] = usage
|
||||
return unescaped_models
|
||||
return {}
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting user token usage by model: {e}")
|
||||
return {}
|
||||
|
||||
async def reset_user_token_stats(self, user_id: int) -> None:
|
||||
"""Reset all token usage statistics for a user"""
|
||||
try:
|
||||
# Delete the user's token stats document
|
||||
await self.db.user_token_stats.delete_one({"user_id": user_id})
|
||||
|
||||
# Optionally, also delete individual usage records
|
||||
await self.db.token_usage.delete_many({"user_id": user_id})
|
||||
|
||||
logging.info(f"Reset token statistics for user {user_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error resetting user token stats: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""Properly close the database connection"""
|
||||
self.client.close()
|
||||
|
||||
@@ -13,7 +13,6 @@ import sys
|
||||
import subprocess
|
||||
import base64
|
||||
import traceback
|
||||
import tiktoken
|
||||
from datetime import datetime, timedelta
|
||||
from src.utils.openai_utils import process_tool_calls, prepare_messages_for_api, get_tools_for_model
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
@@ -27,6 +26,25 @@ user_last_request = {}
|
||||
RATE_LIMIT_WINDOW = 5 # seconds
|
||||
MAX_REQUESTS = 3 # max requests per window
|
||||
|
||||
# Model pricing per 1M tokens (in USD)
|
||||
MODEL_PRICING = {
|
||||
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
|
||||
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
|
||||
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"openai/gpt-5": {"input": 1.25, "output": 10.00},
|
||||
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
|
||||
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
|
||||
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
|
||||
"openai/o1-preview": {"input": 15.00, "output": 60.00},
|
||||
"openai/o1-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o1": {"input": 15.00, "output": 60.00},
|
||||
"openai/o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o3": {"input": 2.00, "output": 8.00},
|
||||
"openai/o4-mini": {"input": 2.00, "output": 8.00}
|
||||
}
|
||||
|
||||
# File extensions that should be treated as text files
|
||||
TEXT_FILE_EXTENSIONS = [
|
||||
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css',
|
||||
@@ -126,8 +144,14 @@ class MessageHandler:
|
||||
if not PANDAS_AVAILABLE:
|
||||
self._install_data_packages()
|
||||
|
||||
# Initialize tiktoken encoder for token counting (using o200k_base for all models)
|
||||
self.token_encoder = tiktoken.get_encoding("o200k_base")
|
||||
# Initialize tiktoken encoder for internal operations (trimming, estimation)
|
||||
try:
|
||||
import tiktoken
|
||||
self.token_encoder = tiktoken.get_encoding("o200k_base")
|
||||
logging.info("Tiktoken encoder initialized for internal operations")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to initialize tiktoken encoder: {e}")
|
||||
self.token_encoder = None
|
||||
|
||||
def _find_user_id_from_current_task(self):
|
||||
"""
|
||||
@@ -145,6 +169,18 @@ class MessageHandler:
|
||||
return user_id
|
||||
return None
|
||||
|
||||
def _count_tokens_with_tiktoken(self, text: str) -> int:
|
||||
"""Count tokens using tiktoken encoder for internal operations."""
|
||||
if self.token_encoder is None:
|
||||
# Fallback estimation if tiktoken is not available
|
||||
return len(text) // 4
|
||||
|
||||
try:
|
||||
return len(self.token_encoder.encode(text))
|
||||
except Exception as e:
|
||||
logging.warning(f"Error counting tokens with tiktoken: {e}")
|
||||
return len(text) // 4
|
||||
|
||||
def _get_discord_message_from_current_task(self):
|
||||
"""
|
||||
Utility method to get the Discord message from the current asyncio task.
|
||||
@@ -1092,23 +1128,22 @@ class MessageHandler:
|
||||
messages_for_api = prepare_messages_for_api(history)
|
||||
|
||||
# Proactively trim history to avoid context overload while preserving system prompt
|
||||
current_tokens = self._count_tokens(messages_for_api)
|
||||
token_limit = MODEL_TOKEN_LIMITS.get(model, DEFAULT_TOKEN_LIMIT)
|
||||
max_tokens = int(token_limit * 0.8) # Use 80% of limit to leave room for response
|
||||
# Simplified: just check message count instead of tokens
|
||||
max_messages = 20
|
||||
|
||||
if current_tokens > max_tokens:
|
||||
logging.info(f"Proactively trimming history: {current_tokens} tokens > {max_tokens} limit for {model}")
|
||||
if len(messages_for_api) > max_messages:
|
||||
logging.info(f"Proactively trimming history: {len(messages_for_api)} messages > {max_messages} limit for {model}")
|
||||
|
||||
if model in ["openai/o1-mini", "openai/o1-preview"]:
|
||||
# For o1 models, trim the history without system prompt
|
||||
trimmed_history_without_system = self._trim_history_to_token_limit(history_without_system, model, max_tokens)
|
||||
trimmed_history_without_system = self._trim_history_to_token_limit(history_without_system, model)
|
||||
messages_for_api = prepare_messages_for_api(trimmed_history_without_system)
|
||||
|
||||
# Update the history tracking
|
||||
history_without_system = trimmed_history_without_system
|
||||
else:
|
||||
# For regular models, trim the full history (preserving system prompt)
|
||||
trimmed_history = self._trim_history_to_token_limit(history, model, max_tokens)
|
||||
trimmed_history = self._trim_history_to_token_limit(history, model)
|
||||
messages_for_api = prepare_messages_for_api(trimmed_history)
|
||||
|
||||
# Update the history tracking
|
||||
@@ -1124,13 +1159,31 @@ class MessageHandler:
|
||||
else:
|
||||
await self.db.save_history(user_id, history)
|
||||
|
||||
final_tokens = self._count_tokens(messages_for_api)
|
||||
logging.info(f"History trimmed from {current_tokens} to {final_tokens} tokens")
|
||||
logging.info(f"History trimmed from multiple messages to {len(messages_for_api)} messages")
|
||||
|
||||
# Determine which models should have tools available
|
||||
# openai/o1-mini and openai/o1-preview do not support tools
|
||||
use_tools = model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "openai/o3", "openai/o4-mini"]
|
||||
|
||||
# Count tokens being sent to API
|
||||
total_content_length = 0
|
||||
for msg in messages_for_api:
|
||||
content = msg.get('content', '')
|
||||
if isinstance(content, list):
|
||||
# Handle list content (mixed text/images)
|
||||
for item in content:
|
||||
if item.get('type') == 'text':
|
||||
total_content_length += len(str(item.get('text', '')))
|
||||
elif isinstance(content, str):
|
||||
total_content_length += len(content)
|
||||
|
||||
estimated_tokens = self._count_tokens_with_tiktoken(' '.join([
|
||||
str(msg.get('content', '')) for msg in messages_for_api
|
||||
]))
|
||||
|
||||
logging.info(f"API Request Debug - Model: {model}, Messages: {len(messages_for_api)}, "
|
||||
f"Est. tokens: {estimated_tokens}, Content length: {total_content_length} chars")
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": model,
|
||||
@@ -1149,7 +1202,8 @@ class MessageHandler:
|
||||
|
||||
# Add tools if using a supported model
|
||||
if use_tools:
|
||||
api_params["tools"] = get_tools_for_model()
|
||||
tools = get_tools_for_model()
|
||||
api_params["tools"] = tools
|
||||
|
||||
# Initialize variables to track tool responses
|
||||
image_generation_used = False
|
||||
@@ -1176,6 +1230,27 @@ class MessageHandler:
|
||||
# Re-raise other errors
|
||||
raise e
|
||||
|
||||
# Extract token usage and calculate cost
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
total_cost = 0.0
|
||||
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||
|
||||
# Calculate cost based on model pricing
|
||||
if model in MODEL_PRICING:
|
||||
pricing = MODEL_PRICING[model]
|
||||
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
||||
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
||||
total_cost = input_cost + output_cost
|
||||
|
||||
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: ${total_cost:.6f}")
|
||||
|
||||
# Save token usage and cost to database
|
||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||
|
||||
# Process tool calls if any
|
||||
updated_messages = None
|
||||
if use_tools and response.choices[0].finish_reason == "tool_calls":
|
||||
@@ -1261,6 +1336,27 @@ class MessageHandler:
|
||||
follow_up_params["temperature"] = 1
|
||||
|
||||
response = await self.client.chat.completions.create(**follow_up_params)
|
||||
|
||||
# Extract token usage and calculate cost for follow-up call
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
follow_up_input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||
follow_up_output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||
|
||||
input_tokens += follow_up_input_tokens
|
||||
output_tokens += follow_up_output_tokens
|
||||
|
||||
# Calculate additional cost
|
||||
if model in MODEL_PRICING:
|
||||
pricing = MODEL_PRICING[model]
|
||||
additional_input_cost = (follow_up_input_tokens / 1_000_000) * pricing["input"]
|
||||
additional_output_cost = (follow_up_output_tokens / 1_000_000) * pricing["output"]
|
||||
additional_cost = additional_input_cost + additional_output_cost
|
||||
total_cost += additional_cost
|
||||
|
||||
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: ${additional_cost:.6f}")
|
||||
|
||||
# Save additional token usage and cost to database
|
||||
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
@@ -1335,9 +1431,9 @@ class MessageHandler:
|
||||
except Exception as e:
|
||||
logging.error(f"Error handling chart: {str(e)}")
|
||||
|
||||
# Log processing time for performance monitoring
|
||||
# Log processing time and cost for performance monitoring
|
||||
processing_time = time.time() - start_time
|
||||
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model})")
|
||||
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: ${total_cost:.6f})")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation cleanly
|
||||
@@ -1780,163 +1876,91 @@ class MessageHandler:
|
||||
|
||||
def _count_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Count tokens in a list of messages using tiktoken o200k_base encoding.
|
||||
DEPRECATED: Token counting is now handled by API response.
|
||||
This method is kept for backward compatibility but returns 0.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
int: Total token count
|
||||
int: Always returns 0 (use API response for actual counts)
|
||||
"""
|
||||
try:
|
||||
total_tokens = 0
|
||||
|
||||
for message in messages:
|
||||
# Count tokens for role
|
||||
if 'role' in message:
|
||||
total_tokens += len(self.token_encoder.encode(message['role']))
|
||||
|
||||
# Count tokens for content
|
||||
if 'content' in message:
|
||||
content = message['content']
|
||||
if isinstance(content, str):
|
||||
# Simple string content
|
||||
total_tokens += len(self.token_encoder.encode(content))
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content (text + images)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text' and 'text' in item:
|
||||
total_tokens += len(self.token_encoder.encode(item['text']))
|
||||
elif item.get('type') == 'image_url':
|
||||
# Images use a fixed token cost (approximation)
|
||||
total_tokens += 765 # Standard cost for high-detail images
|
||||
|
||||
# Add overhead for message formatting
|
||||
total_tokens += 4 # Overhead per message
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error counting tokens: {str(e)}")
|
||||
# Return a conservative estimate if token counting fails
|
||||
return len(str(messages)) // 3 # Rough approximation
|
||||
logging.warning("_count_tokens is deprecated. Use API response usage field instead.")
|
||||
return 0
|
||||
|
||||
def _trim_history_to_token_limit(self, history: List[Dict[str, Any]], model: str, target_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Trim conversation history to fit within model token limits.
|
||||
Trim conversation history using tiktoken for accurate token counting.
|
||||
This is for internal operations only - billing uses API response tokens.
|
||||
|
||||
Args:
|
||||
history: List of message dictionaries
|
||||
model: Model name to get token limit
|
||||
target_tokens: Optional custom target token count
|
||||
model: Model name (for logging)
|
||||
target_tokens: Maximum tokens to keep (default varies by model)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Trimmed history that fits within token limits
|
||||
List[Dict[str, Any]]: Trimmed history within token limits
|
||||
"""
|
||||
try:
|
||||
# Get token limit for the model
|
||||
if target_tokens:
|
||||
token_limit = target_tokens
|
||||
else:
|
||||
token_limit = MODEL_TOKEN_LIMITS.get(model, DEFAULT_TOKEN_LIMIT)
|
||||
# Set reasonable token limits based on model
|
||||
if target_tokens is None:
|
||||
if "gpt-4" in model.lower():
|
||||
target_tokens = 6000 # Conservative for gpt-4 models
|
||||
elif "gpt-3.5" in model.lower():
|
||||
target_tokens = 3000 # Conservative for gpt-3.5
|
||||
else:
|
||||
target_tokens = 4000 # Default for other models
|
||||
|
||||
# Reserve 20% of tokens for the response and some buffer
|
||||
available_tokens = int(token_limit * 0.8)
|
||||
|
||||
# Always keep the system message if present
|
||||
system_message = None
|
||||
# Separate system messages from conversation
|
||||
system_messages = []
|
||||
conversation_messages = []
|
||||
|
||||
for msg in history:
|
||||
if msg.get('role') == 'system':
|
||||
system_message = msg
|
||||
system_messages.append(msg)
|
||||
else:
|
||||
conversation_messages.append(msg)
|
||||
|
||||
# Start with system message
|
||||
trimmed_history = []
|
||||
# Calculate tokens for system messages (always keep these)
|
||||
system_token_count = 0
|
||||
for msg in system_messages:
|
||||
content = str(msg.get('content', ''))
|
||||
system_token_count += self._count_tokens_with_tiktoken(content)
|
||||
|
||||
# Available tokens for conversation
|
||||
available_tokens = max(0, target_tokens - system_token_count)
|
||||
|
||||
# Trim conversation messages from the beginning if needed
|
||||
current_tokens = 0
|
||||
trimmed_conversation = []
|
||||
|
||||
if system_message:
|
||||
system_tokens = self._count_tokens([system_message])
|
||||
if system_tokens < available_tokens:
|
||||
trimmed_history.append(system_message)
|
||||
current_tokens += system_tokens
|
||||
else:
|
||||
# If system message is too large, truncate it
|
||||
content = system_message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
# Truncate system message to fit
|
||||
words = content.split()
|
||||
truncated_content = ''
|
||||
for word in words:
|
||||
test_content = truncated_content + ' ' + word if truncated_content else word
|
||||
test_tokens = len(self.token_encoder.encode(test_content))
|
||||
if test_tokens < available_tokens // 2: # Use half available tokens for system
|
||||
truncated_content = test_content
|
||||
else:
|
||||
break
|
||||
|
||||
truncated_system = {
|
||||
'role': 'system',
|
||||
'content': truncated_content + '...[truncated]'
|
||||
}
|
||||
trimmed_history.append(truncated_system)
|
||||
current_tokens += self._count_tokens([truncated_system])
|
||||
|
||||
# Add conversation messages from most recent backwards
|
||||
available_for_conversation = available_tokens - current_tokens
|
||||
|
||||
# Process messages in reverse order (most recent first)
|
||||
# Start from the end (most recent) and work backwards
|
||||
for msg in reversed(conversation_messages):
|
||||
msg_tokens = self._count_tokens([msg])
|
||||
content = str(msg.get('content', ''))
|
||||
msg_tokens = self._count_tokens_with_tiktoken(content)
|
||||
|
||||
if current_tokens + msg_tokens <= available_tokens:
|
||||
if system_message:
|
||||
# Insert after system message (position 1)
|
||||
trimmed_history.insert(1, msg)
|
||||
else:
|
||||
# Insert at start if no system message
|
||||
trimmed_history.insert(0, msg)
|
||||
trimmed_conversation.insert(0, msg)
|
||||
current_tokens += msg_tokens
|
||||
else:
|
||||
# Stop adding more messages
|
||||
# If this message would exceed the limit, stop trimming
|
||||
break
|
||||
|
||||
# Ensure we have at least the last user message if possible
|
||||
if len(conversation_messages) > 0 and len(trimmed_history) <= (1 if system_message else 0):
|
||||
last_msg = conversation_messages[-1]
|
||||
last_msg_tokens = self._count_tokens([last_msg])
|
||||
|
||||
if last_msg_tokens < available_tokens:
|
||||
if system_message:
|
||||
trimmed_history.insert(-1, last_msg)
|
||||
else:
|
||||
trimmed_history.append(last_msg)
|
||||
# Combine system messages with trimmed conversation
|
||||
result = system_messages + trimmed_conversation
|
||||
|
||||
logging.info(f"Trimmed history from {len(history)} to {len(trimmed_history)} messages "
|
||||
f"({self._count_tokens(history)} to {self._count_tokens(trimmed_history)} tokens) "
|
||||
f"for model {model}")
|
||||
logging.info(f"Trimmed history from {len(history)} to {len(result)} messages "
|
||||
f"(~{current_tokens + system_token_count} tokens for {model})")
|
||||
|
||||
return trimmed_history
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error trimming history: {str(e)}")
|
||||
# Return a safe minimal history
|
||||
if history:
|
||||
# Keep system message + last user message if possible
|
||||
minimal_history = []
|
||||
for msg in history:
|
||||
if msg.get('role') == 'system':
|
||||
minimal_history.append(msg)
|
||||
break
|
||||
|
||||
# Add the last user message
|
||||
for msg in reversed(history):
|
||||
if msg.get('role') == 'user':
|
||||
minimal_history.append(msg)
|
||||
break
|
||||
|
||||
return minimal_history
|
||||
logging.error(f"Error trimming history: {e}")
|
||||
# Fallback: simple message count limit
|
||||
max_messages = 15
|
||||
if len(history) > max_messages:
|
||||
# Keep system messages and last N conversation messages
|
||||
system_msgs = [msg for msg in history if msg.get('role') == 'system']
|
||||
other_msgs = [msg for msg in history if msg.get('role') != 'system']
|
||||
return system_msgs + other_msgs[-max_messages:]
|
||||
return history
|
||||
|
||||
@@ -22,23 +22,19 @@ if PROJECT_ROOT not in sys.path:
|
||||
|
||||
|
||||
def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"""Returns concise tool definitions optimized for token usage."""
|
||||
"""Returns minimal tool definitions optimized for token usage."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "analyze_data_file",
|
||||
"description": "Analyze CSV/Excel files with templates or custom analysis.",
|
||||
"description": "Analyze CSV/Excel files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {"type": "string", "description": "Path to CSV/Excel file"},
|
||||
"analysis_type": {
|
||||
"type": "string",
|
||||
"enum": ["summary", "correlation", "distribution", "comprehensive"],
|
||||
"default": "comprehensive"
|
||||
},
|
||||
"custom_analysis": {"type": "string", "description": "Custom analysis request"}
|
||||
"file_path": {"type": "string"},
|
||||
"analysis_type": {"type": "string", "enum": ["summary", "correlation", "distribution", "comprehensive"]},
|
||||
"custom_analysis": {"type": "string"}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
@@ -48,12 +44,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "edit_image",
|
||||
"description": "Edit images (remove background, etc). Returns image URLs.",
|
||||
"description": "Edit images (remove background). Returns URLs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Image URL to edit"},
|
||||
"operation": {"type": "string", "enum": ["remove_background"], "default": "remove_background"}
|
||||
"image_url": {"type": "string"},
|
||||
"operation": {"type": "string", "enum": ["remove_background"]}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
@@ -63,12 +59,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "enhance_prompt",
|
||||
"description": "Create enhanced versions of text prompts.",
|
||||
"description": "Create enhanced prompt versions.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string", "description": "Original prompt"},
|
||||
"num_versions": {"type": "integer", "default": 3, "minimum": 1, "maximum": 5}
|
||||
"prompt": {"type": "string"},
|
||||
"num_versions": {"type": "integer", "minimum": 1, "maximum": 5}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
@@ -78,12 +74,10 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_to_text",
|
||||
"description": "Convert image to text description.",
|
||||
"description": "Convert image to text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Image URL to analyze"}
|
||||
},
|
||||
"properties": {"image_url": {"type": "string"}},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
}
|
||||
@@ -92,12 +86,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "upscale_image",
|
||||
"description": "Upscale image resolution. Returns image URLs.",
|
||||
"description": "Upscale image resolution. Returns URLs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Image URL to upscale"},
|
||||
"scale_factor": {"type": "integer", "enum": [2, 3, 4], "default": 4}
|
||||
"image_url": {"type": "string"},
|
||||
"scale_factor": {"type": "integer", "enum": [2, 3, 4]}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
@@ -107,14 +101,14 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "photo_maker",
|
||||
"description": "Generate images based on reference photos. Returns image URLs.",
|
||||
"description": "Generate images from reference photos. Returns URLs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string", "description": "Generation prompt"},
|
||||
"input_images": {"type": "array", "items": {"type": "string"}, "description": "Reference image URLs"},
|
||||
"strength": {"type": "integer", "default": 40, "minimum": 1, "maximum": 100},
|
||||
"num_images": {"type": "integer", "default": 1, "minimum": 1, "maximum": 4}
|
||||
"prompt": {"type": "string"},
|
||||
"input_images": {"type": "array", "items": {"type": "string"}},
|
||||
"strength": {"type": "integer", "minimum": 1, "maximum": 100},
|
||||
"num_images": {"type": "integer", "minimum": 1, "maximum": 4}
|
||||
},
|
||||
"required": ["prompt", "input_images"]
|
||||
}
|
||||
@@ -124,13 +118,13 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image_with_refiner",
|
||||
"description": "Generate high-quality images. Returns image URLs.",
|
||||
"description": "Generate high-quality images. Returns URLs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string", "description": "Image prompt"},
|
||||
"num_images": {"type": "integer", "default": 1, "minimum": 1, "maximum": 4},
|
||||
"negative_prompt": {"type": "string", "default": "blurry, low quality"}
|
||||
"prompt": {"type": "string"},
|
||||
"num_images": {"type": "integer", "minimum": 1, "maximum": 4},
|
||||
"negative_prompt": {"type": "string"}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
@@ -144,8 +138,8 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"num_results": {"type": "integer", "default": 3, "minimum": 1, "maximum": 10}
|
||||
"query": {"type": "string"},
|
||||
"num_results": {"type": "integer", "minimum": 1, "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
@@ -158,9 +152,7 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"description": "Extract content from webpage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "Webpage URL"}
|
||||
},
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
@@ -169,12 +161,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": "Generate images from text prompts. Returns image URLs.",
|
||||
"description": "Generate images from text. Returns URLs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string", "description": "Image prompt"},
|
||||
"num_images": {"type": "integer", "default": 1, "minimum": 1, "maximum": 4}
|
||||
"prompt": {"type": "string"},
|
||||
"num_images": {"type": "integer", "minimum": 1, "maximum": 4}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
@@ -184,19 +176,15 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute_python_code",
|
||||
"description": "Execute Python code with automatic package installation. IMPORTANT: If your code imports any library (pandas, numpy, requests, matplotlib, etc.), you MUST include it in 'install_packages' parameter or the code will fail. Always use print() statements to show output. Examples of packages: numpy, pandas, matplotlib, seaborn, requests, beautifulsoup4, opencv-python, scikit-learn, plotly, etc.",
|
||||
"description": "Execute Python code with package installation. MUST use install_packages for any imports.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Python code with print() statements for output"},
|
||||
"input_data": {"type": "string", "description": "Optional input data"},
|
||||
"install_packages": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "REQUIRED: List ALL pip packages your code imports. Examples: ['pandas'] for pd.read_csv(), ['matplotlib'] for plt.plot(), ['requests'] for HTTP requests, ['numpy'] for arrays, ['beautifulsoup4'] for HTML parsing, etc. If you use ANY import statements, add the package here!"
|
||||
},
|
||||
"enable_visualization": {"type": "boolean", "description": "For charts/graphs"},
|
||||
"timeout": {"type": "integer", "default": 60, "minimum": 1, "maximum": 300, "description": "Execution timeout in seconds (default 60, max 300)"}
|
||||
"code": {"type": "string"},
|
||||
"input_data": {"type": "string"},
|
||||
"install_packages": {"type": "array", "items": {"type": "string"}},
|
||||
"enable_visualization": {"type": "boolean"},
|
||||
"timeout": {"type": "integer", "minimum": 1, "maximum": 300}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
@@ -206,12 +194,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "set_reminder",
|
||||
"description": "Set user reminder. Supports relative time (30m, 2h, 1d), specific times (9:00, 15:30, 9:00 pm, 2:30 am), keywords (tomorrow, tonight, noon), and combinations (9:00 pm today, 2:00 pm tomorrow).",
|
||||
"description": "Set user reminder with flexible time formats.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "Reminder content"},
|
||||
"time": {"type": "string", "description": "Time in formats like: '30m', '2h', '1d', '9:00', '15:30', '9:00 pm', '2:30 am', '9:00 pm today', '2:00 pm tomorrow', 'tomorrow', 'tonight', 'noon'"}
|
||||
"content": {"type": "string"},
|
||||
"time": {"type": "string"}
|
||||
},
|
||||
"required": ["content", "time"]
|
||||
}
|
||||
@@ -222,7 +210,7 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"function": {
|
||||
"name": "get_reminders",
|
||||
"description": "Get user reminders list.",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []}
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -324,15 +312,8 @@ def prepare_messages_for_api(messages: List[Dict[str, Any]]) -> List[Dict[str, A
|
||||
"""Prepare message history for the OpenAI API with image URL handling."""
|
||||
prepared_messages = []
|
||||
|
||||
# Check if there's a system message already
|
||||
has_system_message = any(msg.get('role') == 'system' for msg in messages)
|
||||
|
||||
# If no system message exists, add a default one
|
||||
if not has_system_message:
|
||||
prepared_messages.append({
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant. For image tools, you'll receive image URLs in responses - reference them instead of sending binary data."
|
||||
})
|
||||
# Note: System message handling is done in message_handler.py
|
||||
# We don't add a default system message here to avoid duplication
|
||||
|
||||
for msg in messages:
|
||||
# Skip messages with None content
|
||||
|
||||
@@ -5,13 +5,13 @@ import logging
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
|
||||
import tiktoken # Add tiktoken for token counting
|
||||
import tiktoken # Used only for preprocessing content before API calls
|
||||
|
||||
# Global tiktoken encoder - initialized once to avoid blocking
|
||||
# Global tiktoken encoder for preprocessing - initialized once to avoid blocking
|
||||
try:
|
||||
TIKTOKEN_ENCODER = tiktoken.get_encoding("o200k_base")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize tiktoken encoder: {e}")
|
||||
logging.error(f"Failed to initialize tiktoken encoder for preprocessing: {e}")
|
||||
TIKTOKEN_ENCODER = None
|
||||
|
||||
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
|
||||
@@ -91,7 +91,7 @@ def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str,
|
||||
total_tokens = 0
|
||||
used_urls = []
|
||||
|
||||
# Use global encoder directly (no async needed since it's pre-initialized)
|
||||
# Use tiktoken for preprocessing estimation only
|
||||
encoding = TIKTOKEN_ENCODER
|
||||
|
||||
for url in urls:
|
||||
@@ -111,6 +111,7 @@ def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str,
|
||||
# If this is the first URL and it's too large, we need to truncate it
|
||||
if not combined_content:
|
||||
if encoding:
|
||||
# Use tiktoken for accurate preprocessing truncation
|
||||
tokens = encoding.encode(content)
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_content = encoding.decode(truncated_tokens)
|
||||
@@ -186,23 +187,27 @@ def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
text = '\n'.join(line for line in lines if line)
|
||||
|
||||
# Count tokens
|
||||
# Count tokens using tiktoken for preprocessing accuracy
|
||||
token_count = 0
|
||||
try:
|
||||
# Use global o200k_base encoder
|
||||
encoding = TIKTOKEN_ENCODER
|
||||
if encoding:
|
||||
tokens = encoding.encode(text)
|
||||
if TIKTOKEN_ENCODER:
|
||||
tokens = TIKTOKEN_ENCODER.encode(text)
|
||||
token_count = len(tokens)
|
||||
|
||||
# Truncate if token count exceeds max_tokens and we're not returning token count
|
||||
if len(tokens) > max_tokens and not return_token_count:
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
text = encoding.decode(truncated_tokens)
|
||||
text += "...\n[Content truncated due to token limit]"
|
||||
except ImportError:
|
||||
|
||||
# Truncate if content exceeds max_tokens and we're not returning token count
|
||||
if len(tokens) > max_tokens and not return_token_count:
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
text = TIKTOKEN_ENCODER.decode(truncated_tokens)
|
||||
text += "...\n[Content truncated due to token limit]"
|
||||
else:
|
||||
# Fallback to character-based estimation
|
||||
token_count = len(text) // 4
|
||||
if len(text) > max_tokens * 4 and not return_token_count:
|
||||
text = text[:max_tokens * 4] + "...\n[Content truncated due to length]"
|
||||
except Exception as e:
|
||||
logging.warning(f"Token counting failed for preprocessing: {e}")
|
||||
# Fallback to character-based estimation
|
||||
token_count = len(text) // 4 # Rough estimate: 1 token ≈ 4 characters
|
||||
token_count = len(text) // 4
|
||||
if len(text) > max_tokens * 4 and not return_token_count:
|
||||
text = text[:max_tokens * 4] + "...\n[Content truncated due to length]"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user