refactor: optimize memory usage by removing unnecessary global variables and implementing cleanup mechanisms
This commit is contained in:
4
bot.py
4
bot.py
@@ -166,8 +166,8 @@ async def main():
|
||||
bot = commands.Bot(
|
||||
command_prefix="//quocanhvu",
|
||||
intents=intents,
|
||||
heartbeat_timeout=120,
|
||||
max_messages=10000 # Cache more messages to improve experience
|
||||
heartbeat_timeout=180
|
||||
# Removed max_messages to reduce RAM usage
|
||||
)
|
||||
|
||||
# Initialize database handler
|
||||
|
||||
@@ -7,21 +7,18 @@ import logging
|
||||
import re
|
||||
|
||||
class DatabaseHandler:
|
||||
# Class-level cache for database results
|
||||
_cache = {}
|
||||
_cache_expiry = {}
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, mongodb_uri: str):
|
||||
"""Initialize database connection with optimized settings"""
|
||||
# Set up a connection pool with sensible timeouts
|
||||
# Set up a memory-optimized connection pool
|
||||
self.client = AsyncIOMotorClient(
|
||||
mongodb_uri,
|
||||
maxIdleTimeMS=45000,
|
||||
connectTimeoutMS=10000,
|
||||
serverSelectionTimeoutMS=15000,
|
||||
waitQueueTimeoutMS=5000,
|
||||
socketTimeoutMS=30000,
|
||||
maxIdleTimeMS=30000, # Reduced from 45000
|
||||
connectTimeoutMS=8000, # Reduced from 10000
|
||||
serverSelectionTimeoutMS=12000, # Reduced from 15000
|
||||
waitQueueTimeoutMS=3000, # Reduced from 5000
|
||||
socketTimeoutMS=25000, # Reduced from 30000
|
||||
maxPoolSize=8, # Limit connection pool size
|
||||
minPoolSize=2, # Maintain minimum connections
|
||||
retryWrites=True
|
||||
)
|
||||
self.db = self.client['chatgpt_discord_bot'] # Database name
|
||||
@@ -37,42 +34,15 @@ class DatabaseHandler:
|
||||
|
||||
logging.info("Database handler initialized")
|
||||
|
||||
# Helper for caching results
|
||||
async def _get_cached_result(self, cache_key, fetch_func, expiry_seconds=60):
|
||||
"""Get result from cache or execute fetch_func if not cached/expired"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if we have a cached result that's still valid
|
||||
async with self._cache_lock:
|
||||
if (cache_key in self._cache and
|
||||
cache_key in self._cache_expiry and
|
||||
current_time < self._cache_expiry[cache_key]):
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Not in cache or expired, fetch new result
|
||||
result = await fetch_func()
|
||||
|
||||
# Cache the new result
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = result
|
||||
self._cache_expiry[cache_key] = current_time + timedelta(seconds=expiry_seconds)
|
||||
|
||||
return result
|
||||
|
||||
# User history methods
|
||||
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get user conversation history with caching and filter expired image links"""
|
||||
cache_key = f"history_{user_id}"
|
||||
|
||||
async def fetch_history():
|
||||
user_data = await self.db.user_histories.find_one({'user_id': user_id})
|
||||
if user_data and 'history' in user_data:
|
||||
# Filter out expired image links
|
||||
filtered_history = self._filter_expired_images(user_data['history'])
|
||||
return filtered_history
|
||||
return []
|
||||
|
||||
return await self._get_cached_result(cache_key, fetch_history, 30) # 30 second cache
|
||||
"""Get user conversation history and filter expired image links"""
|
||||
user_data = await self.db.user_histories.find_one({'user_id': user_id})
|
||||
if user_data and 'history' in user_data:
|
||||
# Filter out expired image links
|
||||
filtered_history = self._filter_expired_images(user_data['history'])
|
||||
return filtered_history
|
||||
return []
|
||||
|
||||
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filter out image links that are older than 23 hours"""
|
||||
@@ -121,43 +91,26 @@ class DatabaseHandler:
|
||||
return filtered_history
|
||||
|
||||
async def save_history(self, user_id: int, history: List[Dict[str, Any]]) -> None:
|
||||
"""Save user conversation history and update cache"""
|
||||
"""Save user conversation history"""
|
||||
await self.db.user_histories.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'history': history}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"history_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = history
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=30)
|
||||
|
||||
# User model preferences with caching
|
||||
async def get_user_model(self, user_id: int) -> Optional[str]:
|
||||
"""Get user's preferred model with caching"""
|
||||
cache_key = f"model_{user_id}"
|
||||
|
||||
async def fetch_model():
|
||||
user_data = await self.db.user_models.find_one({'user_id': user_id})
|
||||
return user_data['model'] if user_data else None
|
||||
|
||||
return await self._get_cached_result(cache_key, fetch_model, 300) # 5 minute cache
|
||||
"""Get user's preferred model"""
|
||||
user_data = await self.db.user_models.find_one({'user_id': user_id})
|
||||
return user_data['model'] if user_data else None
|
||||
|
||||
async def save_user_model(self, user_id: int, model: str) -> None:
|
||||
"""Save user's preferred model and update cache"""
|
||||
"""Save user's preferred model"""
|
||||
await self.db.user_models.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'model': model}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"model_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = model
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
# Admin and permissions management with caching
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
@@ -167,78 +120,42 @@ class DatabaseHandler:
|
||||
return admin_id == ADMIN_ID
|
||||
|
||||
async def is_user_whitelisted(self, user_id: int) -> bool:
|
||||
"""Check if the user is whitelisted with caching"""
|
||||
"""Check if the user is whitelisted"""
|
||||
if await self.is_admin(user_id):
|
||||
return True
|
||||
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
|
||||
async def check_whitelist():
|
||||
user_data = await self.db.whitelist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
return await self._get_cached_result(cache_key, check_whitelist, 300) # 5 minute cache
|
||||
user_data = await self.db.whitelist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
async def add_user_to_whitelist(self, user_id: int) -> None:
|
||||
"""Add user to whitelist and update cache"""
|
||||
"""Add user to whitelist"""
|
||||
await self.db.whitelist.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'user_id': user_id}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = True
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
async def remove_user_from_whitelist(self, user_id: int) -> bool:
|
||||
"""Remove user from whitelist and update cache"""
|
||||
"""Remove user from whitelist"""
|
||||
result = await self.db.whitelist.delete_one({'user_id': user_id})
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = False
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
return result.deleted_count > 0
|
||||
|
||||
async def is_user_blacklisted(self, user_id: int) -> bool:
|
||||
"""Check if the user is blacklisted with caching"""
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
|
||||
async def check_blacklist():
|
||||
user_data = await self.db.blacklist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
return await self._get_cached_result(cache_key, check_blacklist, 300) # 5 minute cache
|
||||
"""Check if the user is blacklisted"""
|
||||
user_data = await self.db.blacklist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
async def add_user_to_blacklist(self, user_id: int) -> None:
|
||||
"""Add user to blacklist and update cache"""
|
||||
"""Add user to blacklist"""
|
||||
await self.db.blacklist.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'user_id': user_id}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = True
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
async def remove_user_from_blacklist(self, user_id: int) -> bool:
|
||||
"""Remove user from blacklist and update cache"""
|
||||
"""Remove user from blacklist"""
|
||||
result = await self.db.blacklist.delete_one({'user_id': user_id})
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = False
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
return result.deleted_count > 0
|
||||
|
||||
# Connection management and cleanup
|
||||
|
||||
@@ -42,9 +42,8 @@ DATA_FILE_EXTENSIONS = ['.csv', '.xlsx', '.xls']
|
||||
# File extensions for image files (should never be processed as data)
|
||||
IMAGE_FILE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.svg', '.tiff', '.ico']
|
||||
|
||||
# Storage for user data files and charts
|
||||
user_data_files = {}
|
||||
user_charts = {}
|
||||
# Note: Removed global user_data_files and user_charts dictionaries for memory optimization
|
||||
# Data files are now processed immediately and cleaned up
|
||||
|
||||
# Try to import data analysis libraries early
|
||||
try:
|
||||
@@ -81,6 +80,11 @@ class MessageHandler:
|
||||
# Initialize reminder manager
|
||||
self.reminder_manager = ReminderManager(bot, db_handler)
|
||||
|
||||
# Memory-optimized user data tracking (with TTL)
|
||||
self.user_data_files = {} # Will be cleaned up periodically
|
||||
self.user_charts = {} # Will be cleaned up periodically
|
||||
self.max_user_files = 20 # Limit concurrent user files
|
||||
|
||||
# Tool mapping for API integration
|
||||
self.tool_mapping = {
|
||||
"google_search": self._google_search,
|
||||
@@ -98,8 +102,10 @@ class MessageHandler:
|
||||
"generate_image_with_refiner": self._generate_image_with_refiner
|
||||
}
|
||||
|
||||
# Thread pool for CPU-bound tasks
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
||||
# Thread pool for CPU-bound tasks (balanced for performance)
|
||||
import multiprocessing
|
||||
max_workers = min(4, multiprocessing.cpu_count()) # Increased to 4 for better concurrency
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
# Create session for HTTP requests
|
||||
asyncio.create_task(self._setup_aiohttp_session())
|
||||
@@ -171,11 +177,16 @@ class MessageHandler:
|
||||
logging.error(f"Error installing packages: {str(e)}")
|
||||
|
||||
async def _setup_aiohttp_session(self):
|
||||
"""Create a reusable aiohttp session for better performance"""
|
||||
"""Create a memory-optimized aiohttp session"""
|
||||
if self.aiohttp_session is None or self.aiohttp_session.closed:
|
||||
self.aiohttp_session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=240),
|
||||
connector=aiohttp.TCPConnector(limit=20, ttl_dns_cache=300)
|
||||
timeout=aiohttp.ClientTimeout(total=120), # Reduced timeout
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=8, # Reduced from 20 to 8
|
||||
ttl_dns_cache=600, # Increased DNS cache for efficiency
|
||||
enable_cleanup_closed=True, # Enable connection cleanup
|
||||
keepalive_timeout=30 # Shorter keepalive
|
||||
)
|
||||
)
|
||||
|
||||
def _setup_event_handlers(self):
|
||||
@@ -204,8 +215,8 @@ class MessageHandler:
|
||||
user_id = self._find_user_id_from_current_task()
|
||||
|
||||
# Add file context if user has uploaded data files
|
||||
if user_id and user_id in user_data_files:
|
||||
file_info = user_data_files[user_id]
|
||||
if user_id and user_id in self.user_data_files:
|
||||
file_info = self.user_data_files[user_id]
|
||||
file_context = f"\n\n# Data file available: {file_info['filename']}\n"
|
||||
file_context += f"# File path: {file_info['file_path']}\n"
|
||||
file_context += f"# You can access this file using: pd.read_csv('{file_info['file_path']}') or similar\n\n"
|
||||
@@ -502,7 +513,10 @@ class MessageHandler:
|
||||
"file_path": temp_file_path,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
user_data_files[user_id] = file_info
|
||||
|
||||
# Memory-efficient storage with cleanup
|
||||
self._cleanup_old_user_files()
|
||||
self.user_data_files[user_id] = file_info
|
||||
|
||||
logging.info(f"Downloaded and saved data file: {temp_file_path}")
|
||||
return {"success": True, "file_info": file_info}
|
||||
@@ -1061,9 +1075,9 @@ class MessageHandler:
|
||||
new_history.append({"role": "system", "content": system_content})
|
||||
new_history.extend(history_without_system[1:]) # Skip the first "Instructions" message
|
||||
|
||||
# Only keep a reasonable amount of history
|
||||
if len(new_history) > 20:
|
||||
new_history = new_history[:1] + new_history[-19:] # Keep system prompt + last 19 messages
|
||||
# Only keep a reasonable amount of history (reduced for memory)
|
||||
if len(new_history) > 15: # Reduced from 20 to 15
|
||||
new_history = new_history[:1] + new_history[-14:] # Keep system prompt + last 14 messages
|
||||
|
||||
await self.db.save_history(user_id, new_history)
|
||||
else:
|
||||
@@ -1073,9 +1087,9 @@ class MessageHandler:
|
||||
else:
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
|
||||
# Only keep a reasonable amount of history
|
||||
if len(history) > 20:
|
||||
history = history[:1] + history[-19:] # Keep system prompt + last 19 messages
|
||||
# Only keep a reasonable amount of history (reduced for memory)
|
||||
if len(history) > 15: # Reduced from 20 to 15
|
||||
history = history[:1] + history[-14:] # Keep system prompt + last 14 messages
|
||||
|
||||
await self.db.save_history(user_id, history)
|
||||
|
||||
@@ -1083,9 +1097,9 @@ class MessageHandler:
|
||||
await send_response(message.channel, reply)
|
||||
|
||||
# Handle charts from code interpreter if present
|
||||
if chart_id and chart_id in user_charts:
|
||||
if chart_id and chart_id in self.user_charts:
|
||||
try:
|
||||
chart_data = user_charts[chart_id]["image"]
|
||||
chart_data = self.user_charts[chart_id]["image"]
|
||||
chart_filename = f"chart_{chart_id}.png"
|
||||
|
||||
# Send the chart to Discord and get the URL
|
||||
@@ -1303,23 +1317,74 @@ class MessageHandler:
|
||||
return None
|
||||
|
||||
async def _run_chart_cleanup(self):
|
||||
"""Run periodic chart cleanup"""
|
||||
"""Run aggressive chart cleanup for memory optimization"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(3600) # Run every hour
|
||||
# Cleanup logic here
|
||||
await asyncio.sleep(600) # Every 10 minutes (reduced from 1 hour)
|
||||
current_time = datetime.now()
|
||||
|
||||
# Clean charts older than 30 minutes
|
||||
expired_charts = [
|
||||
chart_id for chart_id, data in self.user_charts.items()
|
||||
if current_time - data.get('timestamp', current_time) > timedelta(minutes=30)
|
||||
]
|
||||
|
||||
for chart_id in expired_charts:
|
||||
self.user_charts.pop(chart_id, None)
|
||||
|
||||
if expired_charts:
|
||||
logging.info(f"Cleaned up {len(expired_charts)} expired charts")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in chart cleanup: {str(e)}")
|
||||
|
||||
async def _run_file_cleanup(self):
|
||||
"""Run periodic file cleanup"""
|
||||
"""Run aggressive file cleanup for memory optimization"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(7200) # Run every 2 hours
|
||||
# Cleanup logic here
|
||||
await asyncio.sleep(900) # Every 15 minutes (reduced from 2 hours)
|
||||
self._cleanup_old_user_files()
|
||||
except Exception as e:
|
||||
logging.error(f"Error in file cleanup: {str(e)}")
|
||||
|
||||
def _cleanup_old_user_files(self):
|
||||
"""Clean up old user data files to prevent memory bloat"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Remove files older than 1 hour
|
||||
expired_users = [
|
||||
user_id for user_id, file_info in self.user_data_files.items()
|
||||
if current_time - file_info['timestamp'] > timedelta(hours=1)
|
||||
]
|
||||
|
||||
for user_id in expired_users:
|
||||
file_info = self.user_data_files.pop(user_id, None)
|
||||
if file_info and os.path.exists(file_info['file_path']):
|
||||
try:
|
||||
os.remove(file_info['file_path'])
|
||||
except Exception as e:
|
||||
logging.error(f"Error removing file: {e}")
|
||||
|
||||
# Limit total number of cached files
|
||||
if len(self.user_data_files) > self.max_user_files:
|
||||
# Remove oldest files
|
||||
sorted_files = sorted(
|
||||
self.user_data_files.items(),
|
||||
key=lambda x: x[1]['timestamp']
|
||||
)
|
||||
|
||||
files_to_remove = len(self.user_data_files) - self.max_user_files
|
||||
for user_id, file_info in sorted_files[:files_to_remove]:
|
||||
self.user_data_files.pop(user_id, None)
|
||||
if os.path.exists(file_info['file_path']):
|
||||
try:
|
||||
os.remove(file_info['file_path'])
|
||||
except Exception as e:
|
||||
logging.error(f"Error removing file: {e}")
|
||||
|
||||
if expired_users:
|
||||
logging.info(f"Cleaned up {len(expired_users)} expired user files")
|
||||
|
||||
def _count_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Count tokens in a list of messages using tiktoken o200k_base encoding.
|
||||
|
||||
@@ -246,69 +246,48 @@ async def execute_code_safely(code: str, input_data: str, timeout: int) -> Dict[
|
||||
except ImportError:
|
||||
pd = None
|
||||
|
||||
# Create execution namespace
|
||||
# Create minimal execution namespace (memory optimized)
|
||||
exec_globals = {
|
||||
"__builtins__": {
|
||||
# Safe builtins
|
||||
"print": print,
|
||||
"len": len,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sum": sum,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"set": set,
|
||||
"tuple": tuple,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"type": type,
|
||||
"isinstance": isinstance,
|
||||
"hasattr": hasattr,
|
||||
"getattr": getattr,
|
||||
"setattr": setattr,
|
||||
"dir": dir,
|
||||
"help": help,
|
||||
"__import__": __import__, # Allow controlled imports
|
||||
"ValueError": ValueError,
|
||||
"TypeError": TypeError,
|
||||
"IndexError": IndexError,
|
||||
"KeyError": KeyError,
|
||||
"AttributeError": AttributeError,
|
||||
"ImportError": ImportError,
|
||||
"Exception": Exception,
|
||||
# Essential builtins only
|
||||
"print": print, "len": len, "range": range, "enumerate": enumerate,
|
||||
"zip": zip, "sum": sum, "min": min, "max": max, "abs": abs,
|
||||
"round": round, "sorted": sorted, "list": list, "dict": dict,
|
||||
"set": set, "tuple": tuple, "str": str, "int": int, "float": float,
|
||||
"bool": bool, "type": type, "isinstance": isinstance,
|
||||
"__import__": __import__, # Fixed: Added missing __import__
|
||||
"ValueError": ValueError, "TypeError": TypeError, "IndexError": IndexError,
|
||||
"KeyError": KeyError, "Exception": Exception,
|
||||
},
|
||||
# Add available libraries
|
||||
# Essential modules only
|
||||
"math": __import__("math"),
|
||||
"random": __import__("random"),
|
||||
"json": __import__("json"),
|
||||
"time": __import__("time"),
|
||||
"datetime": __import__("datetime"),
|
||||
"collections": __import__("collections"),
|
||||
"itertools": __import__("itertools"),
|
||||
"functools": __import__("functools"),
|
||||
}
|
||||
|
||||
# Add optional libraries if available
|
||||
if np is not None:
|
||||
exec_globals["np"] = np
|
||||
exec_globals["numpy"] = np
|
||||
if pd is not None:
|
||||
exec_globals["pd"] = pd
|
||||
exec_globals["pandas"] = pd
|
||||
if plt is not None:
|
||||
exec_globals["plt"] = plt
|
||||
exec_globals["matplotlib"] = matplotlib
|
||||
# Add optional libraries only when needed (lazy loading for memory)
|
||||
if "numpy" in code or "np." in code:
|
||||
try:
|
||||
exec_globals["np"] = __import__("numpy")
|
||||
exec_globals["numpy"] = __import__("numpy")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if "pandas" in code or "pd." in code:
|
||||
try:
|
||||
exec_globals["pd"] = __import__("pandas")
|
||||
exec_globals["pandas"] = __import__("pandas")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if "matplotlib" in code or "plt." in code:
|
||||
try:
|
||||
matplotlib = __import__("matplotlib")
|
||||
matplotlib.use('Agg')
|
||||
exec_globals["plt"] = __import__("matplotlib.pyplot")
|
||||
exec_globals["matplotlib"] = matplotlib
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Override input function if input_data is provided
|
||||
if input_data:
|
||||
@@ -367,6 +346,14 @@ async def execute_code_safely(code: str, input_data: str, timeout: int) -> Dict[
|
||||
stdout_output = stdout_capture.getvalue()
|
||||
stderr_output = stderr_capture.getvalue()
|
||||
|
||||
# Force cleanup and garbage collection for memory optimization
|
||||
import gc
|
||||
if 'plt' in exec_globals:
|
||||
plt = exec_globals['plt']
|
||||
plt.close('all')
|
||||
exec_globals.clear() # Clear execution environment
|
||||
gc.collect() # Force garbage collection
|
||||
|
||||
# Check for any image paths in the output
|
||||
image_paths = re.findall(IMAGE_PATH_PATTERN, stdout_output)
|
||||
for img_path in image_paths:
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import requests
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
|
||||
import tiktoken # Add tiktoken for token counting
|
||||
|
||||
# Global tiktoken encoder - initialized once to avoid blocking
|
||||
try:
|
||||
TIKTOKEN_ENCODER = tiktoken.get_encoding("o200k_base")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize tiktoken encoder: {e}")
|
||||
TIKTOKEN_ENCODER = None
|
||||
|
||||
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
|
||||
"""
|
||||
Perform a Google search using the Google Custom Search API and scrape content
|
||||
@@ -83,10 +91,8 @@ def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str,
|
||||
total_tokens = 0
|
||||
used_urls = []
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
except:
|
||||
encoding = None
|
||||
# Use global encoder directly (no async needed since it's pre-initialized)
|
||||
encoding = TIKTOKEN_ENCODER
|
||||
|
||||
for url in urls:
|
||||
# Skip empty URLs
|
||||
@@ -183,10 +189,11 @@ def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token
|
||||
# Count tokens
|
||||
token_count = 0
|
||||
try:
|
||||
# Use cl100k_base encoder which is used by most recent models
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoding.encode(text)
|
||||
token_count = len(tokens)
|
||||
# Use global o200k_base encoder
|
||||
encoding = TIKTOKEN_ENCODER
|
||||
if encoding:
|
||||
tokens = encoding.encode(text)
|
||||
token_count = len(tokens)
|
||||
|
||||
# Truncate if token count exceeds max_tokens and we're not returning token count
|
||||
if len(tokens) > max_tokens and not return_token_count:
|
||||
|
||||
Reference in New Issue
Block a user