feat: Implement data analysis and code execution utilities
- Added data_analyzer.py for comprehensive data analysis with templates for summary, correlation, distribution, and custom analysis. - Integrated logging for tracking analysis processes and errors. - Included package installation functionality for required libraries. - Developed python_executor.py to safely execute user-provided Python code with sanitization and timeout features. - Implemented security measures to prevent unsafe code execution. - Enhanced output handling to capture visualizations and standard output.
This commit is contained in:
8
.idea/.gitignore
generated
vendored
8
.idea/.gitignore
generated
vendored
@@ -1,8 +0,0 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
24
.idea/ChatGPT-Discord-Bot.iml
generated
24
.idea/ChatGPT-Discord-Bot.iml
generated
@@ -1,24 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="Flask">
|
||||
<option name="enabled" value="true" />
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv1" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.12 (ChatGPT-Discord-Bot)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
<component name="TemplatesService">
|
||||
<option name="TEMPLATE_CONFIGURATION" value="Jinja2" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="py.test" />
|
||||
</component>
|
||||
</module>
|
||||
13
.idea/dataSources.xml
generated
13
.idea/dataSources.xml
generated
@@ -1,13 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
||||
<data-source source="LOCAL" name="@cluster0.akbzf5r.mongodb.net" uuid="1fd5036f-4eb6-4575-b3b7-1a641f728a88">
|
||||
<driver-ref>documentdb</driver-ref>
|
||||
<synchronize>true</synchronize>
|
||||
<configured-by-url>true</configured-by-url>
|
||||
<jdbc-driver>com.dbschema.MongoJdbcDriver</jdbc-driver>
|
||||
<jdbc-url>mongodb+srv://chatgpt:Anhtt2021@cluster0.akbzf5r.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0</jdbc-url>
|
||||
<working-dir>$ProjectFileDir$</working-dir>
|
||||
</data-source>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
6
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
7
.idea/misc.xml
generated
@@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12 (ChatGPT-Discord-Bot)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (ChatGPT-Discord-Bot)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/ChatGPT-Discord-Bot.iml" filepath="$PROJECT_DIR$/.idea/ChatGPT-Discord-Bot.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
0
logs/data_analyzer.log
Normal file
0
logs/data_analyzer.log
Normal file
Binary file not shown.
Binary file not shown.
@@ -81,7 +81,29 @@ PDF_BATCH_SIZE = 3
|
||||
# Prompt templates
|
||||
WEB_SCRAPING_PROMPT = "You are a Web Scraping Assistant. You analyze content from webpages to extract key information. Integrate insights from the scraped content to give comprehensive, fact-based responses. When analyzing web content: 1) Focus on the most relevant information, 2) Cite specific sections when appropriate, 3) Maintain a neutral tone, and 4) Organize information logically. Present your response in a clear, conversational manner suitable for Discord."
|
||||
|
||||
NORMAL_CHAT_PROMPT = "You're ChatGPT for Discord! You have access to powerful tools that can enhance your responses. When appropriate, use: 1) Google Search (google_search) to find current information, 2) Web Scraping (scrape_webpage) to analyze webpages, 3) Code Interpreter (code_interpreter) to run code in Python to support calculating and analyzing (priority to use python default internal library), and 4) Image Generation (generate_image) to create images from text descriptions, 5) data analysis (analyze_data) to draw chart based on user data file and 6) Reminder (set_reminder) to set a remind based on user request. When solving problems, follow a step-by-step approach: identify what information is needed, determine which tools might help, and explain your reasoning clearly. For code tasks, always share both the code you're running and its output. Craft responses that are easy to read in Discord without any markdown and latex (except for code you must use markdown). You MUST respond in the same language as the user. You MUST use code_interpreter with Python language for your own code for correct of any calculation. All user request MUST be completed in one single request"
|
||||
NORMAL_CHAT_PROMPT = """You're ChatGPT for Discord! You have access to powerful tools that enhance your capabilities:
|
||||
|
||||
🔍 **Information Tools:**
|
||||
- google_search: Find current information from the web
|
||||
- scrape_webpage: Extract and analyze content from websites
|
||||
|
||||
💻 **Programming & Data Tools:**
|
||||
- execute_python_code: Run Python code for calculations, math problems, algorithms, and custom programming tasks
|
||||
- analyze_data_file: Analyze CSV/Excel files with pre-built templates when users request data analysis or insights
|
||||
|
||||
🎨 **Creative Tools:**
|
||||
- generate_image: Create images from text descriptions
|
||||
- edit_image: Modify existing images (remove backgrounds, etc.)
|
||||
|
||||
📅 **Utility Tools:**
|
||||
- set_reminder/get_reminders: Manage user reminders
|
||||
|
||||
**For Data Files:** When users upload CSV/Excel files:
|
||||
- If they ask for "analysis", "insights", "statistics" → file is automatically processed with analyze_data_file
|
||||
- If they want custom programming or specific code → file path is provided to execute_python_code
|
||||
- Both tools can install packages and create visualizations automatically displayed in Discord
|
||||
|
||||
Always explain your approach step-by-step and provide clear, Discord-friendly responses without excessive markdown. You MUST respond in the same language as the user."""
|
||||
|
||||
SEARCH_PROMPT = "You are a Research Assistant with access to Google Search results. Your task is to synthesize information from search results to provide accurate, comprehensive answers. When analyzing search results: 1) Prioritize information from credible sources, 2) Compare and contrast different perspectives when available, 3) Acknowledge when information is limited or unclear, and 4) Cite specific sources when presenting facts. Structure your response in a clear, logical manner, focusing on directly answering the user's question while providing relevant context."
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
BIN
src/utils/__pycache__/data_analyzer.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/data_analyzer.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
src/utils/__pycache__/python_executor.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/python_executor.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,17 +4,18 @@ import io
|
||||
import re
|
||||
import logging
|
||||
import asyncio
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import traceback
|
||||
import contextlib
|
||||
from typing import Dict, Any, Optional, List
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Use non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from .code_utils import sanitize_code, get_temporary_file_path, generate_analysis_code, analyze_data, DATA_FILES_DIR, format_output_path, clean_old_files
|
||||
|
||||
# Import the new separated modules
|
||||
from .python_executor import execute_python_code
|
||||
from .data_analyzer import analyze_data_file
|
||||
|
||||
# Configure logging
|
||||
log_file = 'logs/code_interpreter.log'
|
||||
@@ -26,11 +27,43 @@ logger = logging.getLogger('code_interpreter')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Regular expression to find image file paths in output
|
||||
IMAGE_PATH_PATTERN = r'(sandbox:)?(\/media\/quocanh\/.*\.(png|jpg|jpeg|gif))'
|
||||
|
||||
async def execute_code(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Main entry point for code execution - routes to appropriate handler.
|
||||
|
||||
This function maintains backward compatibility while routing requests
|
||||
to the appropriate specialized handler.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing execution parameters
|
||||
|
||||
Returns:
|
||||
Dict containing execution results
|
||||
"""
|
||||
try:
|
||||
# Check if this is a data analysis request
|
||||
file_path = args.get("file_path", "")
|
||||
analysis_request = args.get("analysis_request", "")
|
||||
|
||||
if file_path and (analysis_request or args.get("analysis_type")):
|
||||
# Route to data analyzer
|
||||
logger.info("Routing to data analyzer")
|
||||
return await analyze_data_file(args)
|
||||
else:
|
||||
# Route to Python executor
|
||||
logger.info("Routing to Python executor")
|
||||
return await execute_python_code(args)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in code execution router: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"output": "",
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
"""
|
||||
Execute code with support for data analysis and visualization.
|
||||
|
||||
Args:
|
||||
|
||||
547
src/utils/data_analyzer.py
Normal file
547
src/utils/data_analyzer.py
Normal file
@@ -0,0 +1,547 @@
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import contextlib
|
||||
import tempfile
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
# Import data analysis libraries
|
||||
try:
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Use non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import plotly.graph_objects as go
|
||||
import plotly.express as px
|
||||
LIBRARIES_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
LIBRARIES_AVAILABLE = False
|
||||
logging.warning(f"Data analysis libraries not available: {str(e)}")
|
||||
|
||||
# Import utility functions
|
||||
from .code_utils import DATA_FILES_DIR, format_output_path, clean_old_files
|
||||
|
||||
# Configure logging
|
||||
log_file = 'logs/data_analyzer.log'
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger = logging.getLogger('data_analyzer')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
def _is_valid_python_code(code_string: str) -> bool:
|
||||
"""
|
||||
Check if a string contains valid Python code or is natural language.
|
||||
|
||||
Args:
|
||||
code_string: String to check
|
||||
|
||||
Returns:
|
||||
bool: True if it's valid Python code, False if it's natural language
|
||||
"""
|
||||
try:
|
||||
# Strip whitespace and check for common natural language patterns
|
||||
stripped = code_string.strip()
|
||||
|
||||
# Check for obvious natural language patterns
|
||||
natural_language_indicators = [
|
||||
'analyze', 'create', 'show', 'display', 'plot', 'visualize',
|
||||
'tell me', 'give me', 'what is', 'how many', 'find'
|
||||
]
|
||||
|
||||
# If it starts with typical natural language words, it's likely not Python
|
||||
first_words = stripped.lower().split()[:3]
|
||||
if any(indicator in ' '.join(first_words) for indicator in natural_language_indicators):
|
||||
return False
|
||||
|
||||
# Try to compile as Python code
|
||||
compile(stripped, '<string>', 'exec')
|
||||
return True
|
||||
except SyntaxError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Data analysis templates
|
||||
ANALYSIS_TEMPLATES = {
|
||||
"summary": """
|
||||
# Data Summary Analysis
|
||||
# User request: {custom_request}
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Load the data
|
||||
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
|
||||
|
||||
print("=== DATA SUMMARY ===")
|
||||
print(f"Shape: {{df.shape}}")
|
||||
print(f"Columns: {{list(df.columns)}}")
|
||||
print("\\n=== DATA TYPES ===")
|
||||
print(df.dtypes)
|
||||
print("\\n=== MISSING VALUES ===")
|
||||
print(df.isnull().sum())
|
||||
print("\\n=== BASIC STATISTICS ===")
|
||||
print(df.describe())
|
||||
""",
|
||||
|
||||
"correlation": """
|
||||
# Correlation Analysis
|
||||
# User request: {custom_request}
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Load the data
|
||||
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
|
||||
|
||||
# Select only numeric columns
|
||||
numeric_df = df.select_dtypes(include=[np.number])
|
||||
|
||||
if len(numeric_df.columns) > 1:
|
||||
# Calculate correlation matrix
|
||||
correlation_matrix = numeric_df.corr()
|
||||
|
||||
# Create correlation heatmap
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
|
||||
square=True, linewidths=0.5)
|
||||
plt.title('Correlation Matrix')
|
||||
plt.tight_layout()
|
||||
plt.savefig('{output_path}')
|
||||
plt.close()
|
||||
|
||||
print("=== CORRELATION ANALYSIS ===")
|
||||
print(correlation_matrix)
|
||||
|
||||
# Find strong correlations
|
||||
strong_corr = []
|
||||
for i in range(len(correlation_matrix.columns)):
|
||||
for j in range(i+1, len(correlation_matrix.columns)):
|
||||
corr_val = correlation_matrix.iloc[i, j]
|
||||
if abs(corr_val) > 0.7:
|
||||
strong_corr.append((correlation_matrix.columns[i],
|
||||
correlation_matrix.columns[j], corr_val))
|
||||
|
||||
if strong_corr:
|
||||
print("\\n=== STRONG CORRELATIONS (|r| > 0.7) ===")
|
||||
for col1, col2, corr in strong_corr:
|
||||
print(f"{{col1}} <-> {{col2}}: {{corr:.3f}}")
|
||||
else:
|
||||
print("Not enough numeric columns for correlation analysis")
|
||||
""",
|
||||
|
||||
"distribution": """
|
||||
# Distribution Analysis
|
||||
# User request: {custom_request}
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Load the data
|
||||
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
|
||||
|
||||
# Select numeric columns
|
||||
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
||||
|
||||
if len(numeric_cols) > 0:
|
||||
# Create distribution plots
|
||||
n_cols = min(len(numeric_cols), 4)
|
||||
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
|
||||
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
|
||||
if n_rows == 1 and n_cols == 1:
|
||||
axes = [axes]
|
||||
elif n_rows == 1:
|
||||
axes = list(axes)
|
||||
else:
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, col in enumerate(numeric_cols):
|
||||
if i < len(axes):
|
||||
df[col].dropna().hist(bins=30, alpha=0.7, edgecolor='black', ax=axes[i])
|
||||
axes[i].set_title(f'Distribution of {{col}}')
|
||||
axes[i].set_xlabel(col)
|
||||
axes[i].set_ylabel('Frequency')
|
||||
|
||||
# Hide extra subplots
|
||||
for i in range(len(numeric_cols), len(axes)):
|
||||
axes[i].set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('{output_path}')
|
||||
plt.close()
|
||||
|
||||
print("=== DISTRIBUTION ANALYSIS ===")
|
||||
for col in numeric_cols:
|
||||
print(f"\\n{{col}}:")
|
||||
print(f" Mean: {{df[col].mean():.2f}}")
|
||||
print(f" Median: {{df[col].median():.2f}}")
|
||||
print(f" Std: {{df[col].std():.2f}}")
|
||||
print(f" Skewness: {{df[col].skew():.2f}}")
|
||||
else:
|
||||
print("No numeric columns found for distribution analysis")
|
||||
""",
|
||||
|
||||
"comprehensive": """
|
||||
# Comprehensive Data Analysis
|
||||
# User request: {custom_request}
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Load the data
|
||||
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
|
||||
|
||||
print("=== COMPREHENSIVE DATA ANALYSIS ===")
|
||||
print(f"Dataset shape: {{df.shape}}")
|
||||
print(f"Columns: {{list(df.columns)}}")
|
||||
|
||||
# Basic info
|
||||
print("\\n=== DATA TYPES ===")
|
||||
print(df.dtypes)
|
||||
|
||||
print("\\n=== MISSING VALUES ===")
|
||||
missing = df.isnull().sum()
|
||||
print(missing[missing > 0])
|
||||
|
||||
print("\\n=== BASIC STATISTICS ===")
|
||||
print(df.describe())
|
||||
|
||||
# Numeric analysis
|
||||
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
||||
if len(numeric_cols) > 0:
|
||||
print("\\n=== NUMERIC COLUMNS ANALYSIS ===")
|
||||
|
||||
# Create subplot layout
|
||||
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
||||
|
||||
# 1. Correlation heatmap
|
||||
if len(numeric_cols) > 1:
|
||||
corr_matrix = df[numeric_cols].corr()
|
||||
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', ax=axes[0,0])
|
||||
axes[0,0].set_title('Correlation Matrix')
|
||||
|
||||
# 2. Distribution of first numeric column
|
||||
if len(numeric_cols) >= 1:
|
||||
df[numeric_cols[0]].hist(bins=30, ax=axes[0,1])
|
||||
axes[0,1].set_title(f'Distribution of {{numeric_cols[0]}}')
|
||||
|
||||
# 3. Box plot of numeric columns
|
||||
if len(numeric_cols) <= 5:
|
||||
df[numeric_cols].boxplot(ax=axes[1,0])
|
||||
axes[1,0].set_title('Box Plot of Numeric Columns')
|
||||
axes[1,0].tick_params(axis='x', rotation=45)
|
||||
|
||||
# 4. Pairplot for first few numeric columns
|
||||
if len(numeric_cols) >= 2:
|
||||
scatter_cols = numeric_cols[:min(3, len(numeric_cols))]
|
||||
if len(scatter_cols) == 2:
|
||||
axes[1,1].scatter(df[scatter_cols[0]], df[scatter_cols[1]], alpha=0.6)
|
||||
axes[1,1].set_xlabel(scatter_cols[0])
|
||||
axes[1,1].set_ylabel(scatter_cols[1])
|
||||
axes[1,1].set_title(f'{{scatter_cols[0]}} vs {{scatter_cols[1]}}')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('{output_path}')
|
||||
plt.close()
|
||||
|
||||
# Categorical analysis
|
||||
categorical_cols = df.select_dtypes(include=['object']).columns
|
||||
if len(categorical_cols) > 0:
|
||||
print("\\n=== CATEGORICAL COLUMNS ANALYSIS ===")
|
||||
for col in categorical_cols[:3]: # Limit to first 3 categorical columns
|
||||
print(f"\\n{{col}}:")
|
||||
print(df[col].value_counts().head())
|
||||
"""
|
||||
}
|
||||
|
||||
async def install_packages(packages: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Install Python packages in a sandboxed environment.
|
||||
|
||||
Args:
|
||||
packages: List of package names to install
|
||||
|
||||
Returns:
|
||||
Dict containing installation results
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
installed = []
|
||||
failed = []
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
# Use pip to install package
|
||||
result = subprocess.run([
|
||||
sys.executable, "-m", "pip", "install", package
|
||||
], capture_output=True, text=True, timeout=120)
|
||||
|
||||
if result.returncode == 0:
|
||||
installed.append(package)
|
||||
logger.info(f"Successfully installed package: {package}")
|
||||
else:
|
||||
failed.append({"package": package, "error": result.stderr})
|
||||
logger.error(f"Failed to install package {package}: {result.stderr}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
failed.append({"package": package, "error": "Installation timeout"})
|
||||
logger.error(f"Installation timeout for package: {package}")
|
||||
except Exception as e:
|
||||
failed.append({"package": package, "error": str(e)})
|
||||
logger.error(f"Error installing package {package}: {str(e)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"installed": installed,
|
||||
"failed": failed,
|
||||
"message": f"Installed {len(installed)} packages, {len(failed)} failed"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in package installation: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"installed": [],
|
||||
"failed": packages
|
||||
}
|
||||
|
||||
async def analyze_data_file(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze data files with pre-built templates and custom analysis.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing:
|
||||
- file_path: Path to the data file (CSV/Excel)
|
||||
- analysis_type: Type of analysis (summary, correlation, distribution, comprehensive)
|
||||
- custom_analysis: Optional custom analysis request in natural language
|
||||
- user_id: Optional user ID for file management
|
||||
- install_packages: Optional list of packages to install
|
||||
|
||||
Returns:
|
||||
Dict containing analysis results
|
||||
"""
|
||||
try:
|
||||
if not LIBRARIES_AVAILABLE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Data analysis libraries not available. Please install pandas, numpy, matplotlib, seaborn."
|
||||
}
|
||||
|
||||
file_path = args.get("file_path", "")
|
||||
analysis_type = args.get("analysis_type", "comprehensive")
|
||||
custom_analysis = args.get("custom_analysis", "")
|
||||
user_id = args.get("user_id")
|
||||
packages_to_install = args.get("install_packages", [])
|
||||
|
||||
# Install packages if requested
|
||||
if packages_to_install:
|
||||
install_result = await install_packages(packages_to_install)
|
||||
if not install_result["success"]:
|
||||
logger.warning(f"Package installation issues: {install_result}")
|
||||
|
||||
# Validate file path
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Data file not found: {file_path}"
|
||||
}
|
||||
|
||||
# Check file extension
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if file_ext not in ['.csv', '.xlsx', '.xls']:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Unsupported file format. Please use CSV or Excel files."
|
||||
}
|
||||
|
||||
# Generate output path for visualizations
|
||||
timestamp = int(time.time())
|
||||
output_filename = f"analysis_{user_id or 'user'}_{timestamp}.png"
|
||||
output_path = format_output_path(output_filename)
|
||||
|
||||
# Determine analysis code
|
||||
if custom_analysis:
|
||||
# Check if custom_analysis contains valid Python code or is natural language
|
||||
is_python_code = _is_valid_python_code(custom_analysis)
|
||||
|
||||
if is_python_code:
|
||||
# Generate custom analysis code with valid Python
|
||||
code = f"""
|
||||
# Custom Data Analysis
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Load the data
|
||||
df = pd.read_csv('{file_path}') if '{file_path}'.endswith('.csv') else pd.read_excel('{file_path}')
|
||||
|
||||
print("=== CUSTOM DATA ANALYSIS ===")
|
||||
print(f"Dataset loaded: {{df.shape}}")
|
||||
|
||||
# Custom analysis based on user request
|
||||
{custom_analysis}
|
||||
|
||||
# Save any plots
|
||||
if plt.get_fignums():
|
||||
plt.savefig('{output_path}')
|
||||
plt.close()
|
||||
"""
|
||||
else:
|
||||
# For natural language queries, use comprehensive analysis with comment
|
||||
logger.info(f"Natural language query detected: {custom_analysis}")
|
||||
analysis_type = "comprehensive"
|
||||
code = ANALYSIS_TEMPLATES[analysis_type].format(
|
||||
file_path=file_path,
|
||||
output_path=output_path,
|
||||
custom_request=custom_analysis
|
||||
)
|
||||
else:
|
||||
# Use predefined template
|
||||
if analysis_type not in ANALYSIS_TEMPLATES:
|
||||
analysis_type = "comprehensive"
|
||||
|
||||
# Format template with default values
|
||||
template_vars = {
|
||||
'file_path': file_path,
|
||||
'output_path': output_path,
|
||||
'custom_request': custom_analysis or 'General data analysis'
|
||||
}
|
||||
code = ANALYSIS_TEMPLATES[analysis_type].format(**template_vars)
|
||||
|
||||
# Execute the analysis code
|
||||
result = await execute_analysis_code(code, output_path)
|
||||
|
||||
# Add file information to result
|
||||
result.update({
|
||||
"file_path": file_path,
|
||||
"analysis_type": analysis_type,
|
||||
"custom_analysis": bool(custom_analysis)
|
||||
})
|
||||
|
||||
# Clean up old files
|
||||
clean_old_files()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in data analysis: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
|
||||
async def execute_analysis_code(code: str, output_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute data analysis code in a controlled environment.
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
output_path: Path where visualizations should be saved
|
||||
|
||||
Returns:
|
||||
Dict containing execution results
|
||||
"""
|
||||
try:
|
||||
# Capture stdout
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_output = io.StringIO()
|
||||
|
||||
# Create a controlled execution environment
|
||||
exec_globals = {
|
||||
"__builtins__": __builtins__,
|
||||
"pd": pd,
|
||||
"np": np,
|
||||
"plt": plt,
|
||||
"sns": sns,
|
||||
"print": print,
|
||||
}
|
||||
|
||||
# Try to import plotly if available
|
||||
try:
|
||||
exec_globals["go"] = go
|
||||
exec_globals["px"] = px
|
||||
except:
|
||||
pass
|
||||
|
||||
# Execute the code
|
||||
exec(code, exec_globals)
|
||||
|
||||
# Restore stdout
|
||||
sys.stdout = old_stdout
|
||||
|
||||
# Get the output
|
||||
output = captured_output.getvalue()
|
||||
|
||||
# Check if visualization was created
|
||||
visualizations = []
|
||||
if os.path.exists(output_path):
|
||||
visualizations.append(output_path)
|
||||
|
||||
logger.info(f"Data analysis executed successfully, output length: {len(output)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": output,
|
||||
"visualizations": visualizations,
|
||||
"has_visualization": len(visualizations) > 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Restore stdout
|
||||
sys.stdout = old_stdout
|
||||
|
||||
error_msg = f"Error executing analysis code: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"output": captured_output.getvalue() if 'captured_output' in locals() else "",
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
|
||||
# Utility function to validate data analysis requests
|
||||
def validate_analysis_request(args: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate data analysis request parameters.
|
||||
|
||||
Args:
|
||||
args: Analysis request arguments
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
required_fields = ["file_path"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in args or not args[field]:
|
||||
return False, f"Missing required field: {field}"
|
||||
|
||||
# Validate analysis type
|
||||
analysis_type = args.get("analysis_type", "comprehensive")
|
||||
valid_types = list(ANALYSIS_TEMPLATES.keys())
|
||||
|
||||
if analysis_type not in valid_types:
|
||||
return False, f"Invalid analysis type. Valid types: {valid_types}"
|
||||
|
||||
return True, ""
|
||||
@@ -25,26 +25,75 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Returns the list of tools available to the model.
|
||||
|
||||
IMPORTANT: CODE EXECUTION HAS BEEN SEPARATED INTO TWO SPECIALIZED TOOLS:
|
||||
|
||||
1. execute_python_code: For general programming, calculations, algorithms, custom code
|
||||
- Use when: math problems, programming tasks, custom scripts, algorithm implementations
|
||||
- Can create visualizations from scratch
|
||||
- Has sandboxed environment with package installation
|
||||
- Automatically gets context about uploaded files when available
|
||||
|
||||
2. analyze_data_file: For structured data analysis from CSV/Excel files
|
||||
- Use when: user explicitly requests data analysis, statistics, or insights from data files
|
||||
- Has pre-built analysis templates (summary, correlation, distribution, comprehensive)
|
||||
- Automatically handles data loading and creates appropriate visualizations
|
||||
- Specialized for data science workflows
|
||||
- Best for quick data exploration and standard analysis
|
||||
|
||||
AI MODEL GUIDANCE FOR DATA FILE UPLOADS:
|
||||
- When user uploads a data file (.csv, .xlsx, .xls) to Discord:
|
||||
* File is automatically downloaded and saved
|
||||
* User intent is detected (data analysis vs. general programming)
|
||||
* If intent is "data analysis" → analyze_data_file is automatically called
|
||||
* If intent is "general programming" → file context is added to conversation
|
||||
|
||||
- When to use analyze_data_file:
|
||||
* User uploads data file and asks for analysis, insights, statistics
|
||||
* User wants standard data exploration (correlations, distributions, summaries)
|
||||
* User requests "analyze this data" type queries
|
||||
|
||||
- When to use execute_python_code:
|
||||
* User asks for calculations, math problems, algorithms
|
||||
* User wants custom code or programming solutions
|
||||
* User uploads data file but wants custom processing (not standard analysis)
|
||||
* User needs specific data transformations or custom visualizations
|
||||
* File paths will be automatically provided in the execution environment
|
||||
|
||||
DISCORD FILE INTEGRATION:
|
||||
- Data files uploaded to Discord are automatically downloaded and saved
|
||||
- File paths are automatically provided to the appropriate tools
|
||||
- Context about uploaded files is added to Python execution environment
|
||||
- Visualizations are automatically uploaded to Discord and displayed
|
||||
Returns:
|
||||
List of tool objects
|
||||
List of tool objects for the OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "analyze_data_file",
|
||||
"description": "Analyze a data file (CSV or Excel) and generate visualizations. Use this tool when a user uploads a data file and wants insights or visualizations. The visualizations will be automatically displayed in Discord. When describing the results, refer to visualizations by their chart_id and explain what they show. Always inform the user they can see the visualizations directly in the Discord chat.",
|
||||
"description": "**DATA ANALYSIS TOOL** - Use this tool for structured data analysis from CSV/Excel files. This tool specializes in analyzing data with pre-built templates (summary, correlation, distribution, comprehensive) and generates appropriate visualizations automatically. Use this tool when: (1) User explicitly requests data analysis or insights, (2) User asks for statistics, correlations, or data exploration, (3) User wants standard data science analysis. The tool handles file loading, data validation, and creates visualizations that are automatically displayed in Discord.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the data file to analyze"
|
||||
"description": "Path to the data file to analyze (CSV or Excel format required)"
|
||||
},
|
||||
"analysis_type": {
|
||||
"type": "string",
|
||||
"description": "Type of analysis to perform (e.g., 'summary', 'correlation', 'distribution')",
|
||||
"enum": ["summary", "correlation", "distribution", "comprehensive"]
|
||||
"description": "Type of pre-built analysis template to use",
|
||||
"enum": ["summary", "correlation", "distribution", "comprehensive"],
|
||||
"default": "comprehensive"
|
||||
},
|
||||
"custom_analysis": {
|
||||
"type": "string",
|
||||
"description": "Optional custom analysis request in natural language. If provided, this overrides the analysis_type and generates custom Python code for the specific analysis requested."
|
||||
},
|
||||
"install_packages": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of Python packages to install before analysis"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
@@ -279,41 +328,39 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"description": "Execute Python code to solve problems, perform calculations, or create data visualizations. Use this for data analysis, generating charts, and processing data. When analyzing data, ALWAYS include code for visualizations (using matplotlib, seaborn, or plotly) if the user requests charts or graphs. When visualizations are created, tell the user they can view the charts directly in Discord, and reference visualizations by their chart_id in your descriptions.",
|
||||
}, {
|
||||
"type": "function", "function": {
|
||||
"name": "execute_python_code",
|
||||
"description": "**GENERAL PYTHON EXECUTION TOOL** - Use this tool for general programming tasks, mathematical calculations, algorithm implementations, and custom Python scripts. Use this tool when: (1) User asks for calculations or math problems, (2) User wants to run custom Python code, (3) User needs algorithm implementations, (4) User requests programming solutions, (5) Creating custom visualizations or data processing, (6) User uploads data files but wants custom code rather than standard analysis. File paths for uploaded data files are automatically made available in the execution environment.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The Python code to execute. For data analysis, include necessary imports (pandas, matplotlib, etc.) and visualization code."
|
||||
"description": "The Python code to execute. Include all necessary imports and ensure code is complete and runnable."
|
||||
},
|
||||
"language": {
|
||||
"input_data": {
|
||||
"type": "string",
|
||||
"description": "Programming language (only Python supported)",
|
||||
"enum": ["python", "py"]
|
||||
"description": "Optional input data to be made available to the code as a variable named 'input_data'"
|
||||
},
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "Optional input data for the code"
|
||||
"install_packages": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of Python packages to install before execution (e.g., ['numpy', 'matplotlib'])"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Optional path to a data file to analyze (supports CSV and Excel files)"
|
||||
},
|
||||
"analysis_request": {
|
||||
"type": "string",
|
||||
"description": "Natural language description of the analysis to perform. If this includes visualization requests, the generated code must include plotting code using matplotlib, seaborn, or plotly."
|
||||
},
|
||||
"include_visualization": {
|
||||
"enable_visualization": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include visualizations using matplotlib/seaborn"
|
||||
}
|
||||
"description": "Set to true when creating charts, graphs, or any visual output with matplotlib/seaborn"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum execution time in seconds (default: 30, max: 120)",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 120
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
415
src/utils/python_executor.py
Normal file
415
src/utils/python_executor.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import re
|
||||
import logging
|
||||
import asyncio
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import traceback
|
||||
import contextlib
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
# Import utility functions
|
||||
from .code_utils import DATA_FILES_DIR, format_output_path, clean_old_files
|
||||
|
||||
# Configure logging
|
||||
log_file = 'logs/code_interpreter.log'
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger = logging.getLogger('code_interpreter')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Regular expression to find image file paths in output
|
||||
IMAGE_PATH_PATTERN = r'(\/media\/quocanh\/.*\.(png|jpg|jpeg|gif))'
|
||||
|
||||
# Unsafe patterns for code security
|
||||
UNSAFE_IMPORTS = [
|
||||
r'import\s+os\b', r'from\s+os\s+import',
|
||||
r'import\s+subprocess\b', r'from\s+subprocess\s+import',
|
||||
r'import\s+shutil\b', r'from\s+shutil\s+import',
|
||||
r'__import__\([\'"]os[\'"]\)', r'__import__\([\'"]subprocess[\'"]\)',
|
||||
r'import\s+sys\b(?!\s+import\s+path)', r'from\s+sys\s+import'
|
||||
]
|
||||
|
||||
UNSAFE_FUNCTIONS = [
|
||||
r'os\.', r'subprocess\.', r'shutil\.',
|
||||
r'eval\(', r'exec\(', r'sys\.',
|
||||
r'open\([\'"][^\'"]*/[^\']*[\'"]', # File system access
|
||||
r'__import__\(', r'globals\(\)', r'locals\(\)'
|
||||
]
|
||||
|
||||
def sanitize_python_code(code: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Check Python code for potentially unsafe operations.
|
||||
|
||||
Args:
|
||||
code: The code to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, sanitized_code_or_error_message)
|
||||
"""
|
||||
# Check for unsafe imports
|
||||
for pattern in UNSAFE_IMPORTS:
|
||||
if re.search(pattern, code):
|
||||
return False, f"Forbidden import detected: {pattern}"
|
||||
|
||||
# Check for unsafe function calls
|
||||
for pattern in UNSAFE_FUNCTIONS:
|
||||
if re.search(pattern, code):
|
||||
return False, f"Forbidden function call detected: {pattern}"
|
||||
|
||||
# Add safety imports and commonly used libraries
|
||||
safe_imports = """
|
||||
import math
|
||||
import random
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
import collections
|
||||
import itertools
|
||||
import functools
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
pass
|
||||
"""
|
||||
|
||||
return True, safe_imports + "\n" + code
|
||||
|
||||
async def install_packages(packages: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Install Python packages in a sandboxed environment.
|
||||
|
||||
Args:
|
||||
packages: List of package names to install
|
||||
|
||||
Returns:
|
||||
Dict containing installation results
|
||||
"""
|
||||
try:
|
||||
installed = []
|
||||
failed = []
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
# Use pip to install package with timeout
|
||||
result = subprocess.run([
|
||||
sys.executable, "-m", "pip", "install", package, "--user", "--quiet"
|
||||
], capture_output=True, text=True, timeout=120)
|
||||
|
||||
if result.returncode == 0:
|
||||
installed.append(package)
|
||||
logger.info(f"Successfully installed package: {package}")
|
||||
else:
|
||||
failed.append({"package": package, "error": result.stderr})
|
||||
logger.error(f"Failed to install package {package}: {result.stderr}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
failed.append({"package": package, "error": "Installation timeout"})
|
||||
logger.error(f"Installation timeout for package: {package}")
|
||||
except Exception as e:
|
||||
failed.append({"package": package, "error": str(e)})
|
||||
logger.error(f"Error installing package {package}: {str(e)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"installed": installed,
|
||||
"failed": failed,
|
||||
"message": f"Installed {len(installed)} packages, {len(failed)} failed"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in package installation: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"installed": [],
|
||||
"failed": packages
|
||||
}
|
||||
|
||||
async def execute_python_code(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute Python code in a controlled sandbox environment.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing:
|
||||
- code: The Python code to execute
|
||||
- input: Optional input data for the code
|
||||
- install_packages: Optional list of packages to install
|
||||
- timeout: Optional timeout in seconds (default: 30)
|
||||
|
||||
Returns:
|
||||
Dict containing execution results
|
||||
"""
|
||||
try:
|
||||
code = args.get("code", "")
|
||||
input_data = args.get("input", "")
|
||||
packages_to_install = args.get("install_packages", [])
|
||||
timeout = args.get("timeout", 30)
|
||||
|
||||
if not code:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No code provided",
|
||||
"output": ""
|
||||
}
|
||||
|
||||
# Install packages if requested
|
||||
if packages_to_install:
|
||||
install_result = await install_packages(packages_to_install)
|
||||
if not install_result["success"]:
|
||||
logger.warning(f"Package installation issues: {install_result}")
|
||||
|
||||
# Sanitize the code
|
||||
is_safe, sanitized_code = sanitize_python_code(code)
|
||||
if not is_safe:
|
||||
logger.warning(f"Code sanitization failed: {sanitized_code}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": sanitized_code,
|
||||
"output": ""
|
||||
}
|
||||
|
||||
# Clean up old files before execution
|
||||
clean_old_files()
|
||||
|
||||
# Execute code in controlled environment
|
||||
result = await execute_code_safely(sanitized_code, input_data, timeout)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in Python code execution: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"output": "",
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
|
||||
async def execute_code_safely(code: str, input_data: str, timeout: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute code in a safe environment with proper isolation.
|
||||
|
||||
Args:
|
||||
code: Sanitized Python code to execute
|
||||
input_data: Input data for the code
|
||||
timeout: Execution timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict containing execution results
|
||||
"""
|
||||
try:
|
||||
# Capture stdout and stderr
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
# Import commonly used libraries for the execution environment
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Use non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = None
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = None
|
||||
|
||||
# Create execution namespace
|
||||
exec_globals = {
|
||||
"__builtins__": {
|
||||
# Safe builtins
|
||||
"print": print,
|
||||
"len": len,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sum": sum,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"set": set,
|
||||
"tuple": tuple,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"type": type,
|
||||
"isinstance": isinstance,
|
||||
"hasattr": hasattr,
|
||||
"getattr": getattr,
|
||||
"setattr": setattr,
|
||||
"dir": dir,
|
||||
"help": help,
|
||||
"__import__": __import__, # Allow controlled imports
|
||||
"ValueError": ValueError,
|
||||
"TypeError": TypeError,
|
||||
"IndexError": IndexError,
|
||||
"KeyError": KeyError,
|
||||
"AttributeError": AttributeError,
|
||||
"ImportError": ImportError,
|
||||
"Exception": Exception,
|
||||
},
|
||||
# Add available libraries
|
||||
"math": __import__("math"),
|
||||
"random": __import__("random"),
|
||||
"json": __import__("json"),
|
||||
"time": __import__("time"),
|
||||
"datetime": __import__("datetime"),
|
||||
"collections": __import__("collections"),
|
||||
"itertools": __import__("itertools"),
|
||||
"functools": __import__("functools"),
|
||||
}
|
||||
|
||||
# Add optional libraries if available
|
||||
if np is not None:
|
||||
exec_globals["np"] = np
|
||||
exec_globals["numpy"] = np
|
||||
if pd is not None:
|
||||
exec_globals["pd"] = pd
|
||||
exec_globals["pandas"] = pd
|
||||
if plt is not None:
|
||||
exec_globals["plt"] = plt
|
||||
exec_globals["matplotlib"] = matplotlib
|
||||
|
||||
# Override input function if input_data is provided
|
||||
if input_data:
|
||||
input_lines = input_data.strip().split('\n')
|
||||
input_iter = iter(input_lines)
|
||||
exec_globals["input"] = lambda prompt="": next(input_iter, "")
|
||||
|
||||
# Set up output capture
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
# Generate output file path for any plots
|
||||
timestamp = int(time.time())
|
||||
output_filename = f"python_output_{timestamp}.png"
|
||||
output_path = format_output_path(output_filename)
|
||||
|
||||
# Execute the code with timeout
|
||||
try:
|
||||
# Use asyncio.wait_for for timeout
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(exec, code, exec_globals),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Check for any matplotlib figures and save them
|
||||
visualizations = []
|
||||
if plt is not None and plt.get_fignums():
|
||||
for i, fig_num in enumerate(plt.get_fignums()):
|
||||
try:
|
||||
fig = plt.figure(fig_num)
|
||||
if len(fig.get_axes()) > 0:
|
||||
# Save to output path
|
||||
fig_path = output_path.replace('.png', f'_{i}.png')
|
||||
fig.savefig(fig_path, bbox_inches='tight', dpi=150)
|
||||
visualizations.append(fig_path)
|
||||
plt.close(fig)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving figure {i}: {str(e)}")
|
||||
|
||||
# Clear all figures
|
||||
plt.close('all')
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Code execution timed out after {timeout} seconds",
|
||||
"output": stdout_capture.getvalue(),
|
||||
"stderr": stderr_capture.getvalue()
|
||||
}
|
||||
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
# Get the outputs
|
||||
stdout_output = stdout_capture.getvalue()
|
||||
stderr_output = stderr_capture.getvalue()
|
||||
|
||||
# Check for any image paths in the output
|
||||
image_paths = re.findall(IMAGE_PATH_PATTERN, stdout_output)
|
||||
for img_path in image_paths:
|
||||
if os.path.exists(img_path):
|
||||
visualizations.append(img_path)
|
||||
|
||||
# Remove image paths from output text
|
||||
clean_output = stdout_output
|
||||
for img_path in image_paths:
|
||||
clean_output = clean_output.replace(img_path, "[Image saved]")
|
||||
|
||||
logger.info(f"Python code executed successfully, output length: {len(clean_output)}")
|
||||
if visualizations:
|
||||
logger.info(f"Generated {len(visualizations)} visualizations")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": clean_output,
|
||||
"stderr": stderr_output,
|
||||
"visualizations": visualizations,
|
||||
"has_visualization": len(visualizations) > 0,
|
||||
"execution_time": f"Completed in under {timeout}s"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
error_msg = f"Error executing Python code: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"output": stdout_capture.getvalue() if 'stdout_capture' in locals() else "",
|
||||
"stderr": stderr_capture.getvalue() if 'stderr_capture' in locals() else "",
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
|
||||
# Backward compatibility - keep the old function name
|
||||
async def execute_code(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Backward compatibility wrapper for execute_python_code.
|
||||
"""
|
||||
return await execute_python_code(args)
|
||||
@@ -1 +0,0 @@
|
||||
# Initialize the tests package
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,379 +0,0 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import modules for testing
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
from src.utils.openai_utils import count_tokens, trim_content_to_token_limit, prepare_messages_for_api
|
||||
from src.utils.code_utils import sanitize_code, extract_code_blocks
|
||||
from src.utils.web_utils import scrape_web_content
|
||||
from src.utils.pdf_utils import send_response
|
||||
|
||||
|
||||
class TestDatabaseHandler(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test database handler functionality"""
|
||||
|
||||
def setUp(self):
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Try to get MongoDB URI from environment
|
||||
self.mongodb_uri = os.getenv("MONGODB_URI")
|
||||
self.using_real_db = bool(self.mongodb_uri)
|
||||
|
||||
if not self.using_real_db:
|
||||
# Use mock if no real URI available
|
||||
self.mock_client_patcher = patch('motor.motor_asyncio.AsyncIOMotorClient')
|
||||
self.mock_client = self.mock_client_patcher.start()
|
||||
|
||||
# Setup mock database and collections
|
||||
self.mock_db = self.mock_client.return_value.__getitem__.return_value
|
||||
self.mock_histories = MagicMock()
|
||||
self.mock_models = MagicMock() # Store mock_models as instance variable
|
||||
self.mock_db.__getitem__.side_effect = lambda x: {
|
||||
'user_histories': self.mock_histories,
|
||||
'user_models': self.mock_models, # Use the instance variable
|
||||
'whitelist': MagicMock(),
|
||||
'blacklist': MagicMock()
|
||||
}[x]
|
||||
|
||||
# Initialize handler with mock connection string
|
||||
self.db_handler = DatabaseHandler("mongodb://localhost:27017")
|
||||
else:
|
||||
# Use real database connection
|
||||
print(f"Testing with real MongoDB at: {self.mongodb_uri}")
|
||||
self.db_handler = DatabaseHandler(self.mongodb_uri)
|
||||
|
||||
# Extract database name from URI for later use
|
||||
self.db_name = self._extract_db_name_from_uri(self.mongodb_uri)
|
||||
|
||||
async def asyncSetUp(self):
|
||||
# No additional async setup needed, but required by IsolatedAsyncioTestCase
|
||||
pass
|
||||
|
||||
async def asyncTearDown(self):
|
||||
if not self.using_real_db:
|
||||
self.mock_client_patcher.stop()
|
||||
else:
|
||||
# Clean up test data if using real database
|
||||
await self.cleanup_test_data()
|
||||
|
||||
def _extract_db_name_from_uri(self, uri):
|
||||
"""Extract database name from MongoDB URI more reliably"""
|
||||
# Default database name if extraction fails
|
||||
default_db_name = 'chatgpt_discord_bot'
|
||||
|
||||
try:
|
||||
# Handle standard MongoDB URI format
|
||||
# mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/database][?options]
|
||||
match = re.search(r'\/([^/?]+)(\?|$)', uri)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# If no database in URI, return default
|
||||
return default_db_name
|
||||
except:
|
||||
# If any error occurs, return default name
|
||||
return default_db_name
|
||||
|
||||
async def cleanup_test_data(self):
|
||||
"""Remove test data from real database"""
|
||||
if self.using_real_db:
|
||||
try:
|
||||
# Use the database name we extracted in setUp
|
||||
db = self.db_handler.client.get_database(self.db_name)
|
||||
await db.user_histories.delete_one({'user_id': 12345})
|
||||
await db.user_models.delete_one({'user_id': 12345})
|
||||
except Exception as e:
|
||||
print(f"Error during test cleanup: {e}")
|
||||
|
||||
async def test_get_history_empty(self):
|
||||
if self.using_real_db:
|
||||
# Clean up any existing history first
|
||||
await self.cleanup_test_data()
|
||||
# Test with real database
|
||||
result = await self.db_handler.get_history(12345)
|
||||
self.assertEqual(result, [])
|
||||
else:
|
||||
# Mock find_one to return None (no history)
|
||||
self.mock_histories.find_one = AsyncMock(return_value=None)
|
||||
|
||||
# Test getting non-existent history
|
||||
result = await self.db_handler.get_history(12345)
|
||||
self.assertEqual(result, [])
|
||||
self.mock_histories.find_one.assert_called_once_with({'user_id': 12345})
|
||||
|
||||
async def test_get_history_existing(self):
|
||||
# Sample history data
|
||||
sample_history = [
|
||||
{'role': 'user', 'content': 'Hello'},
|
||||
{'role': 'assistant', 'content': 'Hi there!'}
|
||||
]
|
||||
|
||||
if self.using_real_db:
|
||||
# Save test history first
|
||||
await self.db_handler.save_history(12345, sample_history)
|
||||
|
||||
# Test getting existing history
|
||||
result = await self.db_handler.get_history(12345)
|
||||
self.assertEqual(result, sample_history)
|
||||
else:
|
||||
# Mock find_one to return existing history
|
||||
self.mock_histories.find_one = AsyncMock(return_value={'user_id': 12345, 'history': sample_history})
|
||||
|
||||
# Test getting existing history
|
||||
result = await self.db_handler.get_history(12345)
|
||||
self.assertEqual(result, sample_history)
|
||||
|
||||
async def test_save_history(self):
|
||||
# Sample history to save
|
||||
sample_history = [
|
||||
{'role': 'user', 'content': 'Test message'},
|
||||
{'role': 'assistant', 'content': 'Test response'}
|
||||
]
|
||||
|
||||
if self.using_real_db:
|
||||
# Test saving history to real database
|
||||
await self.db_handler.save_history(12345, sample_history)
|
||||
|
||||
# Verify it was saved
|
||||
result = await self.db_handler.get_history(12345)
|
||||
self.assertEqual(result, sample_history)
|
||||
else:
|
||||
# Mock update_one method
|
||||
self.mock_histories.update_one = AsyncMock()
|
||||
|
||||
# Test saving history
|
||||
await self.db_handler.save_history(12345, sample_history)
|
||||
|
||||
# Verify update_one was called with correct parameters
|
||||
self.mock_histories.update_one.assert_called_once_with(
|
||||
{'user_id': 12345},
|
||||
{'$set': {'history': sample_history}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
async def test_user_model_operations(self):
|
||||
if self.using_real_db:
|
||||
# Save a model and then retrieve it
|
||||
await self.db_handler.save_user_model(12345, 'openai/gpt-4o')
|
||||
model = await self.db_handler.get_user_model(12345)
|
||||
self.assertEqual(model, 'openai/gpt-4o')
|
||||
|
||||
# Test updating model
|
||||
await self.db_handler.save_user_model(12345, 'openai/gpt-4o-mini')
|
||||
updated_model = await self.db_handler.get_user_model(12345)
|
||||
self.assertEqual(updated_model, 'openai/gpt-4o-mini')
|
||||
else:
|
||||
# Setup mock for user_models collection
|
||||
# Use self.mock_models instead of creating a new mock
|
||||
self.mock_models.find_one = AsyncMock(return_value={'user_id': 12345, 'model': 'openai/gpt-4o'})
|
||||
self.mock_models.update_one = AsyncMock()
|
||||
|
||||
# Test getting user model
|
||||
model = await self.db_handler.get_user_model(12345)
|
||||
self.assertEqual(model, 'openai/gpt-4o')
|
||||
|
||||
# Test saving user model
|
||||
await self.db_handler.save_user_model(12345, 'openai/gpt-4o-mini')
|
||||
self.mock_models.update_one.assert_called_once_with(
|
||||
{'user_id': 12345},
|
||||
{'$set': {'model': 'openai/gpt-4o-mini'}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIUtils(unittest.TestCase):
|
||||
"""Test OpenAI utility functions"""
|
||||
|
||||
def test_count_tokens(self):
|
||||
# Test token counting
|
||||
self.assertGreater(count_tokens("Hello, world!"), 0)
|
||||
self.assertGreater(count_tokens("This is a longer text that should have more tokens."),
|
||||
count_tokens("Short text"))
|
||||
|
||||
def test_trim_content_to_token_limit(self):
|
||||
# Create a long text
|
||||
long_text = "This is a test. " * 1000
|
||||
|
||||
# Test trimming
|
||||
trimmed = trim_content_to_token_limit(long_text, 100)
|
||||
self.assertLess(count_tokens(trimmed), count_tokens(long_text))
|
||||
self.assertLessEqual(count_tokens(trimmed), 100)
|
||||
|
||||
# Test no trimming needed
|
||||
short_text = "This is a short text."
|
||||
untrimmed = trim_content_to_token_limit(short_text, 100)
|
||||
self.assertEqual(untrimmed, short_text)
|
||||
|
||||
def test_prepare_messages_for_api(self):
|
||||
# Test empty messages
|
||||
empty_result = prepare_messages_for_api([])
|
||||
self.assertEqual(len(empty_result), 1) # Should have system message
|
||||
self.assertEqual(empty_result[0]["role"], "system") # Verify it's a system message
|
||||
|
||||
# Test regular messages
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
result = prepare_messages_for_api(messages)
|
||||
self.assertEqual(len(result), 4) # Should have system message + 3 original messages
|
||||
|
||||
# Test with null content
|
||||
messages_with_null = [
|
||||
{"role": "user", "content": None},
|
||||
{"role": "assistant", "content": "Response"}
|
||||
]
|
||||
result_fixed = prepare_messages_for_api(messages_with_null)
|
||||
self.assertEqual(len(result_fixed), 2) # Should have system message + 1 valid message
|
||||
# Verify the content is correct (system message + only the assistant message)
|
||||
self.assertEqual(result_fixed[0]["role"], "system")
|
||||
self.assertEqual(result_fixed[1]["role"], "assistant")
|
||||
self.assertEqual(result_fixed[1]["content"], "Response")
|
||||
|
||||
class TestCodeUtils(unittest.TestCase):
|
||||
"""Test code utility functions"""
|
||||
|
||||
def test_sanitize_python_code_safe(self):
|
||||
# Safe Python code
|
||||
code = """
|
||||
def factorial(n):
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n-1)
|
||||
|
||||
print(factorial(5))
|
||||
"""
|
||||
is_safe, sanitized = sanitize_code(code, "python")
|
||||
self.assertTrue(is_safe)
|
||||
self.assertIn("def factorial", sanitized)
|
||||
|
||||
def test_sanitize_python_code_unsafe(self):
|
||||
# Unsafe Python code with os.system
|
||||
unsafe_code = """
|
||||
import os
|
||||
os.system('rm -rf /')
|
||||
"""
|
||||
is_safe, message = sanitize_code(unsafe_code, "python")
|
||||
self.assertFalse(is_safe)
|
||||
self.assertIn("Forbidden", message)
|
||||
|
||||
def test_sanitize_cpp_code_safe(self):
|
||||
# Safe C++ code
|
||||
code = """
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
int main() {
|
||||
cout << "Hello, world!" << endl;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
is_safe, sanitized = sanitize_code(code, "cpp")
|
||||
self.assertTrue(is_safe)
|
||||
self.assertIn("Hello, world!", sanitized)
|
||||
|
||||
def test_sanitize_cpp_code_unsafe(self):
|
||||
# Unsafe C++ code with system
|
||||
unsafe_code = """
|
||||
#include <stdlib.h>
|
||||
int main() {
|
||||
system("rm -rf /");
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
is_safe, message = sanitize_code(unsafe_code, "cpp")
|
||||
self.assertFalse(is_safe)
|
||||
self.assertIn("Forbidden", message)
|
||||
|
||||
def test_extract_code_blocks(self):
|
||||
# Test message with code block
|
||||
message = """
|
||||
Here's a Python function to calculate factorial:
|
||||
```python
|
||||
def factorial(n):
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n-1)
|
||||
```
|
||||
And here's a C++ version:
|
||||
```cpp
|
||||
int factorial(int n) {
|
||||
if (n <= 1) return 1;
|
||||
return n * factorial(n-1);
|
||||
}
|
||||
```
|
||||
"""
|
||||
blocks = extract_code_blocks(message)
|
||||
self.assertEqual(len(blocks), 2)
|
||||
self.assertEqual(blocks[0][0], "python")
|
||||
self.assertEqual(blocks[1][0], "cpp")
|
||||
|
||||
# Test without language specifier
|
||||
message_no_lang = """
|
||||
Here's some code:
|
||||
```
|
||||
print("Hello world")
|
||||
```
|
||||
"""
|
||||
blocks_no_lang = extract_code_blocks(message_no_lang)
|
||||
self.assertEqual(len(blocks_no_lang), 1)
|
||||
|
||||
|
||||
#class TestWebUtils(unittest.TestCase):
|
||||
# """Test web utilities"""
|
||||
#
|
||||
# @patch('requests.get')
|
||||
# def test_scrape_web_content(self, mock_get):
|
||||
# # Mock the response
|
||||
# mock_response = MagicMock()
|
||||
# mock_response.text = '<html><body><h1>Test Heading</h1><p>Test paragraph</p></body></html>'
|
||||
# mock_response.status_code = 200
|
||||
# mock_get.return_value = mock_response
|
||||
#
|
||||
# # Test scraping
|
||||
# content = scrape_web_content("example.com")
|
||||
# self.assertIn("Test Heading", content)
|
||||
# self.assertIn("Test paragraph", content)
|
||||
|
||||
class TestPDFUtils(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test PDF utilities"""
|
||||
|
||||
async def test_send_response(self):
|
||||
# Create mock channel
|
||||
mock_channel = AsyncMock()
|
||||
mock_channel.send = AsyncMock()
|
||||
|
||||
# Test sending short response
|
||||
short_response = "This is a short response"
|
||||
await send_response(mock_channel, short_response)
|
||||
mock_channel.send.assert_called_once_with(short_response)
|
||||
|
||||
# Reset mock
|
||||
mock_channel.send.reset_mock()
|
||||
|
||||
# Mock for long response (testing would need file operations)
|
||||
with patch('builtins.open', new_callable=unittest.mock.mock_open):
|
||||
with patch('discord.File', return_value="mocked_file"):
|
||||
# Test sending long response
|
||||
long_response = "X" * 2500 # Over 2000 character limit
|
||||
await send_response(mock_channel, long_response)
|
||||
mock_channel.send.assert_called_once()
|
||||
# Verify it's called with the file argument
|
||||
args, kwargs = mock_channel.send.call_args
|
||||
self.assertIn('file', kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user