refactor: optimize memory usage by removing unnecessary global variables and implementing cleanup mechanisms

This commit is contained in:
2025-08-15 00:12:15 +07:00
parent 8cad2c541f
commit 7b19756932
5 changed files with 179 additions and 203 deletions

4
bot.py
View File

@@ -166,8 +166,8 @@ async def main():
bot = commands.Bot(
command_prefix="//quocanhvu",
intents=intents,
heartbeat_timeout=120,
max_messages=10000 # Cache more messages to improve experience
heartbeat_timeout=180
# Removed max_messages to reduce RAM usage
)
# Initialize database handler

View File

@@ -7,21 +7,18 @@ import logging
import re
class DatabaseHandler:
# Class-level cache for database results
_cache = {}
_cache_expiry = {}
_cache_lock = asyncio.Lock()
def __init__(self, mongodb_uri: str):
"""Initialize database connection with optimized settings"""
# Set up a connection pool with sensible timeouts
# Set up a memory-optimized connection pool
self.client = AsyncIOMotorClient(
mongodb_uri,
maxIdleTimeMS=45000,
connectTimeoutMS=10000,
serverSelectionTimeoutMS=15000,
waitQueueTimeoutMS=5000,
socketTimeoutMS=30000,
maxIdleTimeMS=30000, # Reduced from 45000
connectTimeoutMS=8000, # Reduced from 10000
serverSelectionTimeoutMS=12000, # Reduced from 15000
waitQueueTimeoutMS=3000, # Reduced from 5000
socketTimeoutMS=25000, # Reduced from 30000
maxPoolSize=8, # Limit connection pool size
minPoolSize=2, # Maintain minimum connections
retryWrites=True
)
self.db = self.client['chatgpt_discord_bot'] # Database name
@@ -37,42 +34,15 @@ class DatabaseHandler:
logging.info("Database handler initialized")
# Helper for caching results
async def _get_cached_result(self, cache_key, fetch_func, expiry_seconds=60):
"""Get result from cache or execute fetch_func if not cached/expired"""
current_time = datetime.now()
# Check if we have a cached result that's still valid
async with self._cache_lock:
if (cache_key in self._cache and
cache_key in self._cache_expiry and
current_time < self._cache_expiry[cache_key]):
return self._cache[cache_key]
# Not in cache or expired, fetch new result
result = await fetch_func()
# Cache the new result
async with self._cache_lock:
self._cache[cache_key] = result
self._cache_expiry[cache_key] = current_time + timedelta(seconds=expiry_seconds)
return result
# User history methods
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
"""Get user conversation history with caching and filter expired image links"""
cache_key = f"history_{user_id}"
async def fetch_history():
user_data = await self.db.user_histories.find_one({'user_id': user_id})
if user_data and 'history' in user_data:
# Filter out expired image links
filtered_history = self._filter_expired_images(user_data['history'])
return filtered_history
return []
return await self._get_cached_result(cache_key, fetch_history, 30) # 30 second cache
"""Get user conversation history and filter expired image links"""
user_data = await self.db.user_histories.find_one({'user_id': user_id})
if user_data and 'history' in user_data:
# Filter out expired image links
filtered_history = self._filter_expired_images(user_data['history'])
return filtered_history
return []
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter out image links that are older than 23 hours"""
@@ -121,43 +91,26 @@ class DatabaseHandler:
return filtered_history
async def save_history(self, user_id: int, history: List[Dict[str, Any]]) -> None:
"""Save user conversation history and update cache"""
"""Save user conversation history"""
await self.db.user_histories.update_one(
{'user_id': user_id},
{'$set': {'history': history}},
upsert=True
)
# Update the cache
cache_key = f"history_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = history
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=30)
# User model preferences with caching
async def get_user_model(self, user_id: int) -> Optional[str]:
"""Get user's preferred model with caching"""
cache_key = f"model_{user_id}"
async def fetch_model():
user_data = await self.db.user_models.find_one({'user_id': user_id})
return user_data['model'] if user_data else None
return await self._get_cached_result(cache_key, fetch_model, 300) # 5 minute cache
"""Get user's preferred model"""
user_data = await self.db.user_models.find_one({'user_id': user_id})
return user_data['model'] if user_data else None
async def save_user_model(self, user_id: int, model: str) -> None:
"""Save user's preferred model and update cache"""
"""Save user's preferred model"""
await self.db.user_models.update_one(
{'user_id': user_id},
{'$set': {'model': model}},
upsert=True
)
# Update the cache
cache_key = f"model_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = model
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
# Admin and permissions management with caching
async def is_admin(self, user_id: int) -> bool:
@@ -167,78 +120,42 @@ class DatabaseHandler:
return admin_id == ADMIN_ID
async def is_user_whitelisted(self, user_id: int) -> bool:
"""Check if the user is whitelisted with caching"""
"""Check if the user is whitelisted"""
if await self.is_admin(user_id):
return True
cache_key = f"whitelist_{user_id}"
async def check_whitelist():
user_data = await self.db.whitelist.find_one({'user_id': user_id})
return user_data is not None
return await self._get_cached_result(cache_key, check_whitelist, 300) # 5 minute cache
user_data = await self.db.whitelist.find_one({'user_id': user_id})
return user_data is not None
async def add_user_to_whitelist(self, user_id: int) -> None:
"""Add user to whitelist and update cache"""
"""Add user to whitelist"""
await self.db.whitelist.update_one(
{'user_id': user_id},
{'$set': {'user_id': user_id}},
upsert=True
)
# Update the cache
cache_key = f"whitelist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = True
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
async def remove_user_from_whitelist(self, user_id: int) -> bool:
"""Remove user from whitelist and update cache"""
"""Remove user from whitelist"""
result = await self.db.whitelist.delete_one({'user_id': user_id})
# Update the cache
cache_key = f"whitelist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = False
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
return result.deleted_count > 0
async def is_user_blacklisted(self, user_id: int) -> bool:
"""Check if the user is blacklisted with caching"""
cache_key = f"blacklist_{user_id}"
async def check_blacklist():
user_data = await self.db.blacklist.find_one({'user_id': user_id})
return user_data is not None
return await self._get_cached_result(cache_key, check_blacklist, 300) # 5 minute cache
"""Check if the user is blacklisted"""
user_data = await self.db.blacklist.find_one({'user_id': user_id})
return user_data is not None
async def add_user_to_blacklist(self, user_id: int) -> None:
"""Add user to blacklist and update cache"""
"""Add user to blacklist"""
await self.db.blacklist.update_one(
{'user_id': user_id},
{'$set': {'user_id': user_id}},
upsert=True
)
# Update the cache
cache_key = f"blacklist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = True
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
async def remove_user_from_blacklist(self, user_id: int) -> bool:
"""Remove user from blacklist and update cache"""
"""Remove user from blacklist"""
result = await self.db.blacklist.delete_one({'user_id': user_id})
# Update the cache
cache_key = f"blacklist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = False
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
return result.deleted_count > 0
# Connection management and cleanup

