refactor: Clean up code formatting and enhance test cases for better readability and coverage
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.git/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
.idea/
|
||||
.vscode/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.git/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
.idea/
|
||||
.vscode/
|
||||
.github/
|
||||
86
.github/workflows/main.yml
vendored
86
.github/workflows/main.yml
vendored
@@ -1,9 +1,13 @@
|
||||
name: Build and Run ChatGPT-Discord-Bot Docker
|
||||
|
||||
name: Build and Deploy ChatGPT-Discord-Bot
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'LICENSE'
|
||||
- '.gitignore'
|
||||
workflow_dispatch: # Allow manual triggering
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
@@ -14,31 +18,33 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for better versioning
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5.3.0
|
||||
with:
|
||||
python-version: '3.12.3'
|
||||
|
||||
- name: Cache Python dependencies
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
cache: 'pip' # Use built-in pip caching
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install pytest pytest-cov flake8
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Run unit tests
|
||||
- name: Lint code
|
||||
run: |
|
||||
python -m pytest tests/
|
||||
# Stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# Exit-zero treats all errors as warnings
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
|
||||
- name: pyupio/safety-action
|
||||
- name: Run unit tests with coverage
|
||||
run: |
|
||||
python -m pytest tests/ --cov=src
|
||||
|
||||
- name: Check dependencies for security issues
|
||||
uses: pyupio/safety-action@v1.0.1
|
||||
with:
|
||||
api-key: ${{ secrets.SAFETY_API_KEY }}
|
||||
@@ -47,12 +53,28 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
environment: Private Server Deploy
|
||||
needs: tests
|
||||
outputs:
|
||||
image: ghcr.io/coder-vippro/chatgpt-discord-bot
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate version number
|
||||
id: version
|
||||
run: |
|
||||
# Create a version with timestamp and short commit hash
|
||||
VERSION=$(date +'%Y%m%d%H%M')-${GITHUB_SHA::7}
|
||||
echo "version=${VERSION}" >> $GITHUB_OUTPUT
|
||||
echo "Version: ${VERSION}"
|
||||
|
||||
- name: Set up QEMU for multi-architecture builds
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
@@ -67,9 +89,39 @@ jobs:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: ghcr.io/coder-vippro/chatgpt-discord-bot:latest
|
||||
cache-from: type=gha,scope=pip-dependencies
|
||||
cache-to: type=gha,mode=min,scope=pip-dependencies
|
||||
tags: |
|
||||
ghcr.io/coder-vippro/chatgpt-discord-bot:latest
|
||||
ghcr.io/coder-vippro/chatgpt-discord-bot:${{ steps.version.outputs.version }}
|
||||
labels: |
|
||||
org.opencontainers.image.title=ChatGPT-Discord-Bot
|
||||
org.opencontainers.image.description=Discord bot powered by OpenAI ChatGPT
|
||||
org.opencontainers.image.source=https://github.com/coder-vippro/ChatGPT-Discord-Bot
|
||||
org.opencontainers.image.created=${{ github.event.repository.updated_at }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
org.opencontainers.image.version=${{ steps.version.outputs.version }}
|
||||
cache-from: type=gha,scope=build-cache
|
||||
cache-to: type=gha,mode=max,scope=build-cache
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
BUILD_DATE=${{ github.event.repository.updated_at }}
|
||||
VCS_REF=${{ github.sha }}
|
||||
VERSION=${{ steps.version.outputs.version }}
|
||||
|
||||
deploy-notification:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-push
|
||||
if: ${{ success() }}
|
||||
steps:
|
||||
- name: Send deployment notification
|
||||
uses: sarisia/actions-status-discord@v1
|
||||
with:
|
||||
webhook: ${{ secrets.DISCORD_WEBHOOK }}
|
||||
title: "✅ New deployment successful!"
|
||||
description: |
|
||||
Image: ${{ needs.build-and-push.outputs.image }}:${{ needs.build-and-push.outputs.version }}
|
||||
Commit: ${{ github.sha }}
|
||||
Repository: ${{ github.repository }}
|
||||
color: 0x00ff00
|
||||
username: GitHub Actions
|
||||
|
||||
|
||||
|
||||
18
.gitignore
vendored
18
.gitignore
vendored
@@ -1,10 +1,10 @@
|
||||
test.py
|
||||
.env
|
||||
chat_history.db
|
||||
bot_copy.py
|
||||
__pycache__/bot.cpython-312.pyc
|
||||
tests/__pycache__/test_bot.cpython-312.pyc
|
||||
.vscode/settings.json
|
||||
chatgpt.zip
|
||||
response.txt
|
||||
test.py
|
||||
.env
|
||||
chat_history.db
|
||||
bot_copy.py
|
||||
__pycache__/bot.cpython-312.pyc
|
||||
tests/__pycache__/test_bot.cpython-312.pyc
|
||||
.vscode/settings.json
|
||||
chatgpt.zip
|
||||
response.txt
|
||||
logs
|
||||
102
Dockerfile
102
Dockerfile
@@ -1,51 +1,51 @@
|
||||
# Build stage with all build dependencies
|
||||
FROM python:3.12.3-alpine AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache \
|
||||
curl \
|
||||
g++ \
|
||||
gcc \
|
||||
musl-dev \
|
||||
make \
|
||||
rust \
|
||||
cargo \
|
||||
build-base
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python packages with BuildKit cache
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Runtime stage with minimal dependencies
|
||||
FROM python:3.12.3-alpine
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache libstdc++
|
||||
RUN apk add --no-cache g++
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /usr/src/discordbot
|
||||
|
||||
# Copy installed Python packages from builder - using correct site-packages path
|
||||
COPY --from=builder /usr/local/lib/python3.12 /usr/local/lib/python3.12
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
# Copy the application source code
|
||||
COPY . .
|
||||
|
||||
# Command to run the application
|
||||
CMD ["python3", "bot.py"]
|
||||
# Build stage with all build dependencies
|
||||
FROM python:3.12.3-alpine AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache \
|
||||
curl \
|
||||
g++ \
|
||||
gcc \
|
||||
musl-dev \
|
||||
make \
|
||||
rust \
|
||||
cargo \
|
||||
build-base
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python packages with BuildKit cache
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Runtime stage with minimal dependencies
|
||||
FROM python:3.12.3-alpine
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache libstdc++
|
||||
RUN apk add --no-cache g++
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /usr/src/discordbot
|
||||
|
||||
# Copy installed Python packages from builder - using correct site-packages path
|
||||
COPY --from=builder /usr/local/lib/python3.12 /usr/local/lib/python3.12
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
# Copy the application source code
|
||||
COPY . .
|
||||
|
||||
# Command to run the application
|
||||
CMD ["python3", "bot.py"]
|
||||
|
||||
596
bot.py
596
bot.py
@@ -1,298 +1,298 @@
|
||||
import os
|
||||
import sys
|
||||
import discord
|
||||
import logging
|
||||
import asyncio
|
||||
import signal
|
||||
import traceback
|
||||
import time
|
||||
import logging.config
|
||||
from discord.ext import commands, tasks
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dotenv import load_dotenv
|
||||
from discord import app_commands
|
||||
|
||||
# Import configuration
|
||||
from src.config.config import (
|
||||
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
|
||||
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
|
||||
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
|
||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
|
||||
)
|
||||
|
||||
# Import webhook logger
|
||||
from src.utils.webhook_logger import webhook_log_manager, webhook_logger
|
||||
|
||||
# Import database handler
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
|
||||
# Import the message handler
|
||||
from src.module.message_handler import MessageHandler
|
||||
|
||||
# Import various utility modules
|
||||
from src.utils.image_utils import ImageGenerator
|
||||
|
||||
# Global shutdown flag
|
||||
shutdown_flag = asyncio.Event()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging with more detail, rotation, and webhook integration
|
||||
def setup_logging():
|
||||
# Apply the dictionary config
|
||||
try:
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
logging.info("Configured logging from dictionary configuration")
|
||||
except Exception as e:
|
||||
# Fall back to basic configuration
|
||||
log_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
|
||||
)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(log_formatter)
|
||||
|
||||
# File handler with rotation (keep 5 files of 5MB each)
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
file_handler = RotatingFileHandler(
|
||||
'logs/discord_bot.log',
|
||||
maxBytes=5*1024*1024, # 5MB
|
||||
backupCount=5
|
||||
)
|
||||
file_handler.setFormatter(log_formatter)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
# Fall back to basic logging if file logging fails
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
stream=sys.stdout
|
||||
)
|
||||
logging.warning(f"Could not set up file logging: {str(e)}")
|
||||
|
||||
# Set up webhook logging if enabled
|
||||
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
|
||||
try:
|
||||
# Convert string log level to int using our mapping
|
||||
log_level = LOG_LEVEL_MAP.get(WEBHOOK_LOG_LEVEL.upper(), logging.INFO)
|
||||
|
||||
# Set up webhook logging
|
||||
webhook_log_manager.setup_webhook_logging(
|
||||
webhook_url=LOGGING_WEBHOOK_URL,
|
||||
app_name=WEBHOOK_APP_NAME,
|
||||
level=log_level,
|
||||
loggers=None, # Use root logger
|
||||
batch_size=WEBHOOK_BATCH_SIZE,
|
||||
flush_interval=WEBHOOK_FLUSH_INTERVAL
|
||||
)
|
||||
logging.info(f"Webhook logging enabled at level {WEBHOOK_LOG_LEVEL}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to set up webhook logging: {str(e)}")
|
||||
|
||||
# Create a function to change bot status periodically
|
||||
async def change_status_loop(bot):
|
||||
"""Change bot status every 5 minutes"""
|
||||
while not shutdown_flag.is_set():
|
||||
for status in STATUSES:
|
||||
await bot.change_presence(activity=discord.Game(name=status))
|
||||
try:
|
||||
# Wait but be interruptible
|
||||
await asyncio.wait_for(shutdown_flag.wait(), timeout=300)
|
||||
if shutdown_flag.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Normal timeout, continue to next status
|
||||
continue
|
||||
|
||||
async def main():
|
||||
# Set up logging
|
||||
setup_logging()
|
||||
|
||||
# Check if required environment variables are set
|
||||
missing_vars = []
|
||||
if not DISCORD_TOKEN:
|
||||
missing_vars.append("DISCORD_TOKEN")
|
||||
|
||||
if not MONGODB_URI:
|
||||
missing_vars.append("MONGODB_URI")
|
||||
|
||||
if missing_vars:
|
||||
logging.error(f"The following required environment variables are not set: {', '.join(missing_vars)}")
|
||||
return
|
||||
|
||||
if not RUNWARE_API_KEY:
|
||||
logging.warning("RUNWARE_API_KEY environment variable not set - image generation will not work")
|
||||
|
||||
# Initialize the OpenAI client
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
openai_client = AsyncOpenAI()
|
||||
logging.info("OpenAI client initialized successfully")
|
||||
except ImportError:
|
||||
logging.error("Failed to import OpenAI. Make sure it's installed: pip install openai")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing OpenAI client: {e}")
|
||||
return
|
||||
|
||||
# Global references to objects that need cleanup
|
||||
message_handler = None
|
||||
db_handler = None
|
||||
|
||||
try:
|
||||
# Initialize image generator if API key is available
|
||||
image_generator = None
|
||||
if RUNWARE_API_KEY:
|
||||
try:
|
||||
image_generator = ImageGenerator(RUNWARE_API_KEY)
|
||||
logging.info("Image generator initialized successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing image generator: {e}")
|
||||
|
||||
# Set up Discord intents
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
# Initialize the bot with command prefixes and more robust timeout settings
|
||||
bot = commands.Bot(
|
||||
command_prefix="//quocanhvu",
|
||||
intents=intents,
|
||||
heartbeat_timeout=120,
|
||||
max_messages=10000 # Cache more messages to improve experience
|
||||
)
|
||||
|
||||
# Initialize database handler
|
||||
db_handler = DatabaseHandler(MONGODB_URI)
|
||||
|
||||
# Create database indexes for performance
|
||||
await db_handler.create_indexes()
|
||||
logging.info("Database indexes created")
|
||||
|
||||
# Khởi tạo collection reminders
|
||||
await db_handler.ensure_reminders_collection()
|
||||
|
||||
# Event handler when the bot is ready
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
"""Bot startup event to sync slash commands and start status loop."""
|
||||
await bot.tree.sync() # Sync slash commands
|
||||
bot_info = f"Logged in as {bot.user} (ID: {bot.user.id})"
|
||||
logging.info("=" * len(bot_info))
|
||||
logging.info(bot_info)
|
||||
logging.info(f"Connected to {len(bot.guilds)} guilds")
|
||||
logging.info("=" * len(bot_info))
|
||||
|
||||
# Start the status changing task
|
||||
asyncio.create_task(change_status_loop(bot))
|
||||
|
||||
# Handle general errors to prevent crashes
|
||||
@bot.event
|
||||
async def on_error(event, *args, **kwargs):
|
||||
error_msg = traceback.format_exc()
|
||||
logging.error(f"Discord event error in {event}:\n{error_msg}")
|
||||
|
||||
@bot.event
|
||||
async def on_command_error(ctx, error):
|
||||
if isinstance(error, commands.CommandNotFound):
|
||||
return
|
||||
|
||||
error_msg = str(error)
|
||||
trace = "".join(traceback.format_exception(type(error), error, error.__traceback__))
|
||||
logging.error(f"Command error: {error_msg}\n{trace}")
|
||||
await ctx.send(f"Error: {error_msg}")
|
||||
|
||||
# Initialize message handler
|
||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
|
||||
|
||||
# Set up slash commands
|
||||
from src.commands.commands import setup_commands
|
||||
setup_commands(bot, db_handler, openai_client, image_generator)
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Signal handlers for graceful shutdown
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(
|
||||
sig,
|
||||
lambda sig=sig: asyncio.create_task(shutdown(sig, loop, bot, db_handler, message_handler))
|
||||
)
|
||||
except (NotImplementedError, RuntimeError):
|
||||
# Windows doesn't support SIGTERM or add_signal_handler
|
||||
# Use fallback for Windows
|
||||
pass
|
||||
|
||||
logging.info("Starting bot...")
|
||||
await bot.start(DISCORD_TOKEN)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = traceback.format_exc()
|
||||
logging.critical(f"Fatal error in main function: {str(e)}\n{error_msg}")
|
||||
|
||||
# Clean up resources if initialization failed halfway
|
||||
await cleanup_resources(bot=None, db_handler=db_handler, message_handler=message_handler)
|
||||
|
||||
async def shutdown(sig, loop, bot, db_handler, message_handler):
|
||||
"""Handle graceful shutdown of the bot"""
|
||||
logging.info(f"Received signal {sig.name}. Starting graceful shutdown...")
|
||||
|
||||
# Set shutdown flag to stop ongoing tasks
|
||||
shutdown_flag.set()
|
||||
|
||||
# Give running tasks a moment to detect shutdown flag
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Start cleanup
|
||||
await cleanup_resources(bot, db_handler, message_handler)
|
||||
|
||||
# Stop the event loop
|
||||
loop.stop()
|
||||
|
||||
async def cleanup_resources(bot, db_handler, message_handler):
|
||||
"""Clean up all resources during shutdown"""
|
||||
try:
|
||||
# Close the bot connection
|
||||
if bot is not None:
|
||||
logging.info("Closing bot connection...")
|
||||
await bot.close()
|
||||
|
||||
# Close message handler resources
|
||||
if message_handler is not None:
|
||||
logging.info("Closing message handler resources...")
|
||||
await message_handler.close()
|
||||
|
||||
# Close database connection
|
||||
if db_handler is not None:
|
||||
logging.info("Closing database connection...")
|
||||
await db_handler.close()
|
||||
|
||||
# Clean up webhook logging
|
||||
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
|
||||
logging.info("Cleaning up webhook logging...")
|
||||
webhook_log_manager.cleanup()
|
||||
|
||||
logging.info("Cleanup completed successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Error during cleanup: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Use asyncio.run to properly run the async main function
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Bot stopped via keyboard interrupt")
|
||||
except Exception as e:
|
||||
logging.critical(f"Unhandled exception in main thread: {str(e)}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
logging.info("Bot shut down completely")
|
||||
import os
|
||||
import sys
|
||||
import discord
|
||||
import logging
|
||||
import asyncio
|
||||
import signal
|
||||
import traceback
|
||||
import time
|
||||
import logging.config
|
||||
from discord.ext import commands, tasks
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dotenv import load_dotenv
|
||||
from discord import app_commands
|
||||
|
||||
# Import configuration
|
||||
from src.config.config import (
|
||||
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
|
||||
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
|
||||
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
|
||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
|
||||
)
|
||||
|
||||
# Import webhook logger
|
||||
from src.utils.webhook_logger import webhook_log_manager, webhook_logger
|
||||
|
||||
# Import database handler
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
|
||||
# Import the message handler
|
||||
from src.module.message_handler import MessageHandler
|
||||
|
||||
# Import various utility modules
|
||||
from src.utils.image_utils import ImageGenerator
|
||||
|
||||
# Global shutdown flag
|
||||
shutdown_flag = asyncio.Event()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging with more detail, rotation, and webhook integration
|
||||
def setup_logging():
|
||||
# Apply the dictionary config
|
||||
try:
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
logging.info("Configured logging from dictionary configuration")
|
||||
except Exception as e:
|
||||
# Fall back to basic configuration
|
||||
log_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
|
||||
)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(log_formatter)
|
||||
|
||||
# File handler with rotation (keep 5 files of 5MB each)
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
file_handler = RotatingFileHandler(
|
||||
'logs/discord_bot.log',
|
||||
maxBytes=5*1024*1024, # 5MB
|
||||
backupCount=5
|
||||
)
|
||||
file_handler.setFormatter(log_formatter)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
# Fall back to basic logging if file logging fails
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
stream=sys.stdout
|
||||
)
|
||||
logging.warning(f"Could not set up file logging: {str(e)}")
|
||||
|
||||
# Set up webhook logging if enabled
|
||||
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
|
||||
try:
|
||||
# Convert string log level to int using our mapping
|
||||
log_level = LOG_LEVEL_MAP.get(WEBHOOK_LOG_LEVEL.upper(), logging.INFO)
|
||||
|
||||
# Set up webhook logging
|
||||
webhook_log_manager.setup_webhook_logging(
|
||||
webhook_url=LOGGING_WEBHOOK_URL,
|
||||
app_name=WEBHOOK_APP_NAME,
|
||||
level=log_level,
|
||||
loggers=None, # Use root logger
|
||||
batch_size=WEBHOOK_BATCH_SIZE,
|
||||
flush_interval=WEBHOOK_FLUSH_INTERVAL
|
||||
)
|
||||
logging.info(f"Webhook logging enabled at level {WEBHOOK_LOG_LEVEL}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to set up webhook logging: {str(e)}")
|
||||
|
||||
# Create a function to change bot status periodically
|
||||
async def change_status_loop(bot):
|
||||
"""Change bot status every 5 minutes"""
|
||||
while not shutdown_flag.is_set():
|
||||
for status in STATUSES:
|
||||
await bot.change_presence(activity=discord.Game(name=status))
|
||||
try:
|
||||
# Wait but be interruptible
|
||||
await asyncio.wait_for(shutdown_flag.wait(), timeout=300)
|
||||
if shutdown_flag.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Normal timeout, continue to next status
|
||||
continue
|
||||
|
||||
async def main():
|
||||
# Set up logging
|
||||
setup_logging()
|
||||
|
||||
# Check if required environment variables are set
|
||||
missing_vars = []
|
||||
if not DISCORD_TOKEN:
|
||||
missing_vars.append("DISCORD_TOKEN")
|
||||
|
||||
if not MONGODB_URI:
|
||||
missing_vars.append("MONGODB_URI")
|
||||
|
||||
if missing_vars:
|
||||
logging.error(f"The following required environment variables are not set: {', '.join(missing_vars)}")
|
||||
return
|
||||
|
||||
if not RUNWARE_API_KEY:
|
||||
logging.warning("RUNWARE_API_KEY environment variable not set - image generation will not work")
|
||||
|
||||
# Initialize the OpenAI client
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
openai_client = AsyncOpenAI()
|
||||
logging.info("OpenAI client initialized successfully")
|
||||
except ImportError:
|
||||
logging.error("Failed to import OpenAI. Make sure it's installed: pip install openai")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing OpenAI client: {e}")
|
||||
return
|
||||
|
||||
# Global references to objects that need cleanup
|
||||
message_handler = None
|
||||
db_handler = None
|
||||
|
||||
try:
|
||||
# Initialize image generator if API key is available
|
||||
image_generator = None
|
||||
if RUNWARE_API_KEY:
|
||||
try:
|
||||
image_generator = ImageGenerator(RUNWARE_API_KEY)
|
||||
logging.info("Image generator initialized successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing image generator: {e}")
|
||||
|
||||
# Set up Discord intents
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
# Initialize the bot with command prefixes and more robust timeout settings
|
||||
bot = commands.Bot(
|
||||
command_prefix="//quocanhvu",
|
||||
intents=intents,
|
||||
heartbeat_timeout=120,
|
||||
max_messages=10000 # Cache more messages to improve experience
|
||||
)
|
||||
|
||||
# Initialize database handler
|
||||
db_handler = DatabaseHandler(MONGODB_URI)
|
||||
|
||||
# Create database indexes for performance
|
||||
await db_handler.create_indexes()
|
||||
logging.info("Database indexes created")
|
||||
|
||||
# Khởi tạo collection reminders
|
||||
await db_handler.ensure_reminders_collection()
|
||||
|
||||
# Event handler when the bot is ready
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
"""Bot startup event to sync slash commands and start status loop."""
|
||||
await bot.tree.sync() # Sync slash commands
|
||||
bot_info = f"Logged in as {bot.user} (ID: {bot.user.id})"
|
||||
logging.info("=" * len(bot_info))
|
||||
logging.info(bot_info)
|
||||
logging.info(f"Connected to {len(bot.guilds)} guilds")
|
||||
logging.info("=" * len(bot_info))
|
||||
|
||||
# Start the status changing task
|
||||
asyncio.create_task(change_status_loop(bot))
|
||||
|
||||
# Handle general errors to prevent crashes
|
||||
@bot.event
|
||||
async def on_error(event, *args, **kwargs):
|
||||
error_msg = traceback.format_exc()
|
||||
logging.error(f"Discord event error in {event}:\n{error_msg}")
|
||||
|
||||
@bot.event
|
||||
async def on_command_error(ctx, error):
|
||||
if isinstance(error, commands.CommandNotFound):
|
||||
return
|
||||
|
||||
error_msg = str(error)
|
||||
trace = "".join(traceback.format_exception(type(error), error, error.__traceback__))
|
||||
logging.error(f"Command error: {error_msg}\n{trace}")
|
||||
await ctx.send(f"Error: {error_msg}")
|
||||
|
||||
# Initialize message handler
|
||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
|
||||
|
||||
# Set up slash commands
|
||||
from src.commands.commands import setup_commands
|
||||
setup_commands(bot, db_handler, openai_client, image_generator)
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Signal handlers for graceful shutdown
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(
|
||||
sig,
|
||||
lambda sig=sig: asyncio.create_task(shutdown(sig, loop, bot, db_handler, message_handler))
|
||||
)
|
||||
except (NotImplementedError, RuntimeError):
|
||||
# Windows doesn't support SIGTERM or add_signal_handler
|
||||
# Use fallback for Windows
|
||||
pass
|
||||
|
||||
logging.info("Starting bot...")
|
||||
await bot.start(DISCORD_TOKEN)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = traceback.format_exc()
|
||||
logging.critical(f"Fatal error in main function: {str(e)}\n{error_msg}")
|
||||
|
||||
# Clean up resources if initialization failed halfway
|
||||
await cleanup_resources(bot=None, db_handler=db_handler, message_handler=message_handler)
|
||||
|
||||
async def shutdown(sig, loop, bot, db_handler, message_handler):
|
||||
"""Handle graceful shutdown of the bot"""
|
||||
logging.info(f"Received signal {sig.name}. Starting graceful shutdown...")
|
||||
|
||||
# Set shutdown flag to stop ongoing tasks
|
||||
shutdown_flag.set()
|
||||
|
||||
# Give running tasks a moment to detect shutdown flag
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Start cleanup
|
||||
await cleanup_resources(bot, db_handler, message_handler)
|
||||
|
||||
# Stop the event loop
|
||||
loop.stop()
|
||||
|
||||
async def cleanup_resources(bot, db_handler, message_handler):
|
||||
"""Clean up all resources during shutdown"""
|
||||
try:
|
||||
# Close the bot connection
|
||||
if bot is not None:
|
||||
logging.info("Closing bot connection...")
|
||||
await bot.close()
|
||||
|
||||
# Close message handler resources
|
||||
if message_handler is not None:
|
||||
logging.info("Closing message handler resources...")
|
||||
await message_handler.close()
|
||||
|
||||
# Close database connection
|
||||
if db_handler is not None:
|
||||
logging.info("Closing database connection...")
|
||||
await db_handler.close()
|
||||
|
||||
# Clean up webhook logging
|
||||
if ENABLE_WEBHOOK_LOGGING and LOGGING_WEBHOOK_URL:
|
||||
logging.info("Cleaning up webhook logging...")
|
||||
webhook_log_manager.cleanup()
|
||||
|
||||
logging.info("Cleanup completed successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Error during cleanup: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Use asyncio.run to properly run the async main function
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Bot stopped via keyboard interrupt")
|
||||
except Exception as e:
|
||||
logging.critical(f"Unhandled exception in main thread: {str(e)}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
logging.info("Bot shut down completely")
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
discord.py>=2.3.0
|
||||
openai>=1.3.0
|
||||
motor>=3.3.0
|
||||
pymongo>=4.6.0
|
||||
tiktoken>=0.5.0
|
||||
PyPDF2>=3.0.0
|
||||
beautifulsoup4>=4.12.0
|
||||
requests>=2.31.0
|
||||
aiohttp>=3.9.0
|
||||
runware>=0.2.0
|
||||
python-dotenv>=1.0.0
|
||||
webdriver-manager
|
||||
matplotlib
|
||||
pandas
|
||||
openpyxl
|
||||
pytz
|
||||
xlrd
|
||||
discord.py>=2.3.0
|
||||
openai>=1.3.0
|
||||
motor>=3.3.0
|
||||
pymongo>=4.6.0
|
||||
tiktoken>=0.5.0
|
||||
PyPDF2>=3.0.0
|
||||
beautifulsoup4>=4.12.0
|
||||
requests>=2.31.0
|
||||
aiohttp>=3.9.0
|
||||
runware>=0.2.0
|
||||
python-dotenv>=1.0.0
|
||||
webdriver-manager
|
||||
matplotlib
|
||||
pandas
|
||||
openpyxl
|
||||
pytz
|
||||
xlrd
|
||||
scipy
|
||||
@@ -1,466 +1,466 @@
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
import logging
|
||||
import io
|
||||
import asyncio
|
||||
from typing import Optional, Dict, List, Any, Callable
|
||||
|
||||
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS
|
||||
from src.utils.image_utils import ImageGenerator
|
||||
from src.utils.web_utils import google_custom_search, scrape_web_content
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
|
||||
# Dictionary to keep track of user requests and their cooldowns
|
||||
user_requests = {}
|
||||
# Dictionary to store user tasks
|
||||
user_tasks = {}
|
||||
|
||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
||||
"""
|
||||
Set up all slash commands for the bot.
|
||||
|
||||
Args:
|
||||
bot: Discord bot instance
|
||||
db_handler: Database handler instance
|
||||
openai_client: OpenAI client instance
|
||||
image_generator: Image generator instance
|
||||
"""
|
||||
tree = bot.tree
|
||||
|
||||
def check_blacklist():
|
||||
"""Decorator to check if a user is blacklisted before executing a command."""
|
||||
async def predicate(interaction: discord.Interaction):
|
||||
if await db_handler.is_admin(interaction.user.id):
|
||||
return True
|
||||
if await db_handler.is_user_blacklisted(interaction.user.id):
|
||||
await interaction.response.send_message("You have been blacklisted from using this bot. Please contact the admin if you think this is a mistake.", ephemeral=True)
|
||||
return False
|
||||
return True
|
||||
return app_commands.check(predicate)
|
||||
|
||||
# Processes a command request with rate limiting and queuing.
|
||||
async def process_request(interaction, command_func, *args):
|
||||
user_id = interaction.user.id
|
||||
now = discord.utils.utcnow().timestamp()
|
||||
|
||||
if user_id not in user_requests:
|
||||
user_requests[user_id] = {'last_request': 0, 'queue': asyncio.Queue()}
|
||||
|
||||
last_request = user_requests[user_id]['last_request']
|
||||
|
||||
if now - last_request < 5:
|
||||
await interaction.followup.send("You are sending requests too quickly. Please wait a moment.", ephemeral=True)
|
||||
return
|
||||
|
||||
# Update last request time
|
||||
user_requests[user_id]['last_request'] = now
|
||||
|
||||
# Add request to queue
|
||||
queue = user_requests[user_id]['queue']
|
||||
await queue.put((command_func, args))
|
||||
|
||||
# Start processing if it's the only request in the queue
|
||||
if queue.qsize() == 1:
|
||||
await process_queue(interaction)
|
||||
|
||||
# Processes requests in the user's queue sequentially.
|
||||
async def process_queue(interaction):
|
||||
user_id = interaction.user.id
|
||||
queue = user_requests[user_id]['queue']
|
||||
|
||||
while not queue.empty():
|
||||
command_func, args = await queue.get()
|
||||
try:
|
||||
await command_func(interaction, *args)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing command: {str(e)}")
|
||||
await interaction.followup.send(f"An error occurred: {str(e)}", ephemeral=True)
|
||||
await asyncio.sleep(1) # Optional delay between processing
|
||||
|
||||
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
|
||||
@check_blacklist()
|
||||
async def choose_model(interaction: discord.Interaction):
|
||||
"""Lets users choose an AI model and saves it to the database."""
|
||||
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
|
||||
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
|
||||
|
||||
async def select_callback(interaction: discord.Interaction):
|
||||
selected_model = select_menu.values[0]
|
||||
user_id = interaction.user.id
|
||||
|
||||
# Save the model selection to the database
|
||||
await db_handler.save_user_model(user_id, selected_model)
|
||||
await interaction.response.send_message(
|
||||
f"Model set to `{selected_model}` for your responses.", ephemeral=True
|
||||
)
|
||||
|
||||
select_menu.callback = select_callback
|
||||
view = discord.ui.View()
|
||||
view.add_item(select_menu)
|
||||
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
|
||||
|
||||
@tree.command(name="search", description="Search on Google and send results to AI model.")
|
||||
@app_commands.describe(query="The search query")
|
||||
@check_blacklist()
|
||||
async def search(interaction: discord.Interaction, query: str):
|
||||
"""Searches Google and sends results to the AI model."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
async def process_search(interaction: discord.Interaction, query: str):
|
||||
user_id = interaction.user.id
|
||||
model = await db_handler.get_user_model(user_id) or "gpt-4o"
|
||||
history = await db_handler.get_history(user_id)
|
||||
|
||||
try:
|
||||
# Perform Google search
|
||||
search_results = google_custom_search(query)
|
||||
|
||||
if not search_results or not search_results.get('results'):
|
||||
await interaction.followup.send("No search results found.")
|
||||
return
|
||||
|
||||
# Format search results for the AI model
|
||||
from src.config.config import SEARCH_PROMPT
|
||||
formatted_results = f"Search results for: {query}\n\n"
|
||||
|
||||
for i, result in enumerate(search_results.get('results', [])):
|
||||
formatted_results += f"{i+1}. {result.get('title')}\n"
|
||||
formatted_results += f"URL: {result.get('link')}\n"
|
||||
formatted_results += f"Snippet: {result.get('snippet')}\n"
|
||||
if 'scraped_content' in result:
|
||||
content_preview = result['scraped_content'][:300] + "..." if len(result['scraped_content']) > 300 else result['scraped_content']
|
||||
formatted_results += f"Content: {content_preview}\n"
|
||||
formatted_results += "\n"
|
||||
|
||||
# Prepare messages for the AI model, handling system prompts appropriately
|
||||
messages = []
|
||||
if model in ["o1-mini", "o1-preview"]:
|
||||
messages = [
|
||||
{"role": "user", "content": f"Instructions: {SEARCH_PROMPT}\n\n{formatted_results}\n\nUser query: {query}"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": SEARCH_PROMPT},
|
||||
{"role": "user", "content": f"{formatted_results}\n\nUser query: {query}"}
|
||||
]
|
||||
|
||||
# Send to the AI model
|
||||
response = await openai_client.chat.completions.create(
|
||||
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add the interaction to history
|
||||
history.append({"role": "user", "content": f"Search query: {query}"})
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
await db_handler.save_history(user_id, history)
|
||||
|
||||
# Check if the reply exceeds Discord's character limit (2000)
|
||||
if len(reply) > 2000:
|
||||
# Create a text file with the full response
|
||||
file_bytes = io.BytesIO(reply.encode('utf-8'))
|
||||
file = discord.File(file_bytes, filename="search_response.txt")
|
||||
|
||||
# Send a short message with the file attachment
|
||||
await interaction.followup.send(
|
||||
f"The search response for '{query}' is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
|
||||
file=file
|
||||
)
|
||||
else:
|
||||
# Send as normal message if within limits
|
||||
await interaction.followup.send(reply)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Search error: {str(e)}"
|
||||
logging.error(error_message)
|
||||
await interaction.followup.send(f"An error occurred while searching: {str(e)}")
|
||||
|
||||
await process_request(interaction, process_search, query)
|
||||
|
||||
@tree.command(name="web", description="Scrape a webpage and send data to AI model.")
|
||||
@app_commands.describe(url="The webpage URL to scrape")
|
||||
@check_blacklist()
|
||||
async def web(interaction: discord.Interaction, url: str):
|
||||
"""Scrapes a webpage and sends data to the AI model."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
async def process_web(interaction: discord.Interaction, url: str):
|
||||
user_id = interaction.user.id
|
||||
model = await db_handler.get_user_model(user_id) or "gpt-4o"
|
||||
history = await db_handler.get_history(user_id)
|
||||
|
||||
try:
|
||||
content = scrape_web_content(url)
|
||||
if content.startswith("Failed"):
|
||||
await interaction.followup.send(content)
|
||||
return
|
||||
|
||||
from src.config.config import WEB_SCRAPING_PROMPT
|
||||
|
||||
if model in ["o1-mini", "o1-preview"]:
|
||||
messages = [
|
||||
{"role": "user", "content": f"Instructions: {WEB_SCRAPING_PROMPT}\n\nContent from {url}:\n{content}"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": WEB_SCRAPING_PROMPT},
|
||||
{"role": "user", "content": f"Content from {url}:\n{content}"}
|
||||
]
|
||||
|
||||
response = await openai_client.chat.completions.create(
|
||||
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
top_p=0.7
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add the interaction to history
|
||||
history.append({"role": "user", "content": f"Scraped content from {url}"})
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
await db_handler.save_history(user_id, history)
|
||||
|
||||
# Check if the reply exceeds Discord's character limit (2000)
|
||||
if len(reply) > 2000:
|
||||
# Create a text file with the full response
|
||||
file_bytes = io.BytesIO(reply.encode('utf-8'))
|
||||
file = discord.File(file_bytes, filename="web_response.txt")
|
||||
|
||||
# Send a short message with the file attachment
|
||||
await interaction.followup.send(
|
||||
f"The response from analyzing {url} is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
|
||||
file=file
|
||||
)
|
||||
else:
|
||||
# Send as normal message if within limits
|
||||
await interaction.followup.send(reply)
|
||||
|
||||
except Exception as e:
|
||||
await interaction.followup.send(f"Error: {str(e)}", ephemeral=True)
|
||||
|
||||
await process_request(interaction, process_web, url)
|
||||
|
||||
@tree.command(name='generate', description='Generates an image from a text prompt.')
|
||||
@app_commands.describe(prompt='The prompt for image generation')
|
||||
@check_blacklist()
|
||||
async def generate_image_command(interaction: discord.Interaction, prompt: str):
|
||||
"""Generates an image from a text prompt."""
|
||||
await interaction.response.defer(thinking=True) # Indicate that the bot is processing
|
||||
|
||||
async def process_image_generation(interaction: discord.Interaction, prompt: str):
|
||||
try:
|
||||
# Generate images
|
||||
result = await image_generator.generate_image(prompt, 4) # Generate 4 images
|
||||
|
||||
if not result['success']:
|
||||
await interaction.followup.send(f"Error: {result.get('error', 'Unknown error')}")
|
||||
return
|
||||
|
||||
# Send images as attachments
|
||||
if result["binary_images"]:
|
||||
await interaction.followup.send(
|
||||
f"Generated {len(result['binary_images'])} images for prompt: \"{prompt}\"",
|
||||
files=[discord.File(io.BytesIO(img), filename=f"image_{i}.png")
|
||||
for i, img in enumerate(result["binary_images"])]
|
||||
)
|
||||
else:
|
||||
await interaction.followup.send("No images were generated.")
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"An error occurred: {str(e)}"
|
||||
logging.error(f"Error in generate_image_command: {error_message}")
|
||||
await interaction.followup.send(error_message)
|
||||
|
||||
await process_request(interaction, process_image_generation, prompt)
|
||||
|
||||
@tree.command(name="reset", description="Reset the bot by clearing user data.")
|
||||
@check_blacklist()
|
||||
async def reset(interaction: discord.Interaction):
|
||||
"""Resets the bot by clearing user data."""
|
||||
user_id = interaction.user.id
|
||||
await db_handler.save_history(user_id, [])
|
||||
await interaction.response.send_message("Your conversation history has been cleared and reset!", ephemeral=True)
|
||||
|
||||
@tree.command(name="user_stat", description="Get your current input token, output token, and model.")
|
||||
@check_blacklist()
|
||||
async def user_stat(interaction: discord.Interaction):
|
||||
"""Fetches and displays the current input token, output token, and model for the user."""
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
async def process_user_stat(interaction: discord.Interaction):
|
||||
import tiktoken
|
||||
|
||||
user_id = interaction.user.id
|
||||
history = await db_handler.get_history(user_id)
|
||||
model = await db_handler.get_user_model(user_id) or "gpt-4o" # Default model
|
||||
|
||||
# Adjust model for encoding purposes
|
||||
if model in ["gpt-4o", "o1", "o1-preview", "o1-mini", "o3-mini"]:
|
||||
encoding_model = "gpt-4o"
|
||||
else:
|
||||
encoding_model = model
|
||||
|
||||
# Retrieve the appropriate encoding for the selected model
|
||||
encoding = tiktoken.encoding_for_model(encoding_model)
|
||||
|
||||
# Initialize token counts
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# Calculate input and output tokens
|
||||
if history:
|
||||
for item in history:
|
||||
content = item.get('content', '')
|
||||
|
||||
# Handle case where content is a list or other type
|
||||
if isinstance(content, list):
|
||||
content_str = ""
|
||||
for part in content:
|
||||
if isinstance(part, dict) and 'text' in part:
|
||||
content_str += part['text'] + " "
|
||||
content = content_str
|
||||
|
||||
# Ensure content is a string before processing
|
||||
if isinstance(content, str):
|
||||
tokens = len(encoding.encode(content))
|
||||
if item.get('role') == 'user':
|
||||
input_tokens += tokens
|
||||
elif item.get('role') == 'assistant':
|
||||
output_tokens += tokens
|
||||
|
||||
# Create the statistics message
|
||||
stat_message = (
|
||||
f"**User Statistics:**\n"
|
||||
f"Model: `{model}`\n"
|
||||
f"Input Tokens: `{input_tokens}`\n"
|
||||
f"Output Tokens: `{output_tokens}`\n"
|
||||
)
|
||||
|
||||
# Send the response
|
||||
await interaction.followup.send(stat_message, ephemeral=True)
|
||||
|
||||
await process_request(interaction, process_user_stat)
|
||||
|
||||
@tree.command(name="help", description="Display a list of available commands.")
|
||||
@check_blacklist()
|
||||
async def help_command(interaction: discord.Interaction):
|
||||
"""Sends a list of available commands to the user."""
|
||||
help_message = (
|
||||
"**Available commands:**\n"
|
||||
"/choose_model - Select which AI model to use for responses (gpt-4o, gpt-4o-mini, o1-preview, o1-mini).\n"
|
||||
"/search `<query>` - Search Google and send results to the AI model.\n"
|
||||
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
|
||||
"/generate `<prompt>` - Generate an image from a text prompt.\n"
|
||||
"/reset - Reset your chat history.\n"
|
||||
"/user_stat - Get information about your input tokens, output tokens, and current model.\n"
|
||||
"/help - Display this help message.\n"
|
||||
)
|
||||
await interaction.response.send_message(help_message, ephemeral=True)
|
||||
|
||||
@tree.command(name="stop", description="Stop any process or queue of the user. Admins can stop other users' tasks by providing their ID.")
|
||||
@app_commands.describe(user_id="The Discord user ID to stop tasks for (admin only)")
|
||||
@check_blacklist()
|
||||
async def stop(interaction: discord.Interaction, user_id: str = None):
|
||||
"""Stops any process or queue of the user. Admins can stop other users' tasks by providing their ID."""
|
||||
# Defer the interaction first
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
if user_id and not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.followup.send("You don't have permission to stop other users' tasks.", ephemeral=True)
|
||||
return
|
||||
|
||||
target_user_id = int(user_id) if user_id else interaction.user.id
|
||||
await stop_user_tasks(target_user_id)
|
||||
await interaction.followup.send(f"Stopped all tasks for user {target_user_id}.", ephemeral=True)
|
||||
|
||||
# Admin commands
|
||||
@tree.command(name="whitelist_add", description="Add a user to the PDF processing whitelist")
|
||||
@app_commands.describe(user_id="The Discord user ID to whitelist")
|
||||
async def whitelist_add(interaction: discord.Interaction, user_id: str):
|
||||
"""Adds a user to the PDF processing whitelist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.is_admin(user_id):
|
||||
await interaction.response.send_message("Admins are automatically whitelisted and don't need to be added.", ephemeral=True)
|
||||
return
|
||||
await db_handler.add_user_to_whitelist(user_id)
|
||||
await interaction.response.send_message(f"User {user_id} has been added to the PDF processing whitelist.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
@tree.command(name="whitelist_remove", description="Remove a user from the PDF processing whitelist")
|
||||
@app_commands.describe(user_id="The Discord user ID to remove from whitelist")
|
||||
async def whitelist_remove(interaction: discord.Interaction, user_id: str):
|
||||
"""Removes a user from the PDF processing whitelist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.remove_user_from_whitelist(user_id):
|
||||
await interaction.response.send_message(f"User {user_id} has been removed from the PDF processing whitelist.", ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(f"User {user_id} was not found in the whitelist.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
@tree.command(name="blacklist_add", description="Add a user to the bot blacklist")
|
||||
@app_commands.describe(user_id="The Discord user ID to blacklist")
|
||||
async def blacklist_add(interaction: discord.Interaction, user_id: str):
|
||||
"""Adds a user to the bot blacklist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.is_admin(user_id):
|
||||
await interaction.response.send_message("Cannot blacklist an admin.", ephemeral=True)
|
||||
return
|
||||
await db_handler.add_user_to_blacklist(user_id)
|
||||
await interaction.response.send_message(f"User {user_id} has been added to the bot blacklist. They can no longer use any bot features.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
@tree.command(name="blacklist_remove", description="Remove a user from the bot blacklist")
|
||||
@app_commands.describe(user_id="The Discord user ID to remove from blacklist")
|
||||
async def blacklist_remove(interaction: discord.Interaction, user_id: str):
|
||||
"""Removes a user from the bot blacklist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.remove_user_from_blacklist(user_id):
|
||||
await interaction.response.send_message(f"User {user_id} has been removed from the bot blacklist. They can now use bot features again.", ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(f"User {user_id} was not found in the blacklist.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
# Helper function to stop user tasks
|
||||
async def stop_user_tasks(user_id: int):
|
||||
"""Stop all tasks for a specific user."""
|
||||
if user_id in user_tasks:
|
||||
for task in user_tasks[user_id]:
|
||||
task.cancel()
|
||||
user_tasks[user_id] = []
|
||||
|
||||
# Clear any queued requests
|
||||
if user_id in user_requests:
|
||||
while not user_requests[user_id]['queue'].empty():
|
||||
try:
|
||||
user_requests[user_id]['queue'].get_nowait()
|
||||
except:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
import logging
|
||||
import io
|
||||
import asyncio
|
||||
from typing import Optional, Dict, List, Any, Callable
|
||||
|
||||
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS
|
||||
from src.utils.image_utils import ImageGenerator
|
||||
from src.utils.web_utils import google_custom_search, scrape_web_content
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
|
||||
# Dictionary to keep track of user requests and their cooldowns
|
||||
user_requests = {}
|
||||
# Dictionary to store user tasks
|
||||
user_tasks = {}
|
||||
|
||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
||||
"""
|
||||
Set up all slash commands for the bot.
|
||||
|
||||
Args:
|
||||
bot: Discord bot instance
|
||||
db_handler: Database handler instance
|
||||
openai_client: OpenAI client instance
|
||||
image_generator: Image generator instance
|
||||
"""
|
||||
tree = bot.tree
|
||||
|
||||
def check_blacklist():
|
||||
"""Decorator to check if a user is blacklisted before executing a command."""
|
||||
async def predicate(interaction: discord.Interaction):
|
||||
if await db_handler.is_admin(interaction.user.id):
|
||||
return True
|
||||
if await db_handler.is_user_blacklisted(interaction.user.id):
|
||||
await interaction.response.send_message("You have been blacklisted from using this bot. Please contact the admin if you think this is a mistake.", ephemeral=True)
|
||||
return False
|
||||
return True
|
||||
return app_commands.check(predicate)
|
||||
|
||||
# Processes a command request with rate limiting and queuing.
|
||||
async def process_request(interaction, command_func, *args):
|
||||
user_id = interaction.user.id
|
||||
now = discord.utils.utcnow().timestamp()
|
||||
|
||||
if user_id not in user_requests:
|
||||
user_requests[user_id] = {'last_request': 0, 'queue': asyncio.Queue()}
|
||||
|
||||
last_request = user_requests[user_id]['last_request']
|
||||
|
||||
if now - last_request < 5:
|
||||
await interaction.followup.send("You are sending requests too quickly. Please wait a moment.", ephemeral=True)
|
||||
return
|
||||
|
||||
# Update last request time
|
||||
user_requests[user_id]['last_request'] = now
|
||||
|
||||
# Add request to queue
|
||||
queue = user_requests[user_id]['queue']
|
||||
await queue.put((command_func, args))
|
||||
|
||||
# Start processing if it's the only request in the queue
|
||||
if queue.qsize() == 1:
|
||||
await process_queue(interaction)
|
||||
|
||||
# Processes requests in the user's queue sequentially.
|
||||
async def process_queue(interaction):
|
||||
user_id = interaction.user.id
|
||||
queue = user_requests[user_id]['queue']
|
||||
|
||||
while not queue.empty():
|
||||
command_func, args = await queue.get()
|
||||
try:
|
||||
await command_func(interaction, *args)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing command: {str(e)}")
|
||||
await interaction.followup.send(f"An error occurred: {str(e)}", ephemeral=True)
|
||||
await asyncio.sleep(1) # Optional delay between processing
|
||||
|
||||
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
|
||||
@check_blacklist()
|
||||
async def choose_model(interaction: discord.Interaction):
|
||||
"""Lets users choose an AI model and saves it to the database."""
|
||||
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
|
||||
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
|
||||
|
||||
async def select_callback(interaction: discord.Interaction):
|
||||
selected_model = select_menu.values[0]
|
||||
user_id = interaction.user.id
|
||||
|
||||
# Save the model selection to the database
|
||||
await db_handler.save_user_model(user_id, selected_model)
|
||||
await interaction.response.send_message(
|
||||
f"Model set to `{selected_model}` for your responses.", ephemeral=True
|
||||
)
|
||||
|
||||
select_menu.callback = select_callback
|
||||
view = discord.ui.View()
|
||||
view.add_item(select_menu)
|
||||
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
|
||||
|
||||
@tree.command(name="search", description="Search on Google and send results to AI model.")
|
||||
@app_commands.describe(query="The search query")
|
||||
@check_blacklist()
|
||||
async def search(interaction: discord.Interaction, query: str):
|
||||
"""Searches Google and sends results to the AI model."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
async def process_search(interaction: discord.Interaction, query: str):
|
||||
user_id = interaction.user.id
|
||||
model = await db_handler.get_user_model(user_id) or "gpt-4o"
|
||||
history = await db_handler.get_history(user_id)
|
||||
|
||||
try:
|
||||
# Perform Google search
|
||||
search_results = google_custom_search(query)
|
||||
|
||||
if not search_results or not search_results.get('results'):
|
||||
await interaction.followup.send("No search results found.")
|
||||
return
|
||||
|
||||
# Format search results for the AI model
|
||||
from src.config.config import SEARCH_PROMPT
|
||||
formatted_results = f"Search results for: {query}\n\n"
|
||||
|
||||
for i, result in enumerate(search_results.get('results', [])):
|
||||
formatted_results += f"{i+1}. {result.get('title')}\n"
|
||||
formatted_results += f"URL: {result.get('link')}\n"
|
||||
formatted_results += f"Snippet: {result.get('snippet')}\n"
|
||||
if 'scraped_content' in result:
|
||||
content_preview = result['scraped_content'][:300] + "..." if len(result['scraped_content']) > 300 else result['scraped_content']
|
||||
formatted_results += f"Content: {content_preview}\n"
|
||||
formatted_results += "\n"
|
||||
|
||||
# Prepare messages for the AI model, handling system prompts appropriately
|
||||
messages = []
|
||||
if model in ["o1-mini", "o1-preview"]:
|
||||
messages = [
|
||||
{"role": "user", "content": f"Instructions: {SEARCH_PROMPT}\n\n{formatted_results}\n\nUser query: {query}"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": SEARCH_PROMPT},
|
||||
{"role": "user", "content": f"{formatted_results}\n\nUser query: {query}"}
|
||||
]
|
||||
|
||||
# Send to the AI model
|
||||
response = await openai_client.chat.completions.create(
|
||||
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add the interaction to history
|
||||
history.append({"role": "user", "content": f"Search query: {query}"})
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
await db_handler.save_history(user_id, history)
|
||||
|
||||
# Check if the reply exceeds Discord's character limit (2000)
|
||||
if len(reply) > 2000:
|
||||
# Create a text file with the full response
|
||||
file_bytes = io.BytesIO(reply.encode('utf-8'))
|
||||
file = discord.File(file_bytes, filename="search_response.txt")
|
||||
|
||||
# Send a short message with the file attachment
|
||||
await interaction.followup.send(
|
||||
f"The search response for '{query}' is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
|
||||
file=file
|
||||
)
|
||||
else:
|
||||
# Send as normal message if within limits
|
||||
await interaction.followup.send(reply)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Search error: {str(e)}"
|
||||
logging.error(error_message)
|
||||
await interaction.followup.send(f"An error occurred while searching: {str(e)}")
|
||||
|
||||
await process_request(interaction, process_search, query)
|
||||
|
||||
@tree.command(name="web", description="Scrape a webpage and send data to AI model.")
|
||||
@app_commands.describe(url="The webpage URL to scrape")
|
||||
@check_blacklist()
|
||||
async def web(interaction: discord.Interaction, url: str):
|
||||
"""Scrapes a webpage and sends data to the AI model."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
async def process_web(interaction: discord.Interaction, url: str):
|
||||
user_id = interaction.user.id
|
||||
model = await db_handler.get_user_model(user_id) or "gpt-4o"
|
||||
history = await db_handler.get_history(user_id)
|
||||
|
||||
try:
|
||||
content = scrape_web_content(url)
|
||||
if content.startswith("Failed"):
|
||||
await interaction.followup.send(content)
|
||||
return
|
||||
|
||||
from src.config.config import WEB_SCRAPING_PROMPT
|
||||
|
||||
if model in ["o1-mini", "o1-preview"]:
|
||||
messages = [
|
||||
{"role": "user", "content": f"Instructions: {WEB_SCRAPING_PROMPT}\n\nContent from {url}:\n{content}"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": WEB_SCRAPING_PROMPT},
|
||||
{"role": "user", "content": f"Content from {url}:\n{content}"}
|
||||
]
|
||||
|
||||
response = await openai_client.chat.completions.create(
|
||||
model=model if model in ["gpt-4o", "gpt-4o-mini"] else "gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
top_p=0.7
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add the interaction to history
|
||||
history.append({"role": "user", "content": f"Scraped content from {url}"})
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
await db_handler.save_history(user_id, history)
|
||||
|
||||
# Check if the reply exceeds Discord's character limit (2000)
|
||||
if len(reply) > 2000:
|
||||
# Create a text file with the full response
|
||||
file_bytes = io.BytesIO(reply.encode('utf-8'))
|
||||
file = discord.File(file_bytes, filename="web_response.txt")
|
||||
|
||||
# Send a short message with the file attachment
|
||||
await interaction.followup.send(
|
||||
f"The response from analyzing {url} is too long for Discord (>{len(reply)} characters). Here's the full response as a text file:",
|
||||
file=file
|
||||
)
|
||||
else:
|
||||
# Send as normal message if within limits
|
||||
await interaction.followup.send(reply)
|
||||
|
||||
except Exception as e:
|
||||
await interaction.followup.send(f"Error: {str(e)}", ephemeral=True)
|
||||
|
||||
await process_request(interaction, process_web, url)
|
||||
|
||||
@tree.command(name='generate', description='Generates an image from a text prompt.')
|
||||
@app_commands.describe(prompt='The prompt for image generation')
|
||||
@check_blacklist()
|
||||
async def generate_image_command(interaction: discord.Interaction, prompt: str):
|
||||
"""Generates an image from a text prompt."""
|
||||
await interaction.response.defer(thinking=True) # Indicate that the bot is processing
|
||||
|
||||
async def process_image_generation(interaction: discord.Interaction, prompt: str):
|
||||
try:
|
||||
# Generate images
|
||||
result = await image_generator.generate_image(prompt, 4) # Generate 4 images
|
||||
|
||||
if not result['success']:
|
||||
await interaction.followup.send(f"Error: {result.get('error', 'Unknown error')}")
|
||||
return
|
||||
|
||||
# Send images as attachments
|
||||
if result["binary_images"]:
|
||||
await interaction.followup.send(
|
||||
f"Generated {len(result['binary_images'])} images for prompt: \"{prompt}\"",
|
||||
files=[discord.File(io.BytesIO(img), filename=f"image_{i}.png")
|
||||
for i, img in enumerate(result["binary_images"])]
|
||||
)
|
||||
else:
|
||||
await interaction.followup.send("No images were generated.")
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"An error occurred: {str(e)}"
|
||||
logging.error(f"Error in generate_image_command: {error_message}")
|
||||
await interaction.followup.send(error_message)
|
||||
|
||||
await process_request(interaction, process_image_generation, prompt)
|
||||
|
||||
@tree.command(name="reset", description="Reset the bot by clearing user data.")
|
||||
@check_blacklist()
|
||||
async def reset(interaction: discord.Interaction):
|
||||
"""Resets the bot by clearing user data."""
|
||||
user_id = interaction.user.id
|
||||
await db_handler.save_history(user_id, [])
|
||||
await interaction.response.send_message("Your conversation history has been cleared and reset!", ephemeral=True)
|
||||
|
||||
@tree.command(name="user_stat", description="Get your current input token, output token, and model.")
|
||||
@check_blacklist()
|
||||
async def user_stat(interaction: discord.Interaction):
|
||||
"""Fetches and displays the current input token, output token, and model for the user."""
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
async def process_user_stat(interaction: discord.Interaction):
|
||||
import tiktoken
|
||||
|
||||
user_id = interaction.user.id
|
||||
history = await db_handler.get_history(user_id)
|
||||
model = await db_handler.get_user_model(user_id) or "gpt-4o" # Default model
|
||||
|
||||
# Adjust model for encoding purposes
|
||||
if model in ["gpt-4o", "o1", "o1-preview", "o1-mini", "o3-mini"]:
|
||||
encoding_model = "gpt-4o"
|
||||
else:
|
||||
encoding_model = model
|
||||
|
||||
# Retrieve the appropriate encoding for the selected model
|
||||
encoding = tiktoken.encoding_for_model(encoding_model)
|
||||
|
||||
# Initialize token counts
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# Calculate input and output tokens
|
||||
if history:
|
||||
for item in history:
|
||||
content = item.get('content', '')
|
||||
|
||||
# Handle case where content is a list or other type
|
||||
if isinstance(content, list):
|
||||
content_str = ""
|
||||
for part in content:
|
||||
if isinstance(part, dict) and 'text' in part:
|
||||
content_str += part['text'] + " "
|
||||
content = content_str
|
||||
|
||||
# Ensure content is a string before processing
|
||||
if isinstance(content, str):
|
||||
tokens = len(encoding.encode(content))
|
||||
if item.get('role') == 'user':
|
||||
input_tokens += tokens
|
||||
elif item.get('role') == 'assistant':
|
||||
output_tokens += tokens
|
||||
|
||||
# Create the statistics message
|
||||
stat_message = (
|
||||
f"**User Statistics:**\n"
|
||||
f"Model: `{model}`\n"
|
||||
f"Input Tokens: `{input_tokens}`\n"
|
||||
f"Output Tokens: `{output_tokens}`\n"
|
||||
)
|
||||
|
||||
# Send the response
|
||||
await interaction.followup.send(stat_message, ephemeral=True)
|
||||
|
||||
await process_request(interaction, process_user_stat)
|
||||
|
||||
@tree.command(name="help", description="Display a list of available commands.")
|
||||
@check_blacklist()
|
||||
async def help_command(interaction: discord.Interaction):
|
||||
"""Sends a list of available commands to the user."""
|
||||
help_message = (
|
||||
"**Available commands:**\n"
|
||||
"/choose_model - Select which AI model to use for responses (gpt-4o, gpt-4o-mini, o1-preview, o1-mini).\n"
|
||||
"/search `<query>` - Search Google and send results to the AI model.\n"
|
||||
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
|
||||
"/generate `<prompt>` - Generate an image from a text prompt.\n"
|
||||
"/reset - Reset your chat history.\n"
|
||||
"/user_stat - Get information about your input tokens, output tokens, and current model.\n"
|
||||
"/help - Display this help message.\n"
|
||||
)
|
||||
await interaction.response.send_message(help_message, ephemeral=True)
|
||||
|
||||
@tree.command(name="stop", description="Stop any process or queue of the user. Admins can stop other users' tasks by providing their ID.")
|
||||
@app_commands.describe(user_id="The Discord user ID to stop tasks for (admin only)")
|
||||
@check_blacklist()
|
||||
async def stop(interaction: discord.Interaction, user_id: str = None):
|
||||
"""Stops any process or queue of the user. Admins can stop other users' tasks by providing their ID."""
|
||||
# Defer the interaction first
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
if user_id and not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.followup.send("You don't have permission to stop other users' tasks.", ephemeral=True)
|
||||
return
|
||||
|
||||
target_user_id = int(user_id) if user_id else interaction.user.id
|
||||
await stop_user_tasks(target_user_id)
|
||||
await interaction.followup.send(f"Stopped all tasks for user {target_user_id}.", ephemeral=True)
|
||||
|
||||
# Admin commands
|
||||
@tree.command(name="whitelist_add", description="Add a user to the PDF processing whitelist")
|
||||
@app_commands.describe(user_id="The Discord user ID to whitelist")
|
||||
async def whitelist_add(interaction: discord.Interaction, user_id: str):
|
||||
"""Adds a user to the PDF processing whitelist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.is_admin(user_id):
|
||||
await interaction.response.send_message("Admins are automatically whitelisted and don't need to be added.", ephemeral=True)
|
||||
return
|
||||
await db_handler.add_user_to_whitelist(user_id)
|
||||
await interaction.response.send_message(f"User {user_id} has been added to the PDF processing whitelist.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
@tree.command(name="whitelist_remove", description="Remove a user from the PDF processing whitelist")
|
||||
@app_commands.describe(user_id="The Discord user ID to remove from whitelist")
|
||||
async def whitelist_remove(interaction: discord.Interaction, user_id: str):
|
||||
"""Removes a user from the PDF processing whitelist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use whitelist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.remove_user_from_whitelist(user_id):
|
||||
await interaction.response.send_message(f"User {user_id} has been removed from the PDF processing whitelist.", ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(f"User {user_id} was not found in the whitelist.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
@tree.command(name="blacklist_add", description="Add a user to the bot blacklist")
|
||||
@app_commands.describe(user_id="The Discord user ID to blacklist")
|
||||
async def blacklist_add(interaction: discord.Interaction, user_id: str):
|
||||
"""Adds a user to the bot blacklist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.is_admin(user_id):
|
||||
await interaction.response.send_message("Cannot blacklist an admin.", ephemeral=True)
|
||||
return
|
||||
await db_handler.add_user_to_blacklist(user_id)
|
||||
await interaction.response.send_message(f"User {user_id} has been added to the bot blacklist. They can no longer use any bot features.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
@tree.command(name="blacklist_remove", description="Remove a user from the bot blacklist")
|
||||
@app_commands.describe(user_id="The Discord user ID to remove from blacklist")
|
||||
async def blacklist_remove(interaction: discord.Interaction, user_id: str):
|
||||
"""Removes a user from the bot blacklist."""
|
||||
if not await db_handler.is_admin(interaction.user.id):
|
||||
await interaction.response.send_message("You don't have permission to use this command. Only admin can use blacklist commands.", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
if await db_handler.remove_user_from_blacklist(user_id):
|
||||
await interaction.response.send_message(f"User {user_id} has been removed from the bot blacklist. They can now use bot features again.", ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(f"User {user_id} was not found in the blacklist.", ephemeral=True)
|
||||
except ValueError:
|
||||
await interaction.response.send_message("Invalid user ID. Please provide a valid Discord user ID.", ephemeral=True)
|
||||
|
||||
# Helper function to stop user tasks
|
||||
async def stop_user_tasks(user_id: int):
|
||||
"""Stop all tasks for a specific user."""
|
||||
if user_id in user_tasks:
|
||||
for task in user_tasks[user_id]:
|
||||
task.cancel()
|
||||
user_tasks[user_id] = []
|
||||
|
||||
# Clear any queued requests
|
||||
if user_id in user_requests:
|
||||
while not user_requests[user_id]['queue'].empty():
|
||||
try:
|
||||
user_requests[user_id]['queue'].get_nowait()
|
||||
except:
|
||||
pass
|
||||
@@ -1,265 +1,265 @@
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import List, Dict, Any, Optional
|
||||
import functools
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import re
|
||||
|
||||
class DatabaseHandler:
|
||||
# Class-level cache for database results
|
||||
_cache = {}
|
||||
_cache_expiry = {}
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, mongodb_uri: str):
|
||||
"""Initialize database connection with optimized settings"""
|
||||
# Set up a connection pool with sensible timeouts
|
||||
self.client = AsyncIOMotorClient(
|
||||
mongodb_uri,
|
||||
maxPoolSize=50,
|
||||
minPoolSize=10,
|
||||
maxIdleTimeMS=45000,
|
||||
connectTimeoutMS=2000,
|
||||
serverSelectionTimeoutMS=3000,
|
||||
waitQueueTimeoutMS=1000,
|
||||
retryWrites=True
|
||||
)
|
||||
self.db = self.client['chatgpt_discord_bot'] # Database name
|
||||
|
||||
# Collections
|
||||
self.users_collection = self.db.users
|
||||
self.history_collection = self.db.history
|
||||
self.admin_collection = self.db.admin
|
||||
self.blacklist_collection = self.db.blacklist
|
||||
self.whitelist_collection = self.db.whitelist
|
||||
self.logs_collection = self.db.logs
|
||||
self.reminders_collection = self.db.reminders
|
||||
|
||||
logging.info("Database handler initialized")
|
||||
|
||||
# Helper for caching results
|
||||
async def _get_cached_result(self, cache_key, fetch_func, expiry_seconds=60):
|
||||
"""Get result from cache or execute fetch_func if not cached/expired"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if we have a cached result that's still valid
|
||||
async with self._cache_lock:
|
||||
if (cache_key in self._cache and
|
||||
cache_key in self._cache_expiry and
|
||||
current_time < self._cache_expiry[cache_key]):
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Not in cache or expired, fetch new result
|
||||
result = await fetch_func()
|
||||
|
||||
# Cache the new result
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = result
|
||||
self._cache_expiry[cache_key] = current_time + timedelta(seconds=expiry_seconds)
|
||||
|
||||
return result
|
||||
|
||||
# User history methods
|
||||
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get user conversation history with caching and filter expired image links"""
|
||||
cache_key = f"history_{user_id}"
|
||||
|
||||
async def fetch_history():
|
||||
user_data = await self.db.user_histories.find_one({'user_id': user_id})
|
||||
if user_data and 'history' in user_data:
|
||||
# Filter out expired image links
|
||||
filtered_history = self._filter_expired_images(user_data['history'])
|
||||
return filtered_history
|
||||
return []
|
||||
|
||||
return await self._get_cached_result(cache_key, fetch_history, 30) # 30 second cache
|
||||
|
||||
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filter out image links that are older than 23 hours"""
|
||||
current_time = datetime.now()
|
||||
expiration_time = current_time - timedelta(hours=23)
|
||||
|
||||
filtered_history = []
|
||||
for msg in history:
|
||||
# Keep system messages unchanged
|
||||
if msg.get('role') == 'system':
|
||||
filtered_history.append(msg)
|
||||
continue
|
||||
|
||||
# Check if message has 'content' field as a list (which may contain image URLs)
|
||||
content = msg.get('content')
|
||||
if isinstance(content, list):
|
||||
# Filter content items
|
||||
filtered_content = []
|
||||
for item in content:
|
||||
# Keep text items
|
||||
if item.get('type') == 'text':
|
||||
filtered_content.append(item)
|
||||
# Check image items for timestamp
|
||||
elif item.get('type') == 'image_url':
|
||||
# If there's no timestamp or timestamp is newer than expiration time, keep it
|
||||
timestamp = item.get('timestamp')
|
||||
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time:
|
||||
filtered_content.append(item)
|
||||
else:
|
||||
logging.info(f"Filtering out expired image URL (added at {timestamp})")
|
||||
|
||||
# Update the message with filtered content
|
||||
if filtered_content:
|
||||
new_msg = dict(msg)
|
||||
new_msg['content'] = filtered_content
|
||||
filtered_history.append(new_msg)
|
||||
else:
|
||||
# If after filtering there's no content, add a placeholder text
|
||||
new_msg = dict(msg)
|
||||
new_msg['content'] = [{"type": "text", "text": "[Image content expired]"}]
|
||||
filtered_history.append(new_msg)
|
||||
else:
|
||||
# For string content or other formats, keep as is
|
||||
filtered_history.append(msg)
|
||||
|
||||
return filtered_history
|
||||
|
||||
async def save_history(self, user_id: int, history: List[Dict[str, Any]]) -> None:
|
||||
"""Save user conversation history and update cache"""
|
||||
await self.db.user_histories.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'history': history}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"history_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = history
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=30)
|
||||
|
||||
# User model preferences with caching
|
||||
async def get_user_model(self, user_id: int) -> Optional[str]:
|
||||
"""Get user's preferred model with caching"""
|
||||
cache_key = f"model_{user_id}"
|
||||
|
||||
async def fetch_model():
|
||||
user_data = await self.db.user_models.find_one({'user_id': user_id})
|
||||
return user_data['model'] if user_data else None
|
||||
|
||||
return await self._get_cached_result(cache_key, fetch_model, 300) # 5 minute cache
|
||||
|
||||
async def save_user_model(self, user_id: int, model: str) -> None:
|
||||
"""Save user's preferred model and update cache"""
|
||||
await self.db.user_models.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'model': model}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"model_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = model
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
# Admin and permissions management with caching
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
"""Check if the user is an admin (no caching for security)"""
|
||||
admin_id = str(user_id) # Convert to string for comparison
|
||||
from src.config.config import ADMIN_ID
|
||||
return admin_id == ADMIN_ID
|
||||
|
||||
async def is_user_whitelisted(self, user_id: int) -> bool:
|
||||
"""Check if the user is whitelisted with caching"""
|
||||
if await self.is_admin(user_id):
|
||||
return True
|
||||
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
|
||||
async def check_whitelist():
|
||||
user_data = await self.db.whitelist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
return await self._get_cached_result(cache_key, check_whitelist, 300) # 5 minute cache
|
||||
|
||||
async def add_user_to_whitelist(self, user_id: int) -> None:
|
||||
"""Add user to whitelist and update cache"""
|
||||
await self.db.whitelist.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'user_id': user_id}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = True
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
async def remove_user_from_whitelist(self, user_id: int) -> bool:
|
||||
"""Remove user from whitelist and update cache"""
|
||||
result = await self.db.whitelist.delete_one({'user_id': user_id})
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = False
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
return result.deleted_count > 0
|
||||
|
||||
async def is_user_blacklisted(self, user_id: int) -> bool:
|
||||
"""Check if the user is blacklisted with caching"""
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
|
||||
async def check_blacklist():
|
||||
user_data = await self.db.blacklist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
return await self._get_cached_result(cache_key, check_blacklist, 300) # 5 minute cache
|
||||
|
||||
async def add_user_to_blacklist(self, user_id: int) -> None:
|
||||
"""Add user to blacklist and update cache"""
|
||||
await self.db.blacklist.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'user_id': user_id}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = True
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
async def remove_user_from_blacklist(self, user_id: int) -> bool:
|
||||
"""Remove user from blacklist and update cache"""
|
||||
result = await self.db.blacklist.delete_one({'user_id': user_id})
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = False
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
return result.deleted_count > 0
|
||||
|
||||
# Connection management and cleanup
|
||||
async def create_indexes(self):
|
||||
"""Create indexes for better query performance"""
|
||||
await self.db.user_histories.create_index("user_id")
|
||||
await self.db.user_models.create_index("user_id")
|
||||
await self.db.whitelist.create_index("user_id")
|
||||
await self.db.blacklist.create_index("user_id")
|
||||
|
||||
async def ensure_reminders_collection(self):
|
||||
"""
|
||||
Ensure the reminders collection exists and create necessary indexes
|
||||
"""
|
||||
# Create the collection if it doesn't exist
|
||||
await self.reminders_collection.create_index([("user_id", 1), ("sent", 1)])
|
||||
await self.reminders_collection.create_index([("remind_at", 1), ("sent", 1)])
|
||||
logging.info("Ensured reminders collection and indexes")
|
||||
|
||||
async def close(self):
|
||||
"""Properly close the database connection"""
|
||||
self.client.close()
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import List, Dict, Any, Optional
|
||||
import functools
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import re
|
||||
|
||||
class DatabaseHandler:
|
||||
# Class-level cache for database results
|
||||
_cache = {}
|
||||
_cache_expiry = {}
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, mongodb_uri: str):
|
||||
"""Initialize database connection with optimized settings"""
|
||||
# Set up a connection pool with sensible timeouts
|
||||
self.client = AsyncIOMotorClient(
|
||||
mongodb_uri,
|
||||
maxPoolSize=50,
|
||||
minPoolSize=10,
|
||||
maxIdleTimeMS=45000,
|
||||
connectTimeoutMS=2000,
|
||||
serverSelectionTimeoutMS=3000,
|
||||
waitQueueTimeoutMS=1000,
|
||||
retryWrites=True
|
||||
)
|
||||
self.db = self.client['chatgpt_discord_bot'] # Database name
|
||||
|
||||
# Collections
|
||||
self.users_collection = self.db.users
|
||||
self.history_collection = self.db.history
|
||||
self.admin_collection = self.db.admin
|
||||
self.blacklist_collection = self.db.blacklist
|
||||
self.whitelist_collection = self.db.whitelist
|
||||
self.logs_collection = self.db.logs
|
||||
self.reminders_collection = self.db.reminders
|
||||
|
||||
logging.info("Database handler initialized")
|
||||
|
||||
# Helper for caching results
|
||||
async def _get_cached_result(self, cache_key, fetch_func, expiry_seconds=60):
|
||||
"""Get result from cache or execute fetch_func if not cached/expired"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if we have a cached result that's still valid
|
||||
async with self._cache_lock:
|
||||
if (cache_key in self._cache and
|
||||
cache_key in self._cache_expiry and
|
||||
current_time < self._cache_expiry[cache_key]):
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Not in cache or expired, fetch new result
|
||||
result = await fetch_func()
|
||||
|
||||
# Cache the new result
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = result
|
||||
self._cache_expiry[cache_key] = current_time + timedelta(seconds=expiry_seconds)
|
||||
|
||||
return result
|
||||
|
||||
# User history methods
|
||||
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get user conversation history with caching and filter expired image links"""
|
||||
cache_key = f"history_{user_id}"
|
||||
|
||||
async def fetch_history():
|
||||
user_data = await self.db.user_histories.find_one({'user_id': user_id})
|
||||
if user_data and 'history' in user_data:
|
||||
# Filter out expired image links
|
||||
filtered_history = self._filter_expired_images(user_data['history'])
|
||||
return filtered_history
|
||||
return []
|
||||
|
||||
return await self._get_cached_result(cache_key, fetch_history, 30) # 30 second cache
|
||||
|
||||
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filter out image links that are older than 23 hours"""
|
||||
current_time = datetime.now()
|
||||
expiration_time = current_time - timedelta(hours=23)
|
||||
|
||||
filtered_history = []
|
||||
for msg in history:
|
||||
# Keep system messages unchanged
|
||||
if msg.get('role') == 'system':
|
||||
filtered_history.append(msg)
|
||||
continue
|
||||
|
||||
# Check if message has 'content' field as a list (which may contain image URLs)
|
||||
content = msg.get('content')
|
||||
if isinstance(content, list):
|
||||
# Filter content items
|
||||
filtered_content = []
|
||||
for item in content:
|
||||
# Keep text items
|
||||
if item.get('type') == 'text':
|
||||
filtered_content.append(item)
|
||||
# Check image items for timestamp
|
||||
elif item.get('type') == 'image_url':
|
||||
# If there's no timestamp or timestamp is newer than expiration time, keep it
|
||||
timestamp = item.get('timestamp')
|
||||
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time:
|
||||
filtered_content.append(item)
|
||||
else:
|
||||
logging.info(f"Filtering out expired image URL (added at {timestamp})")
|
||||
|
||||
# Update the message with filtered content
|
||||
if filtered_content:
|
||||
new_msg = dict(msg)
|
||||
new_msg['content'] = filtered_content
|
||||
filtered_history.append(new_msg)
|
||||
else:
|
||||
# If after filtering there's no content, add a placeholder text
|
||||
new_msg = dict(msg)
|
||||
new_msg['content'] = [{"type": "text", "text": "[Image content expired]"}]
|
||||
filtered_history.append(new_msg)
|
||||
else:
|
||||
# For string content or other formats, keep as is
|
||||
filtered_history.append(msg)
|
||||
|
||||
return filtered_history
|
||||
|
||||
async def save_history(self, user_id: int, history: List[Dict[str, Any]]) -> None:
|
||||
"""Save user conversation history and update cache"""
|
||||
await self.db.user_histories.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'history': history}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"history_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = history
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=30)
|
||||
|
||||
# User model preferences with caching
|
||||
async def get_user_model(self, user_id: int) -> Optional[str]:
|
||||
"""Get user's preferred model with caching"""
|
||||
cache_key = f"model_{user_id}"
|
||||
|
||||
async def fetch_model():
|
||||
user_data = await self.db.user_models.find_one({'user_id': user_id})
|
||||
return user_data['model'] if user_data else None
|
||||
|
||||
return await self._get_cached_result(cache_key, fetch_model, 300) # 5 minute cache
|
||||
|
||||
async def save_user_model(self, user_id: int, model: str) -> None:
|
||||
"""Save user's preferred model and update cache"""
|
||||
await self.db.user_models.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'model': model}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"model_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = model
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
# Admin and permissions management with caching
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
"""Check if the user is an admin (no caching for security)"""
|
||||
admin_id = str(user_id) # Convert to string for comparison
|
||||
from src.config.config import ADMIN_ID
|
||||
return admin_id == ADMIN_ID
|
||||
|
||||
async def is_user_whitelisted(self, user_id: int) -> bool:
|
||||
"""Check if the user is whitelisted with caching"""
|
||||
if await self.is_admin(user_id):
|
||||
return True
|
||||
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
|
||||
async def check_whitelist():
|
||||
user_data = await self.db.whitelist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
return await self._get_cached_result(cache_key, check_whitelist, 300) # 5 minute cache
|
||||
|
||||
async def add_user_to_whitelist(self, user_id: int) -> None:
|
||||
"""Add user to whitelist and update cache"""
|
||||
await self.db.whitelist.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'user_id': user_id}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = True
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
async def remove_user_from_whitelist(self, user_id: int) -> bool:
|
||||
"""Remove user from whitelist and update cache"""
|
||||
result = await self.db.whitelist.delete_one({'user_id': user_id})
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"whitelist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = False
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
return result.deleted_count > 0
|
||||
|
||||
async def is_user_blacklisted(self, user_id: int) -> bool:
|
||||
"""Check if the user is blacklisted with caching"""
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
|
||||
async def check_blacklist():
|
||||
user_data = await self.db.blacklist.find_one({'user_id': user_id})
|
||||
return user_data is not None
|
||||
|
||||
return await self._get_cached_result(cache_key, check_blacklist, 300) # 5 minute cache
|
||||
|
||||
async def add_user_to_blacklist(self, user_id: int) -> None:
|
||||
"""Add user to blacklist and update cache"""
|
||||
await self.db.blacklist.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {'user_id': user_id}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = True
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
async def remove_user_from_blacklist(self, user_id: int) -> bool:
|
||||
"""Remove user from blacklist and update cache"""
|
||||
result = await self.db.blacklist.delete_one({'user_id': user_id})
|
||||
|
||||
# Update the cache
|
||||
cache_key = f"blacklist_{user_id}"
|
||||
async with self._cache_lock:
|
||||
self._cache[cache_key] = False
|
||||
self._cache_expiry[cache_key] = datetime.now() + timedelta(seconds=300)
|
||||
|
||||
return result.deleted_count > 0
|
||||
|
||||
# Connection management and cleanup
|
||||
async def create_indexes(self):
|
||||
"""Create indexes for better query performance"""
|
||||
await self.db.user_histories.create_index("user_id")
|
||||
await self.db.user_models.create_index("user_id")
|
||||
await self.db.whitelist.create_index("user_id")
|
||||
await self.db.blacklist.create_index("user_id")
|
||||
|
||||
async def ensure_reminders_collection(self):
|
||||
"""
|
||||
Ensure the reminders collection exists and create necessary indexes
|
||||
"""
|
||||
# Create the collection if it doesn't exist
|
||||
await self.reminders_collection.create_index([("user_id", 1), ("sent", 1)])
|
||||
await self.reminders_collection.create_index([("remind_at", 1), ("sent", 1)])
|
||||
logging.info("Ensured reminders collection and indexes")
|
||||
|
||||
async def close(self):
|
||||
"""Properly close the database connection"""
|
||||
self.client.close()
|
||||
logging.info("Database connection closed")
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -1,291 +1,532 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
import io
|
||||
import logging
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Dict, Any, Optional, List
|
||||
|
||||
# Ensure matplotlib doesn't require a GUI backend
|
||||
matplotlib.use('Agg')
|
||||
|
||||
async def process_data_file(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Analyze and visualize data from CSV/Excel files.
|
||||
|
||||
Args:
|
||||
file_bytes: File content as bytes
|
||||
filename: File name
|
||||
query: User command/query
|
||||
|
||||
Returns:
|
||||
Tuple containing text summary, image bytes (if any) and metadata
|
||||
"""
|
||||
try:
|
||||
# Use thread pool to avoid blocking event loop with CPU-bound tasks
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(_process_data_file_sync, file_bytes, filename, query)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing data file: {str(e)}")
|
||||
return f"Error processing file {filename}: {str(e)}", None, None
|
||||
|
||||
def _process_data_file_sync(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
|
||||
"""Synchronous version of process_data_file to run in thread pool"""
|
||||
file_obj = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
# Read file based on format with improved error handling
|
||||
if filename.lower().endswith('.csv'):
|
||||
try:
|
||||
# Try multiple encodings and separator detection
|
||||
df = pd.read_csv(file_obj, encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
# Reset file pointer and try different encoding
|
||||
file_obj.seek(0)
|
||||
df = pd.read_csv(file_obj, encoding='latin1')
|
||||
elif filename.lower().endswith(('.xlsx', '.xls')):
|
||||
try:
|
||||
df = pd.read_excel(file_obj)
|
||||
except Exception as excel_err:
|
||||
logging.error(f"Excel read error: {excel_err}")
|
||||
# Try with engine specification
|
||||
file_obj.seek(0)
|
||||
df = pd.read_excel(file_obj, engine='openpyxl')
|
||||
else:
|
||||
return "Unsupported file format. Please use CSV or Excel.", None, None
|
||||
|
||||
if df.empty:
|
||||
return "The file does not contain any data.", None, None
|
||||
|
||||
# Clean column names
|
||||
df.columns = [str(col).strip() for col in df.columns]
|
||||
|
||||
# Create metadata for dataframe
|
||||
rows = len(df)
|
||||
columns = len(df.columns)
|
||||
column_names = list(df.columns)
|
||||
|
||||
# Data preprocessing for better analysis
|
||||
# Convert potential date columns
|
||||
for col in df.columns:
|
||||
# Try to convert columns that might be dates but are stored as strings
|
||||
if df[col].dtype == 'object':
|
||||
try:
|
||||
# Check if the column might contain dates
|
||||
sample = df[col].dropna().iloc[0] if not df[col].dropna().empty else None
|
||||
if sample and isinstance(sample, str) and ('/' in sample or '-' in sample):
|
||||
df[col] = pd.to_datetime(df[col], errors='ignore')
|
||||
except Exception:
|
||||
pass # Skip if conversion fails
|
||||
|
||||
# Format data for analysis
|
||||
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
||||
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
||||
date_cols = df.select_dtypes(include=['datetime']).columns.tolist()
|
||||
|
||||
# Create basic information about the data
|
||||
summary = f"Data analysis for {filename}:\n"
|
||||
summary += f"- Rows: {rows}\n"
|
||||
summary += f"- Columns: {columns}\n"
|
||||
summary += f"- Column names: {', '.join(column_names)}\n\n"
|
||||
|
||||
# Basic descriptive statistics
|
||||
if len(numeric_cols) > 0:
|
||||
summary += "Statistics for numeric data:\n"
|
||||
desc_stats = df[numeric_cols].describe().round(2)
|
||||
summary += desc_stats.to_string() + "\n\n"
|
||||
|
||||
# Statistics for categorical data
|
||||
if len(categorical_cols) > 0:
|
||||
summary += "Value distribution for categorical columns:\n"
|
||||
for col in categorical_cols[:3]: # Limit displayed columns
|
||||
value_counts = df[col].value_counts().head(5)
|
||||
summary += f"{col}: {dict(value_counts)}\n"
|
||||
if len(categorical_cols) > 3:
|
||||
summary += f"...and {len(categorical_cols) - 3} other categorical columns.\n"
|
||||
summary += "\n"
|
||||
|
||||
# Determine if a chart should be created
|
||||
chart_keywords = ["chart", "graph", "plot", "visualization", "visualize",
|
||||
"histogram", "bar chart", "line chart", "pie chart", "scatter"]
|
||||
create_chart = any(keyword in query.lower() for keyword in chart_keywords)
|
||||
|
||||
# Metadata to return
|
||||
metadata = {
|
||||
"filename": filename,
|
||||
"rows": rows,
|
||||
"columns": columns,
|
||||
"column_names": column_names,
|
||||
"numeric_columns": numeric_cols,
|
||||
"categorical_columns": categorical_cols,
|
||||
"date_columns": date_cols
|
||||
}
|
||||
|
||||
# Create chart if requested or by default
|
||||
chart_image = None
|
||||
if (create_chart or len(query) < 10) and len(numeric_cols) > 0: # Default to chart for short queries
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
# Better chart styling
|
||||
plt.style.use('seaborn-v0_8')
|
||||
|
||||
# Determine chart type based on keywords in the query
|
||||
if any(keyword in query.lower() for keyword in ["pie", "circle"]):
|
||||
# Pie chart - works best with categorical data
|
||||
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
||||
cat_col = categorical_cols[0]
|
||||
num_col = numeric_cols[0]
|
||||
# Better handling of pie chart data
|
||||
top_categories = df.groupby(cat_col)[num_col].sum().nlargest(5)
|
||||
# Add "Other" category if there are more than 5 categories
|
||||
if len(df[cat_col].unique()) > 5:
|
||||
other_sum = df.groupby(cat_col)[num_col].sum().sum() - top_categories.sum()
|
||||
if other_sum > 0:
|
||||
top_categories["Other"] = other_sum
|
||||
|
||||
plt.figure(figsize=(10, 7))
|
||||
plt.pie(top_categories, labels=top_categories.index, autopct='%1.1f%%',
|
||||
shadow=True, startangle=90)
|
||||
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
|
||||
plt.title(f"Pie Chart: {num_col} by {cat_col}")
|
||||
elif any(keyword in query.lower() for keyword in ["bar", "column"]):
|
||||
# Bar chart
|
||||
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
||||
cat_col = categorical_cols[0]
|
||||
num_col = numeric_cols[0]
|
||||
# Sort for better visualization
|
||||
top_values = df.groupby(cat_col)[num_col].sum().nlargest(10)
|
||||
top_values.plot.bar(color='skyblue', edgecolor='black')
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
plt.title(f"Bar Chart: {num_col} by {cat_col} (Top 10)")
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.tight_layout()
|
||||
else:
|
||||
df[numeric_cols[0]].nlargest(10).plot.bar(color='skyblue', edgecolor='black')
|
||||
plt.title(f"Top 10 highest values of {numeric_cols[0]}")
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
plt.tight_layout()
|
||||
elif any(keyword in query.lower() for keyword in ["scatter", "dispersion"]):
|
||||
# Scatter plot
|
||||
if len(numeric_cols) >= 2:
|
||||
plt.figure(figsize=(9, 6))
|
||||
plt.scatter(df[numeric_cols[0]], df[numeric_cols[1]], alpha=0.6,
|
||||
edgecolor='w', s=50)
|
||||
plt.xlabel(numeric_cols[0])
|
||||
plt.ylabel(numeric_cols[1])
|
||||
plt.title(f"Scatter plot: {numeric_cols[0]} vs {numeric_cols[1]}")
|
||||
# Add trend line if there seems to be a correlation
|
||||
if abs(df[numeric_cols[0]].corr(df[numeric_cols[1]])) > 0.3:
|
||||
z = np.polyfit(df[numeric_cols[0]].dropna(), df[numeric_cols[1]].dropna(), 1)
|
||||
p = np.poly1d(z)
|
||||
plt.plot(df[numeric_cols[0]].sort_values(),
|
||||
p(df[numeric_cols[0]].sort_values()),
|
||||
"r--", linewidth=1)
|
||||
plt.grid(True, alpha=0.3)
|
||||
elif any(keyword in query.lower() for keyword in ["histogram", "hist", "distribution"]):
|
||||
# Histogram with better binning
|
||||
plt.figure(figsize=(10, 6))
|
||||
# Calculate optimal number of bins using Sturges' rule
|
||||
data = df[numeric_cols[0]].dropna()
|
||||
bins = int(np.ceil(np.log2(len(data))) + 1) if len(data) > 0 else 10
|
||||
|
||||
plt.hist(data, bins=min(bins, 30), color='skyblue', edgecolor='black')
|
||||
plt.title(f"Distribution of {numeric_cols[0]}")
|
||||
plt.xlabel(numeric_cols[0])
|
||||
plt.ylabel("Frequency")
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
elif len(date_cols) > 0 and len(numeric_cols) > 0:
|
||||
# Time series chart if we have dates and numeric data
|
||||
plt.figure(figsize=(12, 6))
|
||||
date_col = date_cols[0]
|
||||
num_col = numeric_cols[0]
|
||||
|
||||
# Sort by date and plot
|
||||
temp_df = df[[date_col, num_col]].dropna().sort_values(date_col)
|
||||
plt.plot(temp_df[date_col], temp_df[num_col], marker='o', markersize=3,
|
||||
linestyle='-', linewidth=1)
|
||||
plt.title(f"{num_col} over time")
|
||||
plt.xticks(rotation=45)
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
else:
|
||||
# Default: line chart if we have multiple numeric values
|
||||
if len(numeric_cols) > 0:
|
||||
plt.figure(figsize=(10, 6))
|
||||
df[numeric_cols[0]].plot(color='#1f77b4', alpha=0.8)
|
||||
plt.title(f"Line chart for {numeric_cols[0]}")
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
# Save chart to bytes buffer
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
||||
buf.seek(0)
|
||||
chart_image = buf.read()
|
||||
plt.close()
|
||||
|
||||
# Create a timestamp for the chart file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
chart_filename = f"chart_{timestamp}.png"
|
||||
|
||||
# Save chart to temporary file
|
||||
chart_dir = os.path.join(os.getcwd(), "temp_charts")
|
||||
if not os.path.exists(chart_dir):
|
||||
os.makedirs(chart_dir)
|
||||
|
||||
chart_path = os.path.join(chart_dir, chart_filename)
|
||||
with open(chart_path, "wb") as f:
|
||||
f.write(chart_image)
|
||||
|
||||
summary += f"Chart created based on the data with improved visualization."
|
||||
|
||||
# Add chart filename to metadata
|
||||
metadata["chart_filename"] = chart_filename
|
||||
metadata["chart_path"] = chart_path
|
||||
metadata["chart_created_at"] = datetime.now().timestamp()
|
||||
|
||||
return summary, chart_image, metadata
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _process_data_file_sync: {str(e)}")
|
||||
return f"Could not analyze file {filename}. Error: {str(e)}", None, None
|
||||
|
||||
async def cleanup_old_charts(max_age_hours=1):
|
||||
"""
|
||||
Clean up chart images older than the specified time
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age in hours before deleting charts
|
||||
"""
|
||||
try:
|
||||
chart_dir = os.path.join(os.getcwd(), "temp_charts")
|
||||
if not os.path.exists(chart_dir):
|
||||
return
|
||||
|
||||
now = datetime.now().timestamp()
|
||||
deleted_count = 0
|
||||
|
||||
for filename in os.listdir(chart_dir):
|
||||
if filename.startswith("chart_") and filename.endswith(".png"):
|
||||
file_path = os.path.join(chart_dir, filename)
|
||||
file_modified_time = os.path.getmtime(file_path)
|
||||
|
||||
# If file is older than max_age_hours
|
||||
if now - file_modified_time > (max_age_hours * 3600):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting chart file {filename}: {str(e)}")
|
||||
|
||||
if deleted_count > 0:
|
||||
logging.info(f"Cleaned up {deleted_count} chart files older than {max_age_hours} hours")
|
||||
except Exception as e:
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
import io
|
||||
import logging
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Dict, Any, Optional, List
|
||||
|
||||
# Ensure matplotlib doesn't require a GUI backend
|
||||
matplotlib.use('Agg')
|
||||
|
||||
# Set global matplotlib parameters for better readability
|
||||
plt.rcParams.update({
|
||||
'font.size': 12,
|
||||
'axes.titlesize': 16,
|
||||
'axes.labelsize': 14,
|
||||
'xtick.labelsize': 12,
|
||||
'ytick.labelsize': 12,
|
||||
'legend.fontsize': 12,
|
||||
'figure.titlesize': 18
|
||||
})
|
||||
|
||||
async def process_data_file(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Analyze and visualize data from CSV/Excel files.
|
||||
|
||||
Args:
|
||||
file_bytes: File content as bytes
|
||||
filename: File name
|
||||
query: User command/query
|
||||
|
||||
Returns:
|
||||
Tuple containing text summary, image bytes (if any) and metadata
|
||||
"""
|
||||
try:
|
||||
# Use thread pool to avoid blocking event loop with CPU-bound tasks
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(_process_data_file_sync, file_bytes, filename, query)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing data file: {str(e)}")
|
||||
return f"Error processing file {filename}: {str(e)}", None, None
|
||||
|
||||
def _process_data_file_sync(file_bytes: bytes, filename: str, query: str) -> Tuple[str, Optional[bytes], Optional[Dict[str, Any]]]:
|
||||
"""Synchronous version of process_data_file to run in thread pool"""
|
||||
file_obj = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
# Read file based on format with improved error handling
|
||||
if filename.lower().endswith('.csv'):
|
||||
try:
|
||||
# Try multiple encodings and separator detection
|
||||
df = pd.read_csv(file_obj, encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
# Reset file pointer and try different encoding
|
||||
file_obj.seek(0)
|
||||
df = pd.read_csv(file_obj, encoding='latin1')
|
||||
elif filename.lower().endswith(('.xlsx', '.xls')):
|
||||
try:
|
||||
df = pd.read_excel(file_obj)
|
||||
except Exception as excel_err:
|
||||
logging.error(f"Excel read error: {excel_err}")
|
||||
# Try with engine specification
|
||||
file_obj.seek(0)
|
||||
df = pd.read_excel(file_obj, engine='openpyxl')
|
||||
else:
|
||||
return "Unsupported file format. Please use CSV or Excel.", None, None
|
||||
|
||||
if df.empty:
|
||||
return "The file does not contain any data.", None, None
|
||||
|
||||
# Clean column names
|
||||
df.columns = [str(col).strip() for col in df.columns]
|
||||
|
||||
# Create metadata for dataframe
|
||||
rows = len(df)
|
||||
columns = len(df.columns)
|
||||
column_names = list(df.columns)
|
||||
|
||||
# Data preprocessing for better analysis
|
||||
# Convert potential date columns
|
||||
for col in df.columns:
|
||||
# Try to convert columns that might be dates but are stored as strings
|
||||
if df[col].dtype == 'object':
|
||||
try:
|
||||
# Check if the column might contain dates
|
||||
sample = df[col].dropna().iloc[0] if not df[col].dropna().empty else None
|
||||
if sample and isinstance(sample, str) and ('/' in sample or '-' in sample):
|
||||
df[col] = pd.to_datetime(df[col], errors='ignore')
|
||||
except Exception:
|
||||
pass # Skip if conversion fails
|
||||
|
||||
# Format data for analysis
|
||||
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
||||
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
||||
date_cols = df.select_dtypes(include=['datetime']).columns.tolist()
|
||||
|
||||
# Create basic information about the data
|
||||
summary = f"Data analysis for {filename}:\n"
|
||||
summary += f"- Rows: {rows}\n"
|
||||
summary += f"- Columns: {columns}\n"
|
||||
summary += f"- Column names: {', '.join(column_names)}\n\n"
|
||||
|
||||
# Basic descriptive statistics
|
||||
if len(numeric_cols) > 0:
|
||||
summary += "Statistics for numeric data:\n"
|
||||
desc_stats = df[numeric_cols].describe().round(2)
|
||||
summary += desc_stats.to_string() + "\n\n"
|
||||
|
||||
# Statistics for categorical data
|
||||
if len(categorical_cols) > 0:
|
||||
summary += "Value distribution for categorical columns:\n"
|
||||
for col in categorical_cols[:3]: # Limit displayed columns
|
||||
value_counts = df[col].value_counts().head(5)
|
||||
summary += f"{col}: {dict(value_counts)}\n"
|
||||
if len(categorical_cols) > 3:
|
||||
summary += f"...and {len(categorical_cols) - 3} other categorical columns.\n"
|
||||
summary += "\n"
|
||||
|
||||
# Determine if a chart should be created
|
||||
chart_keywords = ["chart", "graph", "plot", "visualization", "visualize",
|
||||
"histogram", "bar chart", "line chart", "pie chart", "scatter"]
|
||||
create_chart = any(keyword in query.lower() for keyword in chart_keywords)
|
||||
|
||||
# Metadata to return
|
||||
metadata = {
|
||||
"filename": filename,
|
||||
"rows": rows,
|
||||
"columns": columns,
|
||||
"column_names": column_names,
|
||||
"numeric_columns": numeric_cols,
|
||||
"categorical_columns": categorical_cols,
|
||||
"date_columns": date_cols
|
||||
}
|
||||
|
||||
# Create chart if requested or by default
|
||||
chart_image = None
|
||||
if (create_chart or len(query) < 10) and len(numeric_cols) > 0: # Default to chart for short queries
|
||||
# Use a better visual style
|
||||
plt.style.use('ggplot')
|
||||
|
||||
# Determine chart type based on keywords in the query
|
||||
if any(keyword in query.lower() for keyword in ["pie", "circle"]):
|
||||
# Pie chart - works best with categorical data
|
||||
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
||||
cat_col = categorical_cols[0]
|
||||
num_col = numeric_cols[0]
|
||||
|
||||
# Better handling of pie chart data - limit to top 6 categories max
|
||||
top_categories = df.groupby(cat_col)[num_col].sum().nlargest(6)
|
||||
|
||||
# Add "Other" category if there are more than 6 categories
|
||||
if len(df[cat_col].unique()) > 6:
|
||||
other_sum = df.groupby(cat_col)[num_col].sum().sum() - top_categories.sum()
|
||||
if other_sum > 0:
|
||||
top_categories["Other"] = other_sum
|
||||
|
||||
# Larger figure size for better readability
|
||||
plt.figure(figsize=(12, 9))
|
||||
|
||||
# Enhanced pie chart
|
||||
wedges, texts, autotexts = plt.pie(
|
||||
top_categories,
|
||||
labels=None, # We'll add a legend instead of cluttering the pie
|
||||
autopct='%1.1f%%',
|
||||
shadow=False,
|
||||
startangle=90,
|
||||
explode=[0.05] * len(top_categories), # Slight separation for visibility
|
||||
textprops={'color': 'white', 'weight': 'bold', 'fontsize': 14},
|
||||
wedgeprops={'width': 0.6, 'edgecolor': 'white', 'linewidth': 2}
|
||||
)
|
||||
|
||||
# Make the percentage labels more visible
|
||||
for autotext in autotexts:
|
||||
autotext.set_fontsize(12)
|
||||
autotext.set_weight('bold')
|
||||
|
||||
# Add a legend outside the pie for better readability
|
||||
plt.legend(
|
||||
wedges,
|
||||
top_categories.index,
|
||||
title=cat_col,
|
||||
loc="center left",
|
||||
bbox_to_anchor=(1, 0, 0.5, 1)
|
||||
)
|
||||
|
||||
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
|
||||
plt.title(f"Distribution of {num_col} by {cat_col}", pad=20)
|
||||
plt.tight_layout()
|
||||
|
||||
elif any(keyword in query.lower() for keyword in ["bar", "column"]):
|
||||
# Bar chart
|
||||
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
||||
cat_col = categorical_cols[0]
|
||||
num_col = numeric_cols[0]
|
||||
|
||||
# Determine optimal figure size based on number of categories
|
||||
category_count = min(10, len(df[cat_col].unique()))
|
||||
fig_height = max(6, category_count * 0.4 + 4) # Dynamic height based on categories
|
||||
plt.figure(figsize=(12, fig_height))
|
||||
|
||||
# Sort for better visualization
|
||||
top_values = df.groupby(cat_col)[num_col].sum().nlargest(10)
|
||||
|
||||
# Use a horizontal bar chart for better label readability with many categories
|
||||
if len(top_values) > 5:
|
||||
ax = top_values.plot.barh(
|
||||
color='#5975a4',
|
||||
edgecolor='#344e7a',
|
||||
linewidth=1.5
|
||||
)
|
||||
# Add data labels at the end of each bar
|
||||
for i, v in enumerate(top_values):
|
||||
ax.text(v * 1.01, i, f'{v:,.1f}', va='center', fontweight='bold')
|
||||
plt.xlabel(num_col)
|
||||
plt.ylabel(cat_col)
|
||||
else:
|
||||
# For fewer categories, use vertical bars
|
||||
ax = top_values.plot.bar(
|
||||
color='#5975a4',
|
||||
edgecolor='#344e7a',
|
||||
linewidth=1.5
|
||||
)
|
||||
# Add data labels on top of each bar
|
||||
for i, v in enumerate(top_values):
|
||||
ax.text(i, v * 1.01, f'{v:,.1f}', ha='center', fontweight='bold')
|
||||
plt.ylabel(num_col)
|
||||
plt.xlabel(cat_col)
|
||||
plt.xticks(rotation=30, ha='right')
|
||||
|
||||
plt.grid(axis='both', linestyle='--', alpha=0.7)
|
||||
plt.title(f"{num_col} by {cat_col} (Top {len(top_values)})", pad=20)
|
||||
plt.tight_layout(pad=2)
|
||||
else:
|
||||
# Improved bar chart for numeric data only
|
||||
plt.figure(figsize=(12, 7))
|
||||
top_values = df[numeric_cols[0]].nlargest(10)
|
||||
ax = top_values.plot.bar(
|
||||
color='#5975a4',
|
||||
edgecolor='#344e7a',
|
||||
linewidth=1.5
|
||||
)
|
||||
|
||||
# Add value labels on top of bars
|
||||
for i, v in enumerate(top_values):
|
||||
ax.text(i, v * 1.01, f'{v:,.1f}', ha='center', fontweight='bold')
|
||||
|
||||
plt.title(f"Top {len(top_values)} highest values of {numeric_cols[0]}", pad=20)
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
plt.xticks(rotation=30, ha='right')
|
||||
plt.tight_layout(pad=2)
|
||||
|
||||
elif any(keyword in query.lower() for keyword in ["scatter", "dispersion"]):
|
||||
# Enhanced scatter plot
|
||||
if len(numeric_cols) >= 2:
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# If we have a categorical column, use it for coloring
|
||||
if len(categorical_cols) > 0:
|
||||
# Limit to a reasonable number of categories for coloring
|
||||
cat_col = categorical_cols[0]
|
||||
top_cats = df[cat_col].value_counts().nlargest(8).index.tolist()
|
||||
|
||||
# Create a color map
|
||||
colormap = plt.cm.get_cmap('tab10', len(top_cats))
|
||||
|
||||
# Plot each category with different color
|
||||
for i, category in enumerate(top_cats):
|
||||
subset = df[df[cat_col] == category]
|
||||
plt.scatter(
|
||||
subset[numeric_cols[0]],
|
||||
subset[numeric_cols[1]],
|
||||
alpha=0.7,
|
||||
edgecolor='w',
|
||||
s=80,
|
||||
label=category,
|
||||
color=colormap(i)
|
||||
)
|
||||
plt.legend(title=cat_col, loc='best')
|
||||
else:
|
||||
# Regular scatter plot with improved visibility
|
||||
scatter = plt.scatter(
|
||||
df[numeric_cols[0]],
|
||||
df[numeric_cols[1]],
|
||||
alpha=0.7,
|
||||
edgecolor='w',
|
||||
s=80,
|
||||
c=df[numeric_cols[0]], # Color by x-axis value for visual enhancement
|
||||
cmap='viridis'
|
||||
)
|
||||
plt.colorbar(scatter, label=numeric_cols[0])
|
||||
|
||||
plt.xlabel(numeric_cols[0], fontweight='bold')
|
||||
plt.ylabel(numeric_cols[1], fontweight='bold')
|
||||
plt.title(f"Scatter plot: {numeric_cols[0]} vs {numeric_cols[1]}", pad=20)
|
||||
|
||||
# Add trend line if there seems to be a correlation
|
||||
if abs(df[numeric_cols[0]].corr(df[numeric_cols[1]])) > 0.3:
|
||||
z = np.polyfit(df[numeric_cols[0]].dropna(), df[numeric_cols[1]].dropna(), 1)
|
||||
p = np.poly1d(z)
|
||||
plt.plot(
|
||||
df[numeric_cols[0]].sort_values(),
|
||||
p(df[numeric_cols[0]].sort_values()),
|
||||
"r--",
|
||||
linewidth=2,
|
||||
label=f"Trend line (r={df[numeric_cols[0]].corr(df[numeric_cols[1]]):.2f})"
|
||||
)
|
||||
plt.legend()
|
||||
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout(pad=2)
|
||||
|
||||
elif any(keyword in query.lower() for keyword in ["histogram", "hist", "distribution"]):
|
||||
# Enhanced histogram
|
||||
plt.figure(figsize=(12, 7))
|
||||
# Calculate optimal number of bins
|
||||
data = df[numeric_cols[0]].dropna()
|
||||
|
||||
# Better bin calculation based on data distribution
|
||||
iqr = np.percentile(data, 75) - np.percentile(data, 25)
|
||||
bin_width = 2 * iqr / (len(data) ** (1/3)) # Freedman-Diaconis rule
|
||||
if bin_width > 0:
|
||||
bins = int((data.max() - data.min()) / bin_width)
|
||||
bins = min(max(bins, 10), 50) # Between 10 and 50 bins
|
||||
else:
|
||||
bins = 15 # Default if calculation fails
|
||||
|
||||
# Plot histogram with KDE
|
||||
ax = plt.subplot(111)
|
||||
n, bins_arr, patches = ax.hist(
|
||||
data,
|
||||
bins=bins,
|
||||
alpha=0.7,
|
||||
color='#5975a4',
|
||||
edgecolor='#344e7a',
|
||||
linewidth=1.5,
|
||||
density=True # Normalize for KDE overlay
|
||||
)
|
||||
|
||||
# Add KDE line for smoother visualization
|
||||
from scipy import stats
|
||||
kde_x = np.linspace(data.min(), data.max(), 1000)
|
||||
kde = stats.gaussian_kde(data)
|
||||
ax.plot(kde_x, kde(kde_x), 'r-', linewidth=2, label='Density')
|
||||
|
||||
# Add vertical lines for key statistics
|
||||
mean_val = data.mean()
|
||||
median_val = data.median()
|
||||
ax.axvline(mean_val, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}')
|
||||
ax.axvline(median_val, color='orange', linestyle='-.', linewidth=2, label=f'Median: {median_val:.2f}')
|
||||
|
||||
plt.title(f"Distribution of {numeric_cols[0]}", pad=20)
|
||||
plt.xlabel(numeric_cols[0], fontweight='bold')
|
||||
plt.ylabel("Frequency", fontweight='bold')
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
plt.legend()
|
||||
plt.tight_layout(pad=2)
|
||||
|
||||
elif len(date_cols) > 0 and len(numeric_cols) > 0:
|
||||
# Enhanced time series chart
|
||||
plt.figure(figsize=(14, 8))
|
||||
date_col = date_cols[0]
|
||||
num_col = numeric_cols[0]
|
||||
|
||||
# Sort by date and plot
|
||||
temp_df = df[[date_col, num_col]].dropna().sort_values(date_col)
|
||||
|
||||
# Limit number of points for readability if too many
|
||||
if len(temp_df) > 100:
|
||||
# Resample to reduce point density
|
||||
temp_df = temp_df.set_index(date_col)
|
||||
# Determine appropriate frequency based on date range
|
||||
date_range = (temp_df.index.max() - temp_df.index.min()).days
|
||||
if date_range > 365*2: # More than 2 years
|
||||
freq = 'M' # Monthly
|
||||
elif date_range > 90: # More than 3 months
|
||||
freq = 'W' # Weekly
|
||||
else:
|
||||
freq = 'D' # Daily
|
||||
|
||||
temp_df = temp_df.resample(freq).mean().reset_index()
|
||||
|
||||
# Plot with enhanced styling
|
||||
plt.plot(
|
||||
temp_df[date_col],
|
||||
temp_df[num_col],
|
||||
marker='o',
|
||||
markersize=6,
|
||||
markerfacecolor='white',
|
||||
markeredgecolor='#5975a4',
|
||||
markeredgewidth=1.5,
|
||||
linestyle='-',
|
||||
linewidth=2,
|
||||
color='#5975a4'
|
||||
)
|
||||
|
||||
plt.title(f"{num_col} over time", pad=20)
|
||||
plt.xlabel("Date", fontweight='bold')
|
||||
plt.ylabel(num_col, fontweight='bold')
|
||||
|
||||
# Format x-axis date labels better
|
||||
plt.gcf().autofmt_xdate()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# Add trend line
|
||||
try:
|
||||
x = np.arange(len(temp_df))
|
||||
z = np.polyfit(x, temp_df[num_col], 1)
|
||||
p = np.poly1d(z)
|
||||
plt.plot(temp_df[date_col], p(x), "r--", linewidth=2,
|
||||
label=f"Trend line (slope: {z[0]:.4f})")
|
||||
plt.legend()
|
||||
except Exception:
|
||||
pass # Skip trend line if it fails
|
||||
|
||||
plt.tight_layout(pad=2)
|
||||
|
||||
else:
|
||||
# Default: enhanced line chart for numeric data
|
||||
if len(numeric_cols) > 0:
|
||||
plt.figure(figsize=(14, 8))
|
||||
|
||||
# Get the data
|
||||
data = df[numeric_cols[0]]
|
||||
|
||||
# If too many points, bin or resample
|
||||
if len(data) > 100:
|
||||
# Use rolling average for smoother line
|
||||
window = max(5, len(data) // 50) # Adaptive window size
|
||||
rolling_data = data.rolling(window=window, center=True).mean()
|
||||
|
||||
# Plot both original and smoothed data
|
||||
plt.plot(data.index, data, 'o', markersize=4, alpha=0.4, label='Original data')
|
||||
plt.plot(
|
||||
rolling_data.index,
|
||||
rolling_data,
|
||||
linewidth=3,
|
||||
color='#d62728',
|
||||
label=f'Moving average (window={window})'
|
||||
)
|
||||
plt.legend()
|
||||
else:
|
||||
# For fewer points, use a more detailed visualization
|
||||
plt.plot(
|
||||
data.index,
|
||||
data,
|
||||
marker='o',
|
||||
markersize=6,
|
||||
markerfacecolor='white',
|
||||
markeredgecolor='#5975a4',
|
||||
markeredgewidth=1.5,
|
||||
linestyle='-',
|
||||
linewidth=2,
|
||||
color='#5975a4'
|
||||
)
|
||||
|
||||
plt.title(f"Line chart for {numeric_cols[0]}", pad=20)
|
||||
plt.ylabel(numeric_cols[0], fontweight='bold')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout(pad=2)
|
||||
|
||||
# Save chart to bytes buffer with higher DPI for better quality
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
|
||||
buf.seek(0)
|
||||
chart_image = buf.read()
|
||||
plt.close()
|
||||
|
||||
# Create a timestamp for the chart file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
chart_filename = f"chart_{timestamp}.png"
|
||||
|
||||
# Save chart to temporary file
|
||||
chart_dir = os.path.join(os.getcwd(), "temp_charts")
|
||||
if not os.path.exists(chart_dir):
|
||||
os.makedirs(chart_dir)
|
||||
|
||||
chart_path = os.path.join(chart_dir, chart_filename)
|
||||
with open(chart_path, "wb") as f:
|
||||
f.write(chart_image)
|
||||
|
||||
summary += f"Chart created based on the data with improved visualization."
|
||||
|
||||
# Add chart filename to metadata
|
||||
metadata["chart_filename"] = chart_filename
|
||||
metadata["chart_path"] = chart_path
|
||||
metadata["chart_created_at"] = datetime.now().timestamp()
|
||||
|
||||
return summary, chart_image, metadata
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _process_data_file_sync: {str(e)}")
|
||||
return f"Could not analyze file {filename}. Error: {str(e)}", None, None
|
||||
|
||||
async def cleanup_old_charts(max_age_hours=1):
|
||||
"""
|
||||
Clean up chart images older than the specified time
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age in hours before deleting charts
|
||||
"""
|
||||
try:
|
||||
chart_dir = os.path.join(os.getcwd(), "temp_charts")
|
||||
if not os.path.exists(chart_dir):
|
||||
return
|
||||
|
||||
now = datetime.now().timestamp()
|
||||
deleted_count = 0
|
||||
|
||||
for filename in os.listdir(chart_dir):
|
||||
if filename.startswith("chart_") and filename.endswith(".png"):
|
||||
file_path = os.path.join(chart_dir, filename)
|
||||
file_modified_time = os.path.getmtime(file_path)
|
||||
|
||||
# If file is older than max_age_hours
|
||||
if now - file_modified_time > (max_age_hours * 3600):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting chart file {filename}: {str(e)}")
|
||||
|
||||
if deleted_count > 0:
|
||||
logging.info(f"Cleaned up {deleted_count} chart files older than {max_age_hours} hours")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in cleanup_old_charts: {str(e)}")
|
||||
@@ -1,298 +1,312 @@
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
|
||||
def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Returns the list of tools available to the model.
|
||||
|
||||
Returns:
|
||||
List of tool objects
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "google_search",
|
||||
"description": "Search the web for current information. Use this when you need to answer questions about current events or recent information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "The number of search results to return (1-10)",
|
||||
"default": 3,
|
||||
"minimum": 1,
|
||||
"maximum": 10
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scrape_webpage",
|
||||
"description": "Scrape and extract content from a webpage. Use this to get the content of a specific webpage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL of the webpage to scrape"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"description": "Run code in Python or other supported languages. Use this to execute code, perform calculations, generate plots, and analyze data.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to execute"
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "The programming language (default: python)",
|
||||
"default": "python",
|
||||
"enum": ["python", "javascript", "bash", "c++"]
|
||||
},
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "Optional input data for the code"
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": "Generate images based on text prompts. Use this when the user asks for an image to be created.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The prompt describing the image to generate"
|
||||
},
|
||||
"num_images": {
|
||||
"type": "integer",
|
||||
"description": "The number of images to generate (1-4)",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
"maximum": 4
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "analyze_data",
|
||||
"description": "Analyze data files (CSV, Excel) and create visualizations. Use this when users need to analyze data, create charts, or extract insights from their data files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query describing what analysis to perform on the data, including what type of chart to create (e.g. 'Create a histogram of ages', 'Show a pie chart of categories', 'Calculate average by group')"
|
||||
},
|
||||
"visualization_type": {
|
||||
"type": "string",
|
||||
"description": "The type of visualization to create",
|
||||
"enum": ["bar", "line", "pie", "scatter", "histogram", "auto"],
|
||||
"default": "auto"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "set_reminder",
|
||||
"description": "Set a reminder for the user. Use this when a user wants to be reminded about something at a specific time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content of the reminder"
|
||||
},
|
||||
"time": {
|
||||
"type": "string",
|
||||
"description": "The time for the reminder. Can be relative (e.g., '30m', '2h', '1d') or specific times ('tomorrow', '3:00pm', etc.)"
|
||||
}
|
||||
},
|
||||
"required": ["content", "time"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_reminders",
|
||||
"description": "Get a list of upcoming reminders for the user. Use this when user asks about their reminders.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
async def process_tool_calls(client, response, messages, tool_functions) -> Tuple[bool, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Process and execute tool calls from the OpenAI API response.
|
||||
|
||||
Args:
|
||||
client: OpenAI client
|
||||
response: API response containing tool calls
|
||||
messages: The current chat messages
|
||||
tool_functions: Dictionary mapping tool names to handler functions
|
||||
|
||||
Returns:
|
||||
Tuple containing (processed_any_tools, updated_messages)
|
||||
"""
|
||||
processed_any = False
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
|
||||
# Create a copy of the messages to update
|
||||
updated_messages = messages.copy()
|
||||
|
||||
# Add the assistant message with the tool calls
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
} for tc in tool_calls
|
||||
] if tool_calls else None
|
||||
})
|
||||
|
||||
# Process each tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
if function_name in tool_functions:
|
||||
# Parse the JSON arguments
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Invalid JSON in tool call arguments: {tool_call.function.arguments}")
|
||||
function_args = {}
|
||||
|
||||
# Call the appropriate function
|
||||
try:
|
||||
function_response = await tool_functions[function_name](function_args)
|
||||
|
||||
# Add the tool output back to messages
|
||||
updated_messages.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": str(function_response)
|
||||
})
|
||||
|
||||
processed_any = True
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error executing {function_name}: {str(e)}"
|
||||
logging.error(error_message)
|
||||
|
||||
# Add the error as tool output
|
||||
updated_messages.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": error_message
|
||||
})
|
||||
|
||||
processed_any = True
|
||||
|
||||
return processed_any, updated_messages
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Estimate token count using a simple approximation."""
|
||||
# Rough estimate: 1 word ≈ 1.3 tokens
|
||||
return int(len(text.split()) * 1.3)
|
||||
|
||||
def trim_content_to_token_limit(content: str, max_tokens: int = 8096) -> str:
|
||||
"""Trim content to stay within token limit while preserving the most recent content."""
|
||||
current_tokens = count_tokens(content)
|
||||
if current_tokens <= max_tokens:
|
||||
return content
|
||||
|
||||
# Split into lines and start removing from the beginning until under limit
|
||||
lines = content.split('\n')
|
||||
while lines and count_tokens('\n'.join(lines)) > max_tokens:
|
||||
lines.pop(0)
|
||||
|
||||
if not lines: # If still too long, take the last part
|
||||
text = content
|
||||
while count_tokens(text) > max_tokens:
|
||||
text = text[text.find('\n', 1000):]
|
||||
return text
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def prepare_messages_for_api(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare message history for the OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Returns:
|
||||
Prepared messages for API
|
||||
"""
|
||||
prepared_messages = []
|
||||
|
||||
for msg in messages:
|
||||
# Create a copy of the message to avoid modifying the original
|
||||
processed_msg = dict(msg)
|
||||
|
||||
# Handle image URLs with timestamps in content
|
||||
if isinstance(processed_msg.get('content'), list):
|
||||
# Filter out images that have a timestamp (they're already handled specially)
|
||||
new_content = []
|
||||
for item in processed_msg['content']:
|
||||
if item.get('type') == 'image_url' and 'timestamp' in item:
|
||||
# Remove timestamp from API calls
|
||||
new_item = dict(item)
|
||||
if 'timestamp' in new_item:
|
||||
del new_item['timestamp']
|
||||
new_content.append(new_item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
processed_msg['content'] = new_content
|
||||
|
||||
prepared_messages.append(processed_msg)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
|
||||
def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Returns the list of tools available to the model.
|
||||
|
||||
Returns:
|
||||
List of tool objects
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "google_search",
|
||||
"description": "Search the web for current information. Use this when you need to answer questions about current events or recent information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "The number of search results to return (1-10)",
|
||||
"default": 3,
|
||||
"minimum": 1,
|
||||
"maximum": 10
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scrape_webpage",
|
||||
"description": "Scrape and extract content from a webpage. Use this to get the content of a specific webpage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL of the webpage to scrape"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"description": "Run code in Python or other supported languages. Use this to execute code, perform calculations, generate plots, and analyze data.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to execute"
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "The programming language (default: python)",
|
||||
"default": "python",
|
||||
"enum": ["python", "javascript", "bash", "c++"]
|
||||
},
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "Optional input data for the code"
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": "Generate images based on text prompts. Use this when the user asks for an image to be created.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The prompt describing the image to generate"
|
||||
},
|
||||
"num_images": {
|
||||
"type": "integer",
|
||||
"description": "The number of images to generate (1-4)",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
"maximum": 4
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "analyze_data",
|
||||
"description": "Analyze data files (CSV, Excel) and create visualizations. Use this when users need to analyze data, create charts, or extract insights from their data files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query describing what analysis to perform on the data, including what type of chart to create (e.g. 'Create a histogram of ages', 'Show a pie chart of categories', 'Calculate average by group')"
|
||||
},
|
||||
"visualization_type": {
|
||||
"type": "string",
|
||||
"description": "The type of visualization to create",
|
||||
"enum": ["bar", "line", "pie", "scatter", "histogram", "auto"],
|
||||
"default": "auto"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "set_reminder",
|
||||
"description": "Set a reminder for the user. Use this when a user wants to be reminded about something at a specific time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content of the reminder"
|
||||
},
|
||||
"time": {
|
||||
"type": "string",
|
||||
"description": "The time for the reminder. Can be relative (e.g., '30m', '2h', '1d') or specific times ('tomorrow', '3:00pm', etc.)"
|
||||
}
|
||||
},
|
||||
"required": ["content", "time"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_reminders",
|
||||
"description": "Get a list of upcoming reminders for the user. Use this when user asks about their reminders.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
async def process_tool_calls(client, response, messages, tool_functions) -> Tuple[bool, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Process and execute tool calls from the OpenAI API response.
|
||||
|
||||
Args:
|
||||
client: OpenAI client
|
||||
response: API response containing tool calls
|
||||
messages: The current chat messages
|
||||
tool_functions: Dictionary mapping tool names to handler functions
|
||||
|
||||
Returns:
|
||||
Tuple containing (processed_any_tools, updated_messages)
|
||||
"""
|
||||
processed_any = False
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
|
||||
# Create a copy of the messages to update
|
||||
updated_messages = messages.copy()
|
||||
|
||||
# Add the assistant message with the tool calls
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
} for tc in tool_calls
|
||||
] if tool_calls else None
|
||||
})
|
||||
|
||||
# Process each tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
if function_name in tool_functions:
|
||||
# Parse the JSON arguments
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Invalid JSON in tool call arguments: {tool_call.function.arguments}")
|
||||
function_args = {}
|
||||
|
||||
# Call the appropriate function
|
||||
try:
|
||||
function_response = await tool_functions[function_name](function_args)
|
||||
|
||||
# Add the tool output back to messages
|
||||
updated_messages.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": str(function_response)
|
||||
})
|
||||
|
||||
processed_any = True
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error executing {function_name}: {str(e)}"
|
||||
logging.error(error_message)
|
||||
|
||||
# Add the error as tool output
|
||||
updated_messages.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": error_message
|
||||
})
|
||||
|
||||
processed_any = True
|
||||
|
||||
return processed_any, updated_messages
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Estimate token count using a simple approximation."""
|
||||
# Rough estimate: 1 word ≈ 1.3 tokens
|
||||
return int(len(text.split()) * 1.3)
|
||||
|
||||
def trim_content_to_token_limit(content: str, max_tokens: int = 8096) -> str:
|
||||
"""Trim content to stay within token limit while preserving the most recent content."""
|
||||
current_tokens = count_tokens(content)
|
||||
if current_tokens <= max_tokens:
|
||||
return content
|
||||
|
||||
# Split into lines and start removing from the beginning until under limit
|
||||
lines = content.split('\n')
|
||||
while lines and count_tokens('\n'.join(lines)) > max_tokens:
|
||||
lines.pop(0)
|
||||
|
||||
if not lines: # If still too long, take the last part
|
||||
text = content
|
||||
while count_tokens(text) > max_tokens:
|
||||
text = text[text.find('\n', 1000):]
|
||||
return text
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def prepare_messages_for_api(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare message history for the OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Returns:
|
||||
Prepared messages for API
|
||||
"""
|
||||
prepared_messages = []
|
||||
|
||||
# Check if there's a system message already
|
||||
has_system_message = any(msg.get('role') == 'system' for msg in messages)
|
||||
|
||||
# If no system message exists, add a default one
|
||||
if not has_system_message:
|
||||
prepared_messages.append({
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant that can answer questions, provide information, and assist with various tasks."
|
||||
})
|
||||
|
||||
for msg in messages:
|
||||
# Skip messages with None content
|
||||
if msg.get('content') is None:
|
||||
continue
|
||||
|
||||
# Create a copy of the message to avoid modifying the original
|
||||
processed_msg = dict(msg)
|
||||
|
||||
# Handle image URLs with timestamps in content
|
||||
if isinstance(processed_msg.get('content'), list):
|
||||
# Filter out images that have a timestamp (they're already handled specially)
|
||||
new_content = []
|
||||
for item in processed_msg['content']:
|
||||
if item.get('type') == 'image_url' and 'timestamp' in item:
|
||||
# Remove timestamp from API calls
|
||||
new_item = dict(item)
|
||||
if 'timestamp' in new_item:
|
||||
del new_item['timestamp']
|
||||
new_content.append(new_item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
processed_msg['content'] = new_content
|
||||
|
||||
prepared_messages.append(processed_msg)
|
||||
|
||||
return prepared_messages
|
||||
@@ -1,282 +1,282 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import discord
|
||||
from datetime import datetime, timedelta
|
||||
import pytz
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
class ReminderManager:
|
||||
"""
|
||||
Manages reminder functionality for Discord users
|
||||
"""
|
||||
def __init__(self, bot, db_handler):
|
||||
"""
|
||||
Initialize ReminderManager
|
||||
|
||||
Args:
|
||||
bot: Discord bot instance
|
||||
db_handler: Database handler instance
|
||||
"""
|
||||
self.bot = bot
|
||||
self.db = db_handler
|
||||
self.running = False
|
||||
self.check_task = None
|
||||
|
||||
# Get system timezone to ensure consistency
|
||||
self.timezone = datetime.now().astimezone().tzinfo
|
||||
logging.info(f"Using server timezone: {self.timezone}")
|
||||
|
||||
def start(self):
|
||||
"""Start periodic reminder check"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.check_task = asyncio.create_task(self._check_reminders_loop())
|
||||
logging.info("Reminder manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the reminder check"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.check_task:
|
||||
self.check_task.cancel()
|
||||
try:
|
||||
await self.check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.check_task = None
|
||||
logging.info("Reminder manager stopped")
|
||||
|
||||
def get_current_time(self) -> datetime:
|
||||
"""
|
||||
Get the current time with proper timezone
|
||||
|
||||
Returns:
|
||||
Current datetime with timezone
|
||||
"""
|
||||
return datetime.now().replace(tzinfo=self.timezone)
|
||||
|
||||
async def add_reminder(self, user_id: int, content: str, remind_at: datetime) -> Dict[str, Any]:
|
||||
"""
|
||||
Add a new reminder
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
content: Reminder content
|
||||
remind_at: When to send the reminder
|
||||
|
||||
Returns:
|
||||
Information about the added reminder
|
||||
"""
|
||||
try:
|
||||
now = self.get_current_time()
|
||||
|
||||
reminder = {
|
||||
"user_id": user_id,
|
||||
"content": content,
|
||||
"remind_at": remind_at,
|
||||
"created_at": now,
|
||||
"sent": False
|
||||
}
|
||||
|
||||
result = await self.db.reminders_collection.insert_one(reminder)
|
||||
reminder["_id"] = result.inserted_id
|
||||
|
||||
logging.info(f"Added reminder for user {user_id} at {remind_at} (Server timezone: {self.timezone})")
|
||||
return reminder
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding reminder: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_user_reminders(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a user's reminders
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
List of reminders
|
||||
"""
|
||||
try:
|
||||
cursor = self.db.reminders_collection.find({
|
||||
"user_id": user_id,
|
||||
"sent": False
|
||||
}).sort("remind_at", 1)
|
||||
|
||||
return await cursor.to_list(length=100)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting reminders for user {user_id}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def delete_reminder(self, reminder_id, user_id: int) -> bool:
|
||||
"""
|
||||
Delete a reminder
|
||||
|
||||
Args:
|
||||
reminder_id: Reminder ID
|
||||
user_id: Discord user ID (to verify ownership)
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
# Convert reminder_id to ObjectId if needed
|
||||
if isinstance(reminder_id, str):
|
||||
reminder_id = ObjectId(reminder_id)
|
||||
|
||||
result = await self.db.reminders_collection.delete_one({
|
||||
"_id": reminder_id,
|
||||
"user_id": user_id
|
||||
})
|
||||
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting reminder {reminder_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _check_reminders_loop(self):
|
||||
"""Loop to check for due reminders"""
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
await self._process_due_reminders()
|
||||
await self._clean_expired_reminders()
|
||||
except Exception as e:
|
||||
logging.error(f"Error in reminder check: {str(e)}")
|
||||
|
||||
# Wait 30 seconds before checking again
|
||||
await asyncio.sleep(30)
|
||||
except asyncio.CancelledError:
|
||||
# Handle task cancellation
|
||||
logging.info("Reminder check loop was cancelled")
|
||||
raise
|
||||
|
||||
async def _process_due_reminders(self):
|
||||
"""Process due reminders and send notifications"""
|
||||
now = self.get_current_time()
|
||||
|
||||
# Find due reminders
|
||||
cursor = self.db.reminders_collection.find({
|
||||
"remind_at": {"$lte": now},
|
||||
"sent": False
|
||||
})
|
||||
|
||||
due_reminders = await cursor.to_list(length=100)
|
||||
|
||||
for reminder in due_reminders:
|
||||
try:
|
||||
# Get user information
|
||||
user_id = reminder["user_id"]
|
||||
user = await self.bot.fetch_user(user_id)
|
||||
|
||||
if user:
|
||||
# Format reminder message
|
||||
embed = discord.Embed(
|
||||
title="📅 Reminder",
|
||||
description=reminder["content"],
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
embed.add_field(
|
||||
name="Set on",
|
||||
value=reminder["created_at"].strftime("%Y-%m-%d %H:%M")
|
||||
)
|
||||
embed.set_footer(text="Server time: " + now.strftime("%Y-%m-%d %H:%M"))
|
||||
|
||||
# Send reminder message with mention
|
||||
try:
|
||||
# Try to send a direct message first
|
||||
await user.send(f"<@{user_id}> Here's your reminder:", embed=embed)
|
||||
logging.info(f"Sent reminder DM to user {user_id}")
|
||||
except Exception as dm_error:
|
||||
logging.error(f"Could not send DM to user {user_id}: {str(dm_error)}")
|
||||
# Could implement fallback method here if needed
|
||||
|
||||
# Mark reminder as sent and delete it
|
||||
await self.db.reminders_collection.delete_one({"_id": reminder["_id"]})
|
||||
logging.info(f"Deleted completed reminder {reminder['_id']} for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing reminder {reminder['_id']}: {str(e)}")
|
||||
|
||||
async def _clean_expired_reminders(self):
|
||||
"""Clean up old reminders that were marked as sent but not deleted"""
|
||||
try:
|
||||
result = await self.db.reminders_collection.delete_many({
|
||||
"sent": True
|
||||
})
|
||||
|
||||
if result.deleted_count > 0:
|
||||
logging.info(f"Cleaned up {result.deleted_count} expired reminders")
|
||||
except Exception as e:
|
||||
logging.error(f"Error cleaning expired reminders: {str(e)}")
|
||||
|
||||
async def parse_time(self, time_str: str) -> Optional[datetime]:
|
||||
"""
|
||||
Parse a time string into a datetime object with the server's timezone
|
||||
|
||||
Args:
|
||||
time_str: Time string (e.g., "30m", "2h", "1d", "tomorrow", "15:00")
|
||||
|
||||
Returns:
|
||||
Datetime object or None if parsing fails
|
||||
"""
|
||||
now = self.get_current_time()
|
||||
time_str = time_str.lower().strip()
|
||||
|
||||
try:
|
||||
# Handle special keywords
|
||||
if time_str == "tomorrow":
|
||||
return now.replace(hour=9, minute=0, second=0) + timedelta(days=1)
|
||||
elif time_str == "tonight":
|
||||
# Use 8 PM (20:00) for "tonight"
|
||||
target = now.replace(hour=20, minute=0, second=0)
|
||||
# If it's already past 8 PM, schedule for tomorrow night
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
return target
|
||||
elif time_str == "noon":
|
||||
# Use 12 PM for "noon"
|
||||
target = now.replace(hour=12, minute=0, second=0)
|
||||
# If it's already past noon, schedule for tomorrow
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
return target
|
||||
|
||||
# Handle relative time formats (30m, 2h, 1d)
|
||||
if time_str[-1] in ['m', 'h', 'd']:
|
||||
value = int(time_str[:-1])
|
||||
unit = time_str[-1]
|
||||
|
||||
if unit == 'm': # minutes
|
||||
return now + timedelta(minutes=value)
|
||||
elif unit == 'h': # hours
|
||||
return now + timedelta(hours=value)
|
||||
elif unit == 'd': # days
|
||||
return now + timedelta(days=value)
|
||||
|
||||
# Handle specific time format
|
||||
# HH:MM format for today or tomorrow
|
||||
if ':' in time_str and len(time_str.split(':')) == 2:
|
||||
hour, minute = map(int, time_str.split(':'))
|
||||
|
||||
# Check if valid time
|
||||
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
|
||||
logging.warning(f"Invalid time format: {time_str}")
|
||||
return None
|
||||
|
||||
# Create datetime for the specified time today
|
||||
target = now.replace(hour=hour, minute=minute, second=0)
|
||||
|
||||
# If the time has already passed today, schedule for tomorrow
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
|
||||
logging.info(f"Parsed time '{time_str}' to {target} (Server timezone: {self.timezone})")
|
||||
return target
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing time string '{time_str}': {str(e)}")
|
||||
import asyncio
|
||||
import logging
|
||||
import discord
|
||||
from datetime import datetime, timedelta
|
||||
import pytz
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
class ReminderManager:
|
||||
"""
|
||||
Manages reminder functionality for Discord users
|
||||
"""
|
||||
def __init__(self, bot, db_handler):
|
||||
"""
|
||||
Initialize ReminderManager
|
||||
|
||||
Args:
|
||||
bot: Discord bot instance
|
||||
db_handler: Database handler instance
|
||||
"""
|
||||
self.bot = bot
|
||||
self.db = db_handler
|
||||
self.running = False
|
||||
self.check_task = None
|
||||
|
||||
# Get system timezone to ensure consistency
|
||||
self.timezone = datetime.now().astimezone().tzinfo
|
||||
logging.info(f"Using server timezone: {self.timezone}")
|
||||
|
||||
def start(self):
|
||||
"""Start periodic reminder check"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.check_task = asyncio.create_task(self._check_reminders_loop())
|
||||
logging.info("Reminder manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the reminder check"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.check_task:
|
||||
self.check_task.cancel()
|
||||
try:
|
||||
await self.check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.check_task = None
|
||||
logging.info("Reminder manager stopped")
|
||||
|
||||
def get_current_time(self) -> datetime:
|
||||
"""
|
||||
Get the current time with proper timezone
|
||||
|
||||
Returns:
|
||||
Current datetime with timezone
|
||||
"""
|
||||
return datetime.now().replace(tzinfo=self.timezone)
|
||||
|
||||
async def add_reminder(self, user_id: int, content: str, remind_at: datetime) -> Dict[str, Any]:
|
||||
"""
|
||||
Add a new reminder
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
content: Reminder content
|
||||
remind_at: When to send the reminder
|
||||
|
||||
Returns:
|
||||
Information about the added reminder
|
||||
"""
|
||||
try:
|
||||
now = self.get_current_time()
|
||||
|
||||
reminder = {
|
||||
"user_id": user_id,
|
||||
"content": content,
|
||||
"remind_at": remind_at,
|
||||
"created_at": now,
|
||||
"sent": False
|
||||
}
|
||||
|
||||
result = await self.db.reminders_collection.insert_one(reminder)
|
||||
reminder["_id"] = result.inserted_id
|
||||
|
||||
logging.info(f"Added reminder for user {user_id} at {remind_at} (Server timezone: {self.timezone})")
|
||||
return reminder
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding reminder: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_user_reminders(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a user's reminders
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
List of reminders
|
||||
"""
|
||||
try:
|
||||
cursor = self.db.reminders_collection.find({
|
||||
"user_id": user_id,
|
||||
"sent": False
|
||||
}).sort("remind_at", 1)
|
||||
|
||||
return await cursor.to_list(length=100)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting reminders for user {user_id}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def delete_reminder(self, reminder_id, user_id: int) -> bool:
|
||||
"""
|
||||
Delete a reminder
|
||||
|
||||
Args:
|
||||
reminder_id: Reminder ID
|
||||
user_id: Discord user ID (to verify ownership)
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
# Convert reminder_id to ObjectId if needed
|
||||
if isinstance(reminder_id, str):
|
||||
reminder_id = ObjectId(reminder_id)
|
||||
|
||||
result = await self.db.reminders_collection.delete_one({
|
||||
"_id": reminder_id,
|
||||
"user_id": user_id
|
||||
})
|
||||
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting reminder {reminder_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _check_reminders_loop(self):
|
||||
"""Loop to check for due reminders"""
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
await self._process_due_reminders()
|
||||
await self._clean_expired_reminders()
|
||||
except Exception as e:
|
||||
logging.error(f"Error in reminder check: {str(e)}")
|
||||
|
||||
# Wait 30 seconds before checking again
|
||||
await asyncio.sleep(30)
|
||||
except asyncio.CancelledError:
|
||||
# Handle task cancellation
|
||||
logging.info("Reminder check loop was cancelled")
|
||||
raise
|
||||
|
||||
async def _process_due_reminders(self):
|
||||
"""Process due reminders and send notifications"""
|
||||
now = self.get_current_time()
|
||||
|
||||
# Find due reminders
|
||||
cursor = self.db.reminders_collection.find({
|
||||
"remind_at": {"$lte": now},
|
||||
"sent": False
|
||||
})
|
||||
|
||||
due_reminders = await cursor.to_list(length=100)
|
||||
|
||||
for reminder in due_reminders:
|
||||
try:
|
||||
# Get user information
|
||||
user_id = reminder["user_id"]
|
||||
user = await self.bot.fetch_user(user_id)
|
||||
|
||||
if user:
|
||||
# Format reminder message
|
||||
embed = discord.Embed(
|
||||
title="📅 Reminder",
|
||||
description=reminder["content"],
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
embed.add_field(
|
||||
name="Set on",
|
||||
value=reminder["created_at"].strftime("%Y-%m-%d %H:%M")
|
||||
)
|
||||
embed.set_footer(text="Server time: " + now.strftime("%Y-%m-%d %H:%M"))
|
||||
|
||||
# Send reminder message with mention
|
||||
try:
|
||||
# Try to send a direct message first
|
||||
await user.send(f"<@{user_id}> Here's your reminder:", embed=embed)
|
||||
logging.info(f"Sent reminder DM to user {user_id}")
|
||||
except Exception as dm_error:
|
||||
logging.error(f"Could not send DM to user {user_id}: {str(dm_error)}")
|
||||
# Could implement fallback method here if needed
|
||||
|
||||
# Mark reminder as sent and delete it
|
||||
await self.db.reminders_collection.delete_one({"_id": reminder["_id"]})
|
||||
logging.info(f"Deleted completed reminder {reminder['_id']} for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing reminder {reminder['_id']}: {str(e)}")
|
||||
|
||||
async def _clean_expired_reminders(self):
|
||||
"""Clean up old reminders that were marked as sent but not deleted"""
|
||||
try:
|
||||
result = await self.db.reminders_collection.delete_many({
|
||||
"sent": True
|
||||
})
|
||||
|
||||
if result.deleted_count > 0:
|
||||
logging.info(f"Cleaned up {result.deleted_count} expired reminders")
|
||||
except Exception as e:
|
||||
logging.error(f"Error cleaning expired reminders: {str(e)}")
|
||||
|
||||
async def parse_time(self, time_str: str) -> Optional[datetime]:
|
||||
"""
|
||||
Parse a time string into a datetime object with the server's timezone
|
||||
|
||||
Args:
|
||||
time_str: Time string (e.g., "30m", "2h", "1d", "tomorrow", "15:00")
|
||||
|
||||
Returns:
|
||||
Datetime object or None if parsing fails
|
||||
"""
|
||||
now = self.get_current_time()
|
||||
time_str = time_str.lower().strip()
|
||||
|
||||
try:
|
||||
# Handle special keywords
|
||||
if time_str == "tomorrow":
|
||||
return now.replace(hour=9, minute=0, second=0) + timedelta(days=1)
|
||||
elif time_str == "tonight":
|
||||
# Use 8 PM (20:00) for "tonight"
|
||||
target = now.replace(hour=20, minute=0, second=0)
|
||||
# If it's already past 8 PM, schedule for tomorrow night
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
return target
|
||||
elif time_str == "noon":
|
||||
# Use 12 PM for "noon"
|
||||
target = now.replace(hour=12, minute=0, second=0)
|
||||
# If it's already past noon, schedule for tomorrow
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
return target
|
||||
|
||||
# Handle relative time formats (30m, 2h, 1d)
|
||||
if time_str[-1] in ['m', 'h', 'd']:
|
||||
value = int(time_str[:-1])
|
||||
unit = time_str[-1]
|
||||
|
||||
if unit == 'm': # minutes
|
||||
return now + timedelta(minutes=value)
|
||||
elif unit == 'h': # hours
|
||||
return now + timedelta(hours=value)
|
||||
elif unit == 'd': # days
|
||||
return now + timedelta(days=value)
|
||||
|
||||
# Handle specific time format
|
||||
# HH:MM format for today or tomorrow
|
||||
if ':' in time_str and len(time_str.split(':')) == 2:
|
||||
hour, minute = map(int, time_str.split(':'))
|
||||
|
||||
# Check if valid time
|
||||
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
|
||||
logging.warning(f"Invalid time format: {time_str}")
|
||||
return None
|
||||
|
||||
# Create datetime for the specified time today
|
||||
target = now.replace(hour=hour, minute=minute, second=0)
|
||||
|
||||
# If the time has already passed today, schedule for tomorrow
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
|
||||
logging.info(f"Parsed time '{time_str}' to {target} (Server timezone: {self.timezone})")
|
||||
return target
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing time string '{time_str}': {str(e)}")
|
||||
return None
|
||||
@@ -1,225 +1,225 @@
|
||||
import requests
|
||||
import json
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
|
||||
import tiktoken # Add tiktoken for token counting
|
||||
|
||||
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
|
||||
"""
|
||||
Perform a Google search using the Google Custom Search API and scrape content
|
||||
until reaching token limit.
|
||||
|
||||
Args:
|
||||
query (str): The search query
|
||||
num_results (int): Number of results to return
|
||||
max_tokens (int): Maximum number of tokens for combined scraped content
|
||||
|
||||
Returns:
|
||||
dict: Search results with metadata and combined scraped content
|
||||
"""
|
||||
try:
|
||||
search_url = f"https://www.googleapis.com/customsearch/v1"
|
||||
params = {
|
||||
'key': GOOGLE_API_KEY,
|
||||
'cx': GOOGLE_CX,
|
||||
'q': query,
|
||||
'num': min(num_results, 10) # Google API maximum is 10
|
||||
}
|
||||
|
||||
response = requests.get(search_url, params=params)
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
|
||||
# Format the results for ease of use
|
||||
formatted_results = {
|
||||
'query': query,
|
||||
'results': [],
|
||||
'combined_content': ""
|
||||
}
|
||||
|
||||
if 'items' in search_results:
|
||||
# Extract all links first
|
||||
links = [item.get('link', '') for item in search_results['items']]
|
||||
|
||||
# Scrape content from multiple links up to max_tokens
|
||||
combined_content, used_links = scrape_multiple_links(links, max_tokens)
|
||||
formatted_results['combined_content'] = combined_content
|
||||
|
||||
# Process each search result
|
||||
for item in search_results['items']:
|
||||
result = {
|
||||
'title': item.get('title', ''),
|
||||
'link': item.get('link', ''),
|
||||
'snippet': item.get('snippet', ''),
|
||||
'date': item.get('pagemap', {}).get('metatags', [{}])[0].get('article:published_time', ''),
|
||||
'used_for_content': item.get('link', '') in used_links
|
||||
}
|
||||
formatted_results['results'].append(result)
|
||||
|
||||
return formatted_results
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {
|
||||
'query': query,
|
||||
'error': f"Error during Google search: {str(e)}",
|
||||
'results': [],
|
||||
'combined_content': ""
|
||||
}
|
||||
|
||||
def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Scrape content from multiple URLs, stopping once token limit is reached.
|
||||
|
||||
Args:
|
||||
urls (List[str]): List of URLs to scrape
|
||||
max_tokens (int): Maximum token count for combined content
|
||||
|
||||
Returns:
|
||||
Tuple[str, List[str]]: Combined content and list of used URLs
|
||||
"""
|
||||
combined_content = ""
|
||||
total_tokens = 0
|
||||
used_urls = []
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
except:
|
||||
encoding = None
|
||||
|
||||
for url in urls:
|
||||
# Skip empty URLs
|
||||
if not url:
|
||||
continue
|
||||
|
||||
# Get content from this URL
|
||||
content, token_count = scrape_web_content_with_count(url, return_token_count=True)
|
||||
|
||||
# Skip failed scrapes
|
||||
if content.startswith("Failed"):
|
||||
continue
|
||||
|
||||
# Check if adding this content would exceed token limit
|
||||
if total_tokens + token_count > max_tokens:
|
||||
# If this is the first URL and it's too large, we need to truncate it
|
||||
if not combined_content:
|
||||
if encoding:
|
||||
tokens = encoding.encode(content)
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_content = encoding.decode(truncated_tokens)
|
||||
combined_content = f"{truncated_content}...\n[Content truncated due to token limit]"
|
||||
else:
|
||||
# Fallback to character-based truncation
|
||||
truncated_content = content[:max_tokens * 4]
|
||||
combined_content = f"{truncated_content}...\n[Content truncated due to length]"
|
||||
used_urls.append(url)
|
||||
break
|
||||
|
||||
# Add separator if not the first URL
|
||||
if combined_content:
|
||||
combined_content += f"\n\n--- Content from: {url} ---\n\n"
|
||||
else:
|
||||
combined_content += f"--- Content from: {url} ---\n\n"
|
||||
|
||||
# Add content and update token count
|
||||
combined_content += content
|
||||
total_tokens += token_count
|
||||
used_urls.append(url)
|
||||
|
||||
# If we've reached the token limit, stop
|
||||
if total_tokens >= max_tokens:
|
||||
break
|
||||
|
||||
# If we didn't find any valid content
|
||||
if not combined_content:
|
||||
combined_content = "No valid content could be scraped from the provided URLs."
|
||||
|
||||
return combined_content, used_urls
|
||||
|
||||
def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token_count: bool = False) -> Any:
|
||||
"""
|
||||
Scrape content from a webpage and return with token count if needed.
|
||||
|
||||
Args:
|
||||
url (str): URL of the webpage to scrape
|
||||
max_tokens (int): Maximum number of tokens to return
|
||||
return_token_count (bool): Whether to return token count with the content
|
||||
|
||||
Returns:
|
||||
str or tuple: The scraped text content or (content, token_count)
|
||||
"""
|
||||
if not url:
|
||||
return ("Failed to scrape: No URL provided.", 0) if return_token_count else "Failed to scrape: No URL provided."
|
||||
|
||||
# Ignore URLs that are unlikely to be scrapable or might cause problems
|
||||
if any(x in url.lower() for x in ['.pdf', '.zip', '.jpg', '.png', '.mp3', '.mp4', 'youtube.com', 'youtu.be']):
|
||||
message = f"Failed to scrape: The URL {url} cannot be scraped (unsupported format)."
|
||||
return (message, 0) if return_token_count else message
|
||||
|
||||
try:
|
||||
# Add user agent to mimic a browser
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the content with BeautifulSoup
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style", "header", "footer", "nav"]):
|
||||
script.extract()
|
||||
|
||||
# Get the text content
|
||||
text = soup.get_text(separator='\n')
|
||||
|
||||
# Clean up text: remove extra whitespace and empty lines
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
text = '\n'.join(line for line in lines if line)
|
||||
|
||||
# Count tokens
|
||||
token_count = 0
|
||||
try:
|
||||
# Use cl100k_base encoder which is used by most recent models
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoding.encode(text)
|
||||
token_count = len(tokens)
|
||||
|
||||
# Truncate if token count exceeds max_tokens and we're not returning token count
|
||||
if len(tokens) > max_tokens and not return_token_count:
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
text = encoding.decode(truncated_tokens)
|
||||
text += "...\n[Content truncated due to token limit]"
|
||||
except ImportError:
|
||||
# Fallback to character-based estimation
|
||||
token_count = len(text) // 4 # Rough estimate: 1 token ≈ 4 characters
|
||||
if len(text) > max_tokens * 4 and not return_token_count:
|
||||
text = text[:max_tokens * 4] + "...\n[Content truncated due to length]"
|
||||
|
||||
if return_token_count:
|
||||
return text, token_count
|
||||
return text
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
message = f"Failed to scrape {url}: {str(e)}"
|
||||
return (message, 0) if return_token_count else message
|
||||
except Exception as e:
|
||||
message = f"Failed to process content from {url}: {str(e)}"
|
||||
return (message, 0) if return_token_count else message
|
||||
|
||||
# Keep the original scrape_web_content function for backward compatibility
|
||||
def scrape_web_content(url: str, max_tokens: int = 4000) -> str:
|
||||
"""
|
||||
Scrape content from a webpage and limit by token count.
|
||||
|
||||
Args:
|
||||
url (str): URL of the webpage to scrape
|
||||
max_tokens (int): Maximum number of tokens to return
|
||||
|
||||
Returns:
|
||||
str: The scraped text content or error message
|
||||
"""
|
||||
return scrape_web_content_with_count(url, max_tokens)
|
||||
import requests
|
||||
import json
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from src.config.config import GOOGLE_API_KEY, GOOGLE_CX
|
||||
import tiktoken # Add tiktoken for token counting
|
||||
|
||||
def google_custom_search(query: str, num_results: int = 5, max_tokens: int = 4000) -> dict:
|
||||
"""
|
||||
Perform a Google search using the Google Custom Search API and scrape content
|
||||
until reaching token limit.
|
||||
|
||||
Args:
|
||||
query (str): The search query
|
||||
num_results (int): Number of results to return
|
||||
max_tokens (int): Maximum number of tokens for combined scraped content
|
||||
|
||||
Returns:
|
||||
dict: Search results with metadata and combined scraped content
|
||||
"""
|
||||
try:
|
||||
search_url = f"https://www.googleapis.com/customsearch/v1"
|
||||
params = {
|
||||
'key': GOOGLE_API_KEY,
|
||||
'cx': GOOGLE_CX,
|
||||
'q': query,
|
||||
'num': min(num_results, 10) # Google API maximum is 10
|
||||
}
|
||||
|
||||
response = requests.get(search_url, params=params)
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
|
||||
# Format the results for ease of use
|
||||
formatted_results = {
|
||||
'query': query,
|
||||
'results': [],
|
||||
'combined_content': ""
|
||||
}
|
||||
|
||||
if 'items' in search_results:
|
||||
# Extract all links first
|
||||
links = [item.get('link', '') for item in search_results['items']]
|
||||
|
||||
# Scrape content from multiple links up to max_tokens
|
||||
combined_content, used_links = scrape_multiple_links(links, max_tokens)
|
||||
formatted_results['combined_content'] = combined_content
|
||||
|
||||
# Process each search result
|
||||
for item in search_results['items']:
|
||||
result = {
|
||||
'title': item.get('title', ''),
|
||||
'link': item.get('link', ''),
|
||||
'snippet': item.get('snippet', ''),
|
||||
'date': item.get('pagemap', {}).get('metatags', [{}])[0].get('article:published_time', ''),
|
||||
'used_for_content': item.get('link', '') in used_links
|
||||
}
|
||||
formatted_results['results'].append(result)
|
||||
|
||||
return formatted_results
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {
|
||||
'query': query,
|
||||
'error': f"Error during Google search: {str(e)}",
|
||||
'results': [],
|
||||
'combined_content': ""
|
||||
}
|
||||
|
||||
def scrape_multiple_links(urls: List[str], max_tokens: int = 4000) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Scrape content from multiple URLs, stopping once token limit is reached.
|
||||
|
||||
Args:
|
||||
urls (List[str]): List of URLs to scrape
|
||||
max_tokens (int): Maximum token count for combined content
|
||||
|
||||
Returns:
|
||||
Tuple[str, List[str]]: Combined content and list of used URLs
|
||||
"""
|
||||
combined_content = ""
|
||||
total_tokens = 0
|
||||
used_urls = []
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
except:
|
||||
encoding = None
|
||||
|
||||
for url in urls:
|
||||
# Skip empty URLs
|
||||
if not url:
|
||||
continue
|
||||
|
||||
# Get content from this URL
|
||||
content, token_count = scrape_web_content_with_count(url, return_token_count=True)
|
||||
|
||||
# Skip failed scrapes
|
||||
if content.startswith("Failed"):
|
||||
continue
|
||||
|
||||
# Check if adding this content would exceed token limit
|
||||
if total_tokens + token_count > max_tokens:
|
||||
# If this is the first URL and it's too large, we need to truncate it
|
||||
if not combined_content:
|
||||
if encoding:
|
||||
tokens = encoding.encode(content)
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_content = encoding.decode(truncated_tokens)
|
||||
combined_content = f"{truncated_content}...\n[Content truncated due to token limit]"
|
||||
else:
|
||||
# Fallback to character-based truncation
|
||||
truncated_content = content[:max_tokens * 4]
|
||||
combined_content = f"{truncated_content}...\n[Content truncated due to length]"
|
||||
used_urls.append(url)
|
||||
break
|
||||
|
||||
# Add separator if not the first URL
|
||||
if combined_content:
|
||||
combined_content += f"\n\n--- Content from: {url} ---\n\n"
|
||||
else:
|
||||
combined_content += f"--- Content from: {url} ---\n\n"
|
||||
|
||||
# Add content and update token count
|
||||
combined_content += content
|
||||
total_tokens += token_count
|
||||
used_urls.append(url)
|
||||
|
||||
# If we've reached the token limit, stop
|
||||
if total_tokens >= max_tokens:
|
||||
break
|
||||
|
||||
# If we didn't find any valid content
|
||||
if not combined_content:
|
||||
combined_content = "No valid content could be scraped from the provided URLs."
|
||||
|
||||
return combined_content, used_urls
|
||||
|
||||
def scrape_web_content_with_count(url: str, max_tokens: int = 4000, return_token_count: bool = False) -> Any:
|
||||
"""
|
||||
Scrape content from a webpage and return with token count if needed.
|
||||
|
||||
Args:
|
||||
url (str): URL of the webpage to scrape
|
||||
max_tokens (int): Maximum number of tokens to return
|
||||
return_token_count (bool): Whether to return token count with the content
|
||||
|
||||
Returns:
|
||||
str or tuple: The scraped text content or (content, token_count)
|
||||
"""
|
||||
if not url:
|
||||
return ("Failed to scrape: No URL provided.", 0) if return_token_count else "Failed to scrape: No URL provided."
|
||||
|
||||
# Ignore URLs that are unlikely to be scrapable or might cause problems
|
||||
if any(x in url.lower() for x in ['.pdf', '.zip', '.jpg', '.png', '.mp3', '.mp4', 'youtube.com', 'youtu.be']):
|
||||
message = f"Failed to scrape: The URL {url} cannot be scraped (unsupported format)."
|
||||
return (message, 0) if return_token_count else message
|
||||
|
||||
try:
|
||||
# Add user agent to mimic a browser
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the content with BeautifulSoup
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style", "header", "footer", "nav"]):
|
||||
script.extract()
|
||||
|
||||
# Get the text content
|
||||
text = soup.get_text(separator='\n')
|
||||
|
||||
# Clean up text: remove extra whitespace and empty lines
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
text = '\n'.join(line for line in lines if line)
|
||||
|
||||
# Count tokens
|
||||
token_count = 0
|
||||
try:
|
||||
# Use cl100k_base encoder which is used by most recent models
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoding.encode(text)
|
||||
token_count = len(tokens)
|
||||
|
||||
# Truncate if token count exceeds max_tokens and we're not returning token count
|
||||
if len(tokens) > max_tokens and not return_token_count:
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
text = encoding.decode(truncated_tokens)
|
||||
text += "...\n[Content truncated due to token limit]"
|
||||
except ImportError:
|
||||
# Fallback to character-based estimation
|
||||
token_count = len(text) // 4 # Rough estimate: 1 token ≈ 4 characters
|
||||
if len(text) > max_tokens * 4 and not return_token_count:
|
||||
text = text[:max_tokens * 4] + "...\n[Content truncated due to length]"
|
||||
|
||||
if return_token_count:
|
||||
return text, token_count
|
||||
return text
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
message = f"Failed to scrape {url}: {str(e)}"
|
||||
return (message, 0) if return_token_count else message
|
||||
except Exception as e:
|
||||
message = f"Failed to process content from {url}: {str(e)}"
|
||||
return (message, 0) if return_token_count else message
|
||||
|
||||
# Keep the original scrape_web_content function for backward compatibility
|
||||
def scrape_web_content(url: str, max_tokens: int = 4000) -> str:
|
||||
"""
|
||||
Scrape content from a webpage and limit by token count.
|
||||
|
||||
Args:
|
||||
url (str): URL of the webpage to scrape
|
||||
max_tokens (int): Maximum number of tokens to return
|
||||
|
||||
Returns:
|
||||
str: The scraped text content or error message
|
||||
"""
|
||||
return scrape_web_content_with_count(url, max_tokens)
|
||||
|
||||
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
@@ -55,11 +55,15 @@ class TestDatabaseHandler(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Extract database name from URI for later use
|
||||
self.db_name = self._extract_db_name_from_uri(self.mongodb_uri)
|
||||
|
||||
async def tearDown(self):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
# No additional async setup needed, but required by IsolatedAsyncioTestCase
|
||||
pass
|
||||
|
||||
async def asyncTearDown(self):
|
||||
if not self.using_real_db:
|
||||
self.mock_client_patcher.stop()
|
||||
if self.using_real_db:
|
||||
else:
|
||||
# Clean up test data if using real database
|
||||
await self.cleanup_test_data()
|
||||
|
||||
@@ -215,6 +219,7 @@ class TestOpenAIUtils(unittest.TestCase):
|
||||
# Test empty messages
|
||||
empty_result = prepare_messages_for_api([])
|
||||
self.assertEqual(len(empty_result), 1) # Should have system message
|
||||
self.assertEqual(empty_result[0]["role"], "system") # Verify it's a system message
|
||||
|
||||
# Test regular messages
|
||||
messages = [
|
||||
@@ -223,7 +228,7 @@ class TestOpenAIUtils(unittest.TestCase):
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
result = prepare_messages_for_api(messages)
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(len(result), 4) # Should have system message + 3 original messages
|
||||
|
||||
# Test with null content
|
||||
messages_with_null = [
|
||||
@@ -231,8 +236,11 @@ class TestOpenAIUtils(unittest.TestCase):
|
||||
{"role": "assistant", "content": "Response"}
|
||||
]
|
||||
result_fixed = prepare_messages_for_api(messages_with_null)
|
||||
self.assertEqual(len(result_fixed), 1) # Should exclude the null content
|
||||
|
||||
self.assertEqual(len(result_fixed), 2) # Should have system message + 1 valid message
|
||||
# Verify the content is correct (system message + only the assistant message)
|
||||
self.assertEqual(result_fixed[0]["role"], "system")
|
||||
self.assertEqual(result_fixed[1]["role"], "assistant")
|
||||
self.assertEqual(result_fixed[1]["content"], "Response")
|
||||
|
||||
class TestCodeUtils(unittest.TestCase):
|
||||
"""Test code utility functions"""
|
||||
@@ -339,7 +347,7 @@ print("Hello world")
|
||||
# self.assertIn("Test Heading", content)
|
||||
# self.assertIn("Test paragraph", content)
|
||||
|
||||
class TestPDFUtils(unittest.TestCase):
|
||||
class TestPDFUtils(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test PDF utilities"""
|
||||
|
||||
async def test_send_response(self):
|
||||
|
||||
Reference in New Issue
Block a user