feat: Implement data analysis and code execution utilities

- Added data_analyzer.py for comprehensive data analysis with templates for summary, correlation, distribution, and custom analysis.
- Integrated logging for tracking analysis processes and errors.
- Included package installation functionality for required libraries.
- Developed python_executor.py to safely execute user-provided Python code with sanitization and timeout features.
- Implemented security measures to prevent unsafe code execution.
- Enhanced output handling to capture visualizations and standard output.
This commit is contained in:
2025-06-20 17:56:27 +07:00
parent cce1ff506b
commit d3b92f8bef
32 changed files with 2376 additions and 2680 deletions

8
.idea/.gitignore generated vendored
View File

@@ -1,8 +0,0 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@@ -1,24 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="Flask">
<option name="enabled" value="true" />
</component>
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
<excludeFolder url="file://$MODULE_DIR$/.venv1" />
</content>
<orderEntry type="jdk" jdkName="Python 3.12 (ChatGPT-Discord-Bot)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TemplatesService">
<option name="TEMPLATE_CONFIGURATION" value="Jinja2" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="py.test" />
</component>
</module>

13
.idea/dataSources.xml generated
View File

@@ -1,13 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="@cluster0.akbzf5r.mongodb.net" uuid="1fd5036f-4eb6-4575-b3b7-1a641f728a88">
<driver-ref>documentdb</driver-ref>
<synchronize>true</synchronize>
<configured-by-url>true</configured-by-url>
<jdbc-driver>com.dbschema.MongoJdbcDriver</jdbc-driver>
<jdbc-url>mongodb+srv://chatgpt:Anhtt2021@cluster0.akbzf5r.mongodb.net/?retryWrites=true&amp;w=majority&amp;appName=Cluster0</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

View File

@@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated
View File

@@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.12 (ChatGPT-Discord-Bot)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (ChatGPT-Discord-Bot)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/ChatGPT-Discord-Bot.iml" filepath="$PROJECT_DIR$/.idea/ChatGPT-Discord-Bot.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

0
logs/data_analyzer.log Normal file
View File

View File

@@ -81,7 +81,29 @@ PDF_BATCH_SIZE = 3
# Prompt templates
WEB_SCRAPING_PROMPT = "You are a Web Scraping Assistant. You analyze content from webpages to extract key information. Integrate insights from the scraped content to give comprehensive, fact-based responses. When analyzing web content: 1) Focus on the most relevant information, 2) Cite specific sections when appropriate, 3) Maintain a neutral tone, and 4) Organize information logically. Present your response in a clear, conversational manner suitable for Discord."
NORMAL_CHAT_PROMPT = "You're ChatGPT for Discord! You have access to powerful tools that can enhance your responses. When appropriate, use: 1) Google Search (google_search) to find current information, 2) Web Scraping (scrape_webpage) to analyze webpages, 3) Code Interpreter (code_interpreter) to run code in Python to support calculating and analyzing (priority to use python default internal library), and 4) Image Generation (generate_image) to create images from text descriptions, 5) data analysis (analyze_data) to draw chart based on user data file and 6) Reminder (set_reminder) to set a remind based on user request. When solving problems, follow a step-by-step approach: identify what information is needed, determine which tools might help, and explain your reasoning clearly. For code tasks, always share both the code you're running and its output. Craft responses that are easy to read in Discord without any markdown and latex (except for code you must use markdown). You MUST respond in the same language as the user. You MUST use code_interpreter with Python language for your own code for correct of any calculation. All user request MUST be completed in one single request"
NORMAL_CHAT_PROMPT = """You're ChatGPT for Discord! You have access to powerful tools that enhance your capabilities:
🔍 **Information Tools:**
- google_search: Find current information from the web
- scrape_webpage: Extract and analyze content from websites
💻 **Programming & Data Tools:**
- execute_python_code: Run Python code for calculations, math problems, algorithms, and custom programming tasks
- analyze_data_file: Analyze CSV/Excel files with pre-built templates when users request data analysis or insights
🎨 **Creative Tools:**
- generate_image: Create images from text descriptions
- edit_image: Modify existing images (remove backgrounds, etc.)
📅 **Utility Tools:**
- set_reminder/get_reminders: Manage user reminders
**For Data Files:** When users upload CSV/Excel files:
- If they ask for "analysis", "insights", "statistics" → file is automatically processed with analyze_data_file
- If they want custom programming or specific code → file path is provided to execute_python_code
- Both tools can install packages and create visualizations automatically displayed in Discord
Always explain your approach step-by-step and provide clear, Discord-friendly responses without excessive markdown. You MUST respond in the same language as the user."""
SEARCH_PROMPT = "You are a Research Assistant with access to Google Search results. Your task is to synthesize information from search results to provide accurate, comprehensive answers. When analyzing search results: 1) Prioritize information from credible sources, 2) Compare and contrast different perspectives when available, 3) Acknowledge when information is limited or unclear, and 4) Cite specific sources when presenting facts. Structure your response in a clear, logical manner, focusing on directly answering the user's question while providing relevant context."

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

View File