View File

@@ -42,9 +42,8 @@ DATA_FILE_EXTENSIONS = ['.csv', '.xlsx', '.xls']
# File extensions for image files (should never be processed as data)
IMAGE_FILE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.svg', '.tiff', '.ico']
# Storage for user data files and charts
user_data_files = {}
user_charts = {}
# Note: Removed global user_data_files and user_charts dictionaries for memory optimization
# Data files are now processed immediately and cleaned up
# Try to import data analysis libraries early
try:
@@ -81,6 +80,11 @@ class MessageHandler:
# Initialize reminder manager
self.reminder_manager = ReminderManager(bot, db_handler)
# Memory-optimized user data tracking (with TTL)
self.user_data_files = {} # Will be cleaned up periodically
self.user_charts = {} # Will be cleaned up periodically
self.max_user_files = 20 # Limit concurrent user files
# Tool mapping for API integration
self.tool_mapping = {
"google_search": self._google_search,
@@ -98,8 +102,10 @@ class MessageHandler:
"generate_image_with_refiner": self._generate_image_with_refiner
}
# Thread pool for CPU-bound tasks
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=5)
# Thread pool for CPU-bound tasks (balanced for performance)
import multiprocessing
max_workers = min(4, multiprocessing.cpu_count()) # Increased to 4 for better concurrency
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
# Create session for HTTP requests
asyncio.create_task(self._setup_aiohttp_session())
@@ -171,11 +177,16 @@ class MessageHandler:
logging.error(f"Error installing packages: {str(e)}")
async def _setup_aiohttp_session(self):
"""Create a reusable aiohttp session for better performance"""
"""Create a memory-optimized aiohttp session"""
if self.aiohttp_session is None or self.aiohttp_session.closed:
self.aiohttp_session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=240),
connector=aiohttp.TCPConnector(limit=20, ttl_dns_cache=300)
timeout=aiohttp.ClientTimeout(total=120), # Reduced timeout
connector=aiohttp.TCPConnector(
limit=8, # Reduced from 20 to 8
ttl_dns_cache=600, # Increased DNS cache for efficiency
enable_cleanup_closed=True, # Enable connection cleanup
keepalive_timeout=30 # Shorter keepalive
)
)
def _setup_event_handlers(self):
@@ -204,8 +215,8 @@ class MessageHandler:
user_id = self._find_user_id_from_current_task()
# Add file context if user has uploaded data files
if user_id and user_id in user_data_files:
file_info = user_data_files[user_id]
if user_id and user_id in self.user_data_files:
file_info = self.user_data_files[user_id]
file_context = f"\n\n# Data file available: {file_info['filename']}\n"
file_context += f"# File path: {file_info['file_path']}\n"
file_context += f"# You can access this file using: pd.read_csv('{file_info['file_path']}') or similar\n\n"
@@ -502,7 +513,10 @@ class MessageHandler:
"file_path": temp_file_path,
"timestamp": datetime.now()
}
user_data_files[user_id] = file_info
# Memory-efficient storage with cleanup
self._cleanup_old_user_files()
self.user_data_files[user_id] = file_info
logging.info(f"Downloaded and saved data file: {temp_file_path}")
return {"success": True, "file_info": file_info}
@@ -1061,9 +1075,9 @@ class MessageHandler:
new_history.append({"role": "system", "content": system_content})
new_history.extend(history_without_system[1:]) # Skip the first "Instructions" message
# Only keep a reasonable amount of history
if len(new_history) > 20:
new_history = new_history[:1] + new_history[-19:] # Keep system prompt + last 19 messages
# Only keep a reasonable amount of history (reduced for memory)
if len(new_history) > 15: # Reduced from 20 to 15
new_history = new_history[:1] + new_history[-14:] # Keep system prompt + last 14 messages
await self.db.save_history(user_id, new_history)
else:
@@ -1073,9 +1087,9 @@ class MessageHandler:
else:
history.append({"role": "assistant", "content": reply})
# Only keep a reasonable amount of history
if len(history) > 20:
history = history[:1] + history[-19:] # Keep system prompt + last 19 messages
# Only keep a reasonable amount of history (reduced for memory)
if len(history) > 15: # Reduced from 20 to 15
history = history[:1] + history[-14:] # Keep system prompt + last 14 messages
await self.db.save_history(user_id, history)
@@ -1083,9 +1097,9 @@ class MessageHandler:
await send_response(message.channel, reply)
# Handle charts from code interpreter if present
if chart_id and chart_id in user_charts:
if chart_id and chart_id in self.user_charts:
try:
chart_data = user_charts[chart_id]["image"]
chart_data = self.user_charts[chart_id]["image"]
chart_filename = f"chart_{chart_id}.png"
# Send the chart to Discord and get the URL
@@ -1303,23 +1317,74 @@ class MessageHandler:
return None
async def _run_chart_cleanup(self):
"""Run periodic chart cleanup"""
"""Run aggressive chart cleanup for memory optimization"""
while True:
try:
await asyncio.sleep(3600) # Run every hour
# Cleanup logic here
await asyncio.sleep(600) # Every 10 minutes (reduced from 1 hour)
current_time = datetime.now()
# Clean charts older than 30 minutes
expired_charts = [
chart_id for chart_id, data in self.user_charts.items()
if current_time - data.get('timestamp', current_time) > timedelta(minutes=30)
]
for chart_id in expired_charts:
self.user_charts.pop(chart_id, None)
if expired_charts:
logging.info(f"Cleaned up {len(expired_charts)} expired charts")
except Exception as e:
logging.error(f"Error in chart cleanup: {str(e)}")
async def _run_file_cleanup(self):
"""Run periodic file cleanup"""
"""Run aggressive file cleanup for memory optimization"""
while True:
try:
await asyncio.sleep(7200) # Run every 2 hours
# Cleanup logic here
await asyncio.sleep(900) # Every 15 minutes (reduced from 2 hours)
self._cleanup_old_user_files()
except Exception as e:
logging.error(f"Error in file cleanup: {str(e)}")
def _cleanup_old_user_files(self):
"""Clean up old user data files to prevent memory bloat"""
current_time = datetime.now()
# Remove files older than 1 hour
expired_users = [
user_id for user_id, file_info in self.user_data_files.items()
if current_time - file_info['timestamp'] > timedelta(hours=1)
]
for user_id in expired_users:
file_info = self.user_data_files.pop(user_id, None)
if file_info and os.path.exists(file_info['file_path']):
try:
os.remove(file_info['file_path'])
except Exception as e:
logging.error(f"Error removing file: {e}")
# Limit total number of cached files
if len(self.user_data_files) > self.max_user_files:
# Remove oldest files
sorted_files = sorted(
self.user_data_files.items(),
key=lambda x: x[1]['timestamp']
)
files_to_remove = len(self.user_data_files) - self.max_user_files
for user_id, file_info in sorted_files[:files_to_remove]:
self.user_data_files.pop(user_id, None)
if os.path.exists(file_info['file_path']):
try:
os.remove(file_info['file_path'])
except Exception as e:
logging.error(f"Error removing file: {e}")
if expired_users:
logging.info(f"Cleaned up {len(expired_users)} expired user files")
def _count_tokens(self, messages: List[Dict[str, Any]]) -> int:
"""
Count tokens in a list of messages using tiktoken o200k_base encoding.

View File

@@ -246,69 +246,48 @@ async def execute_code_safely(code: str, input_data: str, timeout: int) -> Dict[
except ImportError:
pd = None
# Create execution namespace
# Create minimal execution namespace (memory optimized)
exec_globals = {
"__builtins__": {
# Safe builtins
"print": print,
"len": len,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sum": sum,
"min": min,
"max": max,
"abs": abs,
"round": round,
"sorted": sorted,
"reversed": reversed,
"list": list,
"dict": dict,
"set": set,
"tuple": tuple,
"str": str,
"int": int,
"float": float,
"bool": bool,
"type": type,
"isinstance": isinstance,
"hasattr": hasattr,
"getattr": getattr,
"setattr": setattr,
"dir": dir,
"help": help,
"__import__": __import__, # Allow controlled imports
"ValueError": ValueError,
"TypeError": TypeError,
"IndexError": IndexError,
"KeyError": KeyError,
"AttributeError": AttributeError,
"ImportError": ImportError,
"Exception": Exception,
# Essential builtins only
"print": print, "len": len, "range": range, "enumerate": enumerate,
"zip": zip, "sum": sum, "min": min, "max": max, "abs": abs,
"round": round, "sorted": sorted, "list": list, "dict": dict,
"set": set, "tuple": tuple, "str": str, "int": int, "float": float,
"bool": bool, "type": type, "isinstance": isinstance,
"__import__": __import__, # Fixed: Added missing __import__
"ValueError": ValueError, "TypeError": TypeError, "IndexError": IndexError,
"KeyError": KeyError, "Exception": Exception,
},
# Add available libraries
# Essential modules only
"math": __import__("math"),
"random": __import__("random"),
"json": __import__("json"),
"time": __import__("time"),
"datetime": __import__("datetime"),
"collections": __import__("collections"),
"itertools": __import__("itertools"),
"functools": __import__("functools"),
}
# Add optional libraries if available
if np is not None:
exec_globals["np"] = np
exec_globals["numpy"] = np
if pd is not None:
exec_globals["pd"] = pd
exec_globals["pandas"] = pd
if plt is not None:
exec_globals["plt"] = plt
exec_globals["matplotlib"] = matplotlib
# Add optional libraries only when needed (lazy loading for memory)
if "numpy" in code or "np." in code:
try:
exec_globals["np"] = __import__("numpy")
exec_globals["numpy"] = __import__("numpy")
except ImportError:
pass
if "pandas" in code or "pd." in code:
try:
exec_globals["pd"] = __import__("pandas")
exec_globals["pandas"] = __import__("pandas")
except ImportError:
pass
if "matplotlib" in code or "plt." in code:
try:
matplotlib = __import__("matplotlib")
matplotlib.use('Agg')
exec_globals["plt"] = __import__("matplotlib.pyplot")
exec_globals["matplotlib"] = matplotlib
except ImportError:
pass
# Override input function if input_data is provided
if input_data:
@@ -367,6 +346,14 @@ async def execute_code_safely(code: str, input_data: str, timeout: int) -> Dict[
stdout_output = stdout_capture.getvalue()
stderr_output = stderr_capture.getvalue()
# Force cleanup and garbage collection for memory optimization
import gc
if 'plt' in exec_globals:
plt = exec_globals['plt']
plt.close('all')
exec_globals.clear() # Clear execution environment
gc.collect() # Force garbage collection
# Check for any image paths in the output
image_paths = re.findall(IMAGE_PATH_PATTERN, stdout_output)
for img_path in image_paths:

View File

@@ -1,11 +1,19 @@
import requests
import json
import re
import logging
from bs4 import BeautifulSoup
from typing import Dict, List, Any, Optional, Tuple
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
import tiktoken # Add tiktoken for token counting
# Global tiktoken encoder - initialized once to avoid blocking
try:
TIKTOKEN_ENCODER = tiktoken.get_encoding("o200k_base")
except Exception as e:
logging.error(f"Failed to initialize tiktoken encoder: {e}")
TIKTOKEN_ENCODER = None
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
"""
Perform a Google search using the Google Custom Search API and scrape content
@@ -83,10 +91,8 @@ def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str,
total_tokens = 0
used_urls = []
try:
encoding = tiktoken.get_encoding("cl100k_base")
except:
encoding = None
# Use global encoder directly (no async needed since it's pre-initialized)
encoding = TIKTOKEN_ENCODER
for url in urls:
# Skip empty URLs
@@ -183,10 +189,11 @@ def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token
# Count tokens
token_count = 0
try:
# Use cl100k_base encoder which is used by most recent models
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
token_count = len(tokens)
# Use global o200k_base encoder
encoding = TIKTOKEN_ENCODER
if encoding:
tokens = encoding.encode(text)
token_count = len(tokens)
# Truncate if token count exceeds max_tokens and we're not returning token count
if len(tokens) > max_tokens and not return_token_count: