Merge pull request #7 from Coder-Vippro/cauvang32/add-slash-command
Add slash command for remaining chat turns and reset
This commit was merged in pull request #7.
This commit is contained in:
51
bot.py
51
bot.py
@@ -176,6 +176,33 @@ def save_user_model(user_id, model):
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# New function to get the remaining chat turns for a user and model
|
||||
def get_remaining_turns(user_id, model):
|
||||
user_turns = db.chat_turns.find_one({'user_id': user_id, 'model': model})
|
||||
if user_turns and 'remaining_turns' in user_turns:
|
||||
return user_turns['remaining_turns']
|
||||
else:
|
||||
# Define rate limits for each model
|
||||
rate_limits = {
|
||||
"o1": 8,
|
||||
"o1-preview": 8,
|
||||
"o1-mini": 12,
|
||||
"gpt-4o": 50,
|
||||
"gpt-4o-mini": 150
|
||||
}
|
||||
return rate_limits.get(model, 100) # Default to 100 turns if not found
|
||||
|
||||
# New function to update the remaining chat turns for a user and model
|
||||
def update_remaining_turns(user_id, model, remaining_turns):
|
||||
db.chat_turns.update_one(
|
||||
{'user_id': user_id, 'model': model},
|
||||
{'$set': {'remaining_turns': remaining_turns}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# New function to reset the remaining chat turns for all users and models
|
||||
def reset_remaining_turns():
|
||||
db.chat_turns.update_many({}, {'$set': {'remaining_turns': 100}})
|
||||
|
||||
# Intents and bot initialization
|
||||
intents = discord.Intents.default()
|
||||
@@ -387,6 +414,18 @@ async def reset(interaction: discord.Interaction):
|
||||
db.user_histories.delete_one({'user_id': user_id})
|
||||
await interaction.response.send_message("Your data has been cleared and reset!", ephemeral=True)
|
||||
|
||||
# Slash command to check remaining chat turns (/remaining_turns)
|
||||
@tree.command(name="remaining_turns", description="Check the remaining chat turns for each model.")
|
||||
async def remaining_turns(interaction: discord.Interaction):
|
||||
"""Checks the remaining chat turns for each model."""
|
||||
user_id = interaction.user.id
|
||||
remaining_turns_info = []
|
||||
|
||||
for model in MODEL_OPTIONS:
|
||||
remaining_turns = get_remaining_turns(user_id, model)
|
||||
remaining_turns_info.append(f"{model}: {remaining_turns} turns left")
|
||||
|
||||
await interaction.response.send_message("\n".join(remaining_turns_info), ephemeral=True)
|
||||
|
||||
# Slash command for help (/help)
|
||||
@tree.command(name="help", description="Display a list of available commands.")
|
||||
@@ -399,6 +438,7 @@ async def help_command(interaction: discord.Interaction):
|
||||
"/web `<url>` - Scrape a webpage and send data to AI model.\n"
|
||||
"/generate `<prompt>` - Generate an image from a text prompt.\n"
|
||||
"/reset - Reset your conversation history.\n"
|
||||
"/remaining_turns - Check the remaining chat turns for each model.\n"
|
||||
"/help - Display this help message.\n"
|
||||
"**Các lệnh có sẵn:**\n"
|
||||
"/choose_model - Chọn mô hình AI để sử dụng cho phản hồi (gpt-4o, gpt-4o-mini, o1-preview, o1-mini).\n"
|
||||
@@ -406,6 +446,7 @@ async def help_command(interaction: discord.Interaction):
|
||||
"/web `<url>` - Thu thập dữ liệu từ trang web và gửi đến mô hình AI.\n"
|
||||
"/generate `<gợi ý>` - Tạo hình ảnh từ gợi ý văn bản.\n"
|
||||
"/reset - Đặt lại lịch sử trò chuyện của bạn.\n"
|
||||
"/remaining_turns - Kiểm tra số lượt trò chuyện còn lại cho mỗi mô hình.\n"
|
||||
"/help - Hiển thị tin nhắn trợ giúp này.\n"
|
||||
)
|
||||
await interaction.response.send_message(help_message, ephemeral=True)
|
||||
@@ -569,7 +610,7 @@ async def handle_user_message(message: discord.Message):
|
||||
for msg in messages_to_send:
|
||||
if msg["role"] == "user" and isinstance(msg["content"], list):
|
||||
msg["content"] = [
|
||||
part for part in msg["content"] if part.get("type") != "image_url"
|
||||
part for part in msg["content"] if part["type"] != "image_url"
|
||||
]
|
||||
messages_to_send = [
|
||||
msg for msg in messages_to_send if msg.get("role") != "system"
|
||||
@@ -688,12 +729,18 @@ async def change_status():
|
||||
await bot.change_presence(activity=discord.Game(name=status))
|
||||
await asyncio.sleep(300) # Change every 60 seconds
|
||||
|
||||
# Task to reset chat turns daily
|
||||
@tasks.loop(hours=24)
|
||||
async def daily_reset():
|
||||
reset_remaining_turns()
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
"""Bot startup event to sync slash commands and start status loop."""
|
||||
await tree.sync() # Sync slash commands
|
||||
print(f"Logged in as {bot.user}")
|
||||
change_status.start() # Start the status changing loop
|
||||
daily_reset.start() # Start the daily reset loop
|
||||
|
||||
# Start Flask in a separate thread
|
||||
flask_thread = threading.Thread(target=run_flask)
|
||||
@@ -703,4 +750,4 @@ flask_thread.start()
|
||||
# Main bot startup
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
bot.run(TOKEN)
|
||||
bot.run(TOKEN)
|
||||
|
||||
@@ -20,6 +20,9 @@ from bot import (
|
||||
save_history,
|
||||
get_user_model,
|
||||
save_user_model,
|
||||
get_remaining_turns,
|
||||
update_remaining_turns,
|
||||
reset_remaining_turns,
|
||||
bot,
|
||||
process_request,
|
||||
process_queue,
|
||||
@@ -27,6 +30,7 @@ from bot import (
|
||||
search,
|
||||
web,
|
||||
reset,
|
||||
remaining_turns,
|
||||
help_command,
|
||||
should_respond_to_message,
|
||||
handle_user_message,
|
||||
@@ -34,6 +38,7 @@ from bot import (
|
||||
generate_image,
|
||||
_generate_image_command,
|
||||
change_status,
|
||||
daily_reset,
|
||||
on_ready
|
||||
)
|
||||
|
||||
@@ -108,6 +113,15 @@ class TestFullBot(unittest.TestCase):
|
||||
def test_reset_command(self):
|
||||
self.loop.run_until_complete(self.async_test_reset())
|
||||
|
||||
async def async_test_remaining_turns(self):
|
||||
interaction = AsyncMock()
|
||||
interaction.user.id = 1234
|
||||
await remaining_turns.callback(interaction)
|
||||
interaction.response.send_message.assert_called()
|
||||
|
||||
def test_remaining_turns_command(self):
|
||||
self.loop.run_until_complete(self.async_test_remaining_turns())
|
||||
|
||||
async def async_test_help_command(self):
|
||||
interaction = AsyncMock()
|
||||
await help_command(interaction)
|
||||
@@ -153,3 +167,31 @@ class TestFullBot(unittest.TestCase):
|
||||
model = get_user_model(1234)
|
||||
self.assertEqual(model, "gpt-4o")
|
||||
|
||||
async def async_test_daily_reset(self):
|
||||
await daily_reset()
|
||||
# Verify that the reset_remaining_turns function was called
|
||||
with patch("bot.reset_remaining_turns") as mock_reset:
|
||||
await daily_reset()
|
||||
mock_reset.assert_called_once()
|
||||
|
||||
def test_daily_reset_task(self):
|
||||
self.loop.run_until_complete(self.async_test_daily_reset())
|
||||
|
||||
def test_get_remaining_turns(self):
|
||||
with patch("bot.db.chat_turns.find_one", return_value={"remaining_turns": 5}):
|
||||
remaining_turns = get_remaining_turns(1234, "gpt-4o")
|
||||
self.assertEqual(remaining_turns, 5, "Should return the correct remaining turns from the database.")
|
||||
|
||||
def test_update_remaining_turns(self):
|
||||
with patch("bot.db.chat_turns.update_one") as mock_update:
|
||||
update_remaining_turns(1234, "gpt-4o", 10)
|
||||
mock_update.assert_called_once_with(
|
||||
{'user_id': 1234, 'model': "gpt-4o"},
|
||||
{'$set': {'remaining_turns': 10}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
def test_reset_remaining_turns(self):
|
||||
with patch("bot.db.chat_turns.update_many") as mock_update:
|
||||
reset_remaining_turns()
|
||||
mock_update.assert_called_once_with({}, {'$set': {'remaining_turns': 100}})
|
||||
|
||||
Reference in New Issue
Block a user