@@ -4,17 +4,18 @@ import io
import re
import logging
import asyncio
import subprocess
import tempfile
import time
import uuid
from logging.handlers import RotatingFileHandler
import traceback
import contextlib
from typing import Dict, Any, Optional, List
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from .code_utils import sanitize_code, get_temporary_file_path, generate_analysis_code, analyze_data, DATA_FILES_DIR, format_output_path, clean_old_files
# Import the new separated modules
from .python_executor import execute_python_code
from .data_analyzer import analyze_data_file
# Configure logging
log_file = 'logs/code_interpreter.log'
@@ -26,11 +27,43 @@ logger = logging.getLogger('code_interpreter')
logger.setLevel(logging.INFO)
logger.addHandler(file_handler)
# Regular expression to find image file paths in output
IMAGE_PATH_PATTERN = r'(sandbox:)?(\/media\/quocanh\/.*\.(png|jpg|jpeg|gif))'
async def execute_code(args: Dict[str, Any]) -> Dict[str, Any]:
"""
Main entry point for code execution - routes to appropriate handler.
This function maintains backward compatibility while routing requests
to the appropriate specialized handler.
Args:
args: Dictionary containing execution parameters
Returns:
Dict containing execution results
"""
try:
# Check if this is a data analysis request
file_path = args.get("file_path", "")
analysis_request = args.get("analysis_request", "")
if file_path and (analysis_request or args.get("analysis_type")):
# Route to data analyzer
logger.info("Routing to data analyzer")
return await analyze_data_file(args)
else:
# Route to Python executor
logger.info("Routing to Python executor")
return await execute_python_code(args)
except Exception as e:
error_msg = f"Error in code execution router: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
return {
"success": False,
"error": error_msg,
"output": "",
"traceback": traceback.format_exc()
}
"""
Execute code with support for data analysis and visualization.
Args:

547
src/utils/data_analyzer.py Normal file
View File

@@ -0,0 +1,547 @@
import os
import sys
import io
import logging
import asyncio
import traceback
import contextlib
import tempfile
import uuid
import time
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime
from logging.handlers import RotatingFileHandler
# Import data analysis libraries
try:
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
LIBRARIES_AVAILABLE = True
except ImportError as e:
LIBRARIES_AVAILABLE = False
logging.warning(f"Data analysis libraries not available: {str(e)}")
# Import utility functions
from .code_utils import DATA_FILES_DIR, format_output_path, clean_old_files
# Configure logging
log_file = 'logs/data_analyzer.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
file_handler.setFormatter(formatter)
logger = logging.getLogger('data_analyzer')
logger.setLevel(logging.INFO)
logger.addHandler(file_handler)
def _is_valid_python_code(code_string: str) -> bool:
"""
Check if a string contains valid Python code or is natural language.
Args:
code_string: String to check
Returns:
bool: True if it's valid Python code, False if it's natural language
"""
try:
# Strip whitespace and check for common natural language patterns
stripped = code_string.strip()
# Check for obvious natural language patterns
natural_language_indicators = [
'analyze', 'create', 'show', 'display', 'plot', 'visualize',
'tell me', 'give me', 'what is', 'how many', 'find'
]
# If it starts with typical natural language words, it's likely not Python
first_words = stripped.lower().split()[:3]
if any(indicator in ' '.join(first_words) for indicator in natural_language_indicators):
return False
# Try to compile as Python code
compile(stripped, '<string>', 'exec')
return True
except SyntaxError:
return False
except Exception:
return False
# Data analysis templates
ANALYSIS_TEMPLATES = {
"summary": """
# Data Summary Analysis
# User request: {custom_request}
import pandas as pd
import numpy as np
# Load the data
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
print("=== DATA SUMMARY ===")
print(f"Shape: {{df.shape}}")
print(f"Columns: {{list(df.columns)}}")
print("\\n=== DATA TYPES ===")
print(df.dtypes)
print("\\n=== MISSING VALUES ===")
print(df.isnull().sum())
print("\\n=== BASIC STATISTICS ===")
print(df.describe())
""",
"correlation": """
# Correlation Analysis
# User request: {custom_request}
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# Load the data
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
# Select only numeric columns
numeric_df = df.select_dtypes(include=[np.number])
if len(numeric_df.columns) > 1:
# Calculate correlation matrix
correlation_matrix = numeric_df.corr()
# Create correlation heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
square=True, linewidths=0.5)
plt.title('Correlation Matrix')
plt.tight_layout()
plt.savefig('{output_path}')
plt.close()
print("=== CORRELATION ANALYSIS ===")
print(correlation_matrix)
# Find strong correlations
strong_corr = []
for i in range(len(correlation_matrix.columns)):
for j in range(i+1, len(correlation_matrix.columns)):
corr_val = correlation_matrix.iloc[i, j]
if abs(corr_val) > 0.7:
strong_corr.append((correlation_matrix.columns[i],
correlation_matrix.columns[j], corr_val))
if strong_corr:
print("\\n=== STRONG CORRELATIONS (|r| > 0.7) ===")
for col1, col2, corr in strong_corr:
print(f"{{col1}} <-> {{col2}}: {{corr:.3f}}")
else:
print("Not enough numeric columns for correlation analysis")
""",
"distribution": """
# Distribution Analysis
# User request: {custom_request}
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# Load the data
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
# Select numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
# Create distribution plots
n_cols = min(len(numeric_cols), 4)
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
if n_rows == 1 and n_cols == 1:
axes = [axes]
elif n_rows == 1:
axes = list(axes)
else:
axes = axes.flatten()
for i, col in enumerate(numeric_cols):
if i < len(axes):
df[col].dropna().hist(bins=30, alpha=0.7, edgecolor='black', ax=axes[i])
axes[i].set_title(f'Distribution of {{col}}')
axes[i].set_xlabel(col)
axes[i].set_ylabel('Frequency')
# Hide extra subplots
for i in range(len(numeric_cols), len(axes)):
axes[i].set_visible(False)
plt.tight_layout()
plt.savefig('{output_path}')
plt.close()
print("=== DISTRIBUTION ANALYSIS ===")
for col in numeric_cols:
print(f"\\n{{col}}:")
print(f" Mean: {{df[col].mean():.2f}}")
print(f" Median: {{df[col].median():.2f}}")
print(f" Std: {{df[col].std():.2f}}")
print(f" Skewness: {{df[col].skew():.2f}}")
else:
print("No numeric columns found for distribution analysis")
""",
"comprehensive": """
# Comprehensive Data Analysis
# User request: {custom_request}
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# Load the data
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
print("=== COMPREHENSIVE DATA ANALYSIS ===")
print(f"Dataset shape: {{df.shape}}")
print(f"Columns: {{list(df.columns)}}")
# Basic info
print("\\n=== DATA TYPES ===")
print(df.dtypes)
print("\\n=== MISSING VALUES ===")
missing = df.isnull().sum()
print(missing[missing > 0])
print("\\n=== BASIC STATISTICS ===")
print(df.describe())
# Numeric analysis
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
print("\\n=== NUMERIC COLUMNS ANALYSIS ===")
# Create subplot layout
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 1. Correlation heatmap
if len(numeric_cols) > 1:
corr_matrix = df[numeric_cols].corr()
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', ax=axes[0,0])
axes[0,0].set_title('Correlation Matrix')
# 2. Distribution of first numeric column
if len(numeric_cols) >= 1:
df[numeric_cols[0]].hist(bins=30, ax=axes[0,1])
axes[0,1].set_title(f'Distribution of {{numeric_cols[0]}}')
# 3. Box plot of numeric columns
if len(numeric_cols) <= 5:
df[numeric_cols].boxplot(ax=axes[1,0])
axes[1,0].set_title('Box Plot of Numeric Columns')
axes[1,0].tick_params(axis='x', rotation=45)
# 4. Pairplot for first few numeric columns
if len(numeric_cols) >= 2:
scatter_cols = numeric_cols[:min(3, len(numeric_cols))]
if len(scatter_cols) == 2:
axes[1,1].scatter(df[scatter_cols[0]], df[scatter_cols[1]], alpha=0.6)
axes[1,1].set_xlabel(scatter_cols[0])
axes[1,1].set_ylabel(scatter_cols[1])
axes[1,1].set_title(f'{{scatter_cols[0]}} vs {{scatter_cols[1]}}')
plt.tight_layout()
plt.savefig('{output_path}')
plt.close()
# Categorical analysis
categorical_cols = df.select_dtypes(include=['object']).columns
if len(categorical_cols) > 0:
print("\\n=== CATEGORICAL COLUMNS ANALYSIS ===")
for col in categorical_cols[:3]: # Limit to first 3 categorical columns
print(f"\\n{{col}}:")
print(df[col].value_counts().head())
"""
}
async def install_packages(packages: List[str]) -> Dict[str, Any]:
"""
Install Python packages in a sandboxed environment.
Args:
packages: List of package names to install
Returns:
Dict containing installation results
"""
try:
import subprocess
installed = []
failed = []
for package in packages:
try:
# Use pip to install package
result = subprocess.run([
sys.executable, "-m", "pip", "install", package
], capture_output=True, text=True, timeout=120)
if result.returncode == 0:
installed.append(package)
logger.info(f"Successfully installed package: {package}")
else:
failed.append({"package": package, "error": result.stderr})
logger.error(f"Failed to install package {package}: {result.stderr}")
except subprocess.TimeoutExpired:
failed.append({"package": package, "error": "Installation timeout"})
logger.error(f"Installation timeout for package: {package}")
except Exception as e:
failed.append({"package": package, "error": str(e)})
logger.error(f"Error installing package {package}: {str(e)}")
return {
"success": True,
"installed": installed,
"failed": failed,
"message": f"Installed {len(installed)} packages, {len(failed)} failed"
}
except Exception as e:
logger.error(f"Error in package installation: {str(e)}")
return {
"success": False,
"error": str(e),
"installed": [],
"failed": packages
}
async def analyze_data_file(args: Dict[str, Any]) -> Dict[str, Any]:
"""
Analyze data files with pre-built templates and custom analysis.
Args:
args: Dictionary containing:
- file_path: Path to the data file (CSV/Excel)
- analysis_type: Type of analysis (summary, correlation, distribution, comprehensive)
- custom_analysis: Optional custom analysis request in natural language
- user_id: Optional user ID for file management
- install_packages: Optional list of packages to install
Returns:
Dict containing analysis results
"""
try:
if not LIBRARIES_AVAILABLE:
return {
"success": False,
"error": "Data analysis libraries not available. Please install pandas, numpy, matplotlib, seaborn."
}
file_path = args.get("file_path", "")
analysis_type = args.get("analysis_type", "comprehensive")
custom_analysis = args.get("custom_analysis", "")
user_id = args.get("user_id")
packages_to_install = args.get("install_packages", [])
# Install packages if requested
if packages_to_install:
install_result = await install_packages(packages_to_install)
if not install_result["success"]:
logger.warning(f"Package installation issues: {install_result}")
# Validate file path
if not file_path or not os.path.exists(file_path):
return {
"success": False,
"error": f"Data file not found: {file_path}"
}
# Check file extension
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in ['.csv', '.xlsx', '.xls']:
return {
"success": False,
"error": "Unsupported file format. Please use CSV or Excel files."
}
# Generate output path for visualizations
timestamp = int(time.time())
output_filename = f"analysis_{user_id or 'user'}_{timestamp}.png"
output_path = format_output_path(output_filename)
# Determine analysis code
if custom_analysis:
# Check if custom_analysis contains valid Python code or is natural language
is_python_code = _is_valid_python_code(custom_analysis)
if is_python_code:
# Generate custom analysis code with valid Python
code = f"""
# Custom Data Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# Load the data
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
print("=== CUSTOM DATA ANALYSIS ===")
print(f"Dataset loaded: {{df.shape}}")
# Custom analysis based on user request
{custom_analysis}
# Save any plots
if plt.get_fignums():
plt.savefig('{output_path}')
plt.close()
"""
else:
# For natural language queries, use comprehensive analysis with comment
logger.info(f"Natural language query detected: {custom_analysis}")
analysis_type = "comprehensive"
code = ANALYSIS_TEMPLATES[analysis_type].format(
file_path=file_path,
output_path=output_path,
custom_request=custom_analysis
)
else:
# Use predefined template
if analysis_type not in ANALYSIS_TEMPLATES:
analysis_type = "comprehensive"
# Format template with default values
template_vars = {
'file_path': file_path,
'output_path': output_path,
'custom_request': custom_analysis or 'General data analysis'
}
code = ANALYSIS_TEMPLATES[analysis_type].format(**template_vars)
# Execute the analysis code
result = await execute_analysis_code(code, output_path)
# Add file information to result
result.update({
"file_path": file_path,
"analysis_type": analysis_type,
"custom_analysis": bool(custom_analysis)
})
# Clean up old files
clean_old_files()
return result
except Exception as e:
error_msg = f"Error in data analysis: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
return {
"success": False,
"error": error_msg,
"traceback": traceback.format_exc()
}
async def execute_analysis_code(code: str, output_path: str) -> Dict[str, Any]:
"""
Execute data analysis code in a controlled environment.
Args:
code: Python code to execute
output_path: Path where visualizations should be saved
Returns:
Dict containing execution results
"""
try:
# Capture stdout
old_stdout = sys.stdout
sys.stdout = captured_output = io.StringIO()
# Create a controlled execution environment
exec_globals = {
"__builtins__": __builtins__,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"print": print,
}
# Try to import plotly if available
try:
exec_globals["go"] = go
exec_globals["px"] = px
except:
pass
# Execute the code
exec(code, exec_globals)
# Restore stdout
sys.stdout = old_stdout
# Get the output
output = captured_output.getvalue()
# Check if visualization was created
visualizations = []
if os.path.exists(output_path):
visualizations.append(output_path)
logger.info(f"Data analysis executed successfully, output length: {len(output)}")
return {
"success": True,
"output": output,
"visualizations": visualizations,
"has_visualization": len(visualizations) > 0
}
except Exception as e:
# Restore stdout
sys.stdout = old_stdout
error_msg = f"Error executing analysis code: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
return {
"success": False,
"error": error_msg,
"output": captured_output.getvalue() if 'captured_output' in locals() else "",
"traceback": traceback.format_exc()
}
# Utility function to validate data analysis requests
def validate_analysis_request(args: Dict[str, Any]) -> Tuple[bool, str]:
"""
Validate data analysis request parameters.
Args:
args: Analysis request arguments
Returns:
Tuple of (is_valid, error_message)
"""
required_fields = ["file_path"]
for field in required_fields:
if field not in args or not args[field]:
return False, f"Missing required field: {field}"
# Validate analysis type
analysis_type = args.get("analysis_type", "comprehensive")
valid_types = list(ANALYSIS_TEMPLATES.keys())
if analysis_type not in valid_types:
return False, f"Invalid analysis type. Valid types: {valid_types}"
return True, ""

View File

@@ -25,26 +25,75 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
"""
Returns the list of tools available to the model.
IMPORTANT: CODE EXECUTION HAS BEEN SEPARATED INTO TWO SPECIALIZED TOOLS:
1. execute_python_code: For general programming, calculations, algorithms, custom code
- Use when: math problems, programming tasks, custom scripts, algorithm implementations
- Can create visualizations from scratch
- Has sandboxed environment with package installation
- Automatically gets context about uploaded files when available
2. analyze_data_file: For structured data analysis from CSV/Excel files
- Use when: user explicitly requests data analysis, statistics, or insights from data files
- Has pre-built analysis templates (summary, correlation, distribution, comprehensive)
- Automatically handles data loading and creates appropriate visualizations
- Specialized for data science workflows
- Best for quick data exploration and standard analysis
AI MODEL GUIDANCE FOR DATA FILE UPLOADS:
- When user uploads a data file (.csv, .xlsx, .xls) to Discord:
* File is automatically downloaded and saved
* User intent is detected (data analysis vs. general programming)
* If intent is "data analysis" → analyze_data_file is automatically called
* If intent is "general programming" → file context is added to conversation
- When to use analyze_data_file:
* User uploads data file and asks for analysis, insights, statistics
* User wants standard data exploration (correlations, distributions, summaries)
* User requests "analyze this data" type queries
- When to use execute_python_code:
* User asks for calculations, math problems, algorithms
* User wants custom code or programming solutions
* User uploads data file but wants custom processing (not standard analysis)
* User needs specific data transformations or custom visualizations
* File paths will be automatically provided in the execution environment
DISCORD FILE INTEGRATION:
- Data files uploaded to Discord are automatically downloaded and saved
- File paths are automatically provided to the appropriate tools
- Context about uploaded files is added to Python execution environment
- Visualizations are automatically uploaded to Discord and displayed
Returns:
List of tool objects
List of tool objects for the OpenAI API
"""
return [
{
"type": "function",
"function": {
"name": "analyze_data_file",
"description": "Analyze a data file (CSV or Excel) and generate visualizations. Use this tool when a user uploads a data file and wants insights or visualizations. The visualizations will be automatically displayed in Discord. When describing the results, refer to visualizations by their chart_id and explain what they show. Always inform the user they can see the visualizations directly in the Discord chat.",
"description": "**DATA ANALYSIS TOOL** - Use this tool for structured data analysis from CSV/Excel files. This tool specializes in analyzing data with pre-built templates (summary, correlation, distribution, comprehensive) and generates appropriate visualizations automatically. Use this tool when: (1) User explicitly requests data analysis or insights, (2) User asks for statistics, correlations, or data exploration, (3) User wants standard data science analysis. The tool handles file loading, data validation, and creates visualizations that are automatically displayed in Discord.",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the data file to analyze"
"description": "Path to the data file to analyze (CSV or Excel format required)"
},
"analysis_type": {
"type": "string",
"description": "Type of analysis to perform (e.g., 'summary', 'correlation', 'distribution')",
"enum": ["summary", "correlation", "distribution", "comprehensive"]
"description": "Type of pre-built analysis template to use",
"enum": ["summary", "correlation", "distribution", "comprehensive"],
"default": "comprehensive"
},
"custom_analysis": {
"type": "string",
"description": "Optional custom analysis request in natural language. If provided, this overrides the analysis_type and generates custom Python code for the specific analysis requested."
},
"install_packages": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of Python packages to install before analysis"
}
},
"required": ["file_path"]
@@ -279,41 +328,39 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
"required": ["prompt"]
}
}
},
{
"type": "function",
"function": {
"name": "code_interpreter",
"description": "Execute Python code to solve problems, perform calculations, or create data visualizations. Use this for data analysis, generating charts, and processing data. When analyzing data, ALWAYS include code for visualizations (using matplotlib, seaborn, or plotly) if the user requests charts or graphs. When visualizations are created, tell the user they can view the charts directly in Discord, and reference visualizations by their chart_id in your descriptions.",
}, {
"type": "function", "function": {
"name": "execute_python_code",
"description": "**GENERAL PYTHON EXECUTION TOOL** - Use this tool for general programming tasks, mathematical calculations, algorithm implementations, and custom Python scripts. Use this tool when: (1) User asks for calculations or math problems, (2) User wants to run custom Python code, (3) User needs algorithm implementations, (4) User requests programming solutions, (5) Creating custom visualizations or data processing, (6) User uploads data files but wants custom code rather than standard analysis. File paths for uploaded data files are automatically made available in the execution environment.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The Python code to execute. For data analysis, include necessary imports (pandas, matplotlib, etc.) and visualization code."
"description": "The Python code to execute. Include all necessary imports and ensure code is complete and runnable."
},
"language": {
"input_data": {
"type": "string",
"description": "Programming language (only Python supported)",
"enum": ["python", "py"]
"description": "Optional input data to be made available to the code as a variable named 'input_data'"
},
"input": {
"type": "string",
"description": "Optional input data for the code"
"install_packages": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of Python packages to install before execution (e.g., ['numpy', 'matplotlib'])"
},
"file_path": {
"type": "string",
"description": "Optional path to a data file to analyze (supports CSV and Excel files)"
},
"analysis_request": {
"type": "string",
"description": "Natural language description of the analysis to perform. If this includes visualization requests, the generated code must include plotting code using matplotlib, seaborn, or plotly."
},
"include_visualization": {
"enable_visualization": {
"type": "boolean",
"description": "Whether to include visualizations using matplotlib/seaborn"
"description": "Set to true when creating charts, graphs, or any visual output with matplotlib/seaborn"
},
"timeout": {
"type": "integer",
"description": "Maximum execution time in seconds (default: 30, max: 120)",
"default": 30,
"minimum": 1,
"maximum": 120
}
}
},
"required": ["code"]
}
}
},

View File

@@ -0,0 +1,415 @@
import os
import sys
import io
import re
import logging
import asyncio
import subprocess
import tempfile
import time
import uuid
from logging.handlers import RotatingFileHandler
import traceback
import contextlib
from typing import Dict, Any, Optional, List
# Import utility functions
from .code_utils import DATA_FILES_DIR, format_output_path, clean_old_files
# Configure logging
log_file = 'logs/code_interpreter.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
file_handler.setFormatter(formatter)
logger = logging.getLogger('code_interpreter')
logger.setLevel(logging.INFO)
logger.addHandler(file_handler)
# Regular expression to find image file paths in output
IMAGE_PATH_PATTERN = r'(\/media\/quocanh\/.*\.(png|jpg|jpeg|gif))'
# Unsafe patterns for code security
UNSAFE_IMPORTS = [
r'import\s+os\b', r'from\s+os\s+import',
r'import\s+subprocess\b', r'from\s+subprocess\s+import',
r'import\s+shutil\b', r'from\s+shutil\s+import',
r'__import__\([\'"]os[\'"]\)', r'__import__\([\'"]subprocess[\'"]\)',
r'import\s+sys\b(?!\s+import\s+path)', r'from\s+sys\s+import'
]
UNSAFE_FUNCTIONS = [
r'os\.', r'subprocess\.', r'shutil\.',
r'eval\(', r'exec\(', r'sys\.',
r'open\([\'"][^\'"]*/[^\']*[\'"]', # File system access
r'__import__\(', r'globals\(\)', r'locals\(\)'
]
def sanitize_python_code(code: str) -> tuple[bool, str]:
"""
Check Python code for potentially unsafe operations.
Args:
code: The code to check
Returns:
Tuple of (is_safe, sanitized_code_or_error_message)
"""
# Check for unsafe imports
for pattern in UNSAFE_IMPORTS:
if re.search(pattern, code):
return False, f"Forbidden import detected: {pattern}"
# Check for unsafe function calls
for pattern in UNSAFE_FUNCTIONS:
if re.search(pattern, code):
return False, f"Forbidden function call detected: {pattern}"
# Add safety imports and commonly used libraries
safe_imports = """
import math
import random
import json
import time
from datetime import datetime, timedelta
import collections
import itertools
import functools
try:
import numpy as np
except ImportError:
pass
try:
import pandas as pd
except ImportError:
pass
try:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
except ImportError:
pass
try:
import seaborn as sns
except ImportError:
pass
"""
return True, safe_imports + "\n" + code
async def install_packages(packages: List[str]) -> Dict[str, Any]:
"""
Install Python packages in a sandboxed environment.
Args:
packages: List of package names to install
Returns:
Dict containing installation results
"""
try:
installed = []
failed = []
for package in packages:
try:
# Use pip to install package with timeout
result = subprocess.run([
sys.executable, "-m", "pip", "install", package, "--user", "--quiet"
], capture_output=True, text=True, timeout=120)
if result.returncode == 0:
installed.append(package)
logger.info(f"Successfully installed package: {package}")
else:
failed.append({"package": package, "error": result.stderr})
logger.error(f"Failed to install package {package}: {result.stderr}")
except subprocess.TimeoutExpired:
failed.append({"package": package, "error": "Installation timeout"})
logger.error(f"Installation timeout for package: {package}")
except Exception as e:
failed.append({"package": package, "error": str(e)})
logger.error(f"Error installing package {package}: {str(e)}")
return {
"success": True,
"installed": installed,
"failed": failed,
"message": f"Installed {len(installed)} packages, {len(failed)} failed"
}
except Exception as e:
logger.error(f"Error in package installation: {str(e)}")
return {
"success": False,
"error": str(e),
"installed": [],
"failed": packages
}
async def execute_python_code(args: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute Python code in a controlled sandbox environment.
Args:
args: Dictionary containing:
- code: The Python code to execute
- input: Optional input data for the code
- install_packages: Optional list of packages to install
- timeout: Optional timeout in seconds (default: 30)
Returns:
Dict containing execution results
"""
try:
code = args.get("code", "")
input_data = args.get("input", "")
packages_to_install = args.get("install_packages", [])
timeout = args.get("timeout", 30)
if not code:
return {
"success": False,
"error": "No code provided",
"output": ""
}
# Install packages if requested
if packages_to_install:
install_result = await install_packages(packages_to_install)
if not install_result["success"]:
logger.warning(f"Package installation issues: {install_result}")
# Sanitize the code
is_safe, sanitized_code = sanitize_python_code(code)
if not is_safe:
logger.warning(f"Code sanitization failed: {sanitized_code}")
return {
"success": False,
"error": sanitized_code,
"output": ""
}
# Clean up old files before execution
clean_old_files()
# Execute code in controlled environment
result = await execute_code_safely(sanitized_code, input_data, timeout)
return result
except Exception as e:
error_msg = f"Error in Python code execution: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
return {
"success": False,
"error": error_msg,
"output": "",
"traceback": traceback.format_exc()
}
async def execute_code_safely(code: str, input_data: str, timeout: int) -> Dict[str, Any]:
"""
Execute code in a safe environment with proper isolation.
Args:
code: Sanitized Python code to execute
input_data: Input data for the code
timeout: Execution timeout in seconds
Returns:
Dict containing execution results
"""
try:
# Capture stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
# Import commonly used libraries for the execution environment
try:
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
except ImportError:
plt = None
try:
import numpy as np
except ImportError:
np = None
try:
import pandas as pd
except ImportError:
pd = None
# Create execution namespace
exec_globals = {
"__builtins__": {
# Safe builtins
"print": print,
"len": len,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sum": sum,
"min": min,
"max": max,
"abs": abs,
"round": round,
"sorted": sorted,
"reversed": reversed,
"list": list,
"dict": dict,
"set": set,
"tuple": tuple,
"str": str,
"int": int,
"float": float,
"bool": bool,
"type": type,
"isinstance": isinstance,
"hasattr": hasattr,
"getattr": getattr,
"setattr": setattr,
"dir": dir,
"help": help,
"__import__": __import__, # Allow controlled imports
"ValueError": ValueError,
"TypeError": TypeError,
"IndexError": IndexError,
"KeyError": KeyError,
"AttributeError": AttributeError,
"ImportError": ImportError,
"Exception": Exception,
},
# Add available libraries
"math": __import__("math"),
"random": __import__("random"),
"json": __import__("json"),
"time": __import__("time"),
"datetime": __import__("datetime"),
"collections": __import__("collections"),
"itertools": __import__("itertools"),
"functools": __import__("functools"),
}
# Add optional libraries if available
if np is not None:
exec_globals["np"] = np
exec_globals["numpy"] = np
if pd is not None:
exec_globals["pd"] = pd
exec_globals["pandas"] = pd
if plt is not None:
exec_globals["plt"] = plt
exec_globals["matplotlib"] = matplotlib
# Override input function if input_data is provided
if input_data:
input_lines = input_data.strip().split('\n')
input_iter = iter(input_lines)
exec_globals["input"] = lambda prompt="": next(input_iter, "")
# Set up output capture
sys.stdout = stdout_capture
sys.stderr = stderr_capture
# Generate output file path for any plots
timestamp = int(time.time())
output_filename = f"python_output_{timestamp}.png"
output_path = format_output_path(output_filename)
# Execute the code with timeout
try:
# Use asyncio.wait_for for timeout
await asyncio.wait_for(
asyncio.to_thread(exec, code, exec_globals),
timeout=timeout
)
# Check for any matplotlib figures and save them
visualizations = []
if plt is not None and plt.get_fignums():
for i, fig_num in enumerate(plt.get_fignums()):
try:
fig = plt.figure(fig_num)
if len(fig.get_axes()) > 0:
# Save to output path
fig_path = output_path.replace('.png', f'_{i}.png')
fig.savefig(fig_path, bbox_inches='tight', dpi=150)
visualizations.append(fig_path)
plt.close(fig)
except Exception as e:
logger.error(f"Error saving figure {i}: {str(e)}")
# Clear all figures
plt.close('all')
except asyncio.TimeoutError:
return {
"success": False,
"error": f"Code execution timed out after {timeout} seconds",
"output": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue()
}
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
# Get the outputs
stdout_output = stdout_capture.getvalue()
stderr_output = stderr_capture.getvalue()
# Check for any image paths in the output
image_paths = re.findall(IMAGE_PATH_PATTERN, stdout_output)
for img_path in image_paths:
if os.path.exists(img_path):
visualizations.append(img_path)
# Remove image paths from output text
clean_output = stdout_output
for img_path in image_paths:
clean_output = clean_output.replace(img_path, "[Image saved]")
logger.info(f"Python code executed successfully, output length: {len(clean_output)}")
if visualizations:
logger.info(f"Generated {len(visualizations)} visualizations")
return {
"success": True,
"output": clean_output,
"stderr": stderr_output,
"visualizations": visualizations,
"has_visualization": len(visualizations) > 0,
"execution_time": f"Completed in under {timeout}s"
}
except Exception as e:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
error_msg = f"Error executing Python code: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
return {
"success": False,
"error": error_msg,
"output": stdout_capture.getvalue() if 'stdout_capture' in locals() else "",
"stderr": stderr_capture.getvalue() if 'stderr_capture' in locals() else "",
"traceback": traceback.format_exc()
}
# Backward compatibility - keep the old function name
async def execute_code(args: Dict[str, Any]) -> Dict[str, Any]:
"""
Backward compatibility wrapper for execute_python_code.
"""
return await execute_python_code(args)

