From ac8fd924c184d528661460b8c7fafd2019d47b05 Mon Sep 17 00:00:00 2001 From: cauvang32 Date: Fri, 20 Jun 2025 21:23:03 +0700 Subject: [PATCH] feat: Add comprehensive unit tests for DatabaseHandler, OpenAI utilities, and code utilities --- test_executor.py | 37 ----- tests/test_bot.py | 379 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 379 insertions(+), 37 deletions(-) delete mode 100644 test_executor.py create mode 100644 tests/test_bot.py diff --git a/test_executor.py b/test_executor.py deleted file mode 100644 index d2842ee..0000000 --- a/test_executor.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -import asyncio -import sys -import os - -# Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) - -from utils.python_executor import execute_python_code - -async def test_calculation(): - # Test with proper print statement (what the AI should generate now) - args = { - 'code': 'print((3+2+1+1231231+2139018230912)/3+120/99+2012)' - } - print('Testing with proper print statement:', repr(args['code'])) - - result = await execute_python_code(args) - print('Success:', result.get('success', False)) - print('Output:', repr(result.get('output', ''))) - print('Expected result: 713006489396.2122') - - # Test another calculation - args2 = { - 'code': ''' -result = (3+2+1+1231231+2139018230912)/3+120/99+2012 -print(f"The calculation result is: {result}") -''' - } - print('\nTesting with formatted output:', repr(args2['code'])) - - result2 = await execute_python_code(args2) - print('Success:', result2.get('success', False)) - print('Output:', repr(result2.get('output', ''))) - -if __name__ == "__main__": - asyncio.run(test_calculation()) diff --git a/tests/test_bot.py b/tests/test_bot.py new file mode 100644 index 0000000..4f1b9b0 --- /dev/null +++ b/tests/test_bot.py @@ -0,0 +1,379 @@ +import asyncio +import unittest +import os +import sys +import json +import io +from unittest.mock import MagicMock, patch, AsyncMock +from dotenv import load_dotenv +import re + +# Add parent directory to path for imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Import modules for testing +from src.database.db_handler import DatabaseHandler +from src.utils.openai_utils import count_tokens, trim_content_to_token_limit, prepare_messages_for_api +from src.utils.code_utils import sanitize_code, extract_code_blocks +from src.utils.web_utils import scrape_web_content +from src.utils.pdf_utils import send_response + + +class TestDatabaseHandler(unittest.IsolatedAsyncioTestCase): + """Test database handler functionality""" + + def setUp(self): + # Load environment variables + load_dotenv() + + # Try to get MongoDB URI from environment + self.mongodb_uri = os.getenv("MONGODB_URI") + self.using_real_db = bool(self.mongodb_uri) + + if not self.using_real_db: + # Use mock if no real URI available + self.mock_client_patcher = patch('motor.motor_asyncio.AsyncIOMotorClient') + self.mock_client = self.mock_client_patcher.start() + + # Setup mock database and collections + self.mock_db = self.mock_client.return_value.__getitem__.return_value + self.mock_histories = MagicMock() + self.mock_models = MagicMock() # Store mock_models as instance variable + self.mock_db.__getitem__.side_effect = lambda x: { + 'user_histories': self.mock_histories, + 'user_models': self.mock_models, # Use the instance variable + 'whitelist': MagicMock(), + 'blacklist': MagicMock() + }[x] + + # Initialize handler with mock connection string + self.db_handler = DatabaseHandler("mongodb://localhost:27017") + else: + # Use real database connection + print(f"Testing with real MongoDB at: {self.mongodb_uri}") + self.db_handler = DatabaseHandler(self.mongodb_uri) + + # Extract database name from URI for later use + self.db_name = self._extract_db_name_from_uri(self.mongodb_uri) + + async def asyncSetUp(self): + # No additional async setup needed, but required by IsolatedAsyncioTestCase + pass + + async def asyncTearDown(self): + if not self.using_real_db: + self.mock_client_patcher.stop() + else: + # Clean up test data if using real database + await self.cleanup_test_data() + + def _extract_db_name_from_uri(self, uri): + """Extract database name from MongoDB URI more reliably""" + # Default database name if extraction fails + default_db_name = 'chatgpt_discord_bot' + + try: + # Handle standard MongoDB URI format + # mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/database][?options] + match = re.search(r'\/([^/?]+)(\?|$)', uri) + if match: + return match.group(1) + + # If no database in URI, return default + return default_db_name + except: + # If any error occurs, return default name + return default_db_name + + async def cleanup_test_data(self): + """Remove test data from real database""" + if self.using_real_db: + try: + # Use the database name we extracted in setUp + db = self.db_handler.client.get_database(self.db_name) + await db.user_histories.delete_one({'user_id': 12345}) + await db.user_models.delete_one({'user_id': 12345}) + except Exception as e: + print(f"Error during test cleanup: {e}") + + async def test_get_history_empty(self): + if self.using_real_db: + # Clean up any existing history first + await self.cleanup_test_data() + # Test with real database + result = await self.db_handler.get_history(12345) + self.assertEqual(result, []) + else: + # Mock find_one to return None (no history) + self.mock_histories.find_one = AsyncMock(return_value=None) + + # Test getting non-existent history + result = await self.db_handler.get_history(12345) + self.assertEqual(result, []) + self.mock_histories.find_one.assert_called_once_with({'user_id': 12345}) + + async def test_get_history_existing(self): + # Sample history data + sample_history = [ + {'role': 'user', 'content': 'Hello'}, + {'role': 'assistant', 'content': 'Hi there!'} + ] + + if self.using_real_db: + # Save test history first + await self.db_handler.save_history(12345, sample_history) + + # Test getting existing history + result = await self.db_handler.get_history(12345) + self.assertEqual(result, sample_history) + else: + # Mock find_one to return existing history + self.mock_histories.find_one = AsyncMock(return_value={'user_id': 12345, 'history': sample_history}) + + # Test getting existing history + result = await self.db_handler.get_history(12345) + self.assertEqual(result, sample_history) + + async def test_save_history(self): + # Sample history to save + sample_history = [ + {'role': 'user', 'content': 'Test message'}, + {'role': 'assistant', 'content': 'Test response'} + ] + + if self.using_real_db: + # Test saving history to real database + await self.db_handler.save_history(12345, sample_history) + + # Verify it was saved + result = await self.db_handler.get_history(12345) + self.assertEqual(result, sample_history) + else: + # Mock update_one method + self.mock_histories.update_one = AsyncMock() + + # Test saving history + await self.db_handler.save_history(12345, sample_history) + + # Verify update_one was called with correct parameters + self.mock_histories.update_one.assert_called_once_with( + {'user_id': 12345}, + {'$set': {'history': sample_history}}, + upsert=True + ) + + async def test_user_model_operations(self): + if self.using_real_db: + # Save a model and then retrieve it + await self.db_handler.save_user_model(12345, 'openai/gpt-4o') + model = await self.db_handler.get_user_model(12345) + self.assertEqual(model, 'openai/gpt-4o') + + # Test updating model + await self.db_handler.save_user_model(12345, 'openai/gpt-4o-mini') + updated_model = await self.db_handler.get_user_model(12345) + self.assertEqual(updated_model, 'openai/gpt-4o-mini') + else: + # Setup mock for user_models collection + # Use self.mock_models instead of creating a new mock + self.mock_models.find_one = AsyncMock(return_value={'user_id': 12345, 'model': 'openai/gpt-4o'}) + self.mock_models.update_one = AsyncMock() + + # Test getting user model + model = await self.db_handler.get_user_model(12345) + self.assertEqual(model, 'openai/gpt-4o') + + # Test saving user model + await self.db_handler.save_user_model(12345, 'openai/gpt-4o-mini') + self.mock_models.update_one.assert_called_once_with( + {'user_id': 12345}, + {'$set': {'model': 'openai/gpt-4o-mini'}}, + upsert=True + ) + + +class TestOpenAIUtils(unittest.TestCase): + """Test OpenAI utility functions""" + + def test_count_tokens(self): + # Test token counting + self.assertGreater(count_tokens("Hello, world!"), 0) + self.assertGreater(count_tokens("This is a longer text that should have more tokens."), + count_tokens("Short text")) + + def test_trim_content_to_token_limit(self): + # Create a long text + long_text = "This is a test. " * 1000 + + # Test trimming + trimmed = trim_content_to_token_limit(long_text, 100) + self.assertLess(count_tokens(trimmed), count_tokens(long_text)) + self.assertLessEqual(count_tokens(trimmed), 100) + + # Test no trimming needed + short_text = "This is a short text." + untrimmed = trim_content_to_token_limit(short_text, 100) + self.assertEqual(untrimmed, short_text) + + def test_prepare_messages_for_api(self): + # Test empty messages + empty_result = prepare_messages_for_api([]) + self.assertEqual(len(empty_result), 1) # Should have system message + self.assertEqual(empty_result[0]["role"], "system") # Verify it's a system message + + # Test regular messages + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + result = prepare_messages_for_api(messages) + self.assertEqual(len(result), 4) # Should have system message + 3 original messages + + # Test with null content + messages_with_null = [ + {"role": "user", "content": None}, + {"role": "assistant", "content": "Response"} + ] + result_fixed = prepare_messages_for_api(messages_with_null) + self.assertEqual(len(result_fixed), 2) # Should have system message + 1 valid message + # Verify the content is correct (system message + only the assistant message) + self.assertEqual(result_fixed[0]["role"], "system") + self.assertEqual(result_fixed[1]["role"], "assistant") + self.assertEqual(result_fixed[1]["content"], "Response") + +class TestCodeUtils(unittest.TestCase): + """Test code utility functions""" + + def test_sanitize_python_code_safe(self): + # Safe Python code + code = """ +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n-1) + +print(factorial(5)) +""" + is_safe, sanitized = sanitize_code(code, "python") + self.assertTrue(is_safe) + self.assertIn("def factorial", sanitized) + + def test_sanitize_python_code_unsafe(self): + # Unsafe Python code with os.system + unsafe_code = """ +import os +os.system('rm -rf /') +""" + is_safe, message = sanitize_code(unsafe_code, "python") + self.assertFalse(is_safe) + self.assertIn("Forbidden", message) + + def test_sanitize_cpp_code_safe(self): + # Safe C++ code + code = """ +#include +using namespace std; + +int main() { + cout << "Hello, world!" << endl; + return 0; +} +""" + is_safe, sanitized = sanitize_code(code, "cpp") + self.assertTrue(is_safe) + self.assertIn("Hello, world!", sanitized) + + def test_sanitize_cpp_code_unsafe(self): + # Unsafe C++ code with system + unsafe_code = """ +#include +int main() { + system("rm -rf /"); + return 0; +} +""" + is_safe, message = sanitize_code(unsafe_code, "cpp") + self.assertFalse(is_safe) + self.assertIn("Forbidden", message) + + def test_extract_code_blocks(self): + # Test message with code block + message = """ +Here's a Python function to calculate factorial: +```python +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n-1) +``` +And here's a C++ version: +```cpp +int factorial(int n) { + if (n <= 1) return 1; + return n * factorial(n-1); +} +``` +""" + blocks = extract_code_blocks(message) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0][0], "python") + self.assertEqual(blocks[1][0], "cpp") + + # Test without language specifier + message_no_lang = """ +Here's some code: +``` +print("Hello world") +``` +""" + blocks_no_lang = extract_code_blocks(message_no_lang) + self.assertEqual(len(blocks_no_lang), 1) + + +#class TestWebUtils(unittest.TestCase): +# """Test web utilities""" +# +# @patch('requests.get') +# def test_scrape_web_content(self, mock_get): +# # Mock the response +# mock_response = MagicMock() +# mock_response.text = '

Test Heading

Test paragraph

' +# mock_response.status_code = 200 +# mock_get.return_value = mock_response +# +# # Test scraping +# content = scrape_web_content("example.com") +# self.assertIn("Test Heading", content) +# self.assertIn("Test paragraph", content) + +class TestPDFUtils(unittest.IsolatedAsyncioTestCase): + """Test PDF utilities""" + + async def test_send_response(self): + # Create mock channel + mock_channel = AsyncMock() + mock_channel.send = AsyncMock() + + # Test sending short response + short_response = "This is a short response" + await send_response(mock_channel, short_response) + mock_channel.send.assert_called_once_with(short_response) + + # Reset mock + mock_channel.send.reset_mock() + + # Mock for long response (testing would need file operations) + with patch('builtins.open', new_callable=unittest.mock.mock_open): + with patch('discord.File', return_value="mocked_file"): + # Test sending long response + long_response = "X" * 2500 # Over 2000 character limit + await send_response(mock_channel, long_response) + mock_channel.send.assert_called_once() + # Verify it's called with the file argument + args, kwargs = mock_channel.send.call_args + self.assertIn('file', kwargs) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file