refactor: Clean up code formatting and enhance test cases for better readability and coverage

This commit is contained in:
2025-03-20 12:50:52 +07:00
parent 80713ac94f
commit b9c43ed50b
17 changed files with 4139 additions and 3823 deletions

View File

@@ -1,13 +1,13 @@
__pycache__/
*.py[cod]
*$py.class
*.so
.git/
.env
.venv
env/
venv/
ENV/
.idea/
.vscode/
__pycache__/
*.py[cod]
*$py.class
*.so
.git/
.env
.venv
env/
venv/
ENV/
.idea/
.vscode/
.github/

View File

@@ -1,9 +1,13 @@
name: Build and Run ChatGPT-Discord-Bot Docker
name: Build and Deploy ChatGPT-Discord-Bot
on:
push:
branches:
- main
paths-ignore:
- '**.md'
- 'LICENSE'
- '.gitignore'
workflow_dispatch: # Allow manual triggering
jobs:
tests:
@@ -14,31 +18,33 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for better versioning
- name: Set up Python
uses: actions/setup-python@v5.3.0
with:
python-version: '3.12.3'
- name: Cache Python dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
cache: 'pip' # Use built-in pip caching
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install pytest pytest-cov flake8
pip install -r requirements.txt
- name: Run unit tests
- name: Lint code
run: |
python -m pytest tests/
# Stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# Exit-zero treats all errors as warnings
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: pyupio/safety-action
- name: Run unit tests with coverage
run: |
python -m pytest tests/ --cov=src
- name: Check dependencies for security issues
uses: pyupio/safety-action@v1.0.1
with:
api-key: ${{ secrets.SAFETY_API_KEY }}
@@ -47,12 +53,28 @@ jobs:
runs-on: ubuntu-latest
environment: Private Server Deploy
needs: tests
outputs:
image: ghcr.io/coder-vippro/chatgpt-discord-bot
version: ${{ steps.version.outputs.version }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Generate version number
id: version
run: |
# Create a version with timestamp and short commit hash
VERSION=$(date +'%Y%m%d%H%M')-${GITHUB_SHA::7}
echo "version=${VERSION}" >> $GITHUB_OUTPUT
echo "Version: ${VERSION}"
- name: Set up QEMU for multi-architecture builds
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
buildkitd-flags: --debug
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
@@ -67,9 +89,39 @@ jobs:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: ghcr.io/coder-vippro/chatgpt-discord-bot:latest
cache-from: type=gha,scope=pip-dependencies
cache-to: type=gha,mode=min,scope=pip-dependencies
tags: |
ghcr.io/coder-vippro/chatgpt-discord-bot:latest
ghcr.io/coder-vippro/chatgpt-discord-bot:${{ steps.version.outputs.version }}
labels: |
org.opencontainers.image.title=ChatGPT-Discord-Bot
org.opencontainers.image.description=Discord bot powered by OpenAI ChatGPT
org.opencontainers.image.source=https://github.com/coder-vippro/ChatGPT-Discord-Bot
org.opencontainers.image.created=${{ github.event.repository.updated_at }}
org.opencontainers.image.revision=${{ github.sha }}
org.opencontainers.image.version=${{ steps.version.outputs.version }}
cache-from: type=gha,scope=build-cache
cache-to: type=gha,mode=max,scope=build-cache
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
BUILD_DATE=${{ github.event.repository.updated_at }}
VCS_REF=${{ github.sha }}
VERSION=${{ steps.version.outputs.version }}
deploy-notification:
runs-on: ubuntu-latest
needs: build-and-push
if: ${{ success() }}
steps:
- name: Send deployment notification
uses: sarisia/actions-status-discord@v1
with:
webhook: ${{ secrets.DISCORD_WEBHOOK }}
title: "✅ New deployment successful!"
description: |
Image: ${{ needs.build-and-push.outputs.image }}:${{ needs.build-and-push.outputs.version }}
Commit: ${{ github.sha }}
Repository: ${{ github.repository }}
color: 0x00ff00
username: GitHub Actions

18
.gitignore vendored
View File

@@ -1,10 +1,10 @@
test.py
.env
chat_history.db
bot_copy.py
__pycache__/bot.cpython-312.pyc
tests/__pycache__/test_bot.cpython-312.pyc
.vscode/settings.json
chatgpt.zip
response.txt
test.py
.env
chat_history.db
bot_copy.py
__pycache__/bot.cpython-312.pyc
tests/__pycache__/test_bot.cpython-312.pyc
.vscode/settings.json
chatgpt.zip
response.txt
logs

View File

@@ -1,51 +1,51 @@
# Build stage with all build dependencies
FROM python:3.12.3-alpine AS builder
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# Install build dependencies
RUN apk add --no-cache \
curl \
g++ \
gcc \
musl-dev \
make \
rust \
cargo \
build-base
# Set the working directory
WORKDIR /app
# Copy requirements file
COPY requirements.txt .
# Install Python packages with BuildKit cache
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir -r requirements.txt
# Runtime stage with minimal dependencies
FROM python:3.12.3-alpine
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# Install runtime dependencies
RUN apk add --no-cache libstdc++
RUN apk add --no-cache g++
# Set the working directory
WORKDIR /usr/src/discordbot
# Copy installed Python packages from builder - using correct site-packages path
COPY --from=builder /usr/local/lib/python3.12 /usr/local/lib/python3.12
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy the application source code
COPY . .
# Command to run the application
CMD ["python3", "bot.py"]
# Build stage with all build dependencies
FROM python:3.12.3-alpine AS builder
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# Install build dependencies
RUN apk add --no-cache \
curl \
g++ \
gcc \
musl-dev \
make \
rust \
cargo \
build-base
# Set the working directory
WORKDIR /app
# Copy requirements file
COPY requirements.txt .
# Install Python packages with BuildKit cache
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir -r requirements.txt
# Runtime stage with minimal dependencies
FROM python:3.12.3-alpine
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# Install runtime dependencies
RUN apk add --no-cache libstdc++
RUN apk add --no-cache g++
# Set the working directory
WORKDIR /usr/src/discordbot
# Copy installed Python packages from builder - using correct site-packages path
COPY --from=builder /usr/local/lib/python3.12 /usr/local/lib/python3.12
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy the application source code
COPY . .
# Command to run the application
CMD ["python3", "bot.py"]

596
bot.py
View File

@@ -1,298 +1,298 @@
import os
import sys
import discord
import logging
import asyncio
import signal
import traceback
import time
import logging.config
from discord.ext import commands, tasks
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
from discord import app_commands
# Import configuration
from src.config.config import (
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
)
# Import webhook logger
from src.utils.webhook_logger import webhook_log_manager, webhook_logger
# Import database handler
from src.database.db_handler import DatabaseHandler
# Import the message handler
from src.module.message_handler import MessageHandler
# Import various utility modules
from src.utils.image_utils import ImageGenerator
# Global shutdown flag
shutdown_flag = asyncio.Event()
# Load environment variables
load_dotenv()
# Configure logging with more detail, rotation, and webhook integration
def setup_logging():
# Apply the dictionary config
try:
logging.config.dictConfig(LOGGING_CONFIG)
logging.info("Configured logging from dictionary configuration")
except Exception as e:
# Fall back to basic configuration
log_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
)
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_formatter)
# File handler with rotation (keep 5 files of 5MB each)
try:
from logging.handlers import RotatingFileHandler
os.makedirs('logs', exist_ok=True)
file_handler = RotatingFileHandler(
'logs/discord_bot.log',
maxBytes=5*1024*1024, # 5MB
backupCount=5
)
file_handler.setFormatter(log_formatter)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
root_logger.addHandler(file_handler)
except Exception as e:
# Fall back to basic logging if file logging fails
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logging.warning(f"Could not set up file logging: {str(e)}")
# Set up webhook logging if enabled
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
try:
# Convert string log level to int using our mapping
log_level = LOG_LEVEL_MAP.get(WEBHOOK_LOG_LEVEL.upper(), logging.INFO)
# Set up webhook logging
webhook_log_manager.setup_webhook_logging(
webhook_url=LOGGING_WEBHOOK_URL,
app_name=WEBHOOK_APP_NAME,
level=log_level,
loggers=None, # Use root logger
batch_size=WEBHOOK_BATCH_SIZE,
flush_interval=WEBHOOK_FLUSH_INTERVAL
)
logging.info(f"Webhook logging enabled at level {WEBHOOK_LOG_LEVEL}")
except Exception as e:
logging.error(f"Failed to set up webhook logging: {str(e)}")
# Create a function to change bot status periodically
async def change_status_loop(bot):
"""Change bot status every 5 minutes"""
while not shutdown_flag.is_set():
for status in STATUSES:
await bot.change_presence(activity=discord.Game(name=status))
try:
# Wait but be interruptible
await asyncio.wait_for(shutdown_flag.wait(), timeout=300)
if shutdown_flag.is_set():
break
except asyncio.TimeoutError:
# Normal timeout, continue to next status
continue
async def main():
# Set up logging
setup_logging()
# Check if required environment variables are set
missing_vars = []
if not DISCORD_TOKEN:
missing_vars.append("DISCORD_TOKEN")
if not MONGODB_URI:
missing_vars.append("MONGODB_URI")
if missing_vars:
logging.error(f"The following required environment variables are not set: {', '.join(missing_vars)}")
return
if not RUNWARE_API_KEY:
logging.warning("RUNWARE_API_KEY environment variable not set - image generation will not work")
# Initialize the OpenAI client
try:
from openai import AsyncOpenAI
openai_client = AsyncOpenAI()
logging.info("OpenAI client initialized successfully")
except ImportError:
logging.error("Failed to import OpenAI. Make sure it's installed: pip install openai")
return
except Exception as e:
logging.error(f"Error initializing OpenAI client: {e}")
return
# Global references to objects that need cleanup
message_handler = None
db_handler = None
try:
# Initialize image generator if API key is available
image_generator = None
if RUNWARE_API_KEY:
try:
image_generator = ImageGenerator(RUNWARE_API_KEY)
logging.info("Image generator initialized successfully")
except Exception as e:
logging.error(f"Error initializing image generator: {e}")
# Set up Discord intents
intents = discord.Intents.default()
intents.message_content = True
# Initialize the bot with command prefixes and more robust timeout settings
bot = commands.Bot(
command_prefix="//quocanhvu",
intents=intents,
heartbeat_timeout=120,
max_messages=10000 # Cache more messages to improve experience
)
# Initialize database handler
db_handler = DatabaseHandler(MONGODB_URI)
# Create database indexes for performance
await db_handler.create_indexes()
logging.info("Database indexes created")
# Khởi tạo collection reminders
await db_handler.ensure_reminders_collection()
# Event handler when the bot is ready
@bot.event
async def on_ready():
"""Bot startup event to sync slash commands and start status loop."""
await bot.tree.sync() # Sync slash commands
bot_info = f"Logged in as {bot.user} (ID: {bot.user.id})"
logging.info("=" * len(bot_info))
logging.info(bot_info)
logging.info(f"Connected to {len(bot.guilds)} guilds")
logging.info("=" * len(bot_info))
# Start the status changing task
asyncio.create_task(change_status_loop(bot))
# Handle general errors to prevent crashes
@bot.event
async def on_error(event, *args, **kwargs):
error_msg = traceback.format_exc()
logging.error(f"Discord event error in {event}:\n{error_msg}")
@bot.event
async def on_command_error(ctx, error):
if isinstance(error, commands.CommandNotFound):
return
error_msg = str(error)
trace = "".join(traceback.format_exception(type(error), error, error.__traceback__))
logging.error(f"Command error: {error_msg}\n{trace}")
await ctx.send(f"Error: {error_msg}")
# Initialize message handler
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
# Set up slash commands
from src.commands.commands import setup_commands
setup_commands(bot, db_handler, openai_client, image_generator)
# Handle shutdown signals
loop = asyncio.get_running_loop()
# Signal handlers for graceful shutdown
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(
sig,
lambda sig=sig: asyncio.create_task(shutdown(sig, loop, bot, db_handler, message_handler))
)
except (NotImplementedError, RuntimeError):
# Windows doesn't support SIGTERM or add_signal_handler
# Use fallback for Windows
pass
logging.info("Starting bot...")
await bot.start(DISCORD_TOKEN)
except Exception as e:
error_msg = traceback.format_exc()
logging.critical(f"Fatal error in main function: {str(e)}\n{error_msg}")
# Clean up resources if initialization failed halfway
await cleanup_resources(bot=None, db_handler=db_handler, message_handler=message_handler)
async def shutdown(sig, loop, bot, db_handler, message_handler):
"""Handle graceful shutdown of the bot"""
logging.info(f"Received signal {sig.name}. Starting graceful shutdown...")
# Set shutdown flag to stop ongoing tasks
shutdown_flag.set()
# Give running tasks a moment to detect shutdown flag
await asyncio.sleep(1)
# Start cleanup
await cleanup_resources(bot, db_handler, message_handler)
# Stop the event loop
loop.stop()
async def cleanup_resources(bot, db_handler, message_handler):
"""Clean up all resources during shutdown"""
try:
# Close the bot connection
if bot is not None:
logging.info("Closing bot connection...")
await bot.close()
# Close message handler resources
if message_handler is not None:
logging.info("Closing message handler resources...")
await message_handler.close()
# Close database connection
if db_handler is not None:
logging.info("Closing database connection...")
await db_handler.close()
# Clean up webhook logging
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
logging.info("Cleaning up webhook logging...")
webhook_log_manager.cleanup()
logging.info("Cleanup completed successfully")
except Exception as e:
logging.error(f"Error during cleanup: {str(e)}")
if __name__ == "__main__":
try:
# Use asyncio.run to properly run the async main function
asyncio.run(main())
except KeyboardInterrupt:
logging.info("Bot stopped via keyboard interrupt")
except Exception as e:
logging.critical(f"Unhandled exception in main thread: {str(e)}")
traceback.print_exc()
finally:
logging.info("Bot shut down completely")
import os
import sys
import discord
import logging
import asyncio
import signal
import traceback
import time
import logging.config
from discord.ext import commands, tasks
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
from discord import app_commands
# Import configuration
from src.config.config import (
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
)
# Import webhook logger
from src.utils.webhook_logger import webhook_log_manager, webhook_logger
# Import database handler
from src.database.db_handler import DatabaseHandler
# Import the message handler
from src.module.message_handler import MessageHandler
# Import various utility modules
from src.utils.image_utils import ImageGenerator
# Global shutdown flag
shutdown_flag = asyncio.Event()
# Load environment variables
load_dotenv()
# Configure logging with more detail, rotation, and webhook integration
def setup_logging():
# Apply the dictionary config
try:
logging.config.dictConfig(LOGGING_CONFIG)
logging.info("Configured logging from dictionary configuration")
except Exception as e:
# Fall back to basic configuration
log_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
)
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_formatter)
# File handler with rotation (keep 5 files of 5MB each)
try:
from logging.handlers import RotatingFileHandler
os.makedirs('logs', exist_ok=True)
file_handler = RotatingFileHandler(
'logs/discord_bot.log',
maxBytes=5*1024*1024, # 5MB
backupCount=5
)
file_handler.setFormatter(log_formatter)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
root_logger.addHandler(file_handler)
except Exception as e:
# Fall back to basic logging if file logging fails
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logging.warning(f"Could not set up file logging: {str(e)}")
# Set up webhook logging if enabled
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
try:
# Convert string log level to int using our mapping
log_level = LOG_LEVEL_MAP.get(WEBHOOK_LOG_LEVEL.upper(), logging.INFO)
# Set up webhook logging
webhook_log_manager.setup_webhook_logging(
webhook_url=LOGGING_WEBHOOK_URL,
app_name=WEBHOOK_APP_NAME,
level=log_level,
loggers=None, # Use root logger
batch_size=WEBHOOK_BATCH_SIZE,
flush_interval=WEBHOOK_FLUSH_INTERVAL
)
logging.info(f"Webhook logging enabled at level {WEBHOOK_LOG_LEVEL}")
except Exception as e:
logging.error(f"Failed to set up webhook logging: {str(e)}")
# Create a function to change bot status periodically
async def change_status_loop(bot):
"""Change bot status every 5 minutes"""
while not shutdown_flag.is_set():
for status in STATUSES:
await bot.change_presence(activity=discord.Game(name=status))
try:
# Wait but be interruptible
await asyncio.wait_for(shutdown_flag.wait(), timeout=300)
if shutdown_flag.is_set():
break
except asyncio.TimeoutError:
# Normal timeout, continue to next status
continue
async def main():
# Set up logging
setup_logging()
# Check if required environment variables are set
missing_vars = []
if not DISCORD_TOKEN:
missing_vars.append("DISCORD_TOKEN")
if not MONGODB_URI:
missing_vars.append("MONGODB_URI")
if missing_vars:
logging.error(f"The following required environment variables are not set: {', '.join(missing_vars)}")
return
if not RUNWARE_API_KEY:
logging.warning("RUNWARE_API_KEY environment variable not set - image generation will not work")
# Initialize the OpenAI client
try:
from openai import AsyncOpenAI
openai_client = AsyncOpenAI()
logging.info("OpenAI client initialized successfully")
except ImportError:
logging.error("Failed to import OpenAI. Make sure it's installed: pip install openai")
return
except Exception as e:
logging.error(f"Error initializing OpenAI client: {e}")
return
# Global references to objects that need cleanup
message_handler = None
db_handler = None
try:
# Initialize image generator if API key is available
image_generator = None
if RUNWARE_API_KEY:
try:
image_generator = ImageGenerator(RUNWARE_API_KEY)
logging.info("Image generator initialized successfully")
except Exception as e:
logging.error(f"Error initializing image generator: {e}")
# Set up Discord intents
intents = discord.Intents.default()
intents.message_content = True
# Initialize the bot with command prefixes and more robust timeout settings
bot = commands.Bot(
command_prefix="//quocanhvu",
intents=intents,
heartbeat_timeout=120,
max_messages=10000 # Cache more messages to improve experience
)
# Initialize database handler
db_handler = DatabaseHandler(MONGODB_URI)
# Create database indexes for performance
await db_handler.create_indexes()
logging.info("Database indexes created")
# Khởi tạo collection reminders
await db_handler.ensure_reminders_collection()
# Event handler when the bot is ready
@bot.event
async def on_ready():
"""Bot startup event to sync slash commands and start status loop."""
await bot.tree.sync() # Sync slash commands
bot_info = f"Logged in as {bot.user} (ID: {bot.user.id})"
logging.info("=" * len(bot_info))
logging.info(bot_info)
logging.info(f"Connected to {len(bot.guilds)} guilds")
logging.info("=" * len(bot_info))
# Start the status changing task
asyncio.create_task(change_status_loop(bot))
# Handle general errors to prevent crashes
@bot.event
async def on_error(event, *args, **kwargs):
error_msg = traceback.format_exc()
logging.error(f"Discord event error in {event}:\n{error_msg}")
@bot.event
async def on_command_error(ctx, error):
if isinstance(error, commands.CommandNotFound):
return
error_msg = str(error)
trace = "".join(traceback.format_exception(type(error), error, error.__traceback__))
logging.error(f"Command error: {error_msg}\n{trace}")
await ctx.send(f"Error: {error_msg}")
# Initialize message handler
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
# Set up slash commands
from src.commands.commands import setup_commands
setup_commands(bot, db_handler, openai_client, image_generator)
# Handle shutdown signals
loop = asyncio.get_running_loop()
# Signal handlers for graceful shutdown
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(
sig,
lambda sig=sig: asyncio.create_task(shutdown(sig, loop, bot, db_handler, message_handler))
)
except (NotImplementedError, RuntimeError):
# Windows doesn't support SIGTERM or add_signal_handler
# Use fallback for Windows
pass
logging.info("Starting bot...")
await bot.start(DISCORD_TOKEN)
except Exception as e:
error_msg = traceback.format_exc()
logging.critical(f"Fatal error in main function: {str(e)}\n{error_msg}")
# Clean up resources if initialization failed halfway
await cleanup_resources(bot=None, db_handler=db_handler, message_handler=message_handler)
async def shutdown(sig, loop, bot, db_handler, message_handler):
"""Handle graceful shutdown of the bot"""
logging.info(f"Received signal {sig.name}. Starting graceful shutdown...")
# Set shutdown flag to stop ongoing tasks
shutdown_flag.set()
# Give running tasks a moment to detect shutdown flag
await asyncio.sleep(1)
# Start cleanup
await cleanup_resources(bot, db_handler, message_handler)
# Stop the event loop
loop.stop()
async def cleanup_resources(bot, db_handler, message_handler):
"""Clean up all resources during shutdown"""
try:
# Close the bot connection
if bot is not None:
logging.info("Closing bot connection...")
await bot.close()
# Close message handler resources
if message_handler is not None:
logging.info("Closing message handler resources...")
await message_handler.close()
# Close database connection
if db_handler is not None:
logging.info("Closing database connection...")
await db_handler.close()
# Clean up webhook logging
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
logging.info("Cleaning up webhook logging...")
webhook_log_manager.cleanup()
logging.info("Cleanup completed successfully")
except Exception as e:
logging.error(f"Error during cleanup: {str(e)}")
if __name__ == "__main__":
try:
# Use asyncio.run to properly run the async main function
asyncio.run(main())
except KeyboardInterrupt:
logging.info("Bot stopped via keyboard interrupt")
except Exception as e:
logging.critical(f"Unhandled exception in main thread: {str(e)}")
traceback.print_exc()
finally:
logging.info("Bot shut down completely")

