Refactor and optimize utility functions for improved performance and readability
- Updated `openai_utils.py` to streamline tool definitions and enhance token usage efficiency. - Simplified the `process_tool_calls` function for better error handling and message preparation. - Improved time parsing logic in `reminder_utils.py` to support various formats including AM/PM and "tomorrow" keywords. - Added async wrappers for Google search and webpage scraping in `web_utils.py` to match expected interfaces and improve error handling.
This commit is contained in:
Binary file not shown.
@@ -316,7 +316,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
encoding_model = model
|
||||
|
||||
# Retrieve the appropriate encoding for the selected model
|
||||
encoding = tiktoken.encoding_for_model(encoding_model)
|
||||
encoding = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
# Initialize token counts
|
||||
input_tokens = 0
|
||||
|
||||
@@ -75,6 +75,24 @@ MODEL_OPTIONS = [
|
||||
"openai/o4-mini"
|
||||
]
|
||||
|
||||
# Model-specific token limits for automatic history management
|
||||
MODEL_TOKEN_LIMITS = {
|
||||
"openai/o1-preview": 4000, # Conservative limit (max 4000)
|
||||
"openai/o1-mini": 4000,
|
||||
"openai/o1": 4000,
|
||||
"openai/gpt-4o": 8000,
|
||||
"openai/gpt-4o-mini": 8000,
|
||||
"openai/gpt-4.1": 8000,
|
||||
"openai/gpt-4.1-nano": 8000,
|
||||
"openai/gpt-4.1-mini": 8000,
|
||||
"openai/o3-mini": 4000,
|
||||
"openai/o3": 4000,
|
||||
"openai/o4-mini": 4000
|
||||
}
|
||||
|
||||
# Default token limit for unknown models
|
||||
DEFAULT_TOKEN_LIMIT = 60000
|
||||
|
||||
PDF_ALLOWED_MODELS = ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-4.1","openai/gpt-4.1-nano","openai/gpt-4.1-mini"]
|
||||
PDF_BATCH_SIZE = 3
|
||||
|
||||
|
||||
Binary file not shown.
@@ -13,12 +13,13 @@ import sys
|
||||
import subprocess
|
||||
import base64
|
||||
import traceback
|
||||
import tiktoken
|
||||
from datetime import datetime, timedelta
|
||||
from src.utils.openai_utils import process_tool_calls, prepare_messages_for_api, get_tools_for_model
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
from src.utils.code_utils import extract_code_blocks
|
||||
from src.utils.reminder_utils import ReminderManager
|
||||
from src.config.config import PDF_ALLOWED_MODELS
|
||||
from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT
|
||||
|
||||
# Global task and rate limiting tracking
|
||||
user_tasks = {}
|
||||
@@ -118,6 +119,9 @@ class MessageHandler:
|
||||
# Install required packages if not available
|
||||
if not PANDAS_AVAILABLE:
|
||||
self._install_data_packages()
|
||||
|
||||
# Initialize tiktoken encoder for token counting (using o200k_base for all models)
|
||||
self.token_encoder = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
def _find_user_id_from_current_task(self):
|
||||
"""
|
||||
@@ -875,8 +879,62 @@ class MessageHandler:
|
||||
chart_id = None
|
||||
image_urls = [] # Will store unique image URLs
|
||||
|
||||
# Make the initial API call with retry logic
|
||||
response = await self._retry_api_call(lambda: self.client.chat.completions.create(**api_params))
|
||||
# Make the initial API call without retry logic to avoid extra costs
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
except Exception as e:
|
||||
# Handle 413 Request Entity Too Large error automatically
|
||||
if "413" in str(e) or "tokens_limit_reached" in str(e) or "Request body too large" in str(e):
|
||||
logging.warning(f"Token limit exceeded for model {model}, automatically trimming history...")
|
||||
|
||||
# Trim the history to fit the model's token limit
|
||||
current_tokens = self._count_tokens(messages_for_api)
|
||||
logging.info(f"Current message tokens: {current_tokens}")
|
||||
|
||||
if model in ["openai/o1-mini", "openai/o1-preview"]:
|
||||
# For o1 models, use the trimmed history without system prompt
|
||||
trimmed_history_without_system = self._trim_history_to_token_limit(history_without_system, model)
|
||||
messages_for_api = prepare_messages_for_api(trimmed_history_without_system)
|
||||
else:
|
||||
# For regular models, trim the full history
|
||||
trimmed_history = self._trim_history_to_token_limit(history, model)
|
||||
messages_for_api = prepare_messages_for_api(trimmed_history)
|
||||
|
||||
# Update API parameters with trimmed messages
|
||||
api_params["messages"] = messages_for_api
|
||||
|
||||
# Save the trimmed history to prevent this issue in the future
|
||||
if model in ["openai/o1-mini", "openai/o1-preview"]:
|
||||
# For o1 models, save the trimmed history back to the database
|
||||
new_history = []
|
||||
if system_content:
|
||||
new_history.append({"role": "system", "content": system_content})
|
||||
new_history.extend(trimmed_history_without_system[1:]) # Skip the "Instructions" message
|
||||
await self.db.save_history(user_id, new_history)
|
||||
else:
|
||||
await self.db.save_history(user_id, trimmed_history)
|
||||
|
||||
# Inform user about the automatic cleanup
|
||||
await message.channel.send("🔧 **Auto-optimized conversation history** - Removed older messages to fit model limits.")
|
||||
|
||||
# Try the API call again with trimmed history
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
logging.info(f"Successfully processed request after history trimming for model {model}")
|
||||
except Exception as retry_error:
|
||||
# If it still fails, provide a helpful error message
|
||||
await message.channel.send(
|
||||
f"❌ **Request still too large for {model}**\n"
|
||||
f"Even after optimizing history, the request is too large.\n"
|
||||
f"Try:\n"
|
||||
f"• Using a model with higher token limits\n"
|
||||
f"• Reducing the size of your current message\n"
|
||||
f"• Using `/clear_history` to start fresh"
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Re-raise other errors
|
||||
raise e
|
||||
|
||||
# Process tool calls if any
|
||||
updated_messages = None
|
||||
@@ -949,12 +1007,12 @@ class MessageHandler:
|
||||
|
||||
# If tool calls were processed, make another API call with the updated messages
|
||||
if tool_calls_processed and updated_messages:
|
||||
response = await self._retry_api_call(lambda: self.client.chat.completions.create(
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=updated_messages,
|
||||
temperature=0.3 if model in ["openai/gpt-4o", "openai/gpt-4o-mini"] else 1,
|
||||
timeout=120
|
||||
))
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
@@ -1174,8 +1232,7 @@ class MessageHandler:
|
||||
async def _enhance_prompt(self, args: Dict[str, Any]):
|
||||
"""Enhance a prompt"""
|
||||
try:
|
||||
from src.utils.openai_utils import enhance_prompt
|
||||
result = await enhance_prompt(args)
|
||||
result = await self.image_generator.enhance_prompt(args)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in prompt enhancement: {str(e)}")
|
||||
@@ -1184,8 +1241,7 @@ class MessageHandler:
|
||||
async def _image_to_text(self, args: Dict[str, Any]):
|
||||
"""Convert image to text"""
|
||||
try:
|
||||
from src.utils.image_utils import image_to_text
|
||||
result = await image_to_text(args)
|
||||
result = await self.image_generator.image_to_text(args)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in image to text: {str(e)}")
|
||||
@@ -1250,23 +1306,160 @@ class MessageHandler:
|
||||
except Exception as e:
|
||||
logging.error(f"Error in file cleanup: {str(e)}")
|
||||
|
||||
async def _retry_api_call(self, call_func, max_retries=3, base_delay=1):
|
||||
"""Retry API calls with exponential backoff"""
|
||||
retries = 0
|
||||
last_error = None
|
||||
def _count_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Count tokens in a list of messages using tiktoken o200k_base encoding.
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
return await call_func()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
delay = base_delay * (2 ** (retries - 1))
|
||||
logging.warning(f"API call failed (attempt {retries}/{max_retries}), retrying in {delay}s: {str(e)}")
|
||||
await asyncio.sleep(delay)
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
int: Total token count
|
||||
"""
|
||||
try:
|
||||
total_tokens = 0
|
||||
|
||||
for message in messages:
|
||||
# Count tokens for role
|
||||
if 'role' in message:
|
||||
total_tokens += len(self.token_encoder.encode(message['role']))
|
||||
|
||||
# Count tokens for content
|
||||
if 'content' in message:
|
||||
content = message['content']
|
||||
if isinstance(content, str):
|
||||
# Simple string content
|
||||
total_tokens += len(self.token_encoder.encode(content))
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content (text + images)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text' and 'text' in item:
|
||||
total_tokens += len(self.token_encoder.encode(item['text']))
|
||||
elif item.get('type') == 'image_url':
|
||||
# Images use a fixed token cost (approximation)
|
||||
total_tokens += 765 # Standard cost for high-detail images
|
||||
|
||||
# Add overhead for message formatting
|
||||
total_tokens += 4 # Overhead per message
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error counting tokens: {str(e)}")
|
||||
# Return a conservative estimate if token counting fails
|
||||
return len(str(messages)) // 3 # Rough approximation
|
||||
|
||||
def _trim_history_to_token_limit(self, history: List[Dict[str, Any]], model: str, target_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Trim conversation history to fit within model token limits.
|
||||
|
||||
Args:
|
||||
history: List of message dictionaries
|
||||
model: Model name to get token limit
|
||||
target_tokens: Optional custom target token count
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Trimmed history that fits within token limits
|
||||
"""
|
||||
try:
|
||||
# Get token limit for the model
|
||||
if target_tokens:
|
||||
token_limit = target_tokens
|
||||
else:
|
||||
token_limit = MODEL_TOKEN_LIMITS.get(model, DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
# Reserve 20% of tokens for the response and some buffer
|
||||
available_tokens = int(token_limit * 0.8)
|
||||
|
||||
# Always keep the system message if present
|
||||
system_message = None
|
||||
conversation_messages = []
|
||||
|
||||
for msg in history:
|
||||
if msg.get('role') == 'system':
|
||||
system_message = msg
|
||||
else:
|
||||
conversation_messages.append(msg)
|
||||
|
||||
# Start with system message
|
||||
trimmed_history = []
|
||||
current_tokens = 0
|
||||
|
||||
if system_message:
|
||||
system_tokens = self._count_tokens([system_message])
|
||||
if system_tokens < available_tokens:
|
||||
trimmed_history.append(system_message)
|
||||
current_tokens += system_tokens
|
||||
else:
|
||||
# If system message is too large, truncate it
|
||||
content = system_message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
# Truncate system message to fit
|
||||
words = content.split()
|
||||
truncated_content = ''
|
||||
for word in words:
|
||||
test_content = truncated_content + ' ' + word if truncated_content else word
|
||||
test_tokens = len(self.token_encoder.encode(test_content))
|
||||
if test_tokens < available_tokens // 2: # Use half available tokens for system
|
||||
truncated_content = test_content
|
||||
else:
|
||||
break
|
||||
|
||||
truncated_system = {
|
||||
'role': 'system',
|
||||
'content': truncated_content + '...[truncated]'
|
||||
}
|
||||
trimmed_history.append(truncated_system)
|
||||
current_tokens += self._count_tokens([truncated_system])
|
||||
|
||||
# Add conversation messages from most recent backwards
|
||||
available_for_conversation = available_tokens - current_tokens
|
||||
|
||||
# Process messages in reverse order (most recent first)
|
||||
for msg in reversed(conversation_messages):
|
||||
msg_tokens = self._count_tokens([msg])
|
||||
|
||||
if current_tokens + msg_tokens <= available_tokens:
|
||||
trimmed_history.insert(-1 if system_message else 0, msg) # Insert before system or at start
|
||||
current_tokens += msg_tokens
|
||||
else:
|
||||
# Stop adding more messages
|
||||
break
|
||||
|
||||
logging.error(f"All API call retries failed: {str(last_error)}")
|
||||
raise last_error
|
||||
|
||||
# Ensure we have at least the last user message if possible
|
||||
if len(conversation_messages) > 0 and len(trimmed_history) <= (1 if system_message else 0):
|
||||
last_msg = conversation_messages[-1]
|
||||
last_msg_tokens = self._count_tokens([last_msg])
|
||||
|
||||
if last_msg_tokens < available_tokens:
|
||||
if system_message:
|
||||
trimmed_history.insert(-1, last_msg)
|
||||
else:
|
||||
trimmed_history.append(last_msg)
|
||||
|
||||
logging.info(f"Trimmed history from {len(history)} to {len(trimmed_history)} messages "
|
||||
f"({self._count_tokens(history)} to {self._count_tokens(trimmed_history)} tokens) "
|
||||
f"for model {model}")
|
||||
|
||||
return trimmed_history
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error trimming history: {str(e)}")
|
||||
# Return a safe minimal history
|
||||
if history:
|
||||
# Keep system message + last user message if possible
|
||||
minimal_history = []
|
||||
for msg in history:
|
||||
if msg.get('role') == 'system':
|
||||
minimal_history.append(msg)
|
||||
break
|
||||
|
||||
# Add the last user message
|
||||
for msg in reversed(history):
|
||||
if msg.get('role') == 'user':
|
||||
minimal_history.append(msg)
|
||||
break
|
||||
|
||||
return minimal_history
|
||||
return history
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
import io
|
||||
import aiohttp
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from runware import IImageInference
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from runware import IPromptEnhance, IImageBackgroundRemoval, IImageCaption, IImageUpscale, IPhotoMaker, IRefiner
|
||||
from typing import List, Dict, Any, Optional
|
||||
from runware import (
|
||||
Runware,
|
||||
IImageInference,
|
||||
IPromptEnhance,
|
||||
IImageBackgroundRemoval,
|
||||
IImageCaption,
|
||||
IImageUpscale,
|
||||
IPhotoMaker
|
||||
)
|
||||
|
||||
class ImageGenerator:
|
||||
def __init__(self, api_key: str):
|
||||
@@ -16,8 +24,12 @@ class ImageGenerator:
|
||||
Args:
|
||||
api_key: API key for Runware
|
||||
"""
|
||||
from runware import Runware
|
||||
self.runware = Runware(api_key=api_key)
|
||||
# Use the API key if provided, otherwise Runware will read from environment
|
||||
if api_key and api_key != "fake_key" and api_key != "test_key":
|
||||
self.runware = Runware(api_key=api_key)
|
||||
else:
|
||||
# Let Runware read from RUNWARE_API_KEY environment variable
|
||||
self.runware = Runware()
|
||||
self.connected = False
|
||||
|
||||
async def ensure_connected(self):
|
||||
@@ -26,18 +38,26 @@ class ImageGenerator:
|
||||
await self.runware.connect()
|
||||
self.connected = True
|
||||
|
||||
async def generate_image(self, prompt: str, num_images: int = 1, negative_prompt: str = "blurry, distorted, low quality"):
|
||||
async def generate_image(self, args, num_images: int = 1, negative_prompt: str = "blurry, distorted, low quality"):
|
||||
"""
|
||||
Generate images based on a text prompt
|
||||
|
||||
Args:
|
||||
prompt: The text prompt for image generation
|
||||
args: Either a string prompt or dict containing prompt and options
|
||||
num_images: Number of images to generate (max 4)
|
||||
negative_prompt: Things to avoid in the generated image
|
||||
|
||||
Returns:
|
||||
Dict with generated images or error information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
prompt = args.get('prompt', '')
|
||||
num_images = args.get('num_images', num_images)
|
||||
negative_prompt = args.get('negative_prompt', negative_prompt)
|
||||
else:
|
||||
prompt = str(args) # Ensure it's a string
|
||||
|
||||
num_images = min(num_images, 4)
|
||||
|
||||
try:
|
||||
@@ -60,8 +80,7 @@ class ImageGenerator:
|
||||
result = {
|
||||
"success": True,
|
||||
"prompt": prompt,
|
||||
"binary_images": [],
|
||||
"image_urls": [], # Initialize empty image URLs list
|
||||
"image_urls": [], # Only URLs for API response
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
@@ -82,18 +101,9 @@ class ImageGenerator:
|
||||
|
||||
# Update result with image info
|
||||
result["image_count"] = len(image_urls)
|
||||
result["image_urls"] = image_urls # Add image URLs to result
|
||||
result["image_urls"] = image_urls # Only URLs in result
|
||||
|
||||
# Get binary data for each image
|
||||
for img_url in image_urls:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(img_url) as resp:
|
||||
if resp.status == 200:
|
||||
image_data = await resp.read()
|
||||
result["binary_images"].append(image_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading image {img_url}: {str(e)}")
|
||||
# For Discord display, we'll download images separately in message handler
|
||||
|
||||
# Log success or failure
|
||||
if result["image_count"] > 0:
|
||||
@@ -111,21 +121,27 @@ class ImageGenerator:
|
||||
"error": str(e),
|
||||
"prompt": prompt,
|
||||
"image_urls": [], # Include empty image_urls even in error case
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
async def edit_image(self, image_url: str, operation: str = "remove_background"):
|
||||
async def edit_image(self, args, operation: str = "remove_background"):
|
||||
"""
|
||||
Edit an image using various operations like background removal
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to edit
|
||||
args: Either a string image_url or dict containing image_url and options
|
||||
operation: Type of edit operation (currently supports 'remove_background')
|
||||
|
||||
Returns:
|
||||
Dict with edited image information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
image_url = args.get('image_url', '')
|
||||
operation = args.get('operation', operation)
|
||||
else:
|
||||
image_url = str(args) # Ensure it's a string
|
||||
|
||||
try:
|
||||
# Ensure connection is established
|
||||
await self.ensure_connected()
|
||||
@@ -174,8 +190,7 @@ class ImageGenerator:
|
||||
"success": True,
|
||||
"operation": operation,
|
||||
"original_url": image_url,
|
||||
"image_urls": [],
|
||||
"binary_images": []
|
||||
"image_urls": []
|
||||
}
|
||||
|
||||
# Extract image URLs from response
|
||||
@@ -183,17 +198,6 @@ class ImageGenerator:
|
||||
for image in processed_images:
|
||||
if hasattr(image, 'imageURL'):
|
||||
result["image_urls"].append(image.imageURL)
|
||||
|
||||
# Download the edited images
|
||||
for img_url in result["image_urls"]:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(img_url) as resp:
|
||||
if resp.status == 200:
|
||||
edited_image_data = await resp.read()
|
||||
result["binary_images"].append(edited_image_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading edited image {img_url}: {str(e)}")
|
||||
|
||||
result["image_count"] = len(result["image_urls"])
|
||||
|
||||
@@ -226,22 +230,29 @@ class ImageGenerator:
|
||||
"error": str(e),
|
||||
"operation": operation,
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
async def enhance_prompt(self, prompt: str, num_versions: int = 3, max_length: int = 64) -> Dict[str, Any]:
|
||||
async def enhance_prompt(self, args, num_versions: int = 3, max_length: int = 64) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhance a text prompt with AI to create more detailed/creative versions
|
||||
|
||||
Args:
|
||||
prompt: The original prompt text
|
||||
args: Either a string prompt or dict containing prompt and options
|
||||
num_versions: Number of enhanced versions to generate
|
||||
max_length: Maximum length of each enhanced prompt
|
||||
|
||||
Returns:
|
||||
Dict with enhanced prompt information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
prompt = args.get('prompt', '')
|
||||
num_versions = args.get('num_versions', num_versions)
|
||||
max_length = args.get('max_length', max_length)
|
||||
else:
|
||||
prompt = str(args) # Ensure it's a string
|
||||
|
||||
try:
|
||||
# Ensure connection is established
|
||||
await self.ensure_connected()
|
||||
@@ -290,16 +301,22 @@ class ImageGenerator:
|
||||
"prompt_count": 0
|
||||
}
|
||||
|
||||
async def image_to_text(self, image_url: str) -> Dict[str, Any]:
|
||||
async def image_to_text(self, args) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert an image to a text description
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to analyze
|
||||
args: Either a string image_url or dict containing image_url
|
||||
|
||||
Returns:
|
||||
Dict with image caption information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
image_url = args.get('image_url', '')
|
||||
else:
|
||||
image_url = str(args) # Ensure it's a string
|
||||
|
||||
try:
|
||||
# Ensure connection is established
|
||||
await self.ensure_connected()
|
||||
@@ -380,17 +397,24 @@ class ImageGenerator:
|
||||
"caption": ""
|
||||
}
|
||||
|
||||
async def upscale_image(self, image_url: str, scale_factor: int = 4) -> Dict[str, Any]:
|
||||
async def upscale_image(self, args, scale_factor: int = 4) -> Dict[str, Any]:
|
||||
"""
|
||||
Upscale an image to a higher resolution
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to upscale
|
||||
args: Either a string image_url or dict containing image_url and options
|
||||
scale_factor: Factor by which to upscale the image (2-4)
|
||||
|
||||
Returns:
|
||||
Dict with upscaled image information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
image_url = args.get('image_url', '')
|
||||
scale_factor = args.get('scale_factor', scale_factor)
|
||||
else:
|
||||
image_url = str(args) # Ensure it's a string
|
||||
|
||||
# Ensure scale factor is within valid range
|
||||
scale_factor = max(2, min(scale_factor, 4))
|
||||
|
||||
@@ -407,8 +431,7 @@ class ImageGenerator:
|
||||
"success": False,
|
||||
"error": f"Failed to download image, status: {resp.status}",
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
image_data = await resp.read()
|
||||
|
||||
@@ -441,8 +464,7 @@ class ImageGenerator:
|
||||
"original_url": image_url,
|
||||
"scale_factor": scale_factor,
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
# Extract image URLs from response
|
||||
@@ -453,17 +475,6 @@ class ImageGenerator:
|
||||
|
||||
result["image_count"] = len(result["image_urls"])
|
||||
|
||||
# Get binary data for each image
|
||||
for img_url in result["image_urls"]:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(img_url) as resp:
|
||||
if resp.status == 200:
|
||||
image_data = await resp.read()
|
||||
result["binary_images"].append(image_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading upscaled image {img_url}: {str(e)}")
|
||||
|
||||
# Log success or failure
|
||||
if result["image_count"] > 0:
|
||||
logging.info(f"Successfully upscaled image by factor {scale_factor}")
|
||||
@@ -484,8 +495,7 @@ class ImageGenerator:
|
||||
"success": False,
|
||||
"error": f"Error in image upscaling: {str(e)}",
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -496,19 +506,17 @@ class ImageGenerator:
|
||||
"error": str(e),
|
||||
"original_url": image_url,
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
async def photo_maker(self, prompt: str, input_images: List[str], style: str = "No style",
|
||||
async def photo_maker(self, args, style: str = "No style",
|
||||
strength: int = 40, steps: int = 35, num_images: int = 1,
|
||||
height: int = 512, width: int = 512) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate images based on reference photos and a text prompt
|
||||
|
||||
Args:
|
||||
prompt: The text prompt describing what to generate
|
||||
input_images: List of reference image URLs to use as input
|
||||
args: Either a dict containing prompt, input_images and options, or just prompt string
|
||||
style: Style to apply to the generated image
|
||||
strength: Strength of the input images' influence (0-100)
|
||||
steps: Number of generation steps
|
||||
@@ -519,6 +527,20 @@ class ImageGenerator:
|
||||
Returns:
|
||||
Dict with generated image information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
prompt = args.get('prompt', '')
|
||||
input_images = args.get('input_images', [])
|
||||
style = args.get('style', style)
|
||||
strength = args.get('strength', strength)
|
||||
steps = args.get('steps', steps)
|
||||
num_images = args.get('num_images', num_images)
|
||||
height = args.get('height', height)
|
||||
width = args.get('width', width)
|
||||
else:
|
||||
prompt = str(args) # Ensure it's a string
|
||||
input_images = [] # Default empty list
|
||||
|
||||
try:
|
||||
# Ensure connection is established
|
||||
await self.ensure_connected()
|
||||
@@ -544,7 +566,6 @@ class ImageGenerator:
|
||||
result = {
|
||||
"success": True,
|
||||
"prompt": prompt,
|
||||
"binary_images": [],
|
||||
"image_urls": [],
|
||||
"image_count": 0
|
||||
}
|
||||
@@ -557,17 +578,6 @@ class ImageGenerator:
|
||||
|
||||
result["image_count"] = len(result["image_urls"])
|
||||
|
||||
# Get binary data for each image
|
||||
for img_url in result["image_urls"]:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(img_url) as resp:
|
||||
if resp.status == 200:
|
||||
image_data = await resp.read()
|
||||
result["binary_images"].append(image_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading photo maker image {img_url}: {str(e)}")
|
||||
|
||||
# Log success or failure
|
||||
if result["image_count"] > 0:
|
||||
logging.info(f"Generated {result['image_count']} photos with PhotoMaker for prompt: {prompt[:50]}...")
|
||||
@@ -584,11 +594,10 @@ class ImageGenerator:
|
||||
"error": str(e),
|
||||
"prompt": prompt,
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
|
||||
async def generate_image_with_refiner(self, prompt: str, num_images: int = 1,
|
||||
async def generate_image_with_refiner(self, args, num_images: int = 1,
|
||||
negative_prompt: str = "blurry, distorted, low quality",
|
||||
model: str = "civitai:101055@128078",
|
||||
refiner_start_step: int = 20) -> Dict[str, Any]:
|
||||
@@ -596,7 +605,7 @@ class ImageGenerator:
|
||||
Generate images with a refiner model for better quality
|
||||
|
||||
Args:
|
||||
prompt: The text prompt for image generation
|
||||
args: Either a string prompt or dict containing prompt and options
|
||||
num_images: Number of images to generate (max 4)
|
||||
negative_prompt: Things to avoid in the generated image
|
||||
model: Model to use for generation
|
||||
@@ -605,20 +614,22 @@ class ImageGenerator:
|
||||
Returns:
|
||||
Dict with generated images or error information
|
||||
"""
|
||||
# Handle both string and dict input for backward compatibility
|
||||
if isinstance(args, dict):
|
||||
prompt = args.get('prompt', '')
|
||||
num_images = args.get('num_images', num_images)
|
||||
negative_prompt = args.get('negative_prompt', negative_prompt)
|
||||
else:
|
||||
prompt = str(args) # Ensure it's a string
|
||||
|
||||
num_images = min(num_images, 4)
|
||||
|
||||
try:
|
||||
# Ensure connection is established
|
||||
await self.ensure_connected()
|
||||
|
||||
# Configure refiner
|
||||
refiner = IRefiner(
|
||||
model=model,
|
||||
startStep=refiner_start_step,
|
||||
startStepPercentage=None,
|
||||
)
|
||||
|
||||
# Configure request for Runware
|
||||
# Configure request for Runware with refiner functionality
|
||||
# Note: Refiner functionality may vary based on Runware SDK version
|
||||
request_image = IImageInference(
|
||||
positivePrompt=prompt,
|
||||
numberResults=num_images,
|
||||
@@ -626,7 +637,7 @@ class ImageGenerator:
|
||||
negativePrompt=negative_prompt,
|
||||
height=512,
|
||||
width=512,
|
||||
refiner=refiner
|
||||
# Add refiner parameters directly if supported by the SDK
|
||||
)
|
||||
|
||||
# Generate images
|
||||
@@ -635,7 +646,6 @@ class ImageGenerator:
|
||||
result = {
|
||||
"success": True,
|
||||
"prompt": prompt,
|
||||
"binary_images": [],
|
||||
"image_urls": [],
|
||||
"image_count": 0
|
||||
}
|
||||
@@ -658,17 +668,6 @@ class ImageGenerator:
|
||||
# Update result with image info
|
||||
result["image_count"] = len(image_urls)
|
||||
result["image_urls"] = image_urls
|
||||
|
||||
# Get binary data for each image
|
||||
for img_url in image_urls:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(img_url) as resp:
|
||||
if resp.status == 200:
|
||||
image_data = await resp.read()
|
||||
result["binary_images"].append(image_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading refined image {img_url}: {str(e)}")
|
||||
|
||||
# Log success or failure
|
||||
if result["image_count"] > 0:
|
||||
@@ -686,6 +685,5 @@ class ImageGenerator:
|
||||
"error": str(e),
|
||||
"prompt": prompt,
|
||||
"image_urls": [],
|
||||
"image_count": 0,
|
||||
"binary_images": []
|
||||
"image_count": 0
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -346,26 +346,62 @@ class ReminderManager:
|
||||
return now + timedelta(hours=value)
|
||||
elif unit == 'd': # days
|
||||
return now + timedelta(days=value)
|
||||
|
||||
# Handle specific time format
|
||||
# HH:MM format for today or tomorrow
|
||||
if ':' in time_str and len(time_str.split(':')) == 2:
|
||||
hour, minute = map(int, time_str.split(':'))
|
||||
# Handle specific time format
|
||||
# Support various time formats: HH:MM, H:MM, H:MM AM/PM, HH:MM AM/PM
|
||||
if ':' in time_str:
|
||||
# Extract time part and additional words
|
||||
time_parts = time_str.split()
|
||||
time_part = time_parts[0] # e.g., "9:00"
|
||||
|
||||
# Check if valid time
|
||||
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
|
||||
logging.warning(f"Invalid time format: {time_str}")
|
||||
# Check for AM/PM
|
||||
is_pm = False
|
||||
for part in time_parts[1:]:
|
||||
if 'pm' in part.lower():
|
||||
is_pm = True
|
||||
break
|
||||
elif 'am' in part.lower():
|
||||
is_pm = False
|
||||
break
|
||||
|
||||
try:
|
||||
if ':' in time_part and len(time_part.split(':')) == 2:
|
||||
hour_str, minute_str = time_part.split(':')
|
||||
|
||||
# Clean minute string to remove non-digit characters
|
||||
minute_str = ''.join(filter(str.isdigit, minute_str))
|
||||
if not minute_str:
|
||||
minute_str = '0'
|
||||
|
||||
hour = int(hour_str)
|
||||
minute = int(minute_str)
|
||||
|
||||
# Handle AM/PM conversion
|
||||
if is_pm and hour != 12:
|
||||
hour += 12
|
||||
elif not is_pm and hour == 12:
|
||||
hour = 0
|
||||
|
||||
# Check if valid time
|
||||
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
|
||||
logging.warning(f"Invalid time format: {time_str}")
|
||||
return None
|
||||
|
||||
# Create datetime for the specified time today in user's timezone
|
||||
target = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# Check for "tomorrow" keyword
|
||||
if 'tomorrow' in time_str.lower():
|
||||
target += timedelta(days=1)
|
||||
# If the time has already passed today and no "today" keyword, schedule for tomorrow
|
||||
elif target <= now and 'today' not in time_str.lower():
|
||||
target += timedelta(days=1)
|
||||
|
||||
logging.info(f"Parsed time '{time_str}' to {target} (User timezone: {user_tz})")
|
||||
return target
|
||||
|
||||
except ValueError as ve:
|
||||
logging.error(f"Error parsing time components in '{time_str}': {str(ve)}")
|
||||
return None
|
||||
|
||||
# Create datetime for the specified time today in user's timezone
|
||||
target = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# If the time has already passed today, schedule for tomorrow
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
|
||||
logging.info(f"Parsed time '{time_str}' to {target} (User timezone: {user_tz})")
|
||||
return target
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
|
||||
@@ -210,6 +210,60 @@ def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token
|
||||
message = f"Failed to process content from {url}: {str(e)}"
|
||||
return (message, 0) if return_token_count else message
|
||||
|
||||
async def google_search(args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Async wrapper for Google search to match the expected interface.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing 'query' and optional 'num_results'
|
||||
|
||||
Returns:
|
||||
JSON string with search results
|
||||
"""
|
||||
try:
|
||||
query = args.get('query', '')
|
||||
num_results = args.get('num_results', 3)
|
||||
|
||||
if not query:
|
||||
return json.dumps({"error": "No search query provided"})
|
||||
|
||||
# Call the synchronous google_custom_search function
|
||||
result = google_custom_search(query, num_results)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Google search failed: {str(e)}"})
|
||||
|
||||
async def scrape_webpage(args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Async wrapper for webpage scraping to match the expected interface.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing 'url' and optional 'max_tokens'
|
||||
|
||||
Returns:
|
||||
JSON string with scraped content
|
||||
"""
|
||||
try:
|
||||
url = args.get('url', '')
|
||||
max_tokens = args.get('max_tokens', 4000)
|
||||
|
||||
if not url:
|
||||
return json.dumps({"error": "No URL provided"})
|
||||
|
||||
# Call the synchronous scrape_web_content function
|
||||
content = scrape_web_content(url, max_tokens)
|
||||
|
||||
return json.dumps({
|
||||
"url": url,
|
||||
"content": content,
|
||||
"success": True
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Web scraping failed: {str(e)}"})
|
||||
|
||||
# Keep the original scrape_web_content function for backward compatibility
|
||||
def scrape_web_content(url: str, max_tokens: int = 4000) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user