View File

@@ -1 +0,0 @@
# Initialize the tests package

View File

@@ -1,379 +0,0 @@
import asyncio
import unittest
import os
import sys
import json
import io
from unittest.mock import MagicMock, patch, AsyncMock
from dotenv import load_dotenv
import re
# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import modules for testing
from src.database.db_handler import DatabaseHandler
from src.utils.openai_utils import count_tokens, trim_content_to_token_limit, prepare_messages_for_api
from src.utils.code_utils import sanitize_code, extract_code_blocks
from src.utils.web_utils import scrape_web_content
from src.utils.pdf_utils import send_response
class TestDatabaseHandler(unittest.IsolatedAsyncioTestCase):
"""Test database handler functionality"""
def setUp(self):
# Load environment variables
load_dotenv()
# Try to get MongoDB URI from environment
self.mongodb_uri = os.getenv("MONGODB_URI")
self.using_real_db = bool(self.mongodb_uri)
if not self.using_real_db:
# Use mock if no real URI available
self.mock_client_patcher = patch('motor.motor_asyncio.AsyncIOMotorClient')
self.mock_client = self.mock_client_patcher.start()
# Setup mock database and collections
self.mock_db = self.mock_client.return_value.__getitem__.return_value
self.mock_histories = MagicMock()
self.mock_models = MagicMock() # Store mock_models as instance variable
self.mock_db.__getitem__.side_effect = lambda x: {
'user_histories': self.mock_histories,
'user_models': self.mock_models, # Use the instance variable
'whitelist': MagicMock(),
'blacklist': MagicMock()
}[x]
# Initialize handler with mock connection string
self.db_handler = DatabaseHandler("mongodb://localhost:27017")
else:
# Use real database connection
print(f"Testing with real MongoDB at: {self.mongodb_uri}")
self.db_handler = DatabaseHandler(self.mongodb_uri)
# Extract database name from URI for later use
self.db_name = self._extract_db_name_from_uri(self.mongodb_uri)
async def asyncSetUp(self):
# No additional async setup needed, but required by IsolatedAsyncioTestCase
pass
async def asyncTearDown(self):
if not self.using_real_db:
self.mock_client_patcher.stop()
else:
# Clean up test data if using real database
await self.cleanup_test_data()
def _extract_db_name_from_uri(self, uri):
"""Extract database name from MongoDB URI more reliably"""
# Default database name if extraction fails
default_db_name = 'chatgpt_discord_bot'
try:
# Handle standard MongoDB URI format
# mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/database][?options]
match = re.search(r'\/([^/?]+)(\?|$)', uri)
if match:
return match.group(1)
# If no database in URI, return default
return default_db_name
except:
# If any error occurs, return default name
return default_db_name
async def cleanup_test_data(self):
"""Remove test data from real database"""
if self.using_real_db:
try:
# Use the database name we extracted in setUp
db = self.db_handler.client.get_database(self.db_name)
await db.user_histories.delete_one({'user_id': 12345})
await db.user_models.delete_one({'user_id': 12345})
except Exception as e:
print(f"Error during test cleanup: {e}")
async def test_get_history_empty(self):
if self.using_real_db:
# Clean up any existing history first
await self.cleanup_test_data()
# Test with real database
result = await self.db_handler.get_history(12345)
self.assertEqual(result, [])
else:
# Mock find_one to return None (no history)
self.mock_histories.find_one = AsyncMock(return_value=None)
# Test getting non-existent history
result = await self.db_handler.get_history(12345)
self.assertEqual(result, [])
self.mock_histories.find_one.assert_called_once_with({'user_id': 12345})
async def test_get_history_existing(self):
# Sample history data
sample_history = [
{'role': 'user', 'content': 'Hello'},
{'role': 'assistant', 'content': 'Hi there!'}
]
if self.using_real_db:
# Save test history first
await self.db_handler.save_history(12345, sample_history)
# Test getting existing history
result = await self.db_handler.get_history(12345)
self.assertEqual(result, sample_history)
else:
# Mock find_one to return existing history
self.mock_histories.find_one = AsyncMock(return_value={'user_id': 12345, 'history': sample_history})
# Test getting existing history
result = await self.db_handler.get_history(12345)
self.assertEqual(result, sample_history)
async def test_save_history(self):
# Sample history to save
sample_history = [
{'role': 'user', 'content': 'Test message'},
{'role': 'assistant', 'content': 'Test response'}
]
if self.using_real_db:
# Test saving history to real database
await self.db_handler.save_history(12345, sample_history)
# Verify it was saved
result = await self.db_handler.get_history(12345)
self.assertEqual(result, sample_history)
else:
# Mock update_one method
self.mock_histories.update_one = AsyncMock()
# Test saving history
await self.db_handler.save_history(12345, sample_history)
# Verify update_one was called with correct parameters
self.mock_histories.update_one.assert_called_once_with(
{'user_id': 12345},
{'$set': {'history': sample_history}},
upsert=True
)
async def test_user_model_operations(self):
if self.using_real_db:
# Save a model and then retrieve it
await self.db_handler.save_user_model(12345, 'openai/gpt-4o')
model = await self.db_handler.get_user_model(12345)
self.assertEqual(model, 'openai/gpt-4o')
# Test updating model
await self.db_handler.save_user_model(12345, 'openai/gpt-4o-mini')
updated_model = await self.db_handler.get_user_model(12345)
self.assertEqual(updated_model, 'openai/gpt-4o-mini')
else:
# Setup mock for user_models collection
# Use self.mock_models instead of creating a new mock
self.mock_models.find_one = AsyncMock(return_value={'user_id': 12345, 'model': 'openai/gpt-4o'})
self.mock_models.update_one = AsyncMock()
# Test getting user model
model = await self.db_handler.get_user_model(12345)
self.assertEqual(model, 'openai/gpt-4o')
# Test saving user model
await self.db_handler.save_user_model(12345, 'openai/gpt-4o-mini')
self.mock_models.update_one.assert_called_once_with(
{'user_id': 12345},
{'$set': {'model': 'openai/gpt-4o-mini'}},
upsert=True
)
class TestOpenAIUtils(unittest.TestCase):
"""Test OpenAI utility functions"""
def test_count_tokens(self):
# Test token counting
self.assertGreater(count_tokens("Hello, world!"), 0)
self.assertGreater(count_tokens("This is a longer text that should have more tokens."),
count_tokens("Short text"))
def test_trim_content_to_token_limit(self):
# Create a long text
long_text = "This is a test. " * 1000
# Test trimming
trimmed = trim_content_to_token_limit(long_text, 100)
self.assertLess(count_tokens(trimmed), count_tokens(long_text))
self.assertLessEqual(count_tokens(trimmed), 100)
# Test no trimming needed
short_text = "This is a short text."
untrimmed = trim_content_to_token_limit(short_text, 100)
self.assertEqual(untrimmed, short_text)
def test_prepare_messages_for_api(self):
# Test empty messages
empty_result = prepare_messages_for_api([])
self.assertEqual(len(empty_result), 1) # Should have system message
self.assertEqual(empty_result[0]["role"], "system") # Verify it's a system message
# Test regular messages
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
result = prepare_messages_for_api(messages)
self.assertEqual(len(result), 4) # Should have system message + 3 original messages
# Test with null content
messages_with_null = [
{"role": "user", "content": None},
{"role": "assistant", "content": "Response"}
]
result_fixed = prepare_messages_for_api(messages_with_null)
self.assertEqual(len(result_fixed), 2) # Should have system message + 1 valid message
# Verify the content is correct (system message + only the assistant message)
self.assertEqual(result_fixed[0]["role"], "system")
self.assertEqual(result_fixed[1]["role"], "assistant")
self.assertEqual(result_fixed[1]["content"], "Response")
class TestCodeUtils(unittest.TestCase):
"""Test code utility functions"""
def test_sanitize_python_code_safe(self):
# Safe Python code
code = """
def factorial(n):
if n <= 1:
return 1
return n * factorial(n-1)
print(factorial(5))
"""
is_safe, sanitized = sanitize_code(code, "python")
self.assertTrue(is_safe)
self.assertIn("def factorial", sanitized)
def test_sanitize_python_code_unsafe(self):
# Unsafe Python code with os.system
unsafe_code = """
import os
os.system('rm -rf /')
"""
is_safe, message = sanitize_code(unsafe_code, "python")
self.assertFalse(is_safe)
self.assertIn("Forbidden", message)
def test_sanitize_cpp_code_safe(self):
# Safe C++ code
code = """
#include <iostream>
using namespace std;
int main() {
cout << "Hello, world!" << endl;
return 0;
}
"""
is_safe, sanitized = sanitize_code(code, "cpp")
self.assertTrue(is_safe)
self.assertIn("Hello, world!", sanitized)
def test_sanitize_cpp_code_unsafe(self):
# Unsafe C++ code with system
unsafe_code = """
#include <stdlib.h>
int main() {
system("rm -rf /");
return 0;
}
"""
is_safe, message = sanitize_code(unsafe_code, "cpp")
self.assertFalse(is_safe)
self.assertIn("Forbidden", message)
def test_extract_code_blocks(self):
# Test message with code block
message = """
Here's a Python function to calculate factorial:
```python
def factorial(n):
if n <= 1:
return 1
return n * factorial(n-1)
```
And here's a C++ version:
```cpp
int factorial(int n) {
if (n <= 1) return 1;
return n * factorial(n-1);
}
```
"""
blocks = extract_code_blocks(message)
self.assertEqual(len(blocks), 2)
self.assertEqual(blocks[0][0], "python")
self.assertEqual(blocks[1][0], "cpp")
# Test without language specifier
message_no_lang = """
Here's some code:
```
print("Hello world")
```
"""
blocks_no_lang = extract_code_blocks(message_no_lang)
self.assertEqual(len(blocks_no_lang), 1)
#class TestWebUtils(unittest.TestCase):
# """Test web utilities"""
#
# @patch('requests.get')
# def test_scrape_web_content(self, mock_get):
# # Mock the response
# mock_response = MagicMock()
# mock_response.text = '<html><body><h1>Test Heading</h1><p>Test paragraph</p></body></html>'
# mock_response.status_code = 200
# mock_get.return_value = mock_response
#
# # Test scraping
# content = scrape_web_content("example.com")
# self.assertIn("Test Heading", content)
# self.assertIn("Test paragraph", content)
class TestPDFUtils(unittest.IsolatedAsyncioTestCase):
"""Test PDF utilities"""
async def test_send_response(self):
# Create mock channel
mock_channel = AsyncMock()
mock_channel.send = AsyncMock()
# Test sending short response
short_response = "This is a short response"
await send_response(mock_channel, short_response)
mock_channel.send.assert_called_once_with(short_response)
# Reset mock
mock_channel.send.reset_mock()
# Mock for long response (testing would need file operations)
with patch('builtins.open', new_callable=unittest.mock.mock_open):
with patch('discord.File', return_value="mocked_file"):
# Test sending long response
long_response = "X" * 2500 # Over 2000 character limit
await send_response(mock_channel, long_response)
mock_channel.send.assert_called_once()
# Verify it's called with the file argument
args, kwargs = mock_channel.send.call_args
self.assertIn('file', kwargs)
if __name__ == "__main__":
unittest.main()