View File

@@ -1,17 +1,18 @@
discord.py>=2.3.0
openai>=1.3.0
motor>=3.3.0
pymongo>=4.6.0
tiktoken>=0.5.0
PyPDF2>=3.0.0
beautifulsoup4>=4.12.0
requests>=2.31.0
aiohttp>=3.9.0
runware>=0.2.0
python-dotenv>=1.0.0
webdriver-manager
matplotlib
pandas
openpyxl
pytz
xlrd
discord.py>=2.3.0
openai>=1.3.0
motor>=3.3.0
pymongo>=4.6.0
tiktoken>=0.5.0
PyPDF2>=3.0.0
beautifulsoup4>=4.12.0
requests>=2.31.0
aiohttp>=3.9.0
runware>=0.2.0
python-dotenv>=1.0.0
webdriver-manager
matplotlib
pandas
openpyxl
pytz
xlrd
scipy

View File

@@ -1,466 +1,466 @@
import discord
from discord import app_commands
from discord.ext import commands
import logging
import io
import asyncio
from typing import Optional, Dict, List, Any, Callable
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS
from src.utils.image_utils import ImageGenerator
from src.utils.web_utils import google_custom_search, scrape_web_content
from src.utils.pdf_utils import process_pdf, send_response
# Dictionary to keep track of user requests and their cooldowns
user_requests = {}
# Dictionary to store user tasks
user_tasks = {}
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
"""
Set up all slash commands for the bot.
Args:
bot: Discord bot instance
db_handler: Database handler instance
openai_client: OpenAI client instance
image_generator: Image generator instance
"""
tree = bot.tree
def check_blacklist():
"""Decorator to check if a user is blacklisted before executing a command."""
async def predicate(interaction: discord.Interaction):
if await db_handler.is_admin(interaction.user.id):
return True
if await db_handler.is_user_blacklisted(interaction.user.id):
await interaction.response.send_message("You have been blacklisted from using this bot. Please contact the admin if you think this is a mistake.", ephemeral=True)
return False
return True
return app_commands.check(predicate)
# Processes a command request with rate limiting and queuing.
async def process_request(interaction, command_func, *args):
user_id = interaction.user.id
now = discord.utils.utcnow().timestamp()
if user_id not in user_requests:
user_requests[user_id] = {'last_request': 0, 'queue': asyncio.Queue()}
last_request = user_requests[user_id]['last_request']
if now - last_request < 5:
await interaction.followup.send("You are sending requests too quickly. Please wait a moment.", ephemeral=True)
return
# Update last request time
user_requests[user_id]['last_request'] = now
# Add request to queue
queue = user_requests[user_id]['queue']
await queue.put((command_func, args))
# Start processing if it's the only request in the queue
if queue.qsize() == 1:
await process_queue(interaction)
# Processes requests in the user's queue sequentially.
async def process_queue(interaction):
user_id = interaction.user.id
queue = user_requests[user_id]['queue']
while not queue.empty():
command_func, args = await queue.get()
try:
await command_func(interaction, *args)
except Exception as e:
logging.error(f"Error processing command: {str(e)}")
await interaction.followup.send(f"An error occurred: {str(e)}", ephemeral=True)
await asyncio.sleep(1) # Optional delay between processing
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
@check_blacklist()
async def choose_model(interaction: discord.Interaction):
"""Lets users choose an AI model and saves it to the database."""
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
async def select_callback(interaction: discord.Interaction):
selected_model = select_menu.values[0]
user_id = interaction.user.id
# Save the model selection to the database
await db_handler.save_user_model(user_id, selected_model)
await interaction.response.send_message(
f"Model set to `{selected_model}` for your responses.", ephemeral=True
)
select_menu.callback = select_callback
view = discord.ui.View()
view.add_item(select_menu)
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
@tree.command(name="search", description="Search on Google and send results to AI model.")
@app_commands.describe(query="The search query")
@check_blacklist()
async def search(interaction: discord.Interaction, query: str):
"""Searches Google and sends results to the AI model."""
await interaction.response.defer(thinking=True)
async def process_search(interaction: discord.Interaction, query: str):
user_id = interaction.user.id
model = await db_handler.get_user_model(user_id) or "gpt-4o"
history = await db_handler.get_history(user_id)
try:
# Perform Google search
search_results = google_custom_search(query)
if not search_results or not search_results.get('results'):
await interaction.followup.send("No search results found.")
return
# Format search results for the AI model
from src.config.config import SEARCH_PROMPT
formatted_results = f"Search results for: {query}\n\n"
for i, result in enumerate(search_results.get('results', [])):
formatted_results += f"{i+1}. {result.get('title')}\n"
formatted_results += f"URL: {result.get('link')}\n"
formatted_results += f"Snippet: {result.get('snippet')}\n"
if 'scraped_content' in result:
content_preview = result['scraped_content'][:300] + "..." if len(result['scraped_content']) > 300 else result['scraped_content']
formatted_results += f"Content: {content_preview}\n"
formatted_results += "\n"
# Prepare messages for the AI model, handling system prompts appropriately
messages = []
if model in ["o1-mini", "o1-preview"]:
messages = [
{"role": "user", "content": f"Instructions: {SEARCH_PROMPT}\n\n{formatted_results}\n\nUser query: {query}"}
]
else:
messages = [
{"role": "system", "content": SEARCH_PROMPT},
{"role": "user", "content": f"{formatted_results}\n\nUser query: {query}"}
]
# Send to the AI model
response = await openai_client.chat.completions.create(
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
messages=messages,
temperature=0.5
)
reply = response.choices[0].message.content
# Add the interaction to history
history.append({"role": "user", "content": f"Search query: {query}"})
history.append({"role": "assistant", "content": reply})
await db_handler.save_history(user_id, history)
# Check if the reply exceeds Discord's character limit (2000)
if len(reply) > 2000:
# Create a text file with the full response
file_bytes = io.BytesIO(reply.encode('utf-8'))
file = discord.File(file_bytes, filename="search_response.txt")
# Send a short message with the file attachment
await interaction.followup.send(
f"The search response for '{query}' is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
file=file
)
else:
# Send as normal message if within limits
await interaction.followup.send(reply)
except Exception as e:
error_message = f"Search error: {str(e)}"
logging.error(error_message)
await interaction.followup.send(f"An error occurred while searching: {str(e)}")
await process_request(interaction, process_search, query)
@tree.command(name="web", description="Scrape a webpage and send data to AI model.")
@app_commands.describe(url="The webpage URL to scrape")
@check_blacklist()
async def web(interaction: discord.Interaction, url: str):
"""Scrapes a webpage and sends data to the AI model."""
await interaction.response.defer(thinking=True)
async def process_web(interaction: discord.Interaction, url: str):
user_id = interaction.user.id
model = await db_handler.get_user_model(user_id) or "gpt-4o"
history = await db_handler.get_history(user_id)
try:
content = scrape_web_content(url)
if content.startswith("Failed"):
await interaction.followup.send(content)
return
from src.config.config import WEB_SCRAPING_PROMPT
if model in ["o1-mini", "o1-preview"]:
messages = [
{"role": "user", "content": f"Instructions: {WEB_SCRAPING_PROMPT}\n\nContent from {url}:\n{content}"}
]
else:
messages = [
{"role": "system", "content": WEB_SCRAPING_PROMPT},
{"role": "user", "content": f"Content from {url}:\n{content}"}
]
response = await openai_client.chat.completions.create(
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
messages=messages,
temperature=0.3,
top_p=0.7
)
reply = response.choices[0].message.content
# Add the interaction to history
history.append({"role": "user", "content": f"Scraped content from {url}"})
history.append({"role": "assistant", "content": reply})
await db_handler.save_history(user_id, history)
# Check if the reply exceeds Discord's character limit (2000)
if len(reply) > 2000:
# Create a text file with the full response
file_bytes = io.BytesIO(reply.encode('utf-8'))
file = discord.File(file_bytes, filename="web_response.txt")
# Send a short message with the file attachment
await interaction.followup.send(
f"The response from analyzing {url} is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
file=file
)
else:
# Send as normal message if within limits
await interaction.followup.send(reply)
except Exception as e:
await interaction.followup.send(f"Error: {str(e)}", ephemeral=True)
await process_request(interaction, process_web, url)
@tree.command(name='generate', description='Generates an image from a text prompt.')
@app_commands.describe(prompt='The prompt for image generation')
@check_blacklist()
async def generate_image_command(interaction: discord.Interaction, prompt: str):
"""Generates an image from a text prompt."""
await interaction.response.defer(thinking=True) # Indicate that the bot is processing
async def process_image_generation(interaction: discord.Interaction, prompt: str):
try:
# Generate images
result = await image_generator.generate_image(prompt, 4) # Generate 4 images
if not result['success']:
await interaction.followup.send(f"Error: {result.get('error', 'Unknown error')}")
return
# Send images as attachments
if result["binary_images"]:
await interaction.followup.send(
f"Generated {len(result['binary_images'])} images for prompt: \"{prompt}\"",
files=[discord.File(io.BytesIO(img), filename=f"image_{i}.png")
for i, img in enumerate(result["binary_images"])]
)
else:
await interaction.followup.send("No images were generated.")
except Exception as e:
error_message = f"An error occurred: {str(e)}"
logging.error(f"Error in generate_image_command: {error_message}")
await interaction.followup.send(error_message)
await process_request(interaction, process_image_generation, prompt)
@tree.command(name="reset", description="Reset the bot by clearing user data.")
@check_blacklist()
async def reset(interaction: discord.Interaction):
"""Resets the bot by clearing user data."""
user_id = interaction.user.id
await db_handler.save_history(user_id, [])
await interaction.response.send_message("Your conversation history has been cleared and reset!", ephemeral=True)
@tree.command(name="user_stat", description="Get your current input token, output token, and model.")
@check_blacklist()
async def user_stat(interaction: discord.Interaction):
"""Fetches and displays the current input token, output token, and model for the user."""
await interaction.response.defer(thinking=True, ephemeral=True)
async def process_user_stat(interaction: discord.Interaction):
import tiktoken
user_id = interaction.user.id
history = await db_handler.get_history(user_id)
model = await db_handler.get_user_model(user_id) or "gpt-4o" # Default model
# Adjust model for encoding purposes
if model in ["gpt-4o", "o1", "o1-preview", "o1-mini", "o3-mini"]:
encoding_model = "gpt-4o"
else:
encoding_model = model
# Retrieve the appropriate encoding for the selected model
encoding = tiktoken.encoding_for_model(encoding_model)
# Initialize token counts
input_tokens = 0
output_tokens = 0
# Calculate input and output tokens
if history:
for item in history:
content = item.get('content', '')
# Handle case where content is a list or other type
if isinstance(content, list):
content_str = ""
for part in content:
if isinstance(part, dict) and 'text' in part:
content_str += part['text'] + " "
content = content_str
# Ensure content is a string before processing
if isinstance(content, str):
tokens = len(encoding.encode(content))
if item.get('role') == 'user':
input_tokens += tokens
elif item.get('role') == 'assistant':
output_tokens += tokens
# Create the statistics message
stat_message = (
f"**User Statistics:**\n"
f"Model: `{model}`\n"
f"Input Tokens: `{input_tokens}`\n"
f"Output Tokens: `{output_tokens}`\n"
)
# Send the response
await interaction.followup.send(stat_message, ephemeral=True)
await process_request(interaction, process_user_stat)
@tree.command(name="help", description="Display a list of available commands.")
@check_blacklist()
async def help_command(interaction: discord.Interaction):
"""Sends a list of available commands to the user."""
help_message = (
"**Available commands:**\n"
"/choose_model - Select which AI model to use for responses (gpt-4o, gpt-4o-mini, o1-preview, o1-mini).\n"
"/search `<query>` - Search Google and send results to the AI model.\n"
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
"/generate `<prompt>` - Generate an image from a text prompt.\n"
"/reset - Reset your chat history.\n"
"/user_stat - Get information about your input tokens, output tokens, and current model.\n"
"/help - Display this help message.\n"
)
await interaction.response.send_message(help_message, ephemeral=True)
@tree.command(name="stop", description="Stop any process or queue of the user. Admins can stop other users' tasks by providing their ID.")
@app_commands.describe(user_id="The Discord user ID to stop tasks for (admin only)")
@check_blacklist()
async def stop(interaction: discord.Interaction, user_id: str = None):
"""Stops any process or queue of the user. Admins can stop other users' tasks by providing their ID."""
# Defer the interaction first
await interaction.response.defer(ephemeral=True)
if user_id and not await db_handler.is_admin(interaction.user.id):
await interaction.followup.send("You don't have permission to stop other users' tasks.", ephemeral=True)
return
target_user_id = int(user_id) if user_id else interaction.user.id
await stop_user_tasks(target_user_id)
await interaction.followup.send(f"Stopped all tasks for user {target_user_id}.", ephemeral=True)
# Admin commands
@tree.command(name="whitelist_add", description="Add a user to the PDF processing whitelist")
@app_commands.describe(user_id="The Discord user ID to whitelist")
async def whitelist_add(interaction: discord.Interaction, user_id: str):
"""Adds a user to the PDF processing whitelist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.is_admin(user_id):
await interaction.response.send_message("Admins are automatically whitelisted and don't need to be added.", ephemeral=True)
return
await db_handler.add_user_to_whitelist(user_id)
await interaction.response.send_message(f"User {user_id} has been added to the PDF processing whitelist.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
@tree.command(name="whitelist_remove", description="Remove a user from the PDF processing whitelist")
@app_commands.describe(user_id="The Discord user ID to remove from whitelist")
async def whitelist_remove(interaction: discord.Interaction, user_id: str):
"""Removes a user from the PDF processing whitelist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.remove_user_from_whitelist(user_id):
await interaction.response.send_message(f"User {user_id} has been removed from the PDF processing whitelist.", ephemeral=True)
else:
await interaction.response.send_message(f"User {user_id} was not found in the whitelist.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
@tree.command(name="blacklist_add", description="Add a user to the bot blacklist")
@app_commands.describe(user_id="The Discord user ID to blacklist")
async def blacklist_add(interaction: discord.Interaction, user_id: str):
"""Adds a user to the bot blacklist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.is_admin(user_id):
await interaction.response.send_message("Cannot blacklist an admin.", ephemeral=True)
return
await db_handler.add_user_to_blacklist(user_id)
await interaction.response.send_message(f"User {user_id} has been added to the bot blacklist. They can no longer use any bot features.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
@tree.command(name="blacklist_remove", description="Remove a user from the bot blacklist")
@app_commands.describe(user_id="The Discord user ID to remove from blacklist")
async def blacklist_remove(interaction: discord.Interaction, user_id: str):
"""Removes a user from the bot blacklist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.remove_user_from_blacklist(user_id):
await interaction.response.send_message(f"User {user_id} has been removed from the bot blacklist. They can now use bot features again.", ephemeral=True)
else:
await interaction.response.send_message(f"User {user_id} was not found in the blacklist.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
# Helper function to stop user tasks
async def stop_user_tasks(user_id: int):
"""Stop all tasks for a specific user."""
if user_id in user_tasks:
for task in user_tasks[user_id]:
task.cancel()
user_tasks[user_id] = []
# Clear any queued requests
if user_id in user_requests:
while not user_requests[user_id]['queue'].empty():
try:
user_requests[user_id]['queue'].get_nowait()
except:
import discord
from discord import app_commands
from discord.ext import commands
import logging
import io
import asyncio
from typing import Optional, Dict, List, Any, Callable
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS
from src.utils.image_utils import ImageGenerator
from src.utils.web_utils import google_custom_search, scrape_web_content
from src.utils.pdf_utils import process_pdf, send_response
# Dictionary to keep track of user requests and their cooldowns
user_requests = {}
# Dictionary to store user tasks
user_tasks = {}
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
"""
Set up all slash commands for the bot.
Args:
bot: Discord bot instance
db_handler: Database handler instance
openai_client: OpenAI client instance
image_generator: Image generator instance
"""
tree = bot.tree
def check_blacklist():
"""Decorator to check if a user is blacklisted before executing a command."""
async def predicate(interaction: discord.Interaction):
if await db_handler.is_admin(interaction.user.id):
return True
if await db_handler.is_user_blacklisted(interaction.user.id):
await interaction.response.send_message("You have been blacklisted from using this bot. Please contact the admin if you think this is a mistake.", ephemeral=True)
return False
return True
return app_commands.check(predicate)
# Processes a command request with rate limiting and queuing.
async def process_request(interaction, command_func, *args):
user_id = interaction.user.id
now = discord.utils.utcnow().timestamp()
if user_id not in user_requests:
user_requests[user_id] = {'last_request': 0, 'queue': asyncio.Queue()}
last_request = user_requests[user_id]['last_request']
if now - last_request < 5:
await interaction.followup.send("You are sending requests too quickly. Please wait a moment.", ephemeral=True)
return
# Update last request time
user_requests[user_id]['last_request'] = now
# Add request to queue
queue = user_requests[user_id]['queue']
await queue.put((command_func, args))
# Start processing if it's the only request in the queue
if queue.qsize() == 1:
await process_queue(interaction)
# Processes requests in the user's queue sequentially.
async def process_queue(interaction):
user_id = interaction.user.id
queue = user_requests[user_id]['queue']
while not queue.empty():
command_func, args = await queue.get()
try:
await command_func(interaction, *args)
except Exception as e:
logging.error(f"Error processing command: {str(e)}")
await interaction.followup.send(f"An error occurred: {str(e)}", ephemeral=True)
await asyncio.sleep(1) # Optional delay between processing
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
@check_blacklist()
async def choose_model(interaction: discord.Interaction):
"""Lets users choose an AI model and saves it to the database."""
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
async def select_callback(interaction: discord.Interaction):
selected_model = select_menu.values[0]
user_id = interaction.user.id
# Save the model selection to the database
await db_handler.save_user_model(user_id, selected_model)
await interaction.response.send_message(
f"Model set to `{selected_model}` for your responses.", ephemeral=True
)
select_menu.callback = select_callback
view = discord.ui.View()
view.add_item(select_menu)
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
@tree.command(name="search", description="Search on Google and send results to AI model.")
@app_commands.describe(query="The search query")
@check_blacklist()
async def search(interaction: discord.Interaction, query: str):
"""Searches Google and sends results to the AI model."""
await interaction.response.defer(thinking=True)
async def process_search(interaction: discord.Interaction, query: str):
user_id = interaction.user.id
model = await db_handler.get_user_model(user_id) or "gpt-4o"
history = await db_handler.get_history(user_id)
try:
# Perform Google search
search_results = google_custom_search(query)
if not search_results or not search_results.get('results'):
await interaction.followup.send("No search results found.")
return
# Format search results for the AI model
from src.config.config import SEARCH_PROMPT
formatted_results = f"Search results for: {query}\n\n"
for i, result in enumerate(search_results.get('results', [])):
formatted_results += f"{i+1}. {result.get('title')}\n"
formatted_results += f"URL: {result.get('link')}\n"
formatted_results += f"Snippet: {result.get('snippet')}\n"
if 'scraped_content' in result:
content_preview = result['scraped_content'][:300] + "..." if len(result['scraped_content']) > 300 else result['scraped_content']
formatted_results += f"Content: {content_preview}\n"
formatted_results += "\n"
# Prepare messages for the AI model, handling system prompts appropriately
messages = []
if model in ["o1-mini", "o1-preview"]:
messages = [
{"role": "user", "content": f"Instructions: {SEARCH_PROMPT}\n\n{formatted_results}\n\nUser query: {query}"}
]
else:
messages = [
{"role": "system", "content": SEARCH_PROMPT},
{"role": "user", "content": f"{formatted_results}\n\nUser query: {query}"}
]
# Send to the AI model
response = await openai_client.chat.completions.create(
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
messages=messages,
temperature=0.5
)
reply = response.choices[0].message.content
# Add the interaction to history
history.append({"role": "user", "content": f"Search query: {query}"})
history.append({"role": "assistant", "content": reply})
await db_handler.save_history(user_id, history)
# Check if the reply exceeds Discord's character limit (2000)
if len(reply) > 2000:
# Create a text file with the full response
file_bytes = io.BytesIO(reply.encode('utf-8'))
file = discord.File(file_bytes, filename="search_response.txt")
# Send a short message with the file attachment
await interaction.followup.send(
f"The search response for '{query}' is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
file=file
)
else:
# Send as normal message if within limits
await interaction.followup.send(reply)
except Exception as e:
error_message = f"Search error: {str(e)}"
logging.error(error_message)
await interaction.followup.send(f"An error occurred while searching: {str(e)}")
await process_request(interaction, process_search, query)
@tree.command(name="web", description="Scrape a webpage and send data to AI model.")
@app_commands.describe(url="The webpage URL to scrape")
@check_blacklist()
async def web(interaction: discord.Interaction, url: str):
"""Scrapes a webpage and sends data to the AI model."""
await interaction.response.defer(thinking=True)
async def process_web(interaction: discord.Interaction, url: str):
user_id = interaction.user.id
model = await db_handler.get_user_model(user_id) or "gpt-4o"
history = await db_handler.get_history(user_id)
try:
content = scrape_web_content(url)
if content.startswith("Failed"):
await interaction.followup.send(content)
return
from src.config.config import WEB_SCRAPING_PROMPT
if model in ["o1-mini", "o1-preview"]:
messages = [
{"role": "user", "content": f"Instructions: {WEB_SCRAPING_PROMPT}\n\nContent from {url}:\n{content}"}
]
else:
messages = [
{"role": "system", "content": WEB_SCRAPING_PROMPT},
{"role": "user", "content": f"Content from {url}:\n{content}"}
]
response = await openai_client.chat.completions.create(
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
messages=messages,
temperature=0.3,
top_p=0.7
)
reply = response.choices[0].message.content
# Add the interaction to history
history.append({"role": "user", "content": f"Scraped content from {url}"})
history.append({"role": "assistant", "content": reply})
await db_handler.save_history(user_id, history)
# Check if the reply exceeds Discord's character limit (2000)
if len(reply) > 2000:
# Create a text file with the full response
file_bytes = io.BytesIO(reply.encode('utf-8'))
file = discord.File(file_bytes, filename="web_response.txt")
# Send a short message with the file attachment
await interaction.followup.send(
f"The response from analyzing {url} is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
file=file
)
else:
# Send as normal message if within limits
await interaction.followup.send(reply)
except Exception as e:
await interaction.followup.send(f"Error: {str(e)}", ephemeral=True)
await process_request(interaction, process_web, url)
@tree.command(name='generate', description='Generates an image from a text prompt.')
@app_commands.describe(prompt='The prompt for image generation')
@check_blacklist()
async def generate_image_command(interaction: discord.Interaction, prompt: str):
"""Generates an image from a text prompt."""
await interaction.response.defer(thinking=True) # Indicate that the bot is processing
async def process_image_generation(interaction: discord.Interaction, prompt: str):
try:
# Generate images
result = await image_generator.generate_image(prompt, 4) # Generate 4 images
if not result['success']:
await interaction.followup.send(f"Error: {result.get('error', 'Unknown error')}")
return
# Send images as attachments
if result["binary_images"]:
await interaction.followup.send(
f"Generated {len(result['binary_images'])} images for prompt: \"{prompt}\"",
files=[discord.File(io.BytesIO(img), filename=f"image_{i}.png")
for i, img in enumerate(result["binary_images"])]
)
else:
await interaction.followup.send("No images were generated.")
except Exception as e:
error_message = f"An error occurred: {str(e)}"
logging.error(f"Error in generate_image_command: {error_message}")
await interaction.followup.send(error_message)
await process_request(interaction, process_image_generation, prompt)
@tree.command(name="reset", description="Reset the bot by clearing user data.")
@check_blacklist()
async def reset(interaction: discord.Interaction):
"""Resets the bot by clearing user data."""
user_id = interaction.user.id
await db_handler.save_history(user_id, [])
await interaction.response.send_message("Your conversation history has been cleared and reset!", ephemeral=True)
@tree.command(name="user_stat", description="Get your current input token, output token, and model.")
@check_blacklist()
async def user_stat(interaction: discord.Interaction):
"""Fetches and displays the current input token, output token, and model for the user."""
await interaction.response.defer(thinking=True, ephemeral=True)
async def process_user_stat(interaction: discord.Interaction):
import tiktoken
user_id = interaction.user.id
history = await db_handler.get_history(user_id)
model = await db_handler.get_user_model(user_id) or "gpt-4o" # Default model
# Adjust model for encoding purposes
if model in ["gpt-4o", "o1", "o1-preview", "o1-mini", "o3-mini"]:
encoding_model = "gpt-4o"
else:
encoding_model = model
# Retrieve the appropriate encoding for the selected model
encoding = tiktoken.encoding_for_model(encoding_model)
# Initialize token counts
input_tokens = 0
output_tokens = 0
# Calculate input and output tokens
if history:
for item in history:
content = item.get('content', '')
# Handle case where content is a list or other type
if isinstance(content, list):
content_str = ""
for part in content:
if isinstance(part, dict) and 'text' in part:
content_str += part['text'] + " "
content = content_str
# Ensure content is a string before processing
if isinstance(content, str):
tokens = len(encoding.encode(content))
if item.get('role') == 'user':
input_tokens += tokens
elif item.get('role') == 'assistant':
output_tokens += tokens
# Create the statistics message
stat_message = (
f"**User Statistics:**\n"
f"Model: `{model}`\n"
f"Input Tokens: `{input_tokens}`\n"
f"Output Tokens: `{output_tokens}`\n"
)
# Send the response
await interaction.followup.send(stat_message, ephemeral=True)
await process_request(interaction, process_user_stat)
@tree.command(name="help", description="Display a list of available commands.")
@check_blacklist()
async def help_command(interaction: discord.Interaction):
"""Sends a list of available commands to the user."""
help_message = (
"**Available commands:**\n"
"/choose_model - Select which AI model to use for responses (gpt-4o, gpt-4o-mini, o1-preview, o1-mini).\n"
"/search `<query>` - Search Google and send results to the AI model.\n"
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
"/generate `<prompt>` - Generate an image from a text prompt.\n"
"/reset - Reset your chat history.\n"
"/user_stat - Get information about your input tokens, output tokens, and current model.\n"
"/help - Display this help message.\n"
)
await interaction.response.send_message(help_message, ephemeral=True)
@tree.command(name="stop", description="Stop any process or queue of the user. Admins can stop other users' tasks by providing their ID.")
@app_commands.describe(user_id="The Discord user ID to stop tasks for (admin only)")
@check_blacklist()
async def stop(interaction: discord.Interaction, user_id: str = None):
"""Stops any process or queue of the user. Admins can stop other users' tasks by providing their ID."""
# Defer the interaction first
await interaction.response.defer(ephemeral=True)
if user_id and not await db_handler.is_admin(interaction.user.id):
await interaction.followup.send("You don't have permission to stop other users' tasks.", ephemeral=True)
return
target_user_id = int(user_id) if user_id else interaction.user.id
await stop_user_tasks(target_user_id)
await interaction.followup.send(f"Stopped all tasks for user {target_user_id}.", ephemeral=True)
# Admin commands
@tree.command(name="whitelist_add", description="Add a user to the PDF processing whitelist")
@app_commands.describe(user_id="The Discord user ID to whitelist")
async def whitelist_add(interaction: discord.Interaction, user_id: str):
"""Adds a user to the PDF processing whitelist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.is_admin(user_id):
await interaction.response.send_message("Admins are automatically whitelisted and don't need to be added.", ephemeral=True)
return
await db_handler.add_user_to_whitelist(user_id)
await interaction.response.send_message(f"User {user_id} has been added to the PDF processing whitelist.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
@tree.command(name="whitelist_remove", description="Remove a user from the PDF processing whitelist")
@app_commands.describe(user_id="The Discord user ID to remove from whitelist")
async def whitelist_remove(interaction: discord.Interaction, user_id: str):
"""Removes a user from the PDF processing whitelist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.remove_user_from_whitelist(user_id):
await interaction.response.send_message(f"User {user_id} has been removed from the PDF processing whitelist.", ephemeral=True)
else:
await interaction.response.send_message(f"User {user_id} was not found in the whitelist.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
@tree.command(name="blacklist_add", description="Add a user to the bot blacklist")
@app_commands.describe(user_id="The Discord user ID to blacklist")
async def blacklist_add(interaction: discord.Interaction, user_id: str):
"""Adds a user to the bot blacklist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.is_admin(user_id):
await interaction.response.send_message("Cannot blacklist an admin.", ephemeral=True)
return
await db_handler.add_user_to_blacklist(user_id)
await interaction.response.send_message(f"User {user_id} has been added to the bot blacklist. They can no longer use any bot features.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
@tree.command(name="blacklist_remove", description="Remove a user from the bot blacklist")
@app_commands.describe(user_id="The Discord user ID to remove from blacklist")
async def blacklist_remove(interaction: discord.Interaction, user_id: str):
"""Removes a user from the bot blacklist."""
if not await db_handler.is_admin(interaction.user.id):
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
return
try:
user_id = int(user_id)
if await db_handler.remove_user_from_blacklist(user_id):
await interaction.response.send_message(f"User {user_id} has been removed from the bot blacklist. They can now use bot features again.", ephemeral=True)
else:
await interaction.response.send_message(f"User {user_id} was not found in the blacklist.", ephemeral=True)
except ValueError:
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
# Helper function to stop user tasks
async def stop_user_tasks(user_id: int):
"""Stop all tasks for a specific user."""
if user_id in user_tasks:
for task in user_tasks[user_id]:
task.cancel()
user_tasks[user_id] = []
# Clear any queued requests
if user_id in user_requests:
while not user_requests[user_id]['queue'].empty():
try:
user_requests[user_id]['queue'].get_nowait()
except:
pass

View File

@@ -1,265 +1,265 @@
from motor.motor_asyncio import AsyncIOMotorClient
from typing import List, Dict, Any, Optional
import functools
import asyncio
from datetime import datetime, timedelta
import logging
import re
class DatabaseHandler:
# Class-level cache for database results
_cache = {}
_cache_expiry = {}
_cache_lock = asyncio.Lock()
def __init__(self, mongodb_uri: str):
"""Initialize database connection with optimized settings"""
# Set up a connection pool with sensible timeouts
self.client = AsyncIOMotorClient(
mongodb_uri,
maxPoolSize=50,
minPoolSize=10,
maxIdleTimeMS=45000,
connectTimeoutMS=2000,
serverSelectionTimeoutMS=3000,
waitQueueTimeoutMS=1000,
retryWrites=True
)
self.db = self.client['chatgpt_discord_bot'] # Database name
# Collections
self.users_collection = self.db.users
self.history_collection = self.db.history
self.admin_collection = self.db.admin
self.blacklist_collection = self.db.blacklist
self.whitelist_collection = self.db.whitelist
self.logs_collection = self.db.logs
self.reminders_collection = self.db.reminders
logging.info("Database handler initialized")
# Helper for caching results
async def _get_cached_result(self, cache_key, fetch_func, expiry_seconds=60):
"""Get result from cache or execute fetch_func if not cached/expired"""
current_time = datetime.now()
# Check if we have a cached result that's still valid
async with self._cache_lock:
if (cache_key in self._cache and
cache_key in self._cache_expiry and
current_time < self._cache_expiry[cache_key]):
return self._cache[cache_key]
# Not in cache or expired, fetch new result
result = await fetch_func()
# Cache the new result
async with self._cache_lock:
self._cache[cache_key] = result
self._cache_expiry[cache_key] = current_time + timedelta(seconds=expiry_seconds)
return result
# User history methods
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
"""Get user conversation history with caching and filter expired image links"""
cache_key = f"history_{user_id}"
async def fetch_history():
user_data = await self.db.user_histories.find_one({'user_id': user_id})
if user_data and 'history' in user_data:
# Filter out expired image links
filtered_history = self._filter_expired_images(user_data['history'])
return filtered_history
return []
return await self._get_cached_result(cache_key, fetch_history, 30) # 30 second cache
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter out image links that are older than 23 hours"""
current_time = datetime.now()
expiration_time = current_time - timedelta(hours=23)
filtered_history = []
for msg in history:
# Keep system messages unchanged
if msg.get('role') == 'system':
filtered_history.append(msg)
continue
# Check if message has 'content' field as a list (which may contain image URLs)
content = msg.get('content')
if isinstance(content, list):
# Filter content items
filtered_content = []
for item in content:
# Keep text items
if item.get('type') == 'text':
filtered_content.append(item)
# Check image items for timestamp
elif item.get('type') == 'image_url':
# If there's no timestamp or timestamp is newer than expiration time, keep it
timestamp = item.get('timestamp')
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time:
filtered_content.append(item)
else:
logging.info(f"Filtering out expired image URL (added at {timestamp})")
# Update the message with filtered content
if filtered_content:
new_msg = dict(msg)
new_msg['content'] = filtered_content
filtered_history.append(new_msg)
else:
# If after filtering there's no content, add a placeholder text
new_msg = dict(msg)
new_msg['content'] = [{"type": "text", "text": "[Image content expired]"}]
filtered_history.append(new_msg)
else:
# For string content or other formats, keep as is
filtered_history.append(msg)
return filtered_history
async def save_history(self, user_id: int, history: List[Dict[str, Any]]) -> None:
"""Save user conversation history and update cache"""
await self.db.user_histories.update_one(
{'user_id': user_id},
{'$set': {'history': history}},
upsert=True
)
# Update the cache
cache_key = f"history_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = history
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=30)
# User model preferences with caching
async def get_user_model(self, user_id: int) -> Optional[str]:
"""Get user's preferred model with caching"""
cache_key = f"model_{user_id}"
async def fetch_model():
user_data = await self.db.user_models.find_one({'user_id': user_id})
return user_data['model'] if user_data else None
return await self._get_cached_result(cache_key, fetch_model, 300) # 5 minute cache
async def save_user_model(self, user_id: int, model: str) -> None:
"""Save user's preferred model and update cache"""
await self.db.user_models.update_one(
{'user_id': user_id},
{'$set': {'model': model}},
upsert=True
)
# Update the cache
cache_key = f"model_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = model
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
# Admin and permissions management with caching
async def is_admin(self, user_id: int) -> bool:
"""Check if the user is an admin (no caching for security)"""
admin_id = str(user_id) # Convert to string for comparison
from src.config.config import ADMIN_ID
return admin_id == ADMIN_ID
async def is_user_whitelisted(self, user_id: int) -> bool:
"""Check if the user is whitelisted with caching"""
if await self.is_admin(user_id):
return True
cache_key = f"whitelist_{user_id}"
async def check_whitelist():
user_data = await self.db.whitelist.find_one({'user_id': user_id})
return user_data is not None
return await self._get_cached_result(cache_key, check_whitelist, 300) # 5 minute cache
async def add_user_to_whitelist(self, user_id: int) -> None:
"""Add user to whitelist and update cache"""
await self.db.whitelist.update_one(
{'user_id': user_id},
{'$set': {'user_id': user_id}},
upsert=True
)
# Update the cache
cache_key = f"whitelist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = True
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
async def remove_user_from_whitelist(self, user_id: int) -> bool:
"""Remove user from whitelist and update cache"""
result = await self.db.whitelist.delete_one({'user_id': user_id})
# Update the cache
cache_key = f"whitelist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = False
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
return result.deleted_count > 0
async def is_user_blacklisted(self, user_id: int) -> bool:
"""Check if the user is blacklisted with caching"""
cache_key = f"blacklist_{user_id}"
async def check_blacklist():
user_data = await self.db.blacklist.find_one({'user_id': user_id})
return user_data is not None
return await self._get_cached_result(cache_key, check_blacklist, 300) # 5 minute cache
async def add_user_to_blacklist(self, user_id: int) -> None:
"""Add user to blacklist and update cache"""
await self.db.blacklist.update_one(
{'user_id': user_id},
{'$set': {'user_id': user_id}},
upsert=True
)
# Update the cache
cache_key = f"blacklist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = True
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
async def remove_user_from_blacklist(self, user_id: int) -> bool:
"""Remove user from blacklist and update cache"""
result = await self.db.blacklist.delete_one({'user_id': user_id})
# Update the cache
cache_key = f"blacklist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = False
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
return result.deleted_count > 0
# Connection management and cleanup
async def create_indexes(self):
"""Create indexes for better query performance"""
await self.db.user_histories.create_index("user_id")
await self.db.user_models.create_index("user_id")
await self.db.whitelist.create_index("user_id")
await self.db.blacklist.create_index("user_id")
async def ensure_reminders_collection(self):
"""
Ensure the reminders collection exists and create necessary indexes
"""
# Create the collection if it doesn't exist
await self.reminders_collection.create_index([("user_id", 1), ("sent", 1)])
await self.reminders_collection.create_index([("remind_at", 1), ("sent", 1)])
logging.info("Ensured reminders collection and indexes")
async def close(self):
"""Properly close the database connection"""
self.client.close()
from motor.motor_asyncio import AsyncIOMotorClient
from typing import List, Dict, Any, Optional
import functools
import asyncio
from datetime import datetime, timedelta
import logging
import re
class DatabaseHandler:
# Class-level cache for database results
_cache = {}
_cache_expiry = {}
_cache_lock = asyncio.Lock()
def __init__(self, mongodb_uri: str):
"""Initialize database connection with optimized settings"""
# Set up a connection pool with sensible timeouts
self.client = AsyncIOMotorClient(
mongodb_uri,
maxPoolSize=50,
minPoolSize=10,
maxIdleTimeMS=45000,
connectTimeoutMS=2000,
serverSelectionTimeoutMS=3000,
waitQueueTimeoutMS=1000,
retryWrites=True
)
self.db = self.client['chatgpt_discord_bot'] # Database name
# Collections
self.users_collection = self.db.users
self.history_collection = self.db.history
self.admin_collection = self.db.admin
self.blacklist_collection = self.db.blacklist
self.whitelist_collection = self.db.whitelist
self.logs_collection = self.db.logs
self.reminders_collection = self.db.reminders
logging.info("Database handler initialized")
# Helper for caching results
async def _get_cached_result(self, cache_key, fetch_func, expiry_seconds=60):
"""Get result from cache or execute fetch_func if not cached/expired"""
current_time = datetime.now()
# Check if we have a cached result that's still valid
async with self._cache_lock:
if (cache_key in self._cache and
cache_key in self._cache_expiry and
current_time < self._cache_expiry[cache_key]):
return self._cache[cache_key]
# Not in cache or expired, fetch new result
result = await fetch_func()
# Cache the new result
async with self._cache_lock:
self._cache[cache_key] = result
self._cache_expiry[cache_key] = current_time + timedelta(seconds=expiry_seconds)
return result
# User history methods
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
"""Get user conversation history with caching and filter expired image links"""
cache_key = f"history_{user_id}"
async def fetch_history():
user_data = await self.db.user_histories.find_one({'user_id': user_id})
if user_data and 'history' in user_data:
# Filter out expired image links
filtered_history = self._filter_expired_images(user_data['history'])
return filtered_history
return []
return await self._get_cached_result(cache_key, fetch_history, 30) # 30 second cache
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter out image links that are older than 23 hours"""
current_time = datetime.now()
expiration_time = current_time - timedelta(hours=23)
filtered_history = []
for msg in history:
# Keep system messages unchanged
if msg.get('role') == 'system':
filtered_history.append(msg)
continue
# Check if message has 'content' field as a list (which may contain image URLs)
content = msg.get('content')
if isinstance(content, list):
# Filter content items
filtered_content = []
for item in content:
# Keep text items
if item.get('type') == 'text':
filtered_content.append(item)
# Check image items for timestamp
elif item.get('type') == 'image_url':
# If there's no timestamp or timestamp is newer than expiration time, keep it
timestamp = item.get('timestamp')
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time:
filtered_content.append(item)
else:
logging.info(f"Filtering out expired image URL (added at {timestamp})")
# Update the message with filtered content
if filtered_content:
new_msg = dict(msg)
new_msg['content'] = filtered_content
filtered_history.append(new_msg)
else:
# If after filtering there's no content, add a placeholder text
new_msg = dict(msg)
new_msg['content'] = [{"type": "text", "text": "[Image content expired]"}]
filtered_history.append(new_msg)
else:
# For string content or other formats, keep as is
filtered_history.append(msg)
return filtered_history
async def save_history(self, user_id: int, history: List[Dict[str, Any]]) -> None:
"""Save user conversation history and update cache"""
await self.db.user_histories.update_one(
{'user_id': user_id},
{'$set': {'history': history}},
upsert=True
)
# Update the cache
cache_key = f"history_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = history
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=30)
# User model preferences with caching
async def get_user_model(self, user_id: int) -> Optional[str]:
"""Get user's preferred model with caching"""
cache_key = f"model_{user_id}"
async def fetch_model():
user_data = await self.db.user_models.find_one({'user_id': user_id})
return user_data['model'] if user_data else None
return await self._get_cached_result(cache_key, fetch_model, 300) # 5 minute cache
async def save_user_model(self, user_id: int, model: str) -> None:
"""Save user's preferred model and update cache"""
await self.db.user_models.update_one(
{'user_id': user_id},
{'$set': {'model': model}},
upsert=True
)
# Update the cache
cache_key = f"model_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = model
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
# Admin and permissions management with caching
async def is_admin(self, user_id: int) -> bool:
"""Check if the user is an admin (no caching for security)"""
admin_id = str(user_id) # Convert to string for comparison
from src.config.config import ADMIN_ID
return admin_id == ADMIN_ID
async def is_user_whitelisted(self, user_id: int) -> bool:
"""Check if the user is whitelisted with caching"""
if await self.is_admin(user_id):
return True
cache_key = f"whitelist_{user_id}"
async def check_whitelist():
user_data = await self.db.whitelist.find_one({'user_id': user_id})
return user_data is not None
return await self._get_cached_result(cache_key, check_whitelist, 300) # 5 minute cache
async def add_user_to_whitelist(self, user_id: int) -> None:
"""Add user to whitelist and update cache"""
await self.db.whitelist.update_one(
{'user_id': user_id},
{'$set': {'user_id': user_id}},
upsert=True
)
# Update the cache
cache_key = f"whitelist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = True
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
async def remove_user_from_whitelist(self, user_id: int) -> bool:
"""Remove user from whitelist and update cache"""
result = await self.db.whitelist.delete_one({'user_id': user_id})
# Update the cache
cache_key = f"whitelist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = False
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
return result.deleted_count > 0
async def is_user_blacklisted(self, user_id: int) -> bool:
"""Check if the user is blacklisted with caching"""
cache_key = f"blacklist_{user_id}"
async def check_blacklist():
user_data = await self.db.blacklist.find_one({'user_id': user_id})
return user_data is not None
return await self._get_cached_result(cache_key, check_blacklist, 300) # 5 minute cache
async def add_user_to_blacklist(self, user_id: int) -> None:
"""Add user to blacklist and update cache"""
await self.db.blacklist.update_one(
{'user_id': user_id},
{'$set': {'user_id': user_id}},
upsert=True
)
# Update the cache
cache_key = f"blacklist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = True
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
async def remove_user_from_blacklist(self, user_id: int) -> bool:
"""Remove user from blacklist and update cache"""
result = await self.db.blacklist.delete_one({'user_id': user_id})
# Update the cache
cache_key = f"blacklist_{user_id}"
async with self._cache_lock:
self._cache[cache_key] = False
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
return result.deleted_count > 0
# Connection management and cleanup
async def create_indexes(self):
"""Create indexes for better query performance"""
await self.db.user_histories.create_index("user_id")
await self.db.user_models.create_index("user_id")
await self.db.whitelist.create_index("user_id")
await self.db.blacklist.create_index("user_id")
async def ensure_reminders_collection(self):
"""
Ensure the reminders collection exists and create necessary indexes
"""
# Create the collection if it doesn't exist
await self.reminders_collection.create_index([("user_id", 1), ("sent", 1)])
await self.reminders_collection.create_index([("remind_at", 1), ("sent", 1)])
logging.info("Ensured reminders collection and indexes")
async def close(self):
"""Properly close the database connection"""
self.client.close()
logging.info("Database connection closed")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,291 +1,532 @@
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import io
import logging
import asyncio
import functools
import os
import numpy as np
from datetime import datetime
from typing import Tuple, Dict, Any, Optional, List
# Ensure matplotlib doesn't require a GUI backend
matplotlib.use('Agg')
async def process_data_file(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
"""
Analyze and visualize data from CSV/Excel files.
Args:
file_bytes: File content as bytes
filename: File name
query: User command/query
Returns:
Tuple containing text summary, image bytes (if any) and metadata
"""
try:
# Use thread pool to avoid blocking event loop with CPU-bound tasks
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
functools.partial(_process_data_file_sync, file_bytes, filename, query)
)
except Exception as e:
logging.error(f"Error processing data file: {str(e)}")
return f"Error processing file {filename}: {str(e)}", None, None
def _process_data_file_sync(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
"""Synchronous version of process_data_file to run in thread pool"""
file_obj = io.BytesIO(file_bytes)
try:
# Read file based on format with improved error handling
if filename.lower().endswith('.csv'):
try:
# Try multiple encodings and separator detection
df = pd.read_csv(file_obj, encoding='utf-8')
except UnicodeDecodeError:
# Reset file pointer and try different encoding
file_obj.seek(0)
df = pd.read_csv(file_obj, encoding='latin1')
elif filename.lower().endswith(('.xlsx', '.xls')):
try:
df = pd.read_excel(file_obj)
except Exception as excel_err:
logging.error(f"Excel read error: {excel_err}")
# Try with engine specification
file_obj.seek(0)
df = pd.read_excel(file_obj, engine='openpyxl')
else:
return "Unsupported file format. Please use CSV or Excel.", None, None
if df.empty:
return "The file does not contain any data.", None, None
# Clean column names
df.columns = [str(col).strip() for col in df.columns]
# Create metadata for dataframe
rows = len(df)
columns = len(df.columns)
column_names = list(df.columns)
# Data preprocessing for better analysis
# Convert potential date columns
for col in df.columns:
# Try to convert columns that might be dates but are stored as strings
if df[col].dtype == 'object':
try:
# Check if the column might contain dates
sample = df[col].dropna().iloc[0] if not df[col].dropna().empty else None
if sample and isinstance(sample, str) and ('/' in sample or '-' in sample):
df[col] = pd.to_datetime(df[col], errors='ignore')
except Exception:
pass # Skip if conversion fails
# Format data for analysis
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
date_cols = df.select_dtypes(include=['datetime']).columns.tolist()
# Create basic information about the data
summary = f"Data analysis for {filename}:\n"
summary += f"- Rows: {rows}\n"
summary += f"- Columns: {columns}\n"
summary += f"- Column names: {', '.join(column_names)}\n\n"
# Basic descriptive statistics
if len(numeric_cols) > 0:
summary += "Statistics for numeric data:\n"
desc_stats = df[numeric_cols].describe().round(2)
summary += desc_stats.to_string() + "\n\n"
# Statistics for categorical data
if len(categorical_cols) > 0:
summary += "Value distribution for categorical columns:\n"
for col in categorical_cols[:3]: # Limit displayed columns
value_counts = df[col].value_counts().head(5)
summary += f"{col}: {dict(value_counts)}\n"
if len(categorical_cols) > 3:
summary += f"...and {len(categorical_cols) - 3} other categorical columns.\n"
summary += "\n"
# Determine if a chart should be created
chart_keywords = ["chart", "graph", "plot", "visualization", "visualize",
"histogram", "bar chart", "line chart", "pie chart", "scatter"]
create_chart = any(keyword in query.lower() for keyword in chart_keywords)
# Metadata to return
metadata = {
"filename": filename,
"rows": rows,
"columns": columns,
"column_names": column_names,
"numeric_columns": numeric_cols,
"categorical_columns": categorical_cols,
"date_columns": date_cols
}
# Create chart if requested or by default
chart_image = None
if (create_chart or len(query) < 10) and len(numeric_cols) > 0: # Default to chart for short queries
plt.figure(figsize=(10, 6))
# Better chart styling
plt.style.use('seaborn-v0_8')
# Determine chart type based on keywords in the query
if any(keyword in query.lower() for keyword in ["pie", "circle"]):
# Pie chart - works best with categorical data
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
cat_col = categorical_cols[0]
num_col = numeric_cols[0]
# Better handling of pie chart data
top_categories = df.groupby(cat_col)[num_col].sum().nlargest(5)
# Add "Other" category if there are more than 5 categories
if len(df[cat_col].unique()) > 5:
other_sum = df.groupby(cat_col)[num_col].sum().sum() - top_categories.sum()
if other_sum > 0:
top_categories["Other"] = other_sum
plt.figure(figsize=(10, 7))
plt.pie(top_categories, labels=top_categories.index, autopct='%1.1f%%',
shadow=True, startangle=90)
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
plt.title(f"Pie Chart: {num_col} by {cat_col}")
elif any(keyword in query.lower() for keyword in ["bar", "column"]):
# Bar chart
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
cat_col = categorical_cols[0]
num_col = numeric_cols[0]
# Sort for better visualization
top_values = df.groupby(cat_col)[num_col].sum().nlargest(10)
top_values.plot.bar(color='skyblue', edgecolor='black')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.title(f"Bar Chart: {num_col} by {cat_col} (Top 10)")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
else:
df[numeric_cols[0]].nlargest(10).plot.bar(color='skyblue', edgecolor='black')
plt.title(f"Top 10 highest values of {numeric_cols[0]}")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
elif any(keyword in query.lower() for keyword in ["scatter", "dispersion"]):
# Scatter plot
if len(numeric_cols) >= 2:
plt.figure(figsize=(9, 6))
plt.scatter(df[numeric_cols[0]], df[numeric_cols[1]], alpha=0.6,
edgecolor='w', s=50)
plt.xlabel(numeric_cols[0])
plt.ylabel(numeric_cols[1])
plt.title(f"Scatter plot: {numeric_cols[0]} vs {numeric_cols[1]}")
# Add trend line if there seems to be a correlation
if abs(df[numeric_cols[0]].corr(df[numeric_cols[1]])) > 0.3:
z = np.polyfit(df[numeric_cols[0]].dropna(), df[numeric_cols[1]].dropna(), 1)
p = np.poly1d(z)
plt.plot(df[numeric_cols[0]].sort_values(),
p(df[numeric_cols[0]].sort_values()),
"r--", linewidth=1)
plt.grid(True, alpha=0.3)
elif any(keyword in query.lower() for keyword in ["histogram", "hist", "distribution"]):
# Histogram with better binning
plt.figure(figsize=(10, 6))
# Calculate optimal number of bins using Sturges' rule
data = df[numeric_cols[0]].dropna()
bins = int(np.ceil(np.log2(len(data))) + 1) if len(data) > 0 else 10
plt.hist(data, bins=min(bins, 30), color='skyblue', edgecolor='black')
plt.title(f"Distribution of {numeric_cols[0]}")
plt.xlabel(numeric_cols[0])
plt.ylabel("Frequency")
plt.grid(axis='y', linestyle='--', alpha=0.7)
elif len(date_cols) > 0 and len(numeric_cols) > 0:
# Time series chart if we have dates and numeric data
plt.figure(figsize=(12, 6))
date_col = date_cols[0]
num_col = numeric_cols[0]
# Sort by date and plot
temp_df = df[[date_col, num_col]].dropna().sort_values(date_col)
plt.plot(temp_df[date_col], temp_df[num_col], marker='o', markersize=3,
linestyle='-', linewidth=1)
plt.title(f"{num_col} over time")
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
else:
# Default: line chart if we have multiple numeric values
if len(numeric_cols) > 0:
plt.figure(figsize=(10, 6))
df[numeric_cols[0]].plot(color='#1f77b4', alpha=0.8)
plt.title(f"Line chart for {numeric_cols[0]}")
plt.grid(True, alpha=0.3)
plt.tight_layout()
# Save chart to bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
chart_image = buf.read()
plt.close()
# Create a timestamp for the chart file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chart_filename = f"chart_{timestamp}.png"
# Save chart to temporary file
chart_dir = os.path.join(os.getcwd(), "temp_charts")
if not os.path.exists(chart_dir):
os.makedirs(chart_dir)
chart_path = os.path.join(chart_dir, chart_filename)
with open(chart_path, "wb") as f:
f.write(chart_image)
summary += f"Chart created based on the data with improved visualization."
# Add chart filename to metadata
metadata["chart_filename"] = chart_filename
metadata["chart_path"] = chart_path
metadata["chart_created_at"] = datetime.now().timestamp()
return summary, chart_image, metadata
except Exception as e:
logging.error(f"Error in _process_data_file_sync: {str(e)}")
return f"Could not analyze file {filename}. Error: {str(e)}", None, None
async def cleanup_old_charts(max_age_hours=1):
"""
Clean up chart images older than the specified time
Args:
max_age_hours: Maximum age in hours before deleting charts
"""
try:
chart_dir = os.path.join(os.getcwd(), "temp_charts")
if not os.path.exists(chart_dir):
return
now = datetime.now().timestamp()
deleted_count = 0
for filename in os.listdir(chart_dir):
if filename.startswith("chart_") and filename.endswith(".png"):
file_path = os.path.join(chart_dir, filename)
file_modified_time = os.path.getmtime(file_path)
# If file is older than max_age_hours
if now - file_modified_time > (max_age_hours * 3600):
try:
os.remove(file_path)
deleted_count += 1
except Exception as e:
logging.error(f"Error deleting chart file {filename}: {str(e)}")
if deleted_count > 0:
logging.info(f"Cleaned up {deleted_count} chart files older than {max_age_hours} hours")
except Exception as e:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import io
import logging
import asyncio
import functools
import os
import numpy as np
from datetime import datetime
from typing import Tuple, Dict, Any, Optional, List
# Ensure matplotlib doesn't require a GUI backend
matplotlib.use('Agg')
# Set global matplotlib parameters for better readability
plt.rcParams.update({
'font.size': 12,
'axes.titlesize': 16,
'axes.labelsize': 14,
'xtick.labelsize': 12,
'ytick.labelsize': 12,
'legend.fontsize': 12,
'figure.titlesize': 18
})
async def process_data_file(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
"""
Analyze and visualize data from CSV/Excel files.
Args:
file_bytes: File content as bytes
filename: File name
query: User command/query
Returns:
Tuple containing text summary, image bytes (if any) and metadata
"""
try:
# Use thread pool to avoid blocking event loop with CPU-bound tasks
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
functools.partial(_process_data_file_sync, file_bytes, filename, query)
)
except Exception as e:
logging.error(f"Error processing data file: {str(e)}")
return f"Error processing file {filename}: {str(e)}", None, None
def _process_data_file_sync(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
"""Synchronous version of process_data_file to run in thread pool"""
file_obj = io.BytesIO(file_bytes)
try:
# Read file based on format with improved error handling
if filename.lower().endswith('.csv'):
try:
# Try multiple encodings and separator detection
df = pd.read_csv(file_obj, encoding='utf-8')
except UnicodeDecodeError:
# Reset file pointer and try different encoding
file_obj.seek(0)
df = pd.read_csv(file_obj, encoding='latin1')
elif filename.lower().endswith(('.xlsx', '.xls')):
try:
df = pd.read_excel(file_obj)
except Exception as excel_err:
logging.error(f"Excel read error: {excel_err}")
# Try with engine specification
file_obj.seek(0)
df = pd.read_excel(file_obj, engine='openpyxl')
else:
return "Unsupported file format. Please use CSV or Excel.", None, None
if df.empty:
return "The file does not contain any data.", None, None
# Clean column names
df.columns = [str(col).strip() for col in df.columns]
# Create metadata for dataframe
rows = len(df)
columns = len(df.columns)
column_names = list(df.columns)
# Data preprocessing for better analysis
# Convert potential date columns
for col in df.columns:
# Try to convert columns that might be dates but are stored as strings
if df[col].dtype == 'object':
try:
# Check if the column might contain dates
sample = df[col].dropna().iloc[0] if not df[col].dropna().empty else None
if sample and isinstance(sample, str) and ('/' in sample or '-' in sample):
df[col] = pd.to_datetime(df[col], errors='ignore')
except Exception:
pass # Skip if conversion fails
# Format data for analysis
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
date_cols = df.select_dtypes(include=['datetime']).columns.tolist()
# Create basic information about the data
summary = f"Data analysis for {filename}:\n"
summary += f"- Rows: {rows}\n"
summary += f"- Columns: {columns}\n"
summary += f"- Column names: {', '.join(column_names)}\n\n"
# Basic descriptive statistics
if len(numeric_cols) > 0:
summary += "Statistics for numeric data:\n"
desc_stats = df[numeric_cols].describe().round(2)
summary += desc_stats.to_string() + "\n\n"
# Statistics for categorical data
if len(categorical_cols) > 0:
summary += "Value distribution for categorical columns:\n"
for col in categorical_cols[:3]: # Limit displayed columns
value_counts = df[col].value_counts().head(5)
summary += f"{col}: {dict(value_counts)}\n"
if len(categorical_cols) > 3:
summary += f"...and {len(categorical_cols) - 3} other categorical columns.\n"
summary += "\n"
# Determine if a chart should be created
chart_keywords = ["chart", "graph", "plot", "visualization", "visualize",
"histogram", "bar chart", "line chart", "pie chart", "scatter"]
create_chart = any(keyword in query.lower() for keyword in chart_keywords)
# Metadata to return
metadata = {
"filename": filename,
"rows": rows,
"columns": columns,
"column_names": column_names,
"numeric_columns": numeric_cols,
"categorical_columns": categorical_cols,
"date_columns": date_cols
}
# Create chart if requested or by default
chart_image = None
if (create_chart or len(query) < 10) and len(numeric_cols) > 0: # Default to chart for short queries
# Use a better visual style
plt.style.use('ggplot')
# Determine chart type based on keywords in the query
if any(keyword in query.lower() for keyword in ["pie", "circle"]):
# Pie chart - works best with categorical data
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
cat_col = categorical_cols[0]
num_col = numeric_cols[0]
# Better handling of pie chart data - limit to top 6 categories max
top_categories = df.groupby(cat_col)[num_col].sum().nlargest(6)
# Add "Other" category if there are more than 6 categories
if len(df[cat_col].unique()) > 6:
other_sum = df.groupby(cat_col)[num_col].sum().sum() - top_categories.sum()
if other_sum > 0:
top_categories["Other"] = other_sum
# Larger figure size for better readability
plt.figure(figsize=(12, 9))
# Enhanced pie chart
wedges, texts, autotexts = plt.pie(
top_categories,
labels=None, # We'll add a legend instead of cluttering the pie
autopct='%1.1f%%',
shadow=False,
startangle=90,
explode=[0.05] * len(top_categories), # Slight separation for visibility
textprops={'color': 'white', 'weight': 'bold', 'fontsize': 14},
wedgeprops={'width': 0.6, 'edgecolor': 'white', 'linewidth': 2}
)
# Make the percentage labels more visible
for autotext in autotexts:
autotext.set_fontsize(12)
autotext.set_weight('bold')
# Add a legend outside the pie for better readability
plt.legend(
wedges,
top_categories.index,
title=cat_col,
loc="center left",
bbox_to_anchor=(1, 0, 0.5, 1)
)
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
plt.title(f"Distribution of {num_col} by {cat_col}", pad=20)
plt.tight_layout()
elif any(keyword in query.lower() for keyword in ["bar", "column"]):
# Bar chart
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
cat_col = categorical_cols[0]
num_col = numeric_cols[0]
# Determine optimal figure size based on number of categories
category_count = min(10, len(df[cat_col].unique()))
fig_height = max(6, category_count * 0.4 + 4) # Dynamic height based on categories
plt.figure(figsize=(12, fig_height))
# Sort for better visualization
top_values = df.groupby(cat_col)[num_col].sum().nlargest(10)
# Use a horizontal bar chart for better label readability with many categories
if len(top_values) > 5:
ax = top_values.plot.barh(
color='#5975a4',
edgecolor='#344e7a',
linewidth=1.5
)
# Add data labels at the end of each bar
for i, v in enumerate(top_values):
ax.text(v * 1.01, i, f'{v:,.1f}', va='center', fontweight='bold')
plt.xlabel(num_col)
plt.ylabel(cat_col)
else:
# For fewer categories, use vertical bars
ax = top_values.plot.bar(
color='#5975a4',
edgecolor='#344e7a',
linewidth=1.5
)
# Add data labels on top of each bar
for i, v in enumerate(top_values):
ax.text(i, v * 1.01, f'{v:,.1f}', ha='center', fontweight='bold')
plt.ylabel(num_col)
plt.xlabel(cat_col)
plt.xticks(rotation=30, ha='right')
plt.grid(axis='both', linestyle='--', alpha=0.7)
plt.title(f"{num_col} by {cat_col} (Top {len(top_values)})", pad=20)
plt.tight_layout(pad=2)
else:
# Improved bar chart for numeric data only
plt.figure(figsize=(12, 7))
top_values = df[numeric_cols[0]].nlargest(10)
ax = top_values.plot.bar(
color='#5975a4',
edgecolor='#344e7a',
linewidth=1.5
)
# Add value labels on top of bars
for i, v in enumerate(top_values):
ax.text(i, v * 1.01, f'{v:,.1f}', ha='center', fontweight='bold')
plt.title(f"Top {len(top_values)} highest values of {numeric_cols[0]}", pad=20)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=30, ha='right')
plt.tight_layout(pad=2)
elif any(keyword in query.lower() for keyword in ["scatter", "dispersion"]):
# Enhanced scatter plot
if len(numeric_cols) >= 2:
plt.figure(figsize=(12, 8))
# If we have a categorical column, use it for coloring
if len(categorical_cols) > 0:
# Limit to a reasonable number of categories for coloring
cat_col = categorical_cols[0]
top_cats = df[cat_col].value_counts().nlargest(8).index.tolist()
# Create a color map
colormap = plt.cm.get_cmap('tab10', len(top_cats))
# Plot each category with different color
for i, category in enumerate(top_cats):
subset = df[df[cat_col] == category]
plt.scatter(
subset[numeric_cols[0]],
subset[numeric_cols[1]],
alpha=0.7,
edgecolor='w',
s=80,
label=category,
color=colormap(i)
)
plt.legend(title=cat_col, loc='best')
else:
# Regular scatter plot with improved visibility
scatter = plt.scatter(
df[numeric_cols[0]],
df[numeric_cols[1]],
alpha=0.7,
edgecolor='w',
s=80,
c=df[numeric_cols[0]], # Color by x-axis value for visual enhancement
cmap='viridis'
)
plt.colorbar(scatter, label=numeric_cols[0])
plt.xlabel(numeric_cols[0], fontweight='bold')
plt.ylabel(numeric_cols[1], fontweight='bold')
plt.title(f"Scatter plot: {numeric_cols[0]} vs {numeric_cols[1]}", pad=20)
# Add trend line if there seems to be a correlation
if abs(df[numeric_cols[0]].corr(df[numeric_cols[1]])) > 0.3:
z = np.polyfit(df[numeric_cols[0]].dropna(), df[numeric_cols[1]].dropna(), 1)
p = np.poly1d(z)
plt.plot(
df[numeric_cols[0]].sort_values(),
p(df[numeric_cols[0]].sort_values()),
"r--",
linewidth=2,
label=f"Trend line (r={df[numeric_cols[0]].corr(df[numeric_cols[1]]):.2f})"
)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout(pad=2)
elif any(keyword in query.lower() for keyword in ["histogram", "hist", "distribution"]):
# Enhanced histogram
plt.figure(figsize=(12, 7))
# Calculate optimal number of bins
data = df[numeric_cols[0]].dropna()
# Better bin calculation based on data distribution
iqr = np.percentile(data, 75) - np.percentile(data, 25)
bin_width = 2 * iqr / (len(data) ** (1/3)) # Freedman-Diaconis rule
if bin_width > 0:
bins = int((data.max() - data.min()) / bin_width)
bins = min(max(bins, 10), 50) # Between 10 and 50 bins
else:
bins = 15 # Default if calculation fails
# Plot histogram with KDE
ax = plt.subplot(111)
n, bins_arr, patches = ax.hist(
data,
bins=bins,
alpha=0.7,
color='#5975a4',
edgecolor='#344e7a',
linewidth=1.5,
density=True # Normalize for KDE overlay
)
# Add KDE line for smoother visualization
from scipy import stats
kde_x = np.linspace(data.min(), data.max(), 1000)
kde = stats.gaussian_kde(data)
ax.plot(kde_x, kde(kde_x), 'r-', linewidth=2, label='Density')
# Add vertical lines for key statistics
mean_val = data.mean()
median_val = data.median()
ax.axvline(mean_val, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}')
ax.axvline(median_val, color='orange', linestyle='-.', linewidth=2, label=f'Median: {median_val:.2f}')
plt.title(f"Distribution of {numeric_cols[0]}", pad=20)
plt.xlabel(numeric_cols[0], fontweight='bold')
plt.ylabel("Frequency", fontweight='bold')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.legend()
plt.tight_layout(pad=2)
elif len(date_cols) > 0 and len(numeric_cols) > 0:
# Enhanced time series chart
plt.figure(figsize=(14, 8))
date_col = date_cols[0]
num_col = numeric_cols[0]
# Sort by date and plot
temp_df = df[[date_col, num_col]].dropna().sort_values(date_col)
# Limit number of points for readability if too many
if len(temp_df) > 100:
# Resample to reduce point density
temp_df = temp_df.set_index(date_col)
# Determine appropriate frequency based on date range
date_range = (temp_df.index.max() - temp_df.index.min()).days
if date_range > 365*2: # More than 2 years
freq = 'M' # Monthly
elif date_range > 90: # More than 3 months
freq = 'W' # Weekly
else:
freq = 'D' # Daily
temp_df = temp_df.resample(freq).mean().reset_index()
# Plot with enhanced styling
plt.plot(
temp_df[date_col],
temp_df[num_col],
marker='o',
markersize=6,
markerfacecolor='white',
markeredgecolor='#5975a4',
markeredgewidth=1.5,
linestyle='-',
linewidth=2,
color='#5975a4'
)
plt.title(f"{num_col} over time", pad=20)
plt.xlabel("Date", fontweight='bold')
plt.ylabel(num_col, fontweight='bold')
# Format x-axis date labels better
plt.gcf().autofmt_xdate()
plt.grid(True, alpha=0.3)
# Add trend line
try:
x = np.arange(len(temp_df))
z = np.polyfit(x, temp_df[num_col], 1)
p = np.poly1d(z)
plt.plot(temp_df[date_col], p(x), "r--", linewidth=2,
label=f"Trend line (slope: {z[0]:.4f})")
plt.legend()
except Exception:
pass # Skip trend line if it fails
plt.tight_layout(pad=2)
else:
# Default: enhanced line chart for numeric data
if len(numeric_cols) > 0:
plt.figure(figsize=(14, 8))
# Get the data
data = df[numeric_cols[0]]
# If too many points, bin or resample
if len(data) > 100:
# Use rolling average for smoother line
window = max(5, len(data) // 50) # Adaptive window size
rolling_data = data.rolling(window=window, center=True).mean()
# Plot both original and smoothed data
plt.plot(data.index, data, 'o', markersize=4, alpha=0.4, label='Original data')
plt.plot(
rolling_data.index,
rolling_data,
linewidth=3,
color='#d62728',
label=f'Moving average (window={window})'
)
plt.legend()
else:
# For fewer points, use a more detailed visualization
plt.plot(
data.index,
data,
marker='o',
markersize=6,
markerfacecolor='white',
markeredgecolor='#5975a4',
markeredgewidth=1.5,
linestyle='-',
linewidth=2,
color='#5975a4'
)
plt.title(f"Line chart for {numeric_cols[0]}", pad=20)
plt.ylabel(numeric_cols[0], fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout(pad=2)
# Save chart to bytes buffer with higher DPI for better quality
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
chart_image = buf.read()
plt.close()
# Create a timestamp for the chart file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chart_filename = f"chart_{timestamp}.png"
# Save chart to temporary file
chart_dir = os.path.join(os.getcwd(), "temp_charts")
if not os.path.exists(chart_dir):
os.makedirs(chart_dir)
chart_path = os.path.join(chart_dir, chart_filename)
with open(chart_path, "wb") as f:
f.write(chart_image)
summary += f"Chart created based on the data with improved visualization."
# Add chart filename to metadata
metadata["chart_filename"] = chart_filename
metadata["chart_path"] = chart_path
metadata["chart_created_at"] = datetime.now().timestamp()
return summary, chart_image, metadata
except Exception as e:
logging.error(f"Error in _process_data_file_sync: {str(e)}")
return f"Could not analyze file {filename}. Error: {str(e)}", None, None
async def cleanup_old_charts(max_age_hours=1):
"""
Clean up chart images older than the specified time
Args:
max_age_hours: Maximum age in hours before deleting charts
"""
try:
chart_dir = os.path.join(os.getcwd(), "temp_charts")
if not os.path.exists(chart_dir):
return
now = datetime.now().timestamp()
deleted_count = 0
for filename in os.listdir(chart_dir):
if filename.startswith("chart_") and filename.endswith(".png"):
file_path = os.path.join(chart_dir, filename)
file_modified_time = os.path.getmtime(file_path)
# If file is older than max_age_hours
if now - file_modified_time > (max_age_hours * 3600):
try:
os.remove(file_path)
deleted_count += 1
except Exception as e:
logging.error(f"Error deleting chart file {filename}: {str(e)}")
if deleted_count > 0:
logging.info(f"Cleaned up {deleted_count} chart files older than {max_age_hours} hours")
except Exception as e:
logging.error(f"Error in cleanup_old_charts: {str(e)}")

View File

@@ -1,298 +1,312 @@
import json
import logging
import asyncio
from typing import List, Dict, Any, Tuple, Optional, Callable
def get_tools_for_model() -> List[Dict[str, Any]]:
"""
Returns the list of tools available to the model.
Returns:
List of tool objects
"""
return [
{
"type": "function",
"function": {
"name": "google_search",
"description": "Search the web for current information. Use this when you need to answer questions about current events or recent information.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"num_results": {
"type": "integer",
"description": "The number of search results to return (1-10)",
"default": 3,
"minimum": 1,
"maximum": 10
}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "scrape_webpage",
"description": "Scrape and extract content from a webpage. Use this to get the content of a specific webpage.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL of the webpage to scrape"
}
},
"required": ["url"]
}
}
},
{
"type": "function",
"function": {
"name": "code_interpreter",
"description": "Run code in Python or other supported languages. Use this to execute code, perform calculations, generate plots, and analyze data.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to execute"
},
"language": {
"type": "string",
"description": "The programming language (default: python)",
"default": "python",
"enum": ["python", "javascript", "bash", "c++"]
},
"input": {
"type": "string",
"description": "Optional input data for the code"
}
},
"required": ["code"]
}
}
},
{
"type": "function",
"function": {
"name": "generate_image",
"description": "Generate images based on text prompts. Use this when the user asks for an image to be created.",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The prompt describing the image to generate"
},
"num_images": {
"type": "integer",
"description": "The number of images to generate (1-4)",
"default": 1,
"minimum": 1,
"maximum": 4
}
},
"required": ["prompt"]
}
}
},
{
"type": "function",
"function": {
"name": "analyze_data",
"description": "Analyze data files (CSV, Excel) and create visualizations. Use this when users need to analyze data, create charts, or extract insights from their data files.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query describing what analysis to perform on the data, including what type of chart to create (e.g. 'Create a histogram of ages', 'Show a pie chart of categories', 'Calculate average by group')"
},
"visualization_type": {
"type": "string",
"description": "The type of visualization to create",
"enum": ["bar", "line", "pie", "scatter", "histogram", "auto"],
"default": "auto"
}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "set_reminder",
"description": "Set a reminder for the user. Use this when a user wants to be reminded about something at a specific time.",
"parameters": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content of the reminder"
},
"time": {
"type": "string",
"description": "The time for the reminder. Can be relative (e.g., '30m', '2h', '1d') or specific times ('tomorrow', '3:00pm', etc.)"
}
},
"required": ["content", "time"]
}
}
},
{
"type": "function",
"function": {
"name": "get_reminders",
"description": "Get a list of upcoming reminders for the user. Use this when user asks about their reminders.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}
]
async def process_tool_calls(client, response, messages, tool_functions) -> Tuple[bool, List[Dict[str, Any]]]:
"""
Process and execute tool calls from the OpenAI API response.
Args:
client: OpenAI client
response: API response containing tool calls
messages: The current chat messages
tool_functions: Dictionary mapping tool names to handler functions
Returns:
Tuple containing (processed_any_tools, updated_messages)
"""
processed_any = False
tool_calls = response.choices[0].message.tool_calls
# Create a copy of the messages to update
updated_messages = messages.copy()
# Add the assistant message with the tool calls
updated_messages.append({
"role": "assistant",
"content": response.choices[0].message.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
} for tc in tool_calls
] if tool_calls else None
})
# Process each tool call
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name in tool_functions:
# Parse the JSON arguments
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
logging.error(f"Invalid JSON in tool call arguments: {tool_call.function.arguments}")
function_args = {}
# Call the appropriate function
try:
function_response = await tool_functions[function_name](function_args)
# Add the tool output back to messages
updated_messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": str(function_response)
})
processed_any = True
except Exception as e:
error_message = f"Error executing {function_name}: {str(e)}"
logging.error(error_message)
# Add the error as tool output
updated_messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": error_message
})
processed_any = True
return processed_any, updated_messages
def count_tokens(text: str) -> int:
"""Estimate token count using a simple approximation."""
# Rough estimate: 1 word ≈ 1.3 tokens
return int(len(text.split()) * 1.3)
def trim_content_to_token_limit(content: str, max_tokens: int = 8096) -> str:
"""Trim content to stay within token limit while preserving the most recent content."""
current_tokens = count_tokens(content)
if current_tokens <= max_tokens:
return content
# Split into lines and start removing from the beginning until under limit
lines = content.split('\n')
while lines and count_tokens('\n'.join(lines)) > max_tokens:
lines.pop(0)
if not lines: # If still too long, take the last part
text = content
while count_tokens(text) > max_tokens:
text = text[text.find('\n', 1000):]
return text
return '\n'.join(lines)
def prepare_messages_for_api(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Prepare message history for the OpenAI API.
Args:
messages: List of message objects
Returns:
Prepared messages for API
"""
prepared_messages = []
for msg in messages:
# Create a copy of the message to avoid modifying the original
processed_msg = dict(msg)
# Handle image URLs with timestamps in content
if isinstance(processed_msg.get('content'), list):
# Filter out images that have a timestamp (they're already handled specially)
new_content = []
for item in processed_msg['content']:
if item.get('type') == 'image_url' and 'timestamp' in item:
# Remove timestamp from API calls
new_item = dict(item)
if 'timestamp' in new_item:
del new_item['timestamp']
new_content.append(new_item)
else:
new_content.append(item)
processed_msg['content'] = new_content
prepared_messages.append(processed_msg)
import json
import logging
import asyncio
from typing import List, Dict, Any, Tuple, Optional, Callable
def get_tools_for_model() -> List[Dict[str, Any]]:
"""
Returns the list of tools available to the model.
Returns:
List of tool objects
"""
return [
{
"type": "function",
"function": {
"name": "google_search",
"description": "Search the web for current information. Use this when you need to answer questions about current events or recent information.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"num_results": {
"type": "integer",
"description": "The number of search results to return (1-10)",
"default": 3,
"minimum": 1,
"maximum": 10
}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "scrape_webpage",
"description": "Scrape and extract content from a webpage. Use this to get the content of a specific webpage.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL of the webpage to scrape"
}
},
"required": ["url"]
}
}
},
{
"type": "function",
"function": {
"name": "code_interpreter",
"description": "Run code in Python or other supported languages. Use this to execute code, perform calculations, generate plots, and analyze data.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to execute"
},
"language": {
"type": "string",
"description": "The programming language (default: python)",
"default": "python",
"enum": ["python", "javascript", "bash", "c++"]
},
"input": {
"type": "string",
"description": "Optional input data for the code"
}
},
"required": ["code"]
}
}
},
{
"type": "function",
"function": {
"name": "generate_image",
"description": "Generate images based on text prompts. Use this when the user asks for an image to be created.",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The prompt describing the image to generate"
},
"num_images": {
"type": "integer",
"description": "The number of images to generate (1-4)",
"default": 1,
"minimum": 1,
"maximum": 4
}
},
"required": ["prompt"]
}
}
},
{
"type": "function",
"function": {
"name": "analyze_data",
"description": "Analyze data files (CSV, Excel) and create visualizations. Use this when users need to analyze data, create charts, or extract insights from their data files.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query describing what analysis to perform on the data, including what type of chart to create (e.g. 'Create a histogram of ages', 'Show a pie chart of categories', 'Calculate average by group')"
},
"visualization_type": {
"type": "string",
"description": "The type of visualization to create",
"enum": ["bar", "line", "pie", "scatter", "histogram", "auto"],
"default": "auto"
}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "set_reminder",
"description": "Set a reminder for the user. Use this when a user wants to be reminded about something at a specific time.",
"parameters": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content of the reminder"
},
"time": {
"type": "string",
"description": "The time for the reminder. Can be relative (e.g., '30m', '2h', '1d') or specific times ('tomorrow', '3:00pm', etc.)"
}
},
"required": ["content", "time"]
}
}
},
{
"type": "function",
"function": {
"name": "get_reminders",
"description": "Get a list of upcoming reminders for the user. Use this when user asks about their reminders.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}
]
async def process_tool_calls(client, response, messages, tool_functions) -> Tuple[bool, List[Dict[str, Any]]]:
"""
Process and execute tool calls from the OpenAI API response.
Args:
client: OpenAI client
response: API response containing tool calls
messages: The current chat messages
tool_functions: Dictionary mapping tool names to handler functions
Returns:
Tuple containing (processed_any_tools, updated_messages)
"""
processed_any = False
tool_calls = response.choices[0].message.tool_calls
# Create a copy of the messages to update
updated_messages = messages.copy()
# Add the assistant message with the tool calls
updated_messages.append({
"role": "assistant",
"content": response.choices[0].message.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
} for tc in tool_calls
] if tool_calls else None
})
# Process each tool call
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name in tool_functions:
# Parse the JSON arguments
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
logging.error(f"Invalid JSON in tool call arguments: {tool_call.function.arguments}")
function_args = {}
# Call the appropriate function
try:
function_response = await tool_functions[function_name](function_args)
# Add the tool output back to messages
updated_messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": str(function_response)
})
processed_any = True
except Exception as e:
error_message = f"Error executing {function_name}: {str(e)}"
logging.error(error_message)
# Add the error as tool output
updated_messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": error_message
})
processed_any = True
return processed_any, updated_messages
def count_tokens(text: str) -> int:
"""Estimate token count using a simple approximation."""
# Rough estimate: 1 word ≈ 1.3 tokens
return int(len(text.split()) * 1.3)
def trim_content_to_token_limit(content: str, max_tokens: int = 8096) -> str:
"""Trim content to stay within token limit while preserving the most recent content."""
current_tokens = count_tokens(content)
if current_tokens <= max_tokens:
return content
# Split into lines and start removing from the beginning until under limit
lines = content.split('\n')
while lines and count_tokens('\n'.join(lines)) > max_tokens:
lines.pop(0)
if not lines: # If still too long, take the last part
text = content
while count_tokens(text) > max_tokens:
text = text[text.find('\n', 1000):]
return text
return '\n'.join(lines)
def prepare_messages_for_api(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Prepare message history for the OpenAI API.
Args:
messages: List of message objects
Returns:
Prepared messages for API
"""
prepared_messages = []
# Check if there's a system message already
has_system_message = any(msg.get('role') == 'system' for msg in messages)
# If no system message exists, add a default one
if not has_system_message:
prepared_messages.append({
"role": "system",
"content": "You are a helpful AI assistant that can answer questions, provide information, and assist with various tasks."
})
for msg in messages:
# Skip messages with None content
if msg.get('content') is None:
continue
# Create a copy of the message to avoid modifying the original
processed_msg = dict(msg)
# Handle image URLs with timestamps in content
if isinstance(processed_msg.get('content'), list):
# Filter out images that have a timestamp (they're already handled specially)
new_content = []
for item in processed_msg['content']:
if item.get('type') == 'image_url' and 'timestamp' in item:
# Remove timestamp from API calls
new_item = dict(item)
if 'timestamp' in new_item:
del new_item['timestamp']
new_content.append(new_item)
else:
new_content.append(item)
processed_msg['content'] = new_content
prepared_messages.append(processed_msg)
return prepared_messages

View File

@@ -1,282 +1,282 @@
import asyncio
import logging
import discord
from datetime import datetime, timedelta
import pytz
from typing import Dict, Any, List, Optional, Union
class ReminderManager:
"""
Manages reminder functionality for Discord users
"""
def __init__(self, bot, db_handler):
"""
Initialize ReminderManager
Args:
bot: Discord bot instance
db_handler: Database handler instance
"""
self.bot = bot
self.db = db_handler
self.running = False
self.check_task = None
# Get system timezone to ensure consistency
self.timezone = datetime.now().astimezone().tzinfo
logging.info(f"Using server timezone: {self.timezone}")
def start(self):
"""Start periodic reminder check"""
if not self.running:
self.running = True
self.check_task = asyncio.create_task(self._check_reminders_loop())
logging.info("Reminder manager started")
async def stop(self):
"""Stop the reminder check"""
if self.running:
self.running = False
if self.check_task:
self.check_task.cancel()
try:
await self.check_task
except asyncio.CancelledError:
pass
self.check_task = None
logging.info("Reminder manager stopped")
def get_current_time(self) -> datetime:
"""
Get the current time with proper timezone
Returns:
Current datetime with timezone
"""
return datetime.now().replace(tzinfo=self.timezone)
async def add_reminder(self, user_id: int, content: str, remind_at: datetime) -> Dict[str, Any]:
"""
Add a new reminder
Args:
user_id: Discord user ID
content: Reminder content
remind_at: When to send the reminder
Returns:
Information about the added reminder
"""
try:
now = self.get_current_time()
reminder = {
"user_id": user_id,
"content": content,
"remind_at": remind_at,
"created_at": now,
"sent": False
}
result = await self.db.reminders_collection.insert_one(reminder)
reminder["_id"] = result.inserted_id
logging.info(f"Added reminder for user {user_id} at {remind_at} (Server timezone: {self.timezone})")
return reminder
except Exception as e:
logging.error(f"Error adding reminder: {str(e)}")
raise
async def get_user_reminders(self, user_id: int) -> List[Dict[str, Any]]:
"""
Get a user's reminders
Args:
user_id: Discord user ID
Returns:
List of reminders
"""
try:
cursor = self.db.reminders_collection.find({
"user_id": user_id,
"sent": False
}).sort("remind_at", 1)
return await cursor.to_list(length=100)
except Exception as e:
logging.error(f"Error getting reminders for user {user_id}: {str(e)}")
return []
async def delete_reminder(self, reminder_id, user_id: int) -> bool:
"""
Delete a reminder
Args:
reminder_id: Reminder ID
user_id: Discord user ID (to verify ownership)
Returns:
True if deleted successfully, False otherwise
"""
try:
from bson.objectid import ObjectId
# Convert reminder_id to ObjectId if needed
if isinstance(reminder_id, str):
reminder_id = ObjectId(reminder_id)
result = await self.db.reminders_collection.delete_one({
"_id": reminder_id,
"user_id": user_id
})
return result.deleted_count > 0
except Exception as e:
logging.error(f"Error deleting reminder {reminder_id}: {str(e)}")
return False
async def _check_reminders_loop(self):
"""Loop to check for due reminders"""
try:
while self.running:
try:
await self._process_due_reminders()
await self._clean_expired_reminders()
except Exception as e:
logging.error(f"Error in reminder check: {str(e)}")
# Wait 30 seconds before checking again
await asyncio.sleep(30)
except asyncio.CancelledError:
# Handle task cancellation
logging.info("Reminder check loop was cancelled")
raise
async def _process_due_reminders(self):
"""Process due reminders and send notifications"""
now = self.get_current_time()
# Find due reminders
cursor = self.db.reminders_collection.find({
"remind_at": {"$lte": now},
"sent": False
})
due_reminders = await cursor.to_list(length=100)
for reminder in due_reminders:
try:
# Get user information
user_id = reminder["user_id"]
user = await self.bot.fetch_user(user_id)
if user:
# Format reminder message
embed = discord.Embed(
title="📅 Reminder",
description=reminder["content"],
color=discord.Color.blue()
)
embed.add_field(
name="Set on",
value=reminder["created_at"].strftime("%Y-%m-%d %H:%M")
)
embed.set_footer(text="Server time: " + now.strftime("%Y-%m-%d %H:%M"))
# Send reminder message with mention
try:
# Try to send a direct message first
await user.send(f"<@{user_id}> Here's your reminder:", embed=embed)
logging.info(f"Sent reminder DM to user {user_id}")
except Exception as dm_error:
logging.error(f"Could not send DM to user {user_id}: {str(dm_error)}")
# Could implement fallback method here if needed
# Mark reminder as sent and delete it
await self.db.reminders_collection.delete_one({"_id": reminder["_id"]})
logging.info(f"Deleted completed reminder {reminder['_id']} for user {user_id}")
except Exception as e:
logging.error(f"Error processing reminder {reminder['_id']}: {str(e)}")
async def _clean_expired_reminders(self):
"""Clean up old reminders that were marked as sent but not deleted"""
try:
result = await self.db.reminders_collection.delete_many({
"sent": True
})
if result.deleted_count > 0:
logging.info(f"Cleaned up {result.deleted_count} expired reminders")
except Exception as e:
logging.error(f"Error cleaning expired reminders: {str(e)}")
async def parse_time(self, time_str: str) -> Optional[datetime]:
"""
Parse a time string into a datetime object with the server's timezone
Args:
time_str: Time string (e.g., "30m", "2h", "1d", "tomorrow", "15:00")
Returns:
Datetime object or None if parsing fails
"""
now = self.get_current_time()
time_str = time_str.lower().strip()
try:
# Handle special keywords
if time_str == "tomorrow":
return now.replace(hour=9, minute=0, second=0) + timedelta(days=1)
elif time_str == "tonight":
# Use 8 PM (20:00) for "tonight"
target = now.replace(hour=20, minute=0, second=0)
# If it's already past 8 PM, schedule for tomorrow night
if target <= now:
target += timedelta(days=1)
return target
elif time_str == "noon":
# Use 12 PM for "noon"
target = now.replace(hour=12, minute=0, second=0)
# If it's already past noon, schedule for tomorrow
if target <= now:
target += timedelta(days=1)
return target
# Handle relative time formats (30m, 2h, 1d)
if time_str[-1] in ['m', 'h', 'd']:
value = int(time_str[:-1])
unit = time_str[-1]
if unit == 'm': # minutes
return now + timedelta(minutes=value)
elif unit == 'h': # hours
return now + timedelta(hours=value)
elif unit == 'd': # days
return now + timedelta(days=value)
# Handle specific time format
# HH:MM format for today or tomorrow
if ':' in time_str and len(time_str.split(':')) == 2:
hour, minute = map(int, time_str.split(':'))
# Check if valid time
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
logging.warning(f"Invalid time format: {time_str}")
return None
# Create datetime for the specified time today
target = now.replace(hour=hour, minute=minute, second=0)
# If the time has already passed today, schedule for tomorrow
if target <= now:
target += timedelta(days=1)
logging.info(f"Parsed time '{time_str}' to {target} (Server timezone: {self.timezone})")
return target
return None
except Exception as e:
logging.error(f"Error parsing time string '{time_str}': {str(e)}")
import asyncio
import logging
import discord
from datetime import datetime, timedelta
import pytz
from typing import Dict, Any, List, Optional, Union
class ReminderManager:
"""
Manages reminder functionality for Discord users
"""
def __init__(self, bot, db_handler):
"""
Initialize ReminderManager
Args:
bot: Discord bot instance
db_handler: Database handler instance
"""
self.bot = bot
self.db = db_handler
self.running = False
self.check_task = None
# Get system timezone to ensure consistency
self.timezone = datetime.now().astimezone().tzinfo
logging.info(f"Using server timezone: {self.timezone}")
def start(self):
"""Start periodic reminder check"""
if not self.running:
self.running = True
self.check_task = asyncio.create_task(self._check_reminders_loop())
logging.info("Reminder manager started")
async def stop(self):
"""Stop the reminder check"""
if self.running:
self.running = False
if self.check_task:
self.check_task.cancel()
try:
await self.check_task
except asyncio.CancelledError:
pass
self.check_task = None
logging.info("Reminder manager stopped")
def get_current_time(self) -> datetime:
"""
Get the current time with proper timezone
Returns:
Current datetime with timezone
"""
return datetime.now().replace(tzinfo=self.timezone)
async def add_reminder(self, user_id: int, content: str, remind_at: datetime) -> Dict[str, Any]:
"""
Add a new reminder
Args:
user_id: Discord user ID
content: Reminder content
remind_at: When to send the reminder
Returns:
Information about the added reminder
"""
try:
now = self.get_current_time()
reminder = {
"user_id": user_id,
"content": content,
"remind_at": remind_at,
"created_at": now,
"sent": False
}
result = await self.db.reminders_collection.insert_one(reminder)
reminder["_id"] = result.inserted_id
logging.info(f"Added reminder for user {user_id} at {remind_at} (Server timezone: {self.timezone})")
return reminder
except Exception as e:
logging.error(f"Error adding reminder: {str(e)}")
raise
async def get_user_reminders(self, user_id: int) -> List[Dict[str, Any]]:
"""
Get a user's reminders
Args:
user_id: Discord user ID
Returns:
List of reminders
"""
try:
cursor = self.db.reminders_collection.find({
"user_id": user_id,
"sent": False
}).sort("remind_at", 1)
return await cursor.to_list(length=100)
except Exception as e:
logging.error(f"Error getting reminders for user {user_id}: {str(e)}")
return []
async def delete_reminder(self, reminder_id, user_id: int) -> bool:
"""
Delete a reminder
Args:
reminder_id: Reminder ID
user_id: Discord user ID (to verify ownership)
Returns:
True if deleted successfully, False otherwise
"""
try:
from bson.objectid import ObjectId
# Convert reminder_id to ObjectId if needed
if isinstance(reminder_id, str):
reminder_id = ObjectId(reminder_id)
result = await self.db.reminders_collection.delete_one({
"_id": reminder_id,
"user_id": user_id
})
return result.deleted_count > 0
except Exception as e:
logging.error(f"Error deleting reminder {reminder_id}: {str(e)}")
return False
async def _check_reminders_loop(self):
"""Loop to check for due reminders"""
try:
while self.running:
try:
await self._process_due_reminders()
await self._clean_expired_reminders()
except Exception as e:
logging.error(f"Error in reminder check: {str(e)}")
# Wait 30 seconds before checking again
await asyncio.sleep(30)
except asyncio.CancelledError:
# Handle task cancellation
logging.info("Reminder check loop was cancelled")
raise
async def _process_due_reminders(self):
"""Process due reminders and send notifications"""
now = self.get_current_time()
# Find due reminders
cursor = self.db.reminders_collection.find({
"remind_at": {"$lte": now},
"sent": False
})
due_reminders = await cursor.to_list(length=100)
for reminder in due_reminders:
try:
# Get user information
user_id = reminder["user_id"]
user = await self.bot.fetch_user(user_id)
if user:
# Format reminder message
embed = discord.Embed(
title="📅 Reminder",
description=reminder["content"],
color=discord.Color.blue()
)
embed.add_field(
name="Set on",
value=reminder["created_at"].strftime("%Y-%m-%d %H:%M")
)
embed.set_footer(text="Server time: " + now.strftime("%Y-%m-%d %H:%M"))
# Send reminder message with mention
try:
# Try to send a direct message first
await user.send(f"<@{user_id}> Here's your reminder:", embed=embed)
logging.info(f"Sent reminder DM to user {user_id}")
except Exception as dm_error:
logging.error(f"Could not send DM to user {user_id}: {str(dm_error)}")
# Could implement fallback method here if needed
# Mark reminder as sent and delete it
await self.db.reminders_collection.delete_one({"_id": reminder["_id"]})
logging.info(f"Deleted completed reminder {reminder['_id']} for user {user_id}")
except Exception as e:
logging.error(f"Error processing reminder {reminder['_id']}: {str(e)}")
async def _clean_expired_reminders(self):
"""Clean up old reminders that were marked as sent but not deleted"""
try:
result = await self.db.reminders_collection.delete_many({
"sent": True
})
if result.deleted_count > 0:
logging.info(f"Cleaned up {result.deleted_count} expired reminders")
except Exception as e:
logging.error(f"Error cleaning expired reminders: {str(e)}")
async def parse_time(self, time_str: str) -> Optional[datetime]:
"""
Parse a time string into a datetime object with the server's timezone
Args:
time_str: Time string (e.g., "30m", "2h", "1d", "tomorrow", "15:00")
Returns:
Datetime object or None if parsing fails
"""
now = self.get_current_time()
time_str = time_str.lower().strip()
try:
# Handle special keywords
if time_str == "tomorrow":
return now.replace(hour=9, minute=0, second=0) + timedelta(days=1)
elif time_str == "tonight":
# Use 8 PM (20:00) for "tonight"
target = now.replace(hour=20, minute=0, second=0)
# If it's already past 8 PM, schedule for tomorrow night
if target <= now:
target += timedelta(days=1)
return target
elif time_str == "noon":
# Use 12 PM for "noon"
target = now.replace(hour=12, minute=0, second=0)
# If it's already past noon, schedule for tomorrow
if target <= now:
target += timedelta(days=1)
return target
# Handle relative time formats (30m, 2h, 1d)
if time_str[-1] in ['m', 'h', 'd']:
value = int(time_str[:-1])
unit = time_str[-1]
if unit == 'm': # minutes
return now + timedelta(minutes=value)
elif unit == 'h': # hours
return now + timedelta(hours=value)
elif unit == 'd': # days
return now + timedelta(days=value)
# Handle specific time format
# HH:MM format for today or tomorrow
if ':' in time_str and len(time_str.split(':')) == 2:
hour, minute = map(int, time_str.split(':'))
# Check if valid time
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
logging.warning(f"Invalid time format: {time_str}")
return None
# Create datetime for the specified time today
target = now.replace(hour=hour, minute=minute, second=0)
# If the time has already passed today, schedule for tomorrow
if target <= now:
target += timedelta(days=1)
logging.info(f"Parsed time '{time_str}' to {target} (Server timezone: {self.timezone})")
return target
return None
except Exception as e:
logging.error(f"Error parsing time string '{time_str}': {str(e)}")
return None

View File

@@ -1,225 +1,225 @@
import requests
import json
import re
from bs4 import BeautifulSoup
from typing import Dict, List, Any, Optional, Tuple
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
import tiktoken # Add tiktoken for token counting
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
"""
Perform a Google search using the Google Custom Search API and scrape content
until reaching token limit.
Args:
query (str): The search query
num_results (int): Number of results to return
max_tokens (int): Maximum number of tokens for combined scraped content
Returns:
dict: Search results with metadata and combined scraped content
"""
try:
search_url = f"https://www.googleapis.com/customsearch/v1"
params = {
'key': GOOGLE_API_KEY,
'cx': GOOGLE_CX,
'q': query,
'num': min(num_results, 10) # Google API maximum is 10
}
response = requests.get(search_url, params=params)
response.raise_for_status()
search_results = response.json()
# Format the results for ease of use
formatted_results = {
'query': query,
'results': [],
'combined_content': ""
}
if 'items' in search_results:
# Extract all links first
links = [item.get('link', '') for item in search_results['items']]
# Scrape content from multiple links up to max_tokens
combined_content, used_links = scrape_multiple_links(links, max_tokens)
formatted_results['combined_content'] = combined_content
# Process each search result
for item in search_results['items']:
result = {
'title': item.get('title', ''),
'link': item.get('link', ''),
'snippet': item.get('snippet', ''),
'date': item.get('pagemap', {}).get('metatags', [{}])[0].get('article:published_time', ''),
'used_for_content': item.get('link', '') in used_links
}
formatted_results['results'].append(result)
return formatted_results
except requests.exceptions.RequestException as e:
return {
'query': query,
'error': f"Error during Google search: {str(e)}",
'results': [],
'combined_content': ""
}
def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str, List[str]]:
"""
Scrape content from multiple URLs, stopping once token limit is reached.
Args:
urls (List[str]): List of URLs to scrape
max_tokens (int): Maximum token count for combined content
Returns:
Tuple[str, List[str]]: Combined content and list of used URLs
"""
combined_content = ""
total_tokens = 0
used_urls = []
try:
encoding = tiktoken.get_encoding("cl100k_base")
except:
encoding = None
for url in urls:
# Skip empty URLs
if not url:
continue
# Get content from this URL
content, token_count = scrape_web_content_with_count(url, return_token_count=True)
# Skip failed scrapes
if content.startswith("Failed"):
continue
# Check if adding this content would exceed token limit
if total_tokens + token_count > max_tokens:
# If this is the first URL and it's too large, we need to truncate it
if not combined_content:
if encoding:
tokens = encoding.encode(content)
truncated_tokens = tokens[:max_tokens]
truncated_content = encoding.decode(truncated_tokens)
combined_content = f"{truncated_content}...\n[Content truncated due to token limit]"
else:
# Fallback to character-based truncation
truncated_content = content[:max_tokens * 4]
combined_content = f"{truncated_content}...\n[Content truncated due to length]"
used_urls.append(url)
break
# Add separator if not the first URL
if combined_content:
combined_content += f"\n\n--- Content from: {url} ---\n\n"
else:
combined_content += f"--- Content from: {url} ---\n\n"
# Add content and update token count
combined_content += content
total_tokens += token_count
used_urls.append(url)
# If we've reached the token limit, stop
if total_tokens >= max_tokens:
break
# If we didn't find any valid content
if not combined_content:
combined_content = "No valid content could be scraped from the provided URLs."
return combined_content, used_urls
def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token_count: bool = False) -> Any:
"""
Scrape content from a webpage and return with token count if needed.
Args:
url (str): URL of the webpage to scrape
max_tokens (int): Maximum number of tokens to return
return_token_count (bool): Whether to return token count with the content
Returns:
str or tuple: The scraped text content or (content, token_count)
"""
if not url:
return ("Failed to scrape: No URL provided.", 0) if return_token_count else "Failed to scrape: No URL provided."
# Ignore URLs that are unlikely to be scrapable or might cause problems
if any(x in url.lower() for x in ['.pdf', '.zip', '.jpg', '.png', '.mp3', '.mp4', 'youtube.com', 'youtu.be']):
message = f"Failed to scrape: The URL {url} cannot be scraped (unsupported format)."
return (message, 0) if return_token_count else message
try:
# Add user agent to mimic a browser
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
# Parse the content with BeautifulSoup
soup = BeautifulSoup(response.text, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style", "header", "footer", "nav"]):
script.extract()
# Get the text content
text = soup.get_text(separator='\n')
# Clean up text: remove extra whitespace and empty lines
lines = (line.strip() for line in text.splitlines())
text = '\n'.join(line for line in lines if line)
# Count tokens
token_count = 0
try:
# Use cl100k_base encoder which is used by most recent models
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
token_count = len(tokens)
# Truncate if token count exceeds max_tokens and we're not returning token count
if len(tokens) > max_tokens and not return_token_count:
truncated_tokens = tokens[:max_tokens]
text = encoding.decode(truncated_tokens)
text += "...\n[Content truncated due to token limit]"
except ImportError:
# Fallback to character-based estimation
token_count = len(text) // 4 # Rough estimate: 1 token ≈ 4 characters
if len(text) > max_tokens * 4 and not return_token_count:
text = text[:max_tokens * 4] + "...\n[Content truncated due to length]"
if return_token_count:
return text, token_count
return text
except requests.exceptions.RequestException as e:
message = f"Failed to scrape {url}: {str(e)}"
return (message, 0) if return_token_count else message
except Exception as e:
message = f"Failed to process content from {url}: {str(e)}"
return (message, 0) if return_token_count else message
# Keep the original scrape_web_content function for backward compatibility
def scrape_web_content(url: str, max_tokens: int = 4000) -> str:
"""
Scrape content from a webpage and limit by token count.
Args:
url (str): URL of the webpage to scrape
max_tokens (int): Maximum number of tokens to return
Returns:
str: The scraped text content or error message
"""
return scrape_web_content_with_count(url, max_tokens)
import requests
import json
import re
from bs4 import BeautifulSoup
from typing import Dict, List, Any, Optional, Tuple
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
import tiktoken # Add tiktoken for token counting
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
"""
Perform a Google search using the Google Custom Search API and scrape content
until reaching token limit.
Args:
query (str): The search query
num_results (int): Number of results to return
max_tokens (int): Maximum number of tokens for combined scraped content
Returns:
dict: Search results with metadata and combined scraped content
"""
try:
search_url = f"https://www.googleapis.com/customsearch/v1"
params = {
'key': GOOGLE_API_KEY,
'cx': GOOGLE_CX,
'q': query,
'num': min(num_results, 10) # Google API maximum is 10
}
response = requests.get(search_url, params=params)
response.raise_for_status()
search_results = response.json()
# Format the results for ease of use
formatted_results = {
'query': query,
'results': [],
'combined_content': ""
}
if 'items' in search_results:
# Extract all links first
links = [item.get('link', '') for item in search_results['items']]
# Scrape content from multiple links up to max_tokens
combined_content, used_links = scrape_multiple_links(links, max_tokens)
formatted_results['combined_content'] = combined_content
# Process each search result
for item in search_results['items']:
result = {
'title': item.get('title', ''),
'link': item.get('link', ''),
'snippet': item.get('snippet', ''),
'date': item.get('pagemap', {}).get('metatags', [{}])[0].get('article:published_time', ''),
'used_for_content': item.get('link', '') in used_links
}
formatted_results['results'].append(result)
return formatted_results
except requests.exceptions.RequestException as e:
return {
'query': query,
'error': f"Error during Google search: {str(e)}",
'results': [],
'combined_content': ""
}
def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str, List[str]]:
"""
Scrape content from multiple URLs, stopping once token limit is reached.
Args:
urls (List[str]): List of URLs to scrape
max_tokens (int): Maximum token count for combined content
Returns:
Tuple[str, List[str]]: Combined content and list of used URLs
"""
combined_content = ""
total_tokens = 0
used_urls = []
try:
encoding = tiktoken.get_encoding("cl100k_base")
except:
encoding = None
for url in urls:
# Skip empty URLs
if not url:
continue
# Get content from this URL
content, token_count = scrape_web_content_with_count(url, return_token_count=True)
# Skip failed scrapes
if content.startswith("Failed"):
continue
# Check if adding this content would exceed token limit
if total_tokens + token_count > max_tokens:
# If this is the first URL and it's too large, we need to truncate it
if not combined_content:
if encoding:
tokens = encoding.encode(content)
truncated_tokens = tokens[:max_tokens]
truncated_content = encoding.decode(truncated_tokens)
combined_content = f"{truncated_content}...\n[Content truncated due to token limit]"
else:
# Fallback to character-based truncation
truncated_content = content[:max_tokens * 4]
combined_content = f"{truncated_content}...\n[Content truncated due to length]"
used_urls.append(url)
break
# Add separator if not the first URL
if combined_content:
combined_content += f"\n\n--- Content from: {url} ---\n\n"
else:
combined_content += f"--- Content from: {url} ---\n\n"
# Add content and update token count
combined_content += content
total_tokens += token_count
used_urls.append(url)
# If we've reached the token limit, stop
if total_tokens >= max_tokens:
break
# If we didn't find any valid content
if not combined_content:
combined_content = "No valid content could be scraped from the provided URLs."
return combined_content, used_urls
def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token_count: bool = False) -> Any:
"""
Scrape content from a webpage and return with token count if needed.
Args:
url (str): URL of the webpage to scrape
max_tokens (int): Maximum number of tokens to return
return_token_count (bool): Whether to return token count with the content
Returns:
str or tuple: The scraped text content or (content, token_count)
"""
if not url:
return ("Failed to scrape: No URL provided.", 0) if return_token_count else "Failed to scrape: No URL provided."
# Ignore URLs that are unlikely to be scrapable or might cause problems
if any(x in url.lower() for x in ['.pdf', '.zip', '.jpg', '.png', '.mp3', '.mp4', 'youtube.com', 'youtu.be']):
message = f"Failed to scrape: The URL {url} cannot be scraped (unsupported format)."
return (message, 0) if return_token_count else message
try:
# Add user agent to mimic a browser
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
# Parse the content with BeautifulSoup
soup = BeautifulSoup(response.text, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style", "header", "footer", "nav"]):
script.extract()
# Get the text content
text = soup.get_text(separator='\n')
# Clean up text: remove extra whitespace and empty lines
lines = (line.strip() for line in text.splitlines())
text = '\n'.join(line for line in lines if line)
# Count tokens
token_count = 0
try:
# Use cl100k_base encoder which is used by most recent models
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
token_count = len(tokens)
# Truncate if token count exceeds max_tokens and we're not returning token count
if len(tokens) > max_tokens and not return_token_count:
truncated_tokens = tokens[:max_tokens]
text = encoding.decode(truncated_tokens)
text += "...\n[Content truncated due to token limit]"
except ImportError:
# Fallback to character-based estimation
token_count = len(text) // 4 # Rough estimate: 1 token ≈ 4 characters
if len(text) > max_tokens * 4 and not return_token_count:
text = text[:max_tokens * 4] + "...\n[Content truncated due to length]"
if return_token_count:
return text, token_count
return text
except requests.exceptions.RequestException as e:
message = f"Failed to scrape {url}: {str(e)}"
return (message, 0) if return_token_count else message
except Exception as e:
message = f"Failed to process content from {url}: {str(e)}"
return (message, 0) if return_token_count else message
# Keep the original scrape_web_content function for backward compatibility
def scrape_web_content(url: str, max_tokens: int = 4000) -> str:
"""
Scrape content from a webpage and limit by token count.
Args:
url (str): URL of the webpage to scrape
max_tokens (int): Maximum number of tokens to return
Returns:
str: The scraped text content or error message
"""
return scrape_web_content_with_count(url, max_tokens)

Binary file not shown.

View File

@@ -55,11 +55,15 @@ class TestDatabaseHandler(unittest.IsolatedAsyncioTestCase):
# Extract database name from URI for later use
self.db_name = self._extract_db_name_from_uri(self.mongodb_uri)
async def tearDown(self):
async def asyncSetUp(self):
# No additional async setup needed, but required by IsolatedAsyncioTestCase
pass
async def asyncTearDown(self):
if not self.using_real_db:
self.mock_client_patcher.stop()
if self.using_real_db:
else:
# Clean up test data if using real database
await self.cleanup_test_data()
@@ -215,6 +219,7 @@ class TestOpenAIUtils(unittest.TestCase):
# Test empty messages
empty_result = prepare_messages_for_api([])
self.assertEqual(len(empty_result), 1) # Should have system message
self.assertEqual(empty_result[0]["role"], "system") # Verify it's a system message
# Test regular messages
messages = [
@@ -223,7 +228,7 @@ class TestOpenAIUtils(unittest.TestCase):
{"role": "user", "content": "How are you?"}
]
result = prepare_messages_for_api(messages)
self.assertEqual(len(result), 3)
self.assertEqual(len(result), 4) # Should have system message + 3 original messages
# Test with null content
messages_with_null = [
@@ -231,8 +236,11 @@ class TestOpenAIUtils(unittest.TestCase):
{"role": "assistant", "content": "Response"}
]
result_fixed = prepare_messages_for_api(messages_with_null)
self.assertEqual(len(result_fixed), 1) # Should exclude the null content
self.assertEqual(len(result_fixed), 2) # Should have system message + 1 valid message
# Verify the content is correct (system message + only the assistant message)
self.assertEqual(result_fixed[0]["role"], "system")
self.assertEqual(result_fixed[1]["role"], "assistant")
self.assertEqual(result_fixed[1]["content"], "Response")
class TestCodeUtils(unittest.TestCase):
"""Test code utility functions"""
@@ -339,7 +347,7 @@ print("Hello world")
# self.assertIn("Test Heading", content)
# self.assertIn("Test paragraph", content)
class TestPDFUtils(unittest.TestCase):
class TestPDFUtils(unittest.IsolatedAsyncioTestCase):
"""Test PDF utilities"""
async def test_send_response(self):