Compare commits
16 Commits
2.0.4
...
copilot/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
455360bfae | ||
|
|
f36424b3de | ||
|
|
d87dd0072c | ||
|
|
01079f2751 | ||
| 41e7c6bac5 | |||
| 211e8dc6c7 | |||
| 6634ea7ad6 | |||
| 0c5c27a660 | |||
| db90774e17 | |||
| d549c43844 | |||
| 668f56020d | |||
|
|
fd143aba17 | ||
| f17081b185 | |||
| e2b961e9c0 | |||
|
|
c9531b37bb | ||
| 86875bf93d |
28
.env.example
28
.env.example
@@ -18,6 +18,11 @@ OPENAI_API_KEY=your_openai_api_key_here
|
||||
# Use OpenAI directly: https://api.openai.com/v1
|
||||
OPENAI_BASE_URL=https://models.github.ai/inference
|
||||
|
||||
# Anthropic API Key (for Claude models)
|
||||
# Get from: https://console.anthropic.com/
|
||||
# Leave empty to disable Claude models
|
||||
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||
|
||||
# ============================================
|
||||
# Image Generation (Optional)
|
||||
# ============================================
|
||||
@@ -88,3 +93,26 @@ TIMEZONE=UTC
|
||||
# 168 = 1 week
|
||||
# -1 = Never expire (permanent storage)
|
||||
FILE_EXPIRATION_HOURS=48
|
||||
|
||||
# ============================================
|
||||
# Monitoring & Observability (Optional)
|
||||
# ============================================
|
||||
|
||||
# Sentry DSN for error tracking
|
||||
# Get from: https://sentry.io/ (create a project and copy the DSN)
|
||||
# Leave empty to disable Sentry error tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Environment name for Sentry (development, staging, production)
|
||||
ENVIRONMENT=development
|
||||
|
||||
# Sentry sample rate (0.0 to 1.0) - percentage of errors to capture
|
||||
# 1.0 = 100% of errors, 0.5 = 50% of errors
|
||||
SENTRY_SAMPLE_RATE=1.0
|
||||
|
||||
# Sentry traces sample rate for performance monitoring (0.0 to 1.0)
|
||||
# 0.1 = 10% of transactions, lower values recommended for high-traffic bots
|
||||
SENTRY_TRACES_RATE=0.1
|
||||
|
||||
# Log level (DEBUG, INFO, WARNING, ERROR)
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
31
.github/workflows/deploy.yml
vendored
Normal file
31
.github/workflows/deploy.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name: Deploy ChatGPT-Discord-Bot
|
||||
on:
|
||||
workflow_dispatch:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: quocanh
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
# cd to /vps/chatgptdsc and do docker compose down then docker compose pull the docker compose up -d
|
||||
steps:
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: cd to deployment directory
|
||||
run: cd /home/vps/chatgptdsc
|
||||
|
||||
- name: Pull latest images
|
||||
run: docker compose -f /home/vps/chatgptdsc/docker-compose.yml pull
|
||||
|
||||
- name: Stop existing services
|
||||
run: docker compose -f /home/vps/chatgptdsc/docker-compose.yml down
|
||||
|
||||
- name: Start services
|
||||
run: docker compose -f /home/vps/chatgptdsc/docker-compose.yml up -d
|
||||
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Build and Deploy ChatGPT-Discord-Bot
|
||||
name: Build ChatGPT-Discord-Bot
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -2,8 +2,8 @@ test.py
|
||||
.env
|
||||
chat_history.db
|
||||
bot_copy.py
|
||||
__pycache__/bot.cpython-312.pyc
|
||||
tests/__pycache__/test_bot.cpython-312.pyc
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.vscode/settings.json
|
||||
chatgpt.zip
|
||||
response.txt
|
||||
@@ -12,4 +12,5 @@ venv
|
||||
temp_charts
|
||||
.idea
|
||||
temp_data_files
|
||||
logs/
|
||||
logs/
|
||||
.pytest_cache/
|
||||
695
LICENSE
695
LICENSE
@@ -1,674 +1,21 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 coder-vippro
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
982
README.md
982
README.md
@@ -1,181 +1,817 @@
|
||||
# ChatGPT Discord Bot
|
||||

|
||||
<div align="center">
|
||||
|
||||
## Overview
|
||||
Welcome to **ChatGPT Discord Bot**! This bot provides a powerful AI assistant for Discord users, powered by OpenAI's latest models. It not only generates text responses but also offers a range of advanced features including image generation, data analysis, web searching, and reminders. The bot is designed for easy deployment with Docker and includes CI/CD integration via GitHub Actions.
|
||||
# 🤖 ChatGPT Discord Bot
|
||||
|
||||
## Features
|
||||
- **Advanced AI Conversations**: Uses OpenAI's latest models (including openai/gpt-4o) for natural language interactions
|
||||
- **Image Generation**: Creates custom images from text prompts using Runware's API
|
||||
- **Data Analysis**: Analyzes CSV and Excel files with visualizations (distributions, correlations, box plots, etc.)
|
||||
- **Code Interpretation**: Executes Python code for calculations and data processing
|
||||
- **Package Installation**: Automatically installs required Python packages for code execution
|
||||
- **Code Display**: Shows executed code, input, and output in Discord chat for transparency
|
||||
- **Secure Sandbox**: Runs code in a controlled environment with safety restrictions
|
||||
- **Reminder System**: Sets timed reminders with custom timezone support
|
||||
- **Web Tools**:
|
||||
- **Google Search**: Searches the web and provides relevant information
|
||||
- **Web Scraping**: Extracts and summarizes content from websites
|
||||
- **PDF Analysis**: Processes and analyzes PDF documents
|
||||
- **User Statistics**: Tracks token usage and model selection per user
|
||||
- **Dockerized Deployment**: Ready for easy deployment with Docker
|
||||
- **Automated CI/CD**: Integrated with GitHub Actions
|
||||
### *Your AI-Powered Assistant with Code Interpreter & Advanced File Management*
|
||||
|
||||
## Prerequisites
|
||||
To get started, ensure you have:
|
||||
- Docker (for containerized deployment)
|
||||
- Python 3.12.7
|
||||
- Discord Bot Token
|
||||
- OpenAI API Key
|
||||
- Runware API Key ([Get yours at Runware](https://runware.ai/))
|
||||
- Google API Key and Custom Search Engine ID (CX)
|
||||
- MongoDB URL (Get from https://cloud.mongodb.com/)
|
||||
[](https://github.com/coder-vippro/ChatGPT-Discord-Bot/actions)
|
||||
[](https://github.com/Coder-Vippro/ChatGPT-Discord-Bot/releases)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](LICENSE)
|
||||
[](https://discord.com)
|
||||
|
||||
## Setup
|
||||
### For Normal Use
|
||||
#### Option A: Deploy with Docker
|
||||
1. Create a `.env` file in the root directory with your configuration:
|
||||
```properties
|
||||
DISCORD_TOKEN=your_discord_token
|
||||
OPENAI_API_KEY=your_openai_api_key
|
||||
RUNWARE_API_KEY=your_runware_api_key
|
||||
GOOGLE_API_KEY=your_google_api_key
|
||||
GOOGLE_CX=your_google_cx
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1/models
|
||||
MONGODB_URI=mongodb://localhost:27017/
|
||||
ADMIN_ID=your_discord_user_id
|
||||
TIMEZONE=Asia/Ho_Chi_Minh
|
||||
```
|
||||
[Features](#-features) • [Quick Start](#-quick-start) • [Documentation](#-documentation) • [Support](#-support)
|
||||
|
||||
2. Use the following `docker-compose.yml`:
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
bot:
|
||||
image: ghcr.io/coder-vippro/chatgpt-discord-bot:latest
|
||||
env_file:
|
||||
- .env
|
||||
restart: always
|
||||
```
|
||||
|
||||
3. Start the bot with:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### Option B: Deploy Without Docker
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/Coder-Vippro/ChatGPT-Discord-Bot.git
|
||||
cd ChatGPT-Discord-Bot
|
||||
```
|
||||
|
||||
2. Create a `.env` file in the root directory with your configuration:
|
||||
```properties
|
||||
DISCORD_TOKEN=your_discord_token
|
||||
OPENAI_API_KEY=your_openai_api_key
|
||||
RUNWARE_API_KEY=your_runware_api_key
|
||||
GOOGLE_API_KEY=your_google_api_key
|
||||
GOOGLE_CX=your_google_cx
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1/models
|
||||
MONGODB_URI=mongodb://localhost:27017/
|
||||
ADMIN_ID=your_discord_user_id
|
||||
TIMEZONE=Asia/Ho_Chi_Minh
|
||||
```
|
||||
|
||||
3. Install the dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Run the bot:
|
||||
```bash
|
||||
python3 bot.py
|
||||
```
|
||||
|
||||
### For Development
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/Coder-Vippro/ChatGPT-Discord-Bot.git
|
||||
cd ChatGPT-Discord-Bot
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Run the bot:
|
||||
```bash
|
||||
python3 bot.py
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
1. Install test dependencies:
|
||||
```bash
|
||||
pip install pytest
|
||||
```
|
||||
|
||||
2. Run tests:
|
||||
```bash
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
## Usage
|
||||
Once the bot is running, it connects to Discord using credentials from `.env`. Available features include:
|
||||
|
||||
### Text Commands
|
||||
- **Normal chat**: Ping the bot with a question or send a DM to start a conversation
|
||||
- **Image Generation**: `/generate prompt: "A futuristic cityscape"`
|
||||
- **Web Content**: `/web url: "https://example.com"`
|
||||
- **Google Search**: `/search prompt: "latest news in Vietnam"`
|
||||
- **User Statistics**: `/user_stat` - Get your token usage and model information
|
||||
|
||||
### Advanced Features
|
||||
- **Data Analysis**: Upload CSV or Excel files for automatic analysis and visualization
|
||||
- **Code Execution**: The bot can execute Python code to solve problems or create visualizations
|
||||
- **Reminders**: Ask the bot to set reminders like "Remind me to check email in 30 minutes"
|
||||
- **PDF Analysis**: Upload PDF documents for the bot to analyze and summarize
|
||||
|
||||
### Available Models
|
||||
The bot supports the following models:
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-5
|
||||
- openai/gpt-5-nano
|
||||
- openai/gpt-5-mini
|
||||
- openai/gpt-5-chat
|
||||
- openai/o1-preview
|
||||
- openai/o1-mini
|
||||
- openai/o1
|
||||
- openai/o3-mini
|
||||
|
||||
## Environment Variables
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| DISCORD_TOKEN | Your Discord bot token | Required |
|
||||
| OPENAI_API_KEY | Your OpenAI API key | Required |
|
||||
| RUNWARE_API_KEY | Runware API key for image generation | Required |
|
||||
| GOOGLE_API_KEY | Google API key for search | Required |
|
||||
| GOOGLE_CX | Google Custom Search Engine ID | Required |
|
||||
| MONGODB_URI | MongoDB connection string | Required |
|
||||
| ADMIN_ID | Discord user ID of the admin | Optional |
|
||||
| TIMEZONE | Timezone for reminder feature | UTC |
|
||||
| ENABLE_WEBHOOK_LOGGING | Enable webhook logging | False |
|
||||
| LOGGING_WEBHOOK_URL | URL for webhook logging | Optional |
|
||||
|
||||
## CI/CD
|
||||
This project uses GitHub Actions for CI/CD, with workflows in `.github/workflows`.
|
||||
|
||||
## Security
|
||||
For supported versions and vulnerability reporting, see [SECURITY.md](SECURITY.md).
|
||||
|
||||
## Contributing
|
||||
Please read our [Code of Conduct](CODE_OF_CONDUCT.md) before contributing to this project.
|
||||
|
||||
## License
|
||||
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
Made with ❤️ by [coder-vippro](https://github.com/coder-vippro)
|
||||
## 🌟 Overview
|
||||
|
||||
**ChatGPT Discord Bot** brings the power of AI directly to your Discord server! Powered by OpenAI's latest models and Anthropic's Claude, this bot goes beyond simple chat - it's a complete AI assistant with **code interpretation**, **file management**, **data analysis**, and much more.
|
||||
|
||||
### 🎯 What Makes This Bot Special?
|
||||
|
||||
- 🧠 **Latest AI Models** - GPT-4o, GPT-5, o1, o3-mini, Claude 4, and more
|
||||
- 💻 **Code Interpreter** - Execute Python code like ChatGPT (NEW in v2.0!)
|
||||
- 📁 **Smart File Management** - Handle 200+ file types with automatic cleanup
|
||||
- 📊 **Data Analysis** - Upload and analyze CSV, Excel, and scientific data
|
||||
- 🎨 **Image Generation** - Create stunning images from text prompts
|
||||
- 🔍 **Web Tools** - Search Google and scrape websites
|
||||
- ⏰ **Reminder System** - Never forget important tasks
|
||||
- 🐳 **Docker Ready** - One-command deployment
|
||||
|
||||
---
|
||||
|
||||
## ✨ Features
|
||||
|
||||
### 🆕 New in Version 2.0.0
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
|
||||
#### 💻 **Unified Code Interpreter**
|
||||
Execute Python code directly in Discord! Similar to ChatGPT's code interpreter.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
df = load_file('your_file_id')
|
||||
print(df.describe())
|
||||
plt.plot(df['column'])
|
||||
plt.savefig('plot.png')
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- ✅ Auto-install packages
|
||||
- ✅ Sandboxed execution
|
||||
- ✅ File output capture
|
||||
- ✅ 5-minute timeout protection
|
||||
|
||||
</td>
|
||||
<td width="50%">
|
||||
|
||||
#### 📁 **Advanced File Management**
|
||||
Upload, store, and process files with intelligent lifecycle management.
|
||||
|
||||
**Supports 200+ file types:**
|
||||
- 📊 Data: CSV, Excel, JSON, Parquet
|
||||
- 🖼️ Images: PNG, JPEG, GIF, SVG, PSD
|
||||
- 📄 Documents: PDF, DOCX, Markdown
|
||||
- 🔬 Scientific: MATLAB, HDF5, NumPy
|
||||
- 🎵 Media: Audio, Video formats
|
||||
- And many more!
|
||||
|
||||
**Smart Features:**
|
||||
- Auto-expiration (configurable)
|
||||
- Per-user storage limits
|
||||
- `/files` command for management
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 🎨 **Image Generation**
|
||||
|
||||
Generate stunning visuals from text prompts using Runware AI:
|
||||
|
||||
```
|
||||
/generate prompt: A futuristic cyberpunk city at night with neon lights
|
||||
```
|
||||
|
||||
- High-quality outputs
|
||||
- Fast generation (2-5 seconds)
|
||||
- Multiple style support
|
||||
|
||||
### 📊 **Data Analysis & Visualization**
|
||||
|
||||
Upload your data files and get instant insights:
|
||||
|
||||
```
|
||||
📈 Statistical Analysis
|
||||
• Descriptive statistics
|
||||
• Correlation matrices
|
||||
• Distribution plots
|
||||
• Custom visualizations
|
||||
|
||||
📉 Supported Formats
|
||||
• CSV, TSV, Excel
|
||||
• JSON, Parquet, Feather
|
||||
• SPSS, Stata, SAS
|
||||
• And 50+ more formats
|
||||
```
|
||||
|
||||
### 🔍 **Web Tools**
|
||||
|
||||
- **Google Search** - Get up-to-date information from the web
|
||||
- **Web Scraping** - Extract and summarize website content
|
||||
- **PDF Analysis** - Process and analyze PDF documents
|
||||
|
||||
### 🤖 **AI Conversation**
|
||||
|
||||
- Natural language understanding
|
||||
- Context-aware responses
|
||||
- Time-zone aware (knows current date/time)
|
||||
- Multi-turn conversations
|
||||
- DM and server support
|
||||
|
||||
### ⏰ **Reminder System**
|
||||
|
||||
Set reminders naturally:
|
||||
```
|
||||
"Remind me to check email in 30 minutes"
|
||||
"Set a reminder for tomorrow at 3pm"
|
||||
"Remind me about the meeting in 2 hours"
|
||||
```
|
||||
|
||||
### 🎯 **Supported AI Models**
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
**GPT-4 Series**
|
||||
- `gpt-4o`
|
||||
- `gpt-4o-mini`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
**GPT-5 Series**
|
||||
- `gpt-5`
|
||||
- `gpt-5-mini`
|
||||
- `gpt-5-nano`
|
||||
- `gpt-5-chat`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
**o1/o3 Series**
|
||||
- `o1-preview`
|
||||
- `o1-mini`
|
||||
- `o1`
|
||||
- `o3-mini`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
**Claude (Anthropic)**
|
||||
- `claude-sonnet-4-20250514`
|
||||
- `claude-opus-4-20250514`
|
||||
- `claude-3.5-sonnet`
|
||||
- `claude-3.5-haiku`
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before you begin, ensure you have:
|
||||
|
||||
- 🐳 **Docker** (recommended) or Python 3.13+
|
||||
- 🎮 **Discord Bot Token** ([Create one here](https://discord.com/developers/applications))
|
||||
- 🔑 **OpenAI API Key** ([Get it here](https://platform.openai.com/api-keys))
|
||||
- 🧠 **Anthropic API Key** (Optional, for Claude models - [Get it here](https://console.anthropic.com/))
|
||||
- 🎨 **Runware API Key** ([Sign up here](https://runware.ai/))
|
||||
- 🔍 **Google API Key** ([Google Cloud Console](https://console.cloud.google.com/))
|
||||
- 🗄️ **MongoDB** ([MongoDB Atlas](https://cloud.mongodb.com/) - Free tier available)
|
||||
|
||||
### 🐳 Option A: Docker Deployment (Recommended)
|
||||
|
||||
**Step 1:** Create `.env` file in your project directory
|
||||
|
||||
```env
|
||||
# Discord Configuration
|
||||
DISCORD_TOKEN=your_discord_bot_token_here
|
||||
|
||||
# AI Provider Keys
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Anthropic (Claude) - Optional
|
||||
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||
|
||||
# Image Generation
|
||||
RUNWARE_API_KEY=your_runware_api_key_here
|
||||
|
||||
# Google Search
|
||||
GOOGLE_API_KEY=your_google_api_key_here
|
||||
GOOGLE_CX=your_custom_search_engine_id_here
|
||||
|
||||
# Database
|
||||
MONGODB_URI=your_mongodb_connection_string_here
|
||||
|
||||
# Bot Configuration
|
||||
ADMIN_ID=your_discord_user_id
|
||||
TIMEZONE=Asia/Ho_Chi_Minh
|
||||
|
||||
# File Management (NEW in v2.0)
|
||||
MAX_FILES_PER_USER=20
|
||||
FILE_EXPIRATION_HOURS=48
|
||||
|
||||
# Code Execution (NEW in v2.0)
|
||||
CODE_EXECUTION_TIMEOUT=300
|
||||
```
|
||||
|
||||
**Step 2:** Create `docker-compose.yml`
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
bot:
|
||||
image: ghcr.io/coder-vippro/chatgpt-discord-bot:latest
|
||||
container_name: chatgpt-discord-bot
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./data/user_files:/tmp/bot_code_interpreter/user_files
|
||||
- ./data/outputs:/tmp/bot_code_interpreter/outputs
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2.0'
|
||||
memory: 4G
|
||||
```
|
||||
|
||||
**Step 3:** Start the bot
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
**Step 4:** Check logs
|
||||
|
||||
```bash
|
||||
docker-compose logs -f bot
|
||||
```
|
||||
|
||||
✅ **Done!** Your bot is now running!
|
||||
|
||||
---
|
||||
|
||||
### 💻 Option B: Local Deployment
|
||||
|
||||
**Step 1:** Clone the repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Coder-Vippro/ChatGPT-Discord-Bot.git
|
||||
cd ChatGPT-Discord-Bot
|
||||
```
|
||||
|
||||
**Step 2:** Create and configure `.env` file
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys and configuration
|
||||
```
|
||||
|
||||
**Step 3:** Install dependencies
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Step 4:** Run the bot
|
||||
|
||||
```bash
|
||||
python3 bot.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 Usage Guide
|
||||
|
||||
### 💬 Basic Chat
|
||||
|
||||
Simply mention the bot or DM it:
|
||||
|
||||
```
|
||||
@Bot Hello! How can you help me?
|
||||
```
|
||||
|
||||
### 🎨 Image Generation
|
||||
|
||||
Use the `/generate` command:
|
||||
|
||||
```
|
||||
/generate prompt: A serene Japanese garden with cherry blossoms
|
||||
```
|
||||
|
||||
### 📁 File Upload & Analysis
|
||||
|
||||
1. **Upload a file** - Drag and drop any file into the chat
|
||||
2. **Get file ID** - Bot confirms upload with file ID
|
||||
3. **Analyze** - Ask the bot to analyze your data
|
||||
|
||||
```
|
||||
User: *uploads data.csv*
|
||||
Bot: 📊 File Uploaded: data.csv
|
||||
🆔 File ID: 123456789_1234567890_abc123
|
||||
|
||||
User: Analyze this data and create visualizations
|
||||
Bot: *executes code and generates plots*
|
||||
```
|
||||
|
||||
### 💻 Code Execution
|
||||
|
||||
Ask the bot to write and execute code:
|
||||
|
||||
```
|
||||
User: Calculate the fibonacci sequence up to 100 and plot it
|
||||
|
||||
Bot: I'll calculate and plot the Fibonacci sequence for you.
|
||||
|
||||
```python
|
||||
def fibonacci(n):
|
||||
sequence = [0, 1]
|
||||
while sequence[-1] < n:
|
||||
sequence.append(sequence[-1] + sequence[-2])
|
||||
return sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
fib = fibonacci(100)
|
||||
plt.plot(fib)
|
||||
plt.title('Fibonacci Sequence')
|
||||
plt.savefig('fibonacci.png')
|
||||
print(f"Generated {len(fib)} numbers")
|
||||
```
|
||||
|
||||
✅ Output: Generated 12 numbers
|
||||
📊 Generated file: fibonacci.png
|
||||
```
|
||||
|
||||
### 📋 File Management
|
||||
|
||||
Use the `/files` command to manage your uploaded files:
|
||||
|
||||
```
|
||||
/files
|
||||
```
|
||||
|
||||
This shows:
|
||||
- List of all your files
|
||||
- File sizes and types
|
||||
- Expiration dates
|
||||
- Delete option
|
||||
|
||||
### 🔍 Web Search
|
||||
|
||||
```
|
||||
/search prompt: Latest AI developments 2025
|
||||
```
|
||||
|
||||
### 🌐 Web Scraping
|
||||
|
||||
```
|
||||
/web url: https://example.com/article
|
||||
```
|
||||
|
||||
### 📊 User Statistics
|
||||
|
||||
```
|
||||
/user_stat
|
||||
```
|
||||
|
||||
Shows your token usage and model preferences.
|
||||
|
||||
### 🔄 Reset Conversation
|
||||
|
||||
```
|
||||
/reset
|
||||
```
|
||||
|
||||
Clears conversation history and deletes all uploaded files.
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
<details>
|
||||
<summary><b>Click to expand full configuration options</b></summary>
|
||||
|
||||
#### Required Variables
|
||||
|
||||
| Variable | Description | Example |
|
||||
|----------|-------------|---------|
|
||||
| `DISCORD_TOKEN` | Your Discord bot token | `MTIzNDU2Nzg5MDEyMzQ1Njc4OQ...` |
|
||||
| `OPENAI_API_KEY` | OpenAI API key | `sk-proj-...` |
|
||||
| `RUNWARE_API_KEY` | Runware API key for images | `rw_...` |
|
||||
| `GOOGLE_API_KEY` | Google API key | `AIza...` |
|
||||
| `GOOGLE_CX` | Custom Search Engine ID | `a1b2c3d4e5f6g7h8i9` |
|
||||
| `MONGODB_URI` | MongoDB connection string | `mongodb://localhost:27017/` |
|
||||
|
||||
#### Optional Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `OPENAI_BASE_URL` | OpenAI API base URL | `https://api.openai.com/v1` |
|
||||
| `ADMIN_ID` | Discord user ID for admin | None |
|
||||
| `TIMEZONE` | Timezone for reminders | `UTC` |
|
||||
| `MAX_FILES_PER_USER` | Max files per user | `20` |
|
||||
| `FILE_EXPIRATION_HOURS` | File expiration time | `48` |
|
||||
| `CODE_EXECUTION_TIMEOUT` | Code timeout in seconds | `300` |
|
||||
| `ENABLE_WEBHOOK_LOGGING` | Enable webhook logs | `False` |
|
||||
| `LOGGING_WEBHOOK_URL` | Webhook URL for logs | None |
|
||||
|
||||
</details>
|
||||
|
||||
### File Management Settings
|
||||
|
||||
```env
|
||||
# Maximum files each user can upload
|
||||
MAX_FILES_PER_USER=20
|
||||
|
||||
# Hours until files expire and are auto-deleted
|
||||
# Set to -1 for permanent storage (no expiration)
|
||||
FILE_EXPIRATION_HOURS=48
|
||||
```
|
||||
|
||||
### Code Execution Settings
|
||||
|
||||
```env
|
||||
# Maximum time for code execution (in seconds)
|
||||
CODE_EXECUTION_TIMEOUT=300
|
||||
|
||||
# Package cleanup period (in code_interpreter.py)
|
||||
PACKAGE_CLEANUP_DAYS=7
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
### 📖 Comprehensive Guides
|
||||
|
||||
- [🚀 Quick Start Guide](docs/QUICK_REFERENCE.md)
|
||||
- [📁 File Management Guide](docs/FILE_MANAGEMENT_GUIDE.md)
|
||||
- [💻 Code Interpreter Guide](docs/CODE_INTERPRETER_GUIDE.md)
|
||||
- [📦 Package Cleanup Guide](docs/PACKAGE_CLEANUP_GUIDE.md)
|
||||
- [🐳 Docker Deployment Guide](docs/DOCKER_DEPLOYMENT_GUIDE.md)
|
||||
- [⚙️ Environment Setup Guide](docs/ENV_SETUP_GUIDE.md)
|
||||
|
||||
### 🆕 What's New in v2.0
|
||||
|
||||
- [📋 Release Notes v2.0.0](RELEASE_NOTES_v2.0.0.md)
|
||||
- [📝 Complete Implementation Summary](docs/COMPLETE_IMPLEMENTATION_SUMMARY.md)
|
||||
- [🔧 All File Types & Timeout Update](docs/ALL_FILE_TYPES_AND_TIMEOUT_UPDATE.md)
|
||||
- [🐛 Bug Fixes Documentation](docs/BUGFIX_DATABASE_METHODS.md)
|
||||
|
||||
### 🛠️ Technical Documentation
|
||||
|
||||
- [🏗️ Architecture Overview](docs/UNIFIED_FILE_SYSTEM_SUMMARY.md)
|
||||
- [📊 Token Counting Guide](docs/TOKEN_COUNTING_GUIDE.md)
|
||||
- [🕐 Current Time in Context](docs/CURRENT_TIME_IN_CONTEXT.md)
|
||||
- [🔒 Security Guidelines](SECURITY.md)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Examples
|
||||
|
||||
### Example 1: Data Analysis
|
||||
|
||||
```
|
||||
User: *uploads sales_data.csv*
|
||||
|
||||
User: Analyze this sales data and show me:
|
||||
1. Monthly revenue trends
|
||||
2. Top 5 products
|
||||
3. Customer distribution by region
|
||||
|
||||
Bot: I'll analyze your sales data and create visualizations.
|
||||
|
||||
[Executes Python code with pandas and matplotlib]
|
||||
[Generates 3 charts]
|
||||
|
||||
📊 Analysis Results:
|
||||
- Total Revenue: $1.2M
|
||||
- Best Month: December ($250K)
|
||||
- Top Product: Product A ($300K)
|
||||
- Main Region: North America (45%)
|
||||
|
||||
📈 Generated 3 visualizations
|
||||
```
|
||||
|
||||
### Example 2: Image Generation
|
||||
|
||||
```
|
||||
User: /generate prompt: A serene Japanese garden with cherry blossoms, koi pond, and traditional temple in the background, anime style
|
||||
|
||||
Bot: 🎨 Generating your image...
|
||||
[2 seconds later]
|
||||
🖼️ Here's your generated image!
|
||||
[Beautiful anime-style Japanese garden image]
|
||||
```
|
||||
|
||||
### Example 3: Code Execution
|
||||
|
||||
```
|
||||
User: Create a function to check if a number is prime, then find all primes up to 1000
|
||||
|
||||
Bot: I'll create a prime checker and find all primes up to 1000.
|
||||
|
||||
```python
|
||||
def is_prime(n):
|
||||
if n < 2:
|
||||
return False
|
||||
for i in range(2, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
primes = [n for n in range(2, 1001) if is_prime(n)]
|
||||
print(f"Found {len(primes)} prime numbers")
|
||||
print(f"First 10: {primes[:10]}")
|
||||
print(f"Last 10: {primes[-10:]}")
|
||||
```
|
||||
|
||||
✅ Output:
|
||||
Found 168 prime numbers
|
||||
First 10: [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
|
||||
Last 10: [937, 941, 947, 953, 967, 971, 977, 983, 991, 997]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions! Here's how you can help:
|
||||
|
||||
### Ways to Contribute
|
||||
|
||||
- 🐛 **Report Bugs** - [Open an issue](https://github.com/Coder-Vippro/ChatGPT-Discord-Bot/issues)
|
||||
- ✨ **Suggest Features** - [Start a discussion](https://github.com/Coder-Vippro/ChatGPT-Discord-Bot/discussions)
|
||||
- 📝 **Improve Docs** - Submit documentation updates
|
||||
- 💻 **Submit Code** - Create pull requests
|
||||
|
||||
### Development Setup
|
||||
|
||||
```bash
|
||||
# Fork and clone the repository
|
||||
git clone https://github.com/YOUR_USERNAME/ChatGPT-Discord-Bot.git
|
||||
cd ChatGPT-Discord-Bot
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Create a branch
|
||||
git checkout -b feature/your-feature-name
|
||||
|
||||
# Make your changes and test
|
||||
python3 bot.py
|
||||
|
||||
# Run tests
|
||||
pytest tests/
|
||||
|
||||
# Commit and push
|
||||
git add .
|
||||
git commit -m "Add your feature"
|
||||
git push origin feature/your-feature-name
|
||||
```
|
||||
|
||||
### Code of Conduct
|
||||
|
||||
Please read our [Code of Conduct](CODE_OF_CONDUCT.md) before contributing.
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
<details>
|
||||
<summary><b>Bot won't start</b></summary>
|
||||
|
||||
**Check:**
|
||||
1. All required environment variables are set
|
||||
2. Discord token is valid
|
||||
3. MongoDB is accessible
|
||||
4. Port 27017 is not blocked (if using local MongoDB)
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Check logs
|
||||
docker-compose logs bot
|
||||
|
||||
# Verify .env file
|
||||
cat .env | grep -v '^#'
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Code execution fails</b></summary>
|
||||
|
||||
**Common causes:**
|
||||
- Package installation timeout
|
||||
- Code exceeds 5-minute timeout
|
||||
- Memory limit exceeded
|
||||
|
||||
**Solutions:**
|
||||
```env
|
||||
# Increase timeout
|
||||
CODE_EXECUTION_TIMEOUT=600
|
||||
|
||||
# In docker-compose.yml, increase memory
|
||||
memory: 8G
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Files not uploading</b></summary>
|
||||
|
||||
**Check:**
|
||||
1. File size (Discord limit: 25MB for free, 500MB for Nitro)
|
||||
2. Storage limit reached (default: 20 files per user)
|
||||
3. Disk space available
|
||||
|
||||
**Solution:**
|
||||
```env
|
||||
# Increase file limit
|
||||
MAX_FILES_PER_USER=50
|
||||
|
||||
# Set permanent storage
|
||||
FILE_EXPIRATION_HOURS=-1
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Docker "Resource busy" error</b></summary>
|
||||
|
||||
This is fixed in v2.0! The bot now uses system Python in Docker.
|
||||
|
||||
**If you still see this error:**
|
||||
```bash
|
||||
# Rebuild from scratch
|
||||
docker-compose down
|
||||
docker-compose build --no-cache
|
||||
docker-compose up -d
|
||||
```
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 📊 Performance
|
||||
|
||||
### System Requirements
|
||||
|
||||
| Deployment | CPU | RAM | Disk | Network |
|
||||
|------------|-----|-----|------|---------|
|
||||
| **Minimal** | 1 core | 2GB | 2GB | 1 Mbps |
|
||||
| **Recommended** | 2 cores | 4GB | 5GB | 10 Mbps |
|
||||
| **High Load** | 4 cores | 8GB | 10GB | 100 Mbps |
|
||||
|
||||
### Benchmarks
|
||||
|
||||
```
|
||||
📈 Response Times (avg):
|
||||
- Simple chat: 1-2 seconds
|
||||
- Code execution: 2-5 seconds
|
||||
- Image generation: 3-5 seconds
|
||||
- Data analysis: 5-10 seconds
|
||||
- File upload: <1 second
|
||||
|
||||
💾 Resource Usage:
|
||||
- Idle: ~200 MB RAM
|
||||
- Active: ~500 MB RAM
|
||||
- Peak: ~2 GB RAM
|
||||
- Docker image: ~600 MB
|
||||
|
||||
🚀 Throughput:
|
||||
- Concurrent users: 50+
|
||||
- Messages/minute: 100+
|
||||
- File uploads/hour: 500+
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔒 Security
|
||||
|
||||
### Security Features
|
||||
|
||||
- ✅ Sandboxed code execution
|
||||
- ✅ Per-user file isolation
|
||||
- ✅ Timeout protection
|
||||
- ✅ Resource limits
|
||||
- ✅ Input validation
|
||||
- ✅ Package validation
|
||||
- ✅ MongoDB injection prevention
|
||||
|
||||
### Reporting Security Issues
|
||||
|
||||
Found a vulnerability? Please **DO NOT** open a public issue.
|
||||
|
||||
See [SECURITY.md](SECURITY.md) for reporting guidelines.
|
||||
|
||||
---
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the **MIT License** - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
---
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
Special thanks to:
|
||||
|
||||
- **[OpenAI](https://openai.com)** - For powering our AI capabilities
|
||||
- **[Runware](https://runware.ai)** - For image generation API
|
||||
- **[Discord.py](https://discordpy.readthedocs.io/)** - For excellent Discord library
|
||||
- **[MongoDB](https://mongodb.com)** - For reliable database services
|
||||
- **All Contributors** - For making this project better
|
||||
|
||||
---
|
||||
|
||||
## 📞 Support & Community
|
||||
|
||||
### Get Help
|
||||
|
||||
- 💬 **Discord Server**: [Join our community](https://discord.gg/yourserver)
|
||||
- 🐛 **GitHub Issues**: [Report bugs](https://github.com/Coder-Vippro/ChatGPT-Discord-Bot/issues)
|
||||
- 💡 **Discussions**: [Share ideas](https://github.com/Coder-Vippro/ChatGPT-Discord-Bot/discussions)
|
||||
|
||||
### Useful Commands
|
||||
|
||||
```bash
|
||||
# View logs
|
||||
docker-compose logs -f bot
|
||||
|
||||
# Restart bot
|
||||
docker-compose restart bot
|
||||
|
||||
# Check file storage
|
||||
du -sh data/user_files/
|
||||
|
||||
# View package cache
|
||||
cat /tmp/bot_code_interpreter/package_cache.json | jq
|
||||
|
||||
# Update to latest version
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📈 Stats & Updates
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
**Latest Release**: v2.0.0 (October 3, 2025)
|
||||
**Active Servers**: Growing daily
|
||||
|
||||
---
|
||||
|
||||
## 🗺️ Roadmap
|
||||
|
||||
### Version 2.1 (Q4 2025)
|
||||
- [ ] Multi-language support
|
||||
- [ ] Voice channel integration
|
||||
- [ ] Usage analytics dashboard
|
||||
- [ ] Advanced reminders (recurring)
|
||||
- [ ] Custom tool creation
|
||||
|
||||
### Version 2.2 (Q1 2026)
|
||||
- [ ] Collaborative code sessions
|
||||
- [ ] Code version history
|
||||
- [ ] Direct database connections
|
||||
- [ ] Mobile companion app
|
||||
- [ ] Workflow automation
|
||||
|
||||
[View full roadmap →](https://github.com/Coder-Vippro/ChatGPT-Discord-Bot/projects)
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
### ⭐ Star Us on GitHub!
|
||||
|
||||
If you find this bot useful, please give it a star! It helps others discover the project.
|
||||
|
||||
---
|
||||
|
||||
Made with ❤️ by [Coder-Vippro](https://github.com/coder-vippro)
|
||||
|
||||
[⬆ Back to Top](#-chatgpt-discord-bot)
|
||||
|
||||
</div>
|
||||
|
||||
20
bot.py
20
bot.py
@@ -17,7 +17,7 @@ from src.config.config import (
|
||||
DISCORD_TOKEN, MONGODB_URI, RUNWARE_API_KEY, STATUSES,
|
||||
LOGGING_CONFIG, ENABLE_WEBHOOK_LOGGING, LOGGING_WEBHOOK_URL,
|
||||
WEBHOOK_LOG_LEVEL, WEBHOOK_APP_NAME, WEBHOOK_BATCH_SIZE,
|
||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP
|
||||
WEBHOOK_FLUSH_INTERVAL, LOG_LEVEL_MAP, ANTHROPIC_API_KEY
|
||||
)
|
||||
|
||||
# Import webhook logger
|
||||
@@ -124,6 +124,20 @@ async def main():
|
||||
logging.error(f"Error initializing OpenAI client: {e}")
|
||||
return
|
||||
|
||||
# Initialize the Anthropic client (for Claude models)
|
||||
anthropic_client = None
|
||||
if ANTHROPIC_API_KEY:
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
|
||||
logging.info("Anthropic client initialized successfully")
|
||||
except ImportError:
|
||||
logging.warning("Anthropic package not installed. Claude models will not be available. Install with: pip install anthropic")
|
||||
except Exception as e:
|
||||
logging.warning(f"Error initializing Anthropic client: {e}. Claude models will not be available.")
|
||||
else:
|
||||
logging.info("ANTHROPIC_API_KEY not set - Claude models will not be available")
|
||||
|
||||
# Global references to objects that need cleanup
|
||||
message_handler = None
|
||||
db_handler = None
|
||||
@@ -191,14 +205,14 @@ async def main():
|
||||
await ctx.send(f"Error: {error_msg}")
|
||||
|
||||
# Initialize message handler
|
||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator)
|
||||
message_handler = MessageHandler(bot, db_handler, openai_client, image_generator, anthropic_client)
|
||||
|
||||
# Attach db_handler to bot for cogs
|
||||
bot.db_handler = db_handler
|
||||
|
||||
# Set up slash commands
|
||||
from src.commands.commands import setup_commands
|
||||
setup_commands(bot, db_handler, openai_client, image_generator)
|
||||
setup_commands(bot, db_handler, openai_client, image_generator, anthropic_client)
|
||||
|
||||
# Load file management commands
|
||||
try:
|
||||
|
||||
266
config/image_config.json
Normal file
266
config/image_config.json
Normal file
@@ -0,0 +1,266 @@
|
||||
{
|
||||
"_comment": "Image Generation Configuration - Add/modify models here",
|
||||
"_version": "2.0.0",
|
||||
|
||||
"settings": {
|
||||
"default_model": "flux",
|
||||
"default_upscale_model": "clarity",
|
||||
"default_background_removal_model": "bria",
|
||||
"connection_timeout": 120,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 2,
|
||||
"output_format": "WEBP",
|
||||
"output_quality": 95
|
||||
},
|
||||
|
||||
"image_models": {
|
||||
"flux": {
|
||||
"model_id": "runware:101@1",
|
||||
"name": "FLUX.1",
|
||||
"description": "High-quality FLUX model for general image generation",
|
||||
"default_width": 1024,
|
||||
"default_height": 1024,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 2048,
|
||||
"max_height": 2048,
|
||||
"step_size": 64,
|
||||
"default_steps": 30,
|
||||
"default_cfg_scale": 7.5,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "general"
|
||||
},
|
||||
"flux-dev": {
|
||||
"model_id": "runware:100@1",
|
||||
"name": "FLUX.1 Dev",
|
||||
"description": "FLUX.1 Development version with more creative outputs",
|
||||
"default_width": 1024,
|
||||
"default_height": 1024,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 2048,
|
||||
"max_height": 2048,
|
||||
"step_size": 64,
|
||||
"default_steps": 25,
|
||||
"default_cfg_scale": 7.0,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "general"
|
||||
},
|
||||
"flux-fill": {
|
||||
"model_id": "runware:102@1",
|
||||
"name": "FLUX Fill",
|
||||
"description": "FLUX model optimized for inpainting and editing",
|
||||
"default_width": 1024,
|
||||
"default_height": 1024,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 2048,
|
||||
"max_height": 2048,
|
||||
"step_size": 64,
|
||||
"default_steps": 30,
|
||||
"default_cfg_scale": 7.5,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "editing"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_id": "civitai:101055@128078",
|
||||
"name": "Stable Diffusion XL",
|
||||
"description": "Stable Diffusion XL for detailed, high-resolution images",
|
||||
"default_width": 1024,
|
||||
"default_height": 1024,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 2048,
|
||||
"max_height": 2048,
|
||||
"step_size": 64,
|
||||
"default_steps": 30,
|
||||
"default_cfg_scale": 7.0,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "general"
|
||||
},
|
||||
"realistic": {
|
||||
"model_id": "civitai:4201@130072",
|
||||
"name": "Realistic Vision",
|
||||
"description": "Photorealistic image generation",
|
||||
"default_width": 768,
|
||||
"default_height": 768,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 1536,
|
||||
"max_height": 1536,
|
||||
"step_size": 64,
|
||||
"default_steps": 35,
|
||||
"default_cfg_scale": 7.5,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "realistic"
|
||||
},
|
||||
"anime": {
|
||||
"model_id": "civitai:4384@128713",
|
||||
"name": "Anime Style",
|
||||
"description": "Anime and illustration style images",
|
||||
"default_width": 768,
|
||||
"default_height": 768,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 1536,
|
||||
"max_height": 1536,
|
||||
"step_size": 64,
|
||||
"default_steps": 28,
|
||||
"default_cfg_scale": 7.0,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "anime"
|
||||
},
|
||||
"dreamshaper": {
|
||||
"model_id": "civitai:4384@128713",
|
||||
"name": "DreamShaper",
|
||||
"description": "Creative and artistic image generation",
|
||||
"default_width": 768,
|
||||
"default_height": 768,
|
||||
"min_width": 512,
|
||||
"min_height": 512,
|
||||
"max_width": 1536,
|
||||
"max_height": 1536,
|
||||
"step_size": 64,
|
||||
"default_steps": 30,
|
||||
"default_cfg_scale": 7.0,
|
||||
"supports_negative_prompt": true,
|
||||
"max_images": 4,
|
||||
"category": "artistic"
|
||||
}
|
||||
},
|
||||
|
||||
"upscale_models": {
|
||||
"clarity": {
|
||||
"model_id": "runware:500@1",
|
||||
"name": "Clarity",
|
||||
"description": "High-quality clarity upscaling",
|
||||
"supported_factors": [2, 4],
|
||||
"max_input_size": 2048,
|
||||
"max_output_size": 4096,
|
||||
"supports_prompts": true
|
||||
},
|
||||
"ccsr": {
|
||||
"model_id": "runware:501@1",
|
||||
"name": "CCSR",
|
||||
"description": "Content-consistent super-resolution upscaling",
|
||||
"supported_factors": [2, 4],
|
||||
"max_input_size": 2048,
|
||||
"max_output_size": 4096,
|
||||
"supports_prompts": true
|
||||
},
|
||||
"sd-latent": {
|
||||
"model_id": "runware:502@1",
|
||||
"name": "SD Latent Upscaler",
|
||||
"description": "Stable Diffusion latent space upscaling",
|
||||
"supported_factors": [2],
|
||||
"max_input_size": 2048,
|
||||
"max_output_size": 4096,
|
||||
"supports_prompts": true
|
||||
},
|
||||
"swinir": {
|
||||
"model_id": "runware:503@1",
|
||||
"name": "SwinIR",
|
||||
"description": "Fast and efficient SwinIR upscaling (supports 4x)",
|
||||
"supported_factors": [2, 4],
|
||||
"max_input_size": 2048,
|
||||
"max_output_size": 4096,
|
||||
"supports_prompts": false
|
||||
}
|
||||
},
|
||||
|
||||
"background_removal_models": {
|
||||
"bria": {
|
||||
"model_id": "runware:110@1",
|
||||
"name": "Bria RMBG 2.0",
|
||||
"description": "High-quality background removal by Bria",
|
||||
"supports_alpha_matting": false
|
||||
},
|
||||
"rembg": {
|
||||
"model_id": "runware:109@1",
|
||||
"name": "RemBG 1.4",
|
||||
"description": "Classic RemBG with alpha matting support",
|
||||
"supports_alpha_matting": true
|
||||
},
|
||||
"birefnet-base": {
|
||||
"model_id": "runware:112@1",
|
||||
"name": "BiRefNet Base",
|
||||
"description": "BiRefNet base model for background removal",
|
||||
"supports_alpha_matting": false
|
||||
},
|
||||
"birefnet-general": {
|
||||
"model_id": "runware:112@5",
|
||||
"name": "BiRefNet General",
|
||||
"description": "BiRefNet general purpose model",
|
||||
"supports_alpha_matting": false
|
||||
},
|
||||
"birefnet-portrait": {
|
||||
"model_id": "runware:112@10",
|
||||
"name": "BiRefNet Portrait",
|
||||
"description": "BiRefNet optimized for portraits",
|
||||
"supports_alpha_matting": false
|
||||
}
|
||||
},
|
||||
|
||||
"controlnet_models": {
|
||||
"flux-canny": {
|
||||
"model_id": "runware:25@1",
|
||||
"name": "FLUX Canny",
|
||||
"description": "Edge detection control for FLUX models",
|
||||
"architecture": "flux"
|
||||
},
|
||||
"flux-depth": {
|
||||
"model_id": "runware:27@1",
|
||||
"name": "FLUX Depth",
|
||||
"description": "Depth map control for FLUX models",
|
||||
"architecture": "flux"
|
||||
},
|
||||
"flux-pose": {
|
||||
"model_id": "runware:29@1",
|
||||
"name": "FLUX Pose",
|
||||
"description": "Pose control for FLUX models",
|
||||
"architecture": "flux"
|
||||
},
|
||||
"sdxl-canny": {
|
||||
"model_id": "runware:20@1",
|
||||
"name": "SDXL Canny",
|
||||
"description": "Edge detection control for SDXL models",
|
||||
"architecture": "sdxl"
|
||||
},
|
||||
"sd15-canny": {
|
||||
"model_id": "civitai:38784@44716",
|
||||
"name": "SD 1.5 Canny",
|
||||
"description": "Edge detection control for SD 1.5 models",
|
||||
"architecture": "sd15"
|
||||
},
|
||||
"sd15-lineart": {
|
||||
"model_id": "civitai:38784@44877",
|
||||
"name": "SD 1.5 Line Art",
|
||||
"description": "Line art control for SD 1.5 models",
|
||||
"architecture": "sd15"
|
||||
}
|
||||
},
|
||||
|
||||
"default_negative_prompts": {
|
||||
"general": "blurry, distorted, low quality, watermark, signature, text, bad anatomy, deformed",
|
||||
"realistic": "cartoon, anime, illustration, painting, drawing, bad anatomy, deformed, blurry, low quality",
|
||||
"anime": "realistic, photo, 3d render, bad anatomy, deformed hands, extra fingers, blurry",
|
||||
"artistic": "bad quality, low resolution, blurry, watermark, signature"
|
||||
},
|
||||
|
||||
"aspect_ratios": {
|
||||
"1:1": {"width": 1024, "height": 1024, "description": "Square"},
|
||||
"16:9": {"width": 1344, "height": 768, "description": "Landscape Wide"},
|
||||
"9:16": {"width": 768, "height": 1344, "description": "Portrait Tall"},
|
||||
"4:3": {"width": 1152, "height": 896, "description": "Landscape"},
|
||||
"3:4": {"width": 896, "height": 1152, "description": "Portrait"},
|
||||
"3:2": {"width": 1248, "height": 832, "description": "Photo Landscape"},
|
||||
"2:3": {"width": 832, "height": 1248, "description": "Photo Portrait"},
|
||||
"21:9": {"width": 1536, "height": 640, "description": "Ultrawide"}
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,50 @@
|
||||
discord.py
|
||||
openai
|
||||
motor
|
||||
pymongo
|
||||
pypdf
|
||||
beautifulsoup4
|
||||
requests
|
||||
aiohttp
|
||||
runware
|
||||
python-dotenv
|
||||
matplotlib
|
||||
pandas
|
||||
openpyxl
|
||||
seaborn
|
||||
tzlocal
|
||||
numpy
|
||||
plotly
|
||||
tiktoken
|
||||
# Discord Bot Core
|
||||
discord.py>=2.3.0
|
||||
openai>=1.40.0
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Database
|
||||
motor>=3.3.0
|
||||
pymongo[srv]>=4.6.0
|
||||
dnspython>=2.5.0
|
||||
|
||||
# Web & HTTP
|
||||
aiohttp>=3.9.0
|
||||
requests>=2.31.0
|
||||
beautifulsoup4>=4.12.0
|
||||
|
||||
# AI & ML
|
||||
runware>=0.4.33
|
||||
tiktoken>=0.7.0
|
||||
anthropic>=0.40.0
|
||||
|
||||
# Data Processing
|
||||
pandas>=2.1.0
|
||||
numpy>=1.26.0
|
||||
openpyxl>=3.1.0
|
||||
|
||||
# Visualization
|
||||
matplotlib>=3.8.0
|
||||
seaborn>=0.13.0
|
||||
plotly>=5.18.0
|
||||
|
||||
# Document Processing
|
||||
pypdf>=4.0.0
|
||||
Pillow>=10.0.0
|
||||
|
||||
# Scheduling & Time
|
||||
APScheduler>=3.10.0
|
||||
tzlocal>=5.2
|
||||
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
pytest-cov>=4.1.0
|
||||
pytest-mock>=3.12.0
|
||||
|
||||
# Code Quality
|
||||
ruff>=0.3.0
|
||||
|
||||
# Monitoring & Logging (Optional)
|
||||
# sentry-sdk>=1.40.0 # Uncomment for error monitoring
|
||||
# python-json-logger>=2.0.0 # Uncomment for structured logging
|
||||
@@ -7,38 +7,70 @@ import asyncio
|
||||
from typing import Optional, Dict, List, Any, Callable
|
||||
|
||||
from src.config.config import MODEL_OPTIONS, PDF_ALLOWED_MODELS, DEFAULT_MODEL
|
||||
from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost
|
||||
from src.utils.image_utils import ImageGenerator
|
||||
from src.utils.web_utils import google_custom_search, scrape_web_content
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
from src.utils.openai_utils import prepare_file_from_path
|
||||
from src.utils.token_counter import token_counter
|
||||
from src.utils.code_interpreter import delete_all_user_files
|
||||
|
||||
# Model pricing per 1M tokens (in USD)
|
||||
MODEL_PRICING = {
|
||||
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
|
||||
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
|
||||
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"openai/gpt-5": {"input": 1.25, "output": 10.00},
|
||||
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
|
||||
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
|
||||
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
|
||||
"openai/o1-preview": {"input": 15.00, "output": 60.00},
|
||||
"openai/o1-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o1": {"input": 15.00, "output": 60.00},
|
||||
"openai/o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o3": {"input": 2.00, "output": 8.00},
|
||||
"openai/o4-mini": {"input": 2.00, "output": 8.00}
|
||||
}
|
||||
from src.utils.discord_utils import create_info_embed, create_error_embed, create_success_embed
|
||||
from src.utils.claude_utils import is_claude_model, call_claude_api
|
||||
|
||||
# Dictionary to keep track of user requests and their cooldowns
|
||||
user_requests = {}
|
||||
user_requests: Dict[int, Dict[str, Any]] = {}
|
||||
# Dictionary to store user tasks
|
||||
user_tasks = {}
|
||||
user_tasks: Dict[int, List] = {}
|
||||
|
||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator):
|
||||
|
||||
# ============================================================
|
||||
# Autocomplete Functions
|
||||
# ============================================================
|
||||
|
||||
async def model_autocomplete(
|
||||
interaction: discord.Interaction,
|
||||
current: str,
|
||||
) -> List[app_commands.Choice[str]]:
|
||||
"""
|
||||
Autocomplete function for model selection.
|
||||
Provides filtered model suggestions based on user input.
|
||||
"""
|
||||
# Filter models based on current input
|
||||
matches = [
|
||||
model for model in MODEL_OPTIONS
|
||||
if current.lower() in model.lower()
|
||||
]
|
||||
|
||||
# If no matches, show all models
|
||||
if not matches:
|
||||
matches = MODEL_OPTIONS
|
||||
|
||||
# Return up to 25 choices (Discord limit)
|
||||
return [
|
||||
app_commands.Choice(name=model, value=model)
|
||||
for model in matches[:25]
|
||||
]
|
||||
|
||||
|
||||
async def image_model_autocomplete(
|
||||
interaction: discord.Interaction,
|
||||
current: str,
|
||||
) -> List[app_commands.Choice[str]]:
|
||||
"""
|
||||
Autocomplete function for image generation model selection.
|
||||
"""
|
||||
image_models = ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"]
|
||||
matches = [m for m in image_models if current.lower() in m.lower()]
|
||||
|
||||
if not matches:
|
||||
matches = image_models
|
||||
|
||||
return [
|
||||
app_commands.Choice(name=model, value=model)
|
||||
for model in matches[:25]
|
||||
]
|
||||
|
||||
def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator: ImageGenerator, anthropic_client=None):
|
||||
"""
|
||||
Set up all slash commands for the bot.
|
||||
|
||||
@@ -47,6 +79,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
db_handler: Database handler instance
|
||||
openai_client: OpenAI client instance
|
||||
image_generator: Image generator instance
|
||||
anthropic_client: Anthropic client instance (optional, for Claude models)
|
||||
"""
|
||||
tree = bot.tree
|
||||
|
||||
@@ -112,7 +145,7 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
@tree.command(name="choose_model", description="Select the AI model to use for responses.")
|
||||
@check_blacklist()
|
||||
async def choose_model(interaction: discord.Interaction):
|
||||
"""Lets users choose an AI model and saves it to the database."""
|
||||
"""Lets users choose an AI model using a dropdown menu."""
|
||||
options = [discord.SelectOption(label=model, value=model) for model in MODEL_OPTIONS]
|
||||
select_menu = discord.ui.Select(placeholder="Choose a model", options=options)
|
||||
|
||||
@@ -131,6 +164,43 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
view.add_item(select_menu)
|
||||
await interaction.response.send_message("Choose a model:", view=view, ephemeral=True)
|
||||
|
||||
@tree.command(name="set_model", description="Set AI model directly with autocomplete suggestions.")
|
||||
@app_commands.describe(model="The AI model to use (type to search)")
|
||||
@app_commands.autocomplete(model=model_autocomplete)
|
||||
@check_blacklist()
|
||||
async def set_model(interaction: discord.Interaction, model: str):
|
||||
"""Sets the AI model directly using autocomplete."""
|
||||
user_id = interaction.user.id
|
||||
|
||||
# Validate the model is in the allowed list
|
||||
if model not in MODEL_OPTIONS:
|
||||
# Find close matches for suggestions
|
||||
close_matches = [m for m in MODEL_OPTIONS if model.lower() in m.lower()]
|
||||
if close_matches:
|
||||
suggestions = ", ".join(f"`{m}`" for m in close_matches[:5])
|
||||
await interaction.response.send_message(
|
||||
f"❌ Invalid model `{model}`. Did you mean: {suggestions}?",
|
||||
ephemeral=True
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
f"❌ Invalid model `{model}`. Use `/choose_model` to see available options.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
# Save the model selection
|
||||
await db_handler.save_user_model(user_id, model)
|
||||
|
||||
# Get pricing info for the selected model
|
||||
pricing = MODEL_PRICING.get(model, {"input": 0, "output": 0})
|
||||
|
||||
await interaction.response.send_message(
|
||||
f"✅ Model set to `{model}`\n"
|
||||
f"💰 Pricing: ${pricing['input']:.2f}/1M input, ${pricing['output']:.2f}/1M output",
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
@tree.command(name="search", description="Search on Google and send results to AI model.")
|
||||
@app_commands.describe(query="The search query")
|
||||
@check_blacklist()
|
||||
@@ -197,24 +267,46 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
f"(text: {input_token_count['text_tokens']}, images: {input_token_count['image_tokens']})"
|
||||
)
|
||||
|
||||
# Send to the AI model
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.5
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
# Check if using Claude model
|
||||
if is_claude_model(model):
|
||||
if anthropic_client is None:
|
||||
await interaction.followup.send(
|
||||
"❌ Claude model not available. ANTHROPIC_API_KEY is not configured.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
# Use Claude API
|
||||
claude_response = await call_claude_api(
|
||||
anthropic_client,
|
||||
messages,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
use_tools=False
|
||||
)
|
||||
|
||||
reply = claude_response.get("content", "")
|
||||
actual_input_tokens = claude_response.get("input_tokens", 0)
|
||||
actual_output_tokens = claude_response.get("output_tokens", 0)
|
||||
else:
|
||||
# Send to the OpenAI model
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.5
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Get actual token usage from API response
|
||||
usage = response.usage
|
||||
actual_input_tokens = usage.prompt_tokens if usage else input_token_count['total_tokens']
|
||||
actual_output_tokens = usage.completion_tokens if usage else token_counter.count_text_tokens(reply, model)
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Get actual token usage from API response
|
||||
usage = response.usage
|
||||
actual_input_tokens = usage.prompt_tokens if usage else input_token_count['total_tokens']
|
||||
actual_output_tokens = usage.completion_tokens if usage else token_counter.count_text_tokens(reply, model)
|
||||
|
||||
# Calculate cost
|
||||
cost = token_counter.estimate_cost(actual_input_tokens, actual_output_tokens, model)
|
||||
@@ -294,19 +386,38 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
{"role": "user", "content": f"Content from {url}:\n{content}"}
|
||||
]
|
||||
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.3
|
||||
api_params["top_p"] = 0.7
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
# Check if using Claude model
|
||||
if is_claude_model(model):
|
||||
if anthropic_client is None:
|
||||
await interaction.followup.send(
|
||||
"❌ Claude model not available. ANTHROPIC_API_KEY is not configured.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
# Use Claude API
|
||||
claude_response = await call_claude_api(
|
||||
anthropic_client,
|
||||
messages,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
use_tools=False
|
||||
)
|
||||
reply = claude_response.get("content", "")
|
||||
else:
|
||||
api_params = {
|
||||
"model": model if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"] else "openai/gpt-4o",
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||
if model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
api_params["temperature"] = 0.3
|
||||
api_params["top_p"] = 0.7
|
||||
|
||||
response = await openai_client.chat.completions.create(**api_params)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add the interaction to history
|
||||
history.append({"role": "user", "content": f"Scraped content from {url}"})
|
||||
@@ -494,16 +605,22 @@ def setup_commands(bot: commands.Bot, db_handler, openai_client, image_generator
|
||||
async def help_command(interaction: discord.Interaction):
|
||||
"""Sends a list of available commands to the user."""
|
||||
help_message = (
|
||||
"**Available commands:**\n"
|
||||
"/choose_model - Select which AI model to use for responses (openai/gpt-4o, openai/gpt-4o-mini, openai/gpt-5, openai/gpt-5-nano, openai/gpt-5-mini, openai/gpt-5-chat, openai/o1-preview, openai/o1-mini).\n"
|
||||
"/search `<query>` - Search Google and send results to the AI model.\n"
|
||||
"/web `<url>` - Scrape a webpage and send the data to the AI model.\n"
|
||||
"/generate `<prompt>` - Generate an image from a text prompt.\n"
|
||||
"/toggle_tools - Toggle display of tool execution details (code, input, output).\n"
|
||||
"/reset - Reset your chat history and token usage statistics.\n"
|
||||
"/user_stat - Get information about your token usage, costs, and current model.\n"
|
||||
"/prices - Display pricing information for all available AI models.\n"
|
||||
"/help - Display this help message.\n"
|
||||
"**🤖 Available Commands:**\n\n"
|
||||
"**Model Selection:**\n"
|
||||
"• `/choose_model` - Select AI model from a dropdown menu\n"
|
||||
"• `/set_model <model>` - Set model directly with autocomplete\n\n"
|
||||
"**Search & Web:**\n"
|
||||
"• `/search <query>` - Search Google and analyze results with AI\n"
|
||||
"• `/web <url>` - Scrape and analyze a webpage\n\n"
|
||||
"**Image Generation:**\n"
|
||||
"• `/generate <prompt>` - Generate images from text\n\n"
|
||||
"**Settings & Stats:**\n"
|
||||
"• `/toggle_tools` - Toggle tool execution details display\n"
|
||||
"• `/user_stat` - View your token usage and costs\n"
|
||||
"• `/prices` - Display model pricing information\n"
|
||||
"• `/reset` - Clear your chat history and statistics\n\n"
|
||||
"**Help:**\n"
|
||||
"• `/help` - Display this help message\n"
|
||||
)
|
||||
await interaction.response.send_message(help_message, ephemeral=True)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ You have access to a powerful code interpreter environment that allows you to:
|
||||
- Execute Python code in a secure, isolated environment
|
||||
- Maximum execution time: 60 seconds
|
||||
- Output limit: 100KB
|
||||
- ⚠️ **IMPORTANT: Use print() to display results!** Only printed output is captured and shown to the user.
|
||||
|
||||
## 📦 **Package Management (Auto-Install)**
|
||||
The code interpreter can AUTOMATICALLY install missing packages when needed!
|
||||
@@ -43,18 +44,64 @@ import seaborn as sns # Will auto-install if missing
|
||||
import pandas as pd # Will auto-install if missing
|
||||
|
||||
df = pd.DataFrame({'x': [1,2,3], 'y': [4,5,6]})
|
||||
print(df) # ⚠️ Use print() to show output!
|
||||
sns.scatterplot(data=df, x='x', y='y')
|
||||
plt.savefig('plot.png')
|
||||
print("Chart saved!") # Confirm completion
|
||||
```
|
||||
|
||||
⚠️ **REMINDER: Only printed output is visible!** Always use print() for any data you want the user to see.
|
||||
|
||||
## 📁 **File Management (48-Hour Lifecycle)**
|
||||
|
||||
### **User-Uploaded Files**
|
||||
- Users can upload files (CSV, Excel, JSON, images, etc.)
|
||||
- Files are stored with unique `file_id`
|
||||
- Access files using: `df = load_file('file_id_here')`
|
||||
- Files expire after 48 hours automatically
|
||||
|
||||
### **CRITICAL: How to Load Files**
|
||||
|
||||
**Option 1: load_file() - Returns data directly (RECOMMENDED)**
|
||||
```python
|
||||
# For CSV files - returns DataFrame directly, DO NOT pass to pd.read_csv()!
|
||||
# ⚠️ Use the ACTUAL file_id from the upload message, NOT this example!
|
||||
df = load_file('<file_id_from_upload_message>')
|
||||
print(df.head()) # Works immediately!
|
||||
```
|
||||
|
||||
**Option 2: get_file_path() - Returns path for manual loading**
|
||||
```python
|
||||
# If you need the actual file path:
|
||||
path = get_file_path('<file_id_from_upload_message>')
|
||||
df = pd.read_csv(path)
|
||||
```
|
||||
|
||||
### **COMMON MISTAKES TO AVOID**
|
||||
```python
|
||||
# ❌ WRONG - load_file() returns a DataFrame, NOT a path!
|
||||
file_path = load_file('<file_id>')
|
||||
df = pd.read_csv(file_path) # ERROR: Cannot read DataFrame as CSV!
|
||||
|
||||
# ❌ WRONG - file_id is NOT a file path!
|
||||
df = pd.read_csv('<file_id>') # ERROR: File not found!
|
||||
|
||||
# ❌ WRONG - Using example IDs from documentation!
|
||||
df = load_file('example_from_docs') # ERROR: Use REAL file_id from upload!
|
||||
|
||||
# ✅ CORRECT - use load_file() with the ACTUAL file_id from upload message
|
||||
df = load_file('<file_id_from_upload_message>') # Copy exact ID from 📁 FILE UPLOADED
|
||||
print(df.head()) # ⚠️ Use print() to show output!
|
||||
print(df.describe())
|
||||
|
||||
# ✅ CORRECT - use get_file_path() if you need the path
|
||||
path = get_file_path('<file_id_from_upload_message>')
|
||||
df = pd.read_csv(path)
|
||||
print(df.info()) # Always print results!
|
||||
```
|
||||
|
||||
⚠️ CRITICAL: The file_id is shown in the conversation when a file is uploaded.
|
||||
Look for: "📁 FILE UPLOADED" or "df = load_file('...')" in recent messages!
|
||||
|
||||
### **Generated Files**
|
||||
- ANY file you create is captured and saved
|
||||
- Supported types: images, CSVs, text, JSON, HTML, PDFs, etc. (80+ formats)
|
||||
@@ -94,10 +141,14 @@ plt.savefig('plot.png')
|
||||
|
||||
**Load uploaded file:**
|
||||
```python
|
||||
# User uploaded 'sales_data.csv' with file_id: 'user_123_1234567890_abc123'
|
||||
df = load_file('user_123_1234567890_abc123')
|
||||
print(df.head())
|
||||
print(f"Loaded {len(df)} rows")
|
||||
# ⚠️ Find the ACTUAL file_id in the conversation's "📁 FILE UPLOADED" message!
|
||||
# DO NOT copy this example - use the real file_id shown when the user uploaded!
|
||||
df = load_file('<paste_actual_file_id_here>')
|
||||
|
||||
# ⚠️ CRITICAL: Always use print() to display results!
|
||||
print(df.head()) # Show first rows
|
||||
print(df.describe()) # Show statistics
|
||||
print(f"Loaded {len(df)} rows, {len(df.columns)} columns")
|
||||
```
|
||||
|
||||
**Create multiple output files:**
|
||||
|
||||
@@ -1,9 +1,34 @@
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# ==================== IMAGE CONFIGURATION ====================
|
||||
# Load image configuration from JSON file
|
||||
def load_image_config() -> dict:
|
||||
"""Load image configuration from JSON file"""
|
||||
config_paths = [
|
||||
Path(__file__).parent.parent.parent / "config" / "image_config.json",
|
||||
Path(__file__).parent.parent / "config" / "image_config.json",
|
||||
Path("config/image_config.json"),
|
||||
]
|
||||
|
||||
for config_path in config_paths:
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Warning: Error loading image config from {config_path}: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
# Load image config once at module import
|
||||
_IMAGE_CONFIG = load_image_config()
|
||||
|
||||
# Bot statuses
|
||||
STATUSES = [
|
||||
"Powered by openai/gpt-4o!",
|
||||
@@ -76,9 +101,68 @@ MODEL_OPTIONS = [
|
||||
"openai/o1",
|
||||
"openai/o3-mini",
|
||||
"openai/o3",
|
||||
"openai/o4-mini"
|
||||
"openai/o4-mini",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
"anthropic/claude-opus-4-20250514",
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"anthropic/claude-3.5-haiku",
|
||||
]
|
||||
|
||||
# ==================== IMAGE GENERATION MODELS ====================
|
||||
# Models are loaded from config/image_config.json
|
||||
# Edit that file to add/modify image models
|
||||
IMAGE_MODELS = _IMAGE_CONFIG.get("image_models", {
|
||||
"flux": {
|
||||
"model_id": "runware:101@1",
|
||||
"name": "FLUX.1",
|
||||
"description": "High-quality image generation with FLUX",
|
||||
"default_width": 1024,
|
||||
"default_height": 1024,
|
||||
"max_width": 2048,
|
||||
"max_height": 2048,
|
||||
"supports_negative_prompt": True
|
||||
}
|
||||
})
|
||||
|
||||
# Upscale models from config
|
||||
UPSCALE_MODELS = _IMAGE_CONFIG.get("upscale_models", {
|
||||
"clarity": {
|
||||
"model_id": "runware:500@1",
|
||||
"name": "Clarity",
|
||||
"supported_factors": [2, 4]
|
||||
}
|
||||
})
|
||||
|
||||
# Background removal models from config
|
||||
BACKGROUND_REMOVAL_MODELS = _IMAGE_CONFIG.get("background_removal_models", {
|
||||
"bria": {
|
||||
"model_id": "runware:110@1",
|
||||
"name": "Bria RMBG 2.0"
|
||||
}
|
||||
})
|
||||
|
||||
# Image settings from config
|
||||
IMAGE_SETTINGS = _IMAGE_CONFIG.get("settings", {
|
||||
"default_model": "flux",
|
||||
"default_upscale_model": "clarity",
|
||||
"default_background_removal_model": "bria"
|
||||
})
|
||||
|
||||
# Default image model
|
||||
DEFAULT_IMAGE_MODEL = IMAGE_SETTINGS.get("default_model", "flux")
|
||||
|
||||
# Default negative prompts by category
|
||||
DEFAULT_NEGATIVE_PROMPTS = _IMAGE_CONFIG.get("default_negative_prompts", {
|
||||
"general": "blurry, distorted, low quality, watermark, signature, text, bad anatomy, deformed"
|
||||
})
|
||||
|
||||
# Aspect ratios from config
|
||||
ASPECT_RATIOS = _IMAGE_CONFIG.get("aspect_ratios", {
|
||||
"1:1": {"width": 1024, "height": 1024},
|
||||
"16:9": {"width": 1344, "height": 768},
|
||||
"9:16": {"width": 768, "height": 1344}
|
||||
})
|
||||
|
||||
# Model-specific token limits for automatic history management
|
||||
MODEL_TOKEN_LIMITS = {
|
||||
"openai/o1-preview": 4000, # Conservative limit (max 4000)
|
||||
@@ -95,7 +179,12 @@ MODEL_TOKEN_LIMITS = {
|
||||
"openai/gpt-5": 4000,
|
||||
"openai/gpt-5-nano": 4000,
|
||||
"openai/gpt-5-mini": 4000,
|
||||
"openai/gpt-5-chat": 4000
|
||||
"openai/gpt-5-chat": 4000,
|
||||
# Claude models (200K context window, using conservative limits)
|
||||
"anthropic/claude-sonnet-4-20250514": 16000,
|
||||
"anthropic/claude-opus-4-20250514": 16000,
|
||||
"anthropic/claude-3.5-sonnet": 16000,
|
||||
"anthropic/claude-3.5-haiku": 16000,
|
||||
}
|
||||
|
||||
# Default token limit for unknown models
|
||||
@@ -104,7 +193,7 @@ DEFAULT_TOKEN_LIMIT = 4000
|
||||
# Default model for new users
|
||||
DEFAULT_MODEL = "openai/gpt-4.1"
|
||||
|
||||
PDF_ALLOWED_MODELS = ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-4.1","openai/gpt-4.1-nano","openai/gpt-4.1-mini"]
|
||||
PDF_ALLOWED_MODELS = ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-4.1","openai/gpt-4.1-nano","openai/gpt-4.1-mini", "anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514", "anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku"]
|
||||
PDF_BATCH_SIZE = 3
|
||||
|
||||
# Prompt templates
|
||||
@@ -112,101 +201,142 @@ WEB_SCRAPING_PROMPT = "Analyze webpage content and extract key information. Focu
|
||||
|
||||
NORMAL_CHAT_PROMPT = """You're ChatGPT for Discord. Be concise, helpful, safe. Reply in user's language. Use short paragraphs, bullets, minimal markdown.
|
||||
|
||||
Tools:
|
||||
- google_search: real-time info, fact-checking, news
|
||||
- scrape_webpage: extract/analyze webpage content
|
||||
- execute_python_code: Python code execution with AUTO-INSTALL packages & file access
|
||||
- image_suite: generate/edit/upscale/create portraits
|
||||
- reminders: schedule/retrieve user reminders
|
||||
- web_search_multi: parallel searches for comprehensive research
|
||||
TOOLS:
|
||||
1. google_search(query) - Web search for current info
|
||||
2. scrape_webpage(url) - Extract webpage content
|
||||
3. execute_python_code(code) - Run Python, packages auto-install. **FILE ACCESS: See critical instructions below!**
|
||||
4. set_reminder(content, time) / get_reminders() - Manage reminders
|
||||
|
||||
🐍 Code Interpreter (execute_python_code):
|
||||
⚠️ CRITICAL: Packages AUTO-INSTALL when imported! ALWAYS import what you need - installation is automatic.
|
||||
═══════════════════════════════════════════════════════════════
|
||||
⚠️ CRITICAL: FILE ACCESS IN CODE INTERPRETER
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
✅ Approved: pandas, numpy, matplotlib, seaborn, scikit-learn, tensorflow, pytorch, plotly, opencv, scipy, statsmodels, pillow, openpyxl, geopandas, folium, xgboost, lightgbm, bokeh, altair, and 80+ more.
|
||||
When users upload files, you will see a message like:
|
||||
📁 FILE UPLOADED - USE THIS FILE_ID:
|
||||
Filename: data.csv
|
||||
⚠️ TO ACCESS THIS FILE IN CODE, YOU MUST USE:
|
||||
df = load_file('<THE_ACTUAL_FILE_ID_FROM_CONTEXT>')
|
||||
|
||||
📂 File Access: When users upload files, you'll receive the file_id in the conversation context (e.g., "File ID: abc123_xyz"). Use load_file('file_id') to access them. The function auto-detects file types:
|
||||
- CSV/TSV → pandas DataFrame
|
||||
- Excel (.xlsx, .xls) → pandas ExcelFile object (use .sheet_names and .parse('Sheet1'))
|
||||
- JSON → dict or DataFrame
|
||||
- Images → PIL Image object
|
||||
- Text → string content
|
||||
- And 200+ more formats...
|
||||
**IMPORTANT: Copy the EXACT file_id from the file upload message - do NOT use examples!**
|
||||
|
||||
📊 Excel Files: load_file() returns ExcelFile object for multi-sheet support:
|
||||
excel_file = load_file('file_id')
|
||||
sheets = excel_file.sheet_names # Get all sheet names
|
||||
df = excel_file.parse('Sheet1') # Read specific sheet
|
||||
# Or: df = pd.read_excel(excel_file, sheet_name='Sheet1')
|
||||
# Check if sheet has data: if not df.empty and len(df.columns) > 0
|
||||
✅ CORRECT:
|
||||
df = load_file('<file_id_from_upload_message>')
|
||||
print(df.head()) # Use print() to show output!
|
||||
|
||||
⚠️ IMPORTANT: Always use print() to display results - code output is only captured via print()!
|
||||
|
||||
⚠️ IMPORTANT:
|
||||
- If load_file() fails, error lists available file IDs - use the correct one
|
||||
- Always check if DataFrames are empty before operations like .describe()
|
||||
- Excel files may have empty sheets - skip or handle them gracefully
|
||||
❌ WRONG - Using filename:
|
||||
df = pd.read_csv('data.csv') # FAILS - file not found!
|
||||
|
||||
❌ WRONG - Using example file_id from prompts:
|
||||
df = load_file('example_id_from_docs') # FAILS - use the REAL ID!
|
||||
|
||||
💾 Output Files: ALL generated files (CSV, images, JSON, text, plots, etc.) are AUTO-CAPTURED and sent to user. Files stored for 48h (configurable). Just create files - they're automatically shared!
|
||||
⚠️ CRITICAL: Look for the 📁 FILE UPLOADED message in this conversation and copy the EXACT file_id shown there!
|
||||
|
||||
✅ DO:
|
||||
- Import packages directly (auto-installs!)
|
||||
- Use load_file('file_id') with the EXACT file_id from context
|
||||
- Check if DataFrames are empty: if not df.empty and len(df.columns) > 0
|
||||
- Handle errors gracefully (empty sheets, missing data, etc.)
|
||||
- Create output files with descriptive names
|
||||
- Generate visualizations (plt.savefig, etc.)
|
||||
- Return multiple files (data + plots + reports)
|
||||
═══════════════════════════════════════════════════════════════
|
||||
IMAGE GENERATION & EDITING TOOLS
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
❌ DON'T:
|
||||
- Check if packages are installed
|
||||
- Use install_packages parameter
|
||||
- Print large datasets (create CSV instead)
|
||||
- Manually handle file paths
|
||||
- Guess file_ids - use the exact ID from the upload message
|
||||
5. generate_image(prompt, model, num_images, width, height, aspect_ratio, negative_prompt, steps, cfg_scale, seed)
|
||||
Create images from text descriptions.
|
||||
|
||||
MODELS (use model parameter):
|
||||
• "flux" - FLUX.1 (default, best quality, 1024x1024)
|
||||
• "flux-dev" - FLUX.1 Dev (more creative outputs)
|
||||
• "sdxl" - Stable Diffusion XL (detailed, high-res)
|
||||
• "realistic" - Realistic Vision (photorealistic)
|
||||
• "anime" - Anime/illustration style
|
||||
• "dreamshaper" - Creative/artistic style
|
||||
|
||||
ASPECT RATIOS (use aspect_ratio parameter):
|
||||
• "1:1" - Square (1024x1024)
|
||||
• "16:9" - Landscape wide (1344x768)
|
||||
• "9:16" - Portrait tall (768x1344)
|
||||
• "4:3" - Landscape (1152x896)
|
||||
• "3:4" - Portrait (896x1152)
|
||||
• "3:2" - Photo landscape (1248x832)
|
||||
• "2:3" - Photo portrait (832x1248)
|
||||
• "21:9" - Ultrawide (1536x640)
|
||||
|
||||
Examples:
|
||||
generate_image("a dragon in a forest", "flux", 1)
|
||||
generate_image({"prompt": "sunset beach", "model": "realistic", "aspect_ratio": "16:9"})
|
||||
generate_image({"prompt": "anime girl", "model": "anime", "width": 768, "height": 1024})
|
||||
|
||||
Example:
|
||||
```python
|
||||
import pandas as pd
|
||||
import seaborn as sns # Auto-installs!
|
||||
import matplotlib.pyplot as plt
|
||||
6. generate_image_with_refiner(prompt, model, num_images)
|
||||
Generate high-quality images using SDXL with refiner for better details.
|
||||
Best for: detailed artwork, complex scenes
|
||||
Example: generate_image_with_refiner("detailed fantasy castle", "sdxl", 1)
|
||||
|
||||
# Load user's file (file_id from upload message: "File ID: 123456_abc")
|
||||
data = load_file('123456_abc') # Auto-detects type
|
||||
7. upscale_image(image_url, scale_factor, model)
|
||||
Enlarge images to higher resolution.
|
||||
|
||||
UPSCALE MODELS:
|
||||
• "clarity" - High-quality clarity upscaling (default)
|
||||
• "ccsr" - Content-consistent super-resolution
|
||||
• "sd-latent" - SD latent space upscaling
|
||||
• "swinir" - Fast SwinIR (supports 4x)
|
||||
|
||||
SCALE FACTORS: 2 or 4 (depending on model)
|
||||
|
||||
Requires: User must provide an image URL first
|
||||
Example: upscale_image("https://example.com/image.jpg", 2, "clarity")
|
||||
|
||||
# For Excel files:
|
||||
if hasattr(data, 'sheet_names'): # It's an ExcelFile
|
||||
for sheet in data.sheet_names:
|
||||
df = data.parse(sheet)
|
||||
if not df.empty and len(df.columns) > 0:
|
||||
# Process non-empty sheets
|
||||
summary = df.describe()
|
||||
summary.to_csv(f'{sheet}_summary.csv')
|
||||
else: # It's already a DataFrame (CSV, etc.)
|
||||
df = data
|
||||
summary = df.describe()
|
||||
summary.to_csv('summary_stats.csv')
|
||||
8. remove_background(image_url, model) / edit_image(image_url, "remove_background")
|
||||
Remove background from images (outputs PNG with transparency).
|
||||
|
||||
BACKGROUND REMOVAL MODELS:
|
||||
• "bria" - Bria RMBG 2.0 (default, high quality)
|
||||
• "rembg" - RemBG 1.4 (classic, supports alpha matting)
|
||||
• "birefnet-base" - BiRefNet base model
|
||||
• "birefnet-general" - BiRefNet general purpose
|
||||
• "birefnet-portrait" - BiRefNet optimized for portraits
|
||||
|
||||
Requires: User must provide an image URL first
|
||||
Example: remove_background("https://example.com/photo.jpg", "bria")
|
||||
|
||||
# Create visualization
|
||||
if not df.empty:
|
||||
sns.heatmap(df.corr(), annot=True)
|
||||
plt.savefig('correlation_plot.png')
|
||||
9. photo_maker(prompt, input_images, style, strength, num_images)
|
||||
Generate images based on reference photos (identity preservation).
|
||||
|
||||
Parameters:
|
||||
• prompt: Text description of desired output
|
||||
• input_images: List of reference image URLs
|
||||
• style: Style to apply (default: "No style")
|
||||
• strength: Reference influence 0-100 (default: 40)
|
||||
|
||||
Requires: User must provide reference images first
|
||||
Example: photo_maker({"prompt": "professional headshot", "input_images": ["url1", "url2"], "style": "Photographic"})
|
||||
|
||||
# Everything is automatically sent to user!
|
||||
```
|
||||
10. image_to_text(image_url)
|
||||
Generate text description/caption from an image.
|
||||
Use for: Understanding image content, accessibility, OCR-like tasks
|
||||
Example: image_to_text("https://example.com/image.jpg")
|
||||
|
||||
Smart Usage:
|
||||
- Chain tools: search→scrape→analyze for deep research
|
||||
- Auto-suggest relevant tools based on user intent
|
||||
- Create multiple outputs (CSV, plots, reports) in one execution
|
||||
- Use execute_python_code for ALL data analysis (replaces old analyze_data_file tool)
|
||||
11. enhance_prompt(prompt, num_versions, max_length)
|
||||
Improve prompts for better image generation results.
|
||||
Returns multiple enhanced versions of your prompt.
|
||||
Example: enhance_prompt("cat on roof", 3, 200)
|
||||
|
||||
Rules:
|
||||
- One clarifying question if ambiguous
|
||||
- Prioritize answers over details
|
||||
- Cite sources: (Title – URL)
|
||||
- Use execute_python_code for complex math & data analysis
|
||||
- Never invent sources
|
||||
- Code fences for equations (no LaTeX)
|
||||
- Return image URLs with brief descriptions"""
|
||||
═══════════════════════════════════════════════════════════════
|
||||
USAGE GUIDELINES
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
WHEN TO USE EACH TOOL:
|
||||
• "create/draw/generate/make an image of X" → generate_image
|
||||
• "high quality/detailed image" → generate_image_with_refiner
|
||||
• "remove/delete background" → remove_background (pass 'latest_image')
|
||||
• "make image bigger/larger/upscale" → upscale_image (pass 'latest_image')
|
||||
• "create image like this/based on this photo" → photo_maker (pass ['latest_image'])
|
||||
• "what's in this image/describe image" → image_to_text (pass 'latest_image')
|
||||
• "improve this prompt" → enhance_prompt
|
||||
|
||||
IMPORTANT NOTES:
|
||||
• For image tools (upscale, remove_background, photo_maker, image_to_text), when user uploads an image, pass 'latest_image' as the image_url parameter - the system automatically uses their most recent uploaded image
|
||||
• You don't need to extract or copy image URLs - just use 'latest_image'
|
||||
• Default model is "flux" - best for general use
|
||||
• Use "realistic" for photos, "anime" for illustrations
|
||||
• For math/data analysis → use execute_python_code instead
|
||||
• Always cite sources (Title–URL) when searching web"""
|
||||
|
||||
SEARCH_PROMPT = "Research Assistant with Google Search access. Synthesize search results into accurate answers. Prioritize credible sources, compare perspectives, acknowledge limitations, cite sources. Structure responses logically."
|
||||
|
||||
@@ -282,6 +412,7 @@ RUNWARE_API_KEY = os.getenv("RUNWARE_API_KEY")
|
||||
MONGODB_URI = os.getenv("MONGODB_URI")
|
||||
ADMIN_ID = os.getenv("ADMIN_ID") # Add ADMIN_ID if you're using it
|
||||
TIMEZONE = os.getenv("TIMEZONE", "UTC") # Default to UTC if not specified
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") # Anthropic API key for Claude models
|
||||
|
||||
# File management settings
|
||||
FILE_EXPIRATION_HOURS = int(os.getenv("FILE_EXPIRATION_HOURS", "48")) # Hours until files expire (-1 for never)
|
||||
@@ -295,5 +426,7 @@ if not MONGODB_URI:
|
||||
print("WARNING: MONGODB_URI not found in .env file")
|
||||
if not RUNWARE_API_KEY:
|
||||
print("WARNING: RUNWARE_API_KEY not found in .env file")
|
||||
if not ANTHROPIC_API_KEY:
|
||||
print("INFO: ANTHROPIC_API_KEY not found in .env file - Claude models will not be available")
|
||||
if ENABLE_WEBHOOK_LOGGING and not LOGGING_WEBHOOK_URL:
|
||||
print("WARNING: Webhook logging enabled but LOGGING_WEBHOOK_URL not found in .env file")
|
||||
108
src/config/pricing.py
Normal file
108
src/config/pricing.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Centralized pricing configuration for OpenAI models.
|
||||
|
||||
This module provides a single source of truth for model pricing,
|
||||
eliminating duplication across the codebase.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPricing:
|
||||
"""Pricing information for a model (per 1M tokens in USD)."""
|
||||
input: float
|
||||
output: float
|
||||
|
||||
def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
|
||||
"""Calculate total cost for given token counts."""
|
||||
input_cost = (input_tokens / 1_000_000) * self.input
|
||||
output_cost = (output_tokens / 1_000_000) * self.output
|
||||
return input_cost + output_cost
|
||||
|
||||
|
||||
# Model pricing per 1M tokens (in USD)
|
||||
# Centralized location - update prices here only
|
||||
MODEL_PRICING: Dict[str, ModelPricing] = {
|
||||
# GPT-4o Family
|
||||
"openai/gpt-4o": ModelPricing(input=5.00, output=20.00),
|
||||
"openai/gpt-4o-mini": ModelPricing(input=0.60, output=2.40),
|
||||
|
||||
# GPT-4.1 Family
|
||||
"openai/gpt-4.1": ModelPricing(input=2.00, output=8.00),
|
||||
"openai/gpt-4.1-mini": ModelPricing(input=0.40, output=1.60),
|
||||
"openai/gpt-4.1-nano": ModelPricing(input=0.10, output=0.40),
|
||||
|
||||
# GPT-5 Family
|
||||
"openai/gpt-5": ModelPricing(input=1.25, output=10.00),
|
||||
"openai/gpt-5-mini": ModelPricing(input=0.25, output=2.00),
|
||||
"openai/gpt-5-nano": ModelPricing(input=0.05, output=0.40),
|
||||
"openai/gpt-5-chat": ModelPricing(input=1.25, output=10.00),
|
||||
|
||||
# o1 Family (Reasoning models)
|
||||
"openai/o1-preview": ModelPricing(input=15.00, output=60.00),
|
||||
"openai/o1-mini": ModelPricing(input=1.10, output=4.40),
|
||||
"openai/o1": ModelPricing(input=15.00, output=60.00),
|
||||
|
||||
# o3 Family
|
||||
"openai/o3-mini": ModelPricing(input=1.10, output=4.40),
|
||||
"openai/o3": ModelPricing(input=2.00, output=8.00),
|
||||
|
||||
# o4 Family
|
||||
"openai/o4-mini": ModelPricing(input=2.00, output=8.00),
|
||||
|
||||
# Claude 4 Family (Anthropic - latest models)
|
||||
"anthropic/claude-sonnet-4-20250514": ModelPricing(input=3.00, output=15.00),
|
||||
"anthropic/claude-opus-4-20250514": ModelPricing(input=15.00, output=75.00),
|
||||
|
||||
# Claude 3.5 Family (Anthropic)
|
||||
"anthropic/claude-3.5-sonnet": ModelPricing(input=3.00, output=15.00),
|
||||
"anthropic/claude-3.5-haiku": ModelPricing(input=0.80, output=4.00),
|
||||
}
|
||||
|
||||
|
||||
def get_model_pricing(model: str) -> Optional[ModelPricing]:
|
||||
"""
|
||||
Get pricing for a specific model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "openai/gpt-4o")
|
||||
|
||||
Returns:
|
||||
ModelPricing object or None if model not found
|
||||
"""
|
||||
return MODEL_PRICING.get(model)
|
||||
|
||||
|
||||
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""
|
||||
Calculate the cost for a given model and token counts.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
input_tokens: Number of input tokens
|
||||
output_tokens: Number of output tokens
|
||||
|
||||
Returns:
|
||||
Total cost in USD, or 0.0 if model not found
|
||||
"""
|
||||
pricing = get_model_pricing(model)
|
||||
if pricing:
|
||||
return pricing.calculate_cost(input_tokens, output_tokens)
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_all_models() -> list:
|
||||
"""Get list of all available models with pricing."""
|
||||
return list(MODEL_PRICING.keys())
|
||||
|
||||
|
||||
def format_cost(cost: float) -> str:
|
||||
"""Format cost for display."""
|
||||
if cost < 0.01:
|
||||
return f"${cost:.6f}"
|
||||
elif cost < 1.00:
|
||||
return f"${cost:.4f}"
|
||||
else:
|
||||
return f"${cost:.2f}"
|
||||
@@ -5,21 +5,43 @@ import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
|
||||
# Configure DNS resolver to be more resilient
|
||||
try:
|
||||
import dns.resolver
|
||||
dns.resolver.default_resolver = dns.resolver.Resolver(configure=False)
|
||||
dns.resolver.default_resolver.nameservers = ['8.8.8.8', '8.8.4.4', '1.1.1.1']
|
||||
dns.resolver.default_resolver.lifetime = 15.0 # 15 second timeout for DNS
|
||||
except ImportError:
|
||||
logging.warning("dnspython not installed, using system DNS resolver")
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not configure custom DNS resolver: {e}")
|
||||
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongodb_uri: str):
|
||||
"""Initialize database connection with optimized settings"""
|
||||
# Set up a memory-optimized connection pool
|
||||
def __init__(self, mongodb_uri: str, max_retries: int = 5):
|
||||
"""Initialize database connection with optimized settings and retry logic"""
|
||||
self.mongodb_uri = mongodb_uri
|
||||
self.max_retries = max_retries
|
||||
self._connected = False
|
||||
self._connection_lock = asyncio.Lock()
|
||||
|
||||
# Set up a memory-optimized connection pool with better resilience
|
||||
self.client = AsyncIOMotorClient(
|
||||
mongodb_uri,
|
||||
maxIdleTimeMS=30000, # Reduced from 45000
|
||||
connectTimeoutMS=8000, # Reduced from 10000
|
||||
serverSelectionTimeoutMS=12000, # Reduced from 15000
|
||||
waitQueueTimeoutMS=3000, # Reduced from 5000
|
||||
socketTimeoutMS=25000, # Reduced from 30000
|
||||
maxPoolSize=8, # Limit connection pool size
|
||||
minPoolSize=2, # Maintain minimum connections
|
||||
retryWrites=True
|
||||
maxIdleTimeMS=45000, # Keep connections alive longer
|
||||
connectTimeoutMS=20000, # 20s connect timeout for DNS issues
|
||||
serverSelectionTimeoutMS=30000, # 30s for server selection
|
||||
waitQueueTimeoutMS=10000, # Wait longer for available connection
|
||||
socketTimeoutMS=45000, # Socket operations timeout
|
||||
maxPoolSize=10, # Slightly larger pool
|
||||
minPoolSize=1, # Keep at least 1 connection
|
||||
retryWrites=True,
|
||||
retryReads=True, # Also retry reads
|
||||
directConnection=False, # Allow replica set discovery
|
||||
appName="ChatGPT-Discord-Bot",
|
||||
heartbeatFrequencyMS=30000, # Reduce heartbeat frequency to avoid DNS issues
|
||||
localThresholdMS=30, # Local threshold for selecting servers
|
||||
)
|
||||
self.db = self.client['chatgpt_discord_bot'] # Database name
|
||||
|
||||
@@ -32,12 +54,86 @@ class DatabaseHandler:
|
||||
self.logs_collection = self.db.logs
|
||||
self.reminders_collection = self.db.reminders
|
||||
|
||||
logging.info("Database handler initialized")
|
||||
logging.info("Database handler initialized with enhanced connection resilience")
|
||||
|
||||
async def _retry_operation(self, operation, *args, **kwargs):
|
||||
"""Execute a database operation with retry logic for transient errors"""
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await operation(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
# Check for transient/retryable errors (expanded list)
|
||||
retryable_errors = [
|
||||
'timeout', 'connection', 'socket', 'dns', 'try again',
|
||||
'network', 'errno -3', 'gaierror', 'nodename', 'servname',
|
||||
'temporary failure', 'name resolution', 'unreachable',
|
||||
'reset by peer', 'broken pipe', 'not connected'
|
||||
]
|
||||
if any(err in error_str for err in retryable_errors):
|
||||
wait_time = min((attempt + 1) * 2, 10) # Exponential backoff: 2s, 4s, 6s, 8s, 10s (max)
|
||||
logging.warning(f"Database operation failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
# Non-retryable error, raise immediately
|
||||
raise
|
||||
# All retries exhausted
|
||||
logging.error(f"Database operation failed after {self.max_retries} attempts: {last_error}")
|
||||
raise last_error
|
||||
|
||||
async def ensure_connected(self) -> bool:
|
||||
"""Ensure database connection is established with retry logic"""
|
||||
async with self._connection_lock:
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
await self.client.admin.command('ping')
|
||||
self._connected = True
|
||||
logging.info("Database connection established successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
wait_time = min((attempt + 1) * 2, 10)
|
||||
logging.warning(f"Database connection attempt {attempt + 1}/{self.max_retries} failed: {e}. Retrying in {wait_time}s...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
logging.error("Failed to establish database connection after all retries")
|
||||
return False
|
||||
|
||||
async def check_connection(self) -> bool:
|
||||
"""Check if database connection is alive with graceful error handling"""
|
||||
try:
|
||||
# Use a short timeout for the ping operation
|
||||
await asyncio.wait_for(
|
||||
self.client.admin.command('ping'),
|
||||
timeout=10.0
|
||||
)
|
||||
self._connected = True
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("Database ping timed out")
|
||||
self._connected = False
|
||||
return False
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
# Don't log DNS resolution failures as errors (they're often transient)
|
||||
if any(err in error_str for err in ['errno -3', 'try again', 'dns', 'gaierror']):
|
||||
logging.debug(f"Transient database connection check failed (DNS): {e}")
|
||||
else:
|
||||
logging.error(f"Database connection check failed: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
# User history methods
|
||||
async def get_history(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get user conversation history and filter expired image links"""
|
||||
user_data = await self.db.user_histories.find_one({'user_id': user_id})
|
||||
async def _get():
|
||||
return await self.db.user_histories.find_one({'user_id': user_id})
|
||||
|
||||
user_data = await self._retry_operation(_get)
|
||||
if user_data and 'history' in user_data:
|
||||
# Filter out expired image links
|
||||
filtered_history = self._filter_expired_images(user_data['history'])
|
||||
@@ -64,7 +160,12 @@ class DatabaseHandler:
|
||||
return []
|
||||
|
||||
def _filter_expired_images(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filter out image links that are older than 23 hours"""
|
||||
"""
|
||||
Filter out image links that are older than 23 hours.
|
||||
|
||||
Properly handles timezone-aware and timezone-naive datetime comparisons
|
||||
to prevent issues with ISO string parsing.
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
expiration_time = current_time - timedelta(hours=23)
|
||||
|
||||
@@ -87,11 +188,27 @@ class DatabaseHandler:
|
||||
# Check image items for timestamp
|
||||
elif item.get('type') == 'image_url':
|
||||
# If there's no timestamp or timestamp is newer than expiration time, keep it
|
||||
timestamp = item.get('timestamp')
|
||||
if not timestamp or datetime.fromisoformat(timestamp) > expiration_time:
|
||||
timestamp_str = item.get('timestamp')
|
||||
if not timestamp_str:
|
||||
# No timestamp, keep the image
|
||||
filtered_content.append(item)
|
||||
else:
|
||||
logging.info(f"Filtering out expired image URL (added at {timestamp})")
|
||||
try:
|
||||
# Parse the ISO timestamp, handling both timezone-aware and naive
|
||||
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
||||
|
||||
# Make comparison timezone-naive for consistency
|
||||
if timestamp.tzinfo is not None:
|
||||
timestamp = timestamp.replace(tzinfo=None)
|
||||
|
||||
if timestamp > expiration_time:
|
||||
filtered_content.append(item)
|
||||
else:
|
||||
logging.debug(f"Filtering out expired image URL (added at {timestamp_str})")
|
||||
except (ValueError, AttributeError) as e:
|
||||
# If we can't parse the timestamp, keep the image to be safe
|
||||
logging.warning(f"Could not parse image timestamp '{timestamp_str}': {e}")
|
||||
filtered_content.append(item)
|
||||
|
||||
# Update the message with filtered content
|
||||
if filtered_content:
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import time
|
||||
import functools
|
||||
import concurrent.futures
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
import io
|
||||
import aiohttp
|
||||
import os
|
||||
@@ -15,36 +15,21 @@ import base64
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from src.utils.openai_utils import process_tool_calls, prepare_messages_for_api, get_tools_for_model
|
||||
from src.utils.claude_utils import is_claude_model, call_claude_api, convert_claude_tool_calls_to_openai
|
||||
from src.utils.pdf_utils import process_pdf, send_response
|
||||
from src.utils.code_utils import extract_code_blocks
|
||||
from src.utils.reminder_utils import ReminderManager
|
||||
from src.config.config import PDF_ALLOWED_MODELS, MODEL_TOKEN_LIMITS, DEFAULT_TOKEN_LIMIT, DEFAULT_MODEL
|
||||
from src.config.pricing import MODEL_PRICING, calculate_cost, format_cost
|
||||
from src.utils.validators import validate_message_content, validate_prompt, sanitize_for_logging
|
||||
from src.utils.discord_utils import send_long_message, create_error_embed, create_progress_embed
|
||||
|
||||
# Global task and rate limiting tracking
|
||||
user_tasks = {}
|
||||
user_last_request = {}
|
||||
user_tasks: Dict[int, Dict] = {}
|
||||
user_last_request: Dict[int, List[float]] = {}
|
||||
RATE_LIMIT_WINDOW = 5 # seconds
|
||||
MAX_REQUESTS = 3 # max requests per window
|
||||
|
||||
# Model pricing per 1M tokens (in USD)
|
||||
MODEL_PRICING = {
|
||||
"openai/gpt-4o": {"input": 5.00, "output": 20.00},
|
||||
"openai/gpt-4o-mini": {"input": 0.60, "output": 2.40},
|
||||
"openai/gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"openai/gpt-5": {"input": 1.25, "output": 10.00},
|
||||
"openai/gpt-5-mini": {"input": 0.25, "output": 2.00},
|
||||
"openai/gpt-5-nano": {"input": 0.05, "output": 0.40},
|
||||
"openai/gpt-5-chat": {"input": 1.25, "output": 10.00},
|
||||
"openai/o1-preview": {"input": 15.00, "output": 60.00},
|
||||
"openai/o1-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o1": {"input": 15.00, "output": 60.00},
|
||||
"openai/o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"openai/o3": {"input": 2.00, "output": 8.00},
|
||||
"openai/o4-mini": {"input": 2.00, "output": 8.00}
|
||||
}
|
||||
|
||||
# File extensions that should be treated as text files
|
||||
TEXT_FILE_EXTENSIONS = [
|
||||
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm', '.css',
|
||||
@@ -111,7 +96,7 @@ except ImportError as e:
|
||||
logging.warning(f"Data analysis libraries not available: {str(e)}")
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self, bot, db_handler, openai_client, image_generator):
|
||||
def __init__(self, bot, db_handler, openai_client, image_generator, anthropic_client=None):
|
||||
"""
|
||||
Initialize the message handler.
|
||||
|
||||
@@ -120,10 +105,12 @@ class MessageHandler:
|
||||
db_handler: Database handler instance
|
||||
openai_client: OpenAI client instance
|
||||
image_generator: Image generator instance
|
||||
anthropic_client: Anthropic client instance (optional, for Claude models)
|
||||
"""
|
||||
self.bot = bot
|
||||
self.db = db_handler
|
||||
self.client = openai_client
|
||||
self.anthropic_client = anthropic_client
|
||||
self.image_generator = image_generator
|
||||
self.aiohttp_session = None
|
||||
|
||||
@@ -135,6 +122,9 @@ class MessageHandler:
|
||||
self.user_charts = {} # Will be cleaned up periodically
|
||||
self.max_user_files = 20 # Limit concurrent user files
|
||||
|
||||
# Store latest image URL per user (in-memory, refreshed from attachments)
|
||||
self.user_latest_image_url = {}
|
||||
|
||||
# Tool mapping for API integration
|
||||
self.tool_mapping = {
|
||||
"google_search": self._google_search,
|
||||
@@ -142,6 +132,7 @@ class MessageHandler:
|
||||
"execute_python_code": self._execute_python_code,
|
||||
"generate_image": self._generate_image,
|
||||
"edit_image": self._edit_image,
|
||||
"remove_background": self._remove_background,
|
||||
"set_reminder": self._set_reminder,
|
||||
"get_reminders": self._get_reminders,
|
||||
"enhance_prompt": self._enhance_prompt,
|
||||
@@ -184,6 +175,26 @@ class MessageHandler:
|
||||
logging.warning(f"Failed to initialize tiktoken encoder: {e}")
|
||||
self.token_encoder = None
|
||||
|
||||
def _build_claude_tool_result_message(self, tool_call_id: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a tool result message for Claude API.
|
||||
|
||||
Args:
|
||||
tool_call_id: The ID of the tool call this result is for
|
||||
content: The result content from the tool execution
|
||||
|
||||
Returns:
|
||||
Dict: A properly formatted Claude tool result message
|
||||
"""
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content
|
||||
}]
|
||||
}
|
||||
|
||||
def _find_user_id_from_current_task(self):
|
||||
"""
|
||||
Utility method to find user_id from the current asyncio task.
|
||||
@@ -200,6 +211,28 @@ class MessageHandler:
|
||||
return user_id
|
||||
return None
|
||||
|
||||
async def _get_latest_image_url_from_db(self, user_id: int) -> str:
|
||||
"""Get the latest valid image URL from user's history in database"""
|
||||
try:
|
||||
# Get history from database (already filtered for expired images)
|
||||
history = await self.db.get_history(user_id)
|
||||
|
||||
# Find the latest image URL by iterating in reverse
|
||||
for msg in reversed(history):
|
||||
content = msg.get('content')
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content):
|
||||
if item.get('type') == 'image_url':
|
||||
image_url_data = item.get('image_url', {})
|
||||
url = image_url_data.get('url') if isinstance(image_url_data, dict) else None
|
||||
if url:
|
||||
logging.info(f"Found latest image URL from database: {url[:80]}...")
|
||||
return url
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting latest image URL from database: {e}")
|
||||
return None
|
||||
|
||||
def _count_tokens_with_tiktoken(self, text: str) -> int:
|
||||
"""Count tokens using tiktoken encoder for internal operations."""
|
||||
if self.token_encoder is None:
|
||||
@@ -1212,11 +1245,20 @@ print("\\n=== Correlation Analysis ===")
|
||||
user_message = message.content.strip() if message.content else ""
|
||||
|
||||
file_context = (
|
||||
f"\n\n[User uploaded file: {filename}]\n"
|
||||
f"[File ID: {file_id}]\n"
|
||||
f"[File Type: {file_type}]\n"
|
||||
f"[Size: {size_str}]\n"
|
||||
f"[Available in code_interpreter via: load_file('{file_id}')]\n"
|
||||
f"\n\n══════════════════════════════════════════════\n"
|
||||
f"📁 FILE UPLOADED - USE THIS FILE_ID:\n"
|
||||
f"══════════════════════════════════════════════\n"
|
||||
f"Filename: {filename}\n"
|
||||
f"File Type: {file_type}\n"
|
||||
f"Size: {size_str}\n"
|
||||
f"\n"
|
||||
f"⚠️ TO ACCESS THIS FILE IN CODE, YOU MUST USE:\n"
|
||||
f" df = load_file('{file_id}')\n"
|
||||
f"\n"
|
||||
f"❌ DO NOT use the filename directly (e.g., pd.read_csv('{filename}'))\n"
|
||||
f"❌ DO NOT use file_id as a path (e.g., pd.read_csv('{file_id}'))\n"
|
||||
f"✅ ONLY use: load_file('{file_id}')\n"
|
||||
f"══════════════════════════════════════════════\n"
|
||||
)
|
||||
|
||||
if user_message:
|
||||
@@ -1366,13 +1408,16 @@ print("\\n=== Correlation Analysis ===")
|
||||
content.append({"type": "text", "text": f"[Error processing {attachment.filename}: {str(e)}]"})
|
||||
|
||||
elif any(attachment.filename.endswith(ext) for ext in ['.png', '.jpg', '.jpeg', '.gif', '.webp']):
|
||||
# Store latest image URL for this user
|
||||
self.user_latest_image_url[user_id] = attachment.url
|
||||
logging.info(f"Stored latest image URL for user {user_id}")
|
||||
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": attachment.url,
|
||||
"detail": "high"
|
||||
},
|
||||
"timestamp": datetime.now().isoformat() # Add timestamp to track image expiration
|
||||
}
|
||||
})
|
||||
else:
|
||||
content.append({"type": "text", "text": f"[Attachment: {attachment.filename}] - I can't process this type of file directly."})
|
||||
@@ -1492,7 +1537,14 @@ print("\\n=== Correlation Analysis ===")
|
||||
|
||||
# Determine which models should have tools available
|
||||
# openai/o1-mini and openai/o1-preview do not support tools
|
||||
use_tools = model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "openai/o3", "openai/o4-mini"]
|
||||
# Claude models support tools
|
||||
use_tools = model in [
|
||||
"openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano",
|
||||
"openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o3-mini",
|
||||
"openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "openai/o3", "openai/o4-mini",
|
||||
"anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514",
|
||||
"anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku"
|
||||
]
|
||||
|
||||
# Count tokens being sent to API
|
||||
total_content_length = 0
|
||||
@@ -1513,181 +1565,310 @@ print("\\n=== Correlation Analysis ===")
|
||||
logging.info(f"API Request Debug - Model: {model}, Messages: {len(messages_for_api)}, "
|
||||
f"Est. tokens: {estimated_tokens}, Content length: {total_content_length} chars")
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": messages_for_api,
|
||||
"timeout": 240 # Increased timeout for better response handling
|
||||
}
|
||||
|
||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
||||
api_params["temperature"] = 0.3
|
||||
api_params["top_p"] = 0.7
|
||||
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
# For other models (not GPT-4o family and not GPT-5 family)
|
||||
api_params["temperature"] = 1
|
||||
api_params["top_p"] = 1
|
||||
|
||||
# Add tools if using a supported model
|
||||
if use_tools:
|
||||
tools = get_tools_for_model()
|
||||
api_params["tools"] = tools
|
||||
|
||||
# Initialize variables to track tool responses
|
||||
image_generation_used = False
|
||||
chart_id = None
|
||||
image_urls = [] # Will store unique image URLs
|
||||
|
||||
# Make the initial API call
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
except Exception as e:
|
||||
# Handle 413 Request Entity Too Large error with a user-friendly message
|
||||
if "413" in str(e) or "tokens_limit_reached" in str(e) or "Request body too large" in str(e):
|
||||
# Check if this is a Claude model
|
||||
if is_claude_model(model):
|
||||
# Use Claude API
|
||||
if self.anthropic_client is None:
|
||||
await message.channel.send(
|
||||
f"❌ **Request too large for {model}**\n"
|
||||
f"Your conversation history or message is too large for this model.\n"
|
||||
f"Try:\n"
|
||||
f"• Using `/reset` to start fresh\n"
|
||||
f"• Using a model with higher token limits\n"
|
||||
f"• Reducing the size of your current message\n"
|
||||
f"• Breaking up large files into smaller pieces"
|
||||
f"❌ **Claude model not available**\n"
|
||||
f"The Anthropic API key is not configured. Please set ANTHROPIC_API_KEY in your .env file."
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Re-raise other errors
|
||||
raise e
|
||||
|
||||
# Extract token usage and calculate cost
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
total_cost = 0.0
|
||||
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||
|
||||
# Calculate cost based on model pricing
|
||||
if model in MODEL_PRICING:
|
||||
pricing = MODEL_PRICING[model]
|
||||
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
||||
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
||||
total_cost = input_cost + output_cost
|
||||
try:
|
||||
claude_response = await call_claude_api(
|
||||
self.anthropic_client,
|
||||
messages_for_api,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
use_tools=use_tools
|
||||
)
|
||||
|
||||
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: ${total_cost:.6f}")
|
||||
# Extract token usage and calculate cost for Claude
|
||||
input_tokens = claude_response.get("input_tokens", 0)
|
||||
output_tokens = claude_response.get("output_tokens", 0)
|
||||
total_cost = 0.0
|
||||
|
||||
# Save token usage and cost to database
|
||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||
|
||||
# Process tool calls if any
|
||||
updated_messages = None
|
||||
if use_tools and response.choices[0].finish_reason == "tool_calls":
|
||||
# Process tools
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
tool_messages = {}
|
||||
|
||||
# Track which tools are being called
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.function.name in self.tool_mapping:
|
||||
tool_messages[tool_call.function.name] = True
|
||||
if tool_call.function.name == "generate_image":
|
||||
image_generation_used = True
|
||||
elif tool_call.function.name == "edit_image":
|
||||
# Display appropriate message for image editing
|
||||
await message.channel.send("🖌️ Editing image...")
|
||||
|
||||
# Display appropriate messages based on which tools are being called
|
||||
if tool_messages.get("google_search") or tool_messages.get("scrape_webpage"):
|
||||
await message.channel.send("🔍 Researching information...")
|
||||
|
||||
if tool_messages.get("execute_python_code") or tool_messages.get("analyze_data_file"):
|
||||
await message.channel.send("💻 Running code...")
|
||||
|
||||
if tool_messages.get("generate_image"):
|
||||
await message.channel.send("🎨 Generating images...")
|
||||
# Calculate cost based on model pricing
|
||||
pricing = MODEL_PRICING.get(model)
|
||||
if pricing:
|
||||
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
||||
logging.info(f"Claude API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||
|
||||
if tool_messages.get("set_reminder") or tool_messages.get("get_reminders"):
|
||||
await message.channel.send("📅 Processing reminders...")
|
||||
|
||||
if not tool_messages:
|
||||
await message.channel.send("🤔 Processing...")
|
||||
|
||||
# Process any tool calls and get the updated messages
|
||||
tool_calls_processed, updated_messages = await process_tool_calls(
|
||||
self.client,
|
||||
response,
|
||||
messages_for_api,
|
||||
self.tool_mapping
|
||||
)
|
||||
|
||||
# Process tool responses to extract important data (images, charts)
|
||||
if updated_messages:
|
||||
# Look for image generation and code interpreter tool responses
|
||||
for msg in updated_messages:
|
||||
if msg.get('role') == 'tool' and msg.get('name') == 'generate_image':
|
||||
try:
|
||||
tool_result = json.loads(msg.get('content', '{}'))
|
||||
if tool_result.get('image_urls'):
|
||||
image_urls.extend(tool_result['image_urls'])
|
||||
except:
|
||||
pass
|
||||
# Process tool calls if any
|
||||
updated_messages = None
|
||||
if use_tools and claude_response.get("tool_calls"):
|
||||
tool_calls = convert_claude_tool_calls_to_openai(claude_response["tool_calls"])
|
||||
tool_messages = {}
|
||||
|
||||
elif msg.get('role') == 'tool' and msg.get('name') == 'edit_image':
|
||||
try:
|
||||
tool_result = json.loads(msg.get('content', '{}'))
|
||||
if tool_result.get('image_url'):
|
||||
image_urls.append(tool_result['image_url'])
|
||||
except:
|
||||
pass
|
||||
# Track which tools are being called
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.function.name in self.tool_mapping:
|
||||
tool_messages[tool_call.function.name] = True
|
||||
if tool_call.function.name == "generate_image":
|
||||
image_generation_used = True
|
||||
elif tool_call.function.name == "edit_image":
|
||||
await message.channel.send("🖌️ Editing image...")
|
||||
|
||||
elif msg.get('role') == 'tool' and msg.get('name') in ['execute_python_code', 'analyze_data_file']:
|
||||
try:
|
||||
tool_result = json.loads(msg.get('content', '{}'))
|
||||
if tool_result.get('chart_id'):
|
||||
chart_id = tool_result['chart_id']
|
||||
except:
|
||||
pass
|
||||
|
||||
# If tool calls were processed, make another API call with the updated messages
|
||||
if tool_calls_processed and updated_messages:
|
||||
# Prepare API parameters for follow-up call
|
||||
follow_up_params = {
|
||||
"model": model,
|
||||
"messages": updated_messages,
|
||||
"timeout": 240
|
||||
}
|
||||
|
||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
||||
follow_up_params["temperature"] = 0.3
|
||||
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
follow_up_params["temperature"] = 1
|
||||
|
||||
response = await self.client.chat.completions.create(**follow_up_params)
|
||||
|
||||
# Extract token usage and calculate cost for follow-up call
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
follow_up_input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||
follow_up_output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||
# Display appropriate messages
|
||||
if tool_messages.get("google_search") or tool_messages.get("scrape_webpage"):
|
||||
await message.channel.send("🔍 Researching information...")
|
||||
if tool_messages.get("execute_python_code") or tool_messages.get("analyze_data_file"):
|
||||
await message.channel.send("💻 Running code...")
|
||||
if tool_messages.get("generate_image"):
|
||||
await message.channel.send("🎨 Generating images...")
|
||||
if tool_messages.get("set_reminder") or tool_messages.get("get_reminders"):
|
||||
await message.channel.send("📅 Processing reminders...")
|
||||
if not tool_messages:
|
||||
await message.channel.send("🤔 Processing...")
|
||||
|
||||
input_tokens += follow_up_input_tokens
|
||||
output_tokens += follow_up_output_tokens
|
||||
# Process tool calls manually for Claude
|
||||
tool_results = []
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
if function_name in self.tool_mapping:
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
function_response = await self.tool_mapping[function_name](function_args)
|
||||
tool_results.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(function_response)
|
||||
})
|
||||
|
||||
# Extract image URLs if generated
|
||||
if function_name == "generate_image":
|
||||
try:
|
||||
tool_result = json.loads(function_response) if isinstance(function_response, str) else function_response
|
||||
if tool_result.get('image_urls'):
|
||||
image_urls.extend(tool_result['image_urls'])
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing {function_name}: {e}")
|
||||
tool_results.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": f"Error: {str(e)}"
|
||||
})
|
||||
|
||||
# Calculate additional cost
|
||||
if model in MODEL_PRICING:
|
||||
pricing = MODEL_PRICING[model]
|
||||
additional_input_cost = (follow_up_input_tokens / 1_000_000) * pricing["input"]
|
||||
additional_output_cost = (follow_up_output_tokens / 1_000_000) * pricing["output"]
|
||||
additional_cost = additional_input_cost + additional_output_cost
|
||||
# Build updated messages with tool results for follow-up call
|
||||
updated_messages = messages_for_api.copy()
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": claude_response.get("content", "")
|
||||
})
|
||||
for result in tool_results:
|
||||
updated_messages.append(
|
||||
self._build_claude_tool_result_message(result["tool_call_id"], result["content"])
|
||||
)
|
||||
|
||||
# Make follow-up call
|
||||
follow_up_response = await call_claude_api(
|
||||
self.anthropic_client,
|
||||
updated_messages,
|
||||
model,
|
||||
max_tokens=4096,
|
||||
use_tools=False # Don't need tools for follow-up
|
||||
)
|
||||
|
||||
# Update token usage
|
||||
follow_up_input = follow_up_response.get("input_tokens", 0)
|
||||
follow_up_output = follow_up_response.get("output_tokens", 0)
|
||||
input_tokens += follow_up_input
|
||||
output_tokens += follow_up_output
|
||||
|
||||
if pricing:
|
||||
additional_cost = pricing.calculate_cost(follow_up_input, follow_up_output)
|
||||
total_cost += additional_cost
|
||||
await self.db.save_token_usage(user_id, model, follow_up_input, follow_up_output, additional_cost)
|
||||
|
||||
reply = follow_up_response.get("content", "")
|
||||
else:
|
||||
reply = claude_response.get("content", "")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if "overloaded" in error_str.lower():
|
||||
await message.channel.send(
|
||||
f"⚠️ **Claude is currently overloaded**\n"
|
||||
f"Please try again in a moment or switch to an OpenAI model."
|
||||
)
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
# Use OpenAI API (existing logic)
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": messages_for_api,
|
||||
"timeout": 240 # Increased timeout for better response handling
|
||||
}
|
||||
|
||||
# Add temperature and top_p only for models that support them (exclude GPT-5 family)
|
||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
||||
api_params["temperature"] = 0.3
|
||||
api_params["top_p"] = 0.7
|
||||
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
# For other models (not GPT-4o family and not GPT-5 family)
|
||||
api_params["temperature"] = 1
|
||||
api_params["top_p"] = 1
|
||||
|
||||
# Add tools if using a supported model
|
||||
if use_tools:
|
||||
tools = get_tools_for_model()
|
||||
api_params["tools"] = tools
|
||||
|
||||
# Make the initial API call
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
except Exception as e:
|
||||
# Handle 413 Request Entity Too Large error with a user-friendly message
|
||||
if "413" in str(e) or "tokens_limit_reached" in str(e) or "Request body too large" in str(e):
|
||||
await message.channel.send(
|
||||
f"❌ **Request too large for {model}**\n"
|
||||
f"Your conversation history or message is too large for this model.\n"
|
||||
f"Try:\n"
|
||||
f"• Using `/reset` to start fresh\n"
|
||||
f"• Using a model with higher token limits\n"
|
||||
f"• Reducing the size of your current message\n"
|
||||
f"• Breaking up large files into smaller pieces"
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Re-raise other errors
|
||||
raise e
|
||||
|
||||
# Extract token usage and calculate cost
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
total_cost = 0.0
|
||||
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||
|
||||
# Calculate cost based on model pricing
|
||||
pricing = MODEL_PRICING.get(model)
|
||||
if pricing:
|
||||
total_cost = pricing.calculate_cost(input_tokens, output_tokens)
|
||||
|
||||
logging.info(f"API call - Model: {model}, Input tokens: {input_tokens}, Output tokens: {output_tokens}, Cost: {format_cost(total_cost)}")
|
||||
|
||||
# Save token usage and cost to database
|
||||
await self.db.save_token_usage(user_id, model, input_tokens, output_tokens, total_cost)
|
||||
|
||||
# Process tool calls if any (OpenAI)
|
||||
updated_messages = None
|
||||
if use_tools and response.choices[0].finish_reason == "tool_calls":
|
||||
# Process tools
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
tool_messages = {}
|
||||
|
||||
# Track which tools are being called
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.function.name in self.tool_mapping:
|
||||
tool_messages[tool_call.function.name] = True
|
||||
if tool_call.function.name == "generate_image":
|
||||
image_generation_used = True
|
||||
elif tool_call.function.name == "edit_image":
|
||||
# Display appropriate message for image editing
|
||||
await message.channel.send("🖌️ Editing image...")
|
||||
|
||||
# Display appropriate messages based on which tools are being called
|
||||
if tool_messages.get("google_search") or tool_messages.get("scrape_webpage"):
|
||||
await message.channel.send("🔍 Researching information...")
|
||||
|
||||
if tool_messages.get("execute_python_code") or tool_messages.get("analyze_data_file"):
|
||||
await message.channel.send("💻 Running code...")
|
||||
|
||||
if tool_messages.get("generate_image"):
|
||||
await message.channel.send("🎨 Generating images...")
|
||||
|
||||
if tool_messages.get("set_reminder") or tool_messages.get("get_reminders"):
|
||||
await message.channel.send("📅 Processing reminders...")
|
||||
|
||||
if not tool_messages:
|
||||
await message.channel.send("🤔 Processing...")
|
||||
|
||||
# Process any tool calls and get the updated messages
|
||||
tool_calls_processed, updated_messages = await process_tool_calls(
|
||||
self.client,
|
||||
response,
|
||||
messages_for_api,
|
||||
self.tool_mapping
|
||||
)
|
||||
|
||||
# Process tool responses to extract important data (images, charts)
|
||||
if updated_messages:
|
||||
# Look for image generation and code interpreter tool responses
|
||||
for msg in updated_messages:
|
||||
if msg.get('role') == 'tool' and msg.get('name') == 'generate_image':
|
||||
try:
|
||||
tool_result = json.loads(msg.get('content', '{}'))
|
||||
if tool_result.get('image_urls'):
|
||||
image_urls.extend(tool_result['image_urls'])
|
||||
except:
|
||||
pass
|
||||
|
||||
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: ${additional_cost:.6f}")
|
||||
elif msg.get('role') == 'tool' and msg.get('name') == 'edit_image':
|
||||
try:
|
||||
tool_result = json.loads(msg.get('content', '{}'))
|
||||
if tool_result.get('image_url'):
|
||||
image_urls.append(tool_result['image_url'])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Save additional token usage and cost to database
|
||||
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
elif msg.get('role') == 'tool' and msg.get('name') in ['execute_python_code', 'analyze_data_file']:
|
||||
try:
|
||||
tool_result = json.loads(msg.get('content', '{}'))
|
||||
if tool_result.get('chart_id'):
|
||||
chart_id = tool_result['chart_id']
|
||||
except:
|
||||
pass
|
||||
|
||||
# If tool calls were processed, make another API call with the updated messages
|
||||
if tool_calls_processed and updated_messages:
|
||||
# Prepare API parameters for follow-up call
|
||||
follow_up_params = {
|
||||
"model": model,
|
||||
"messages": updated_messages,
|
||||
"timeout": 240
|
||||
}
|
||||
|
||||
# Add temperature only for models that support it (exclude GPT-5 family)
|
||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini"]:
|
||||
follow_up_params["temperature"] = 0.3
|
||||
elif model not in ["openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat"]:
|
||||
follow_up_params["temperature"] = 1
|
||||
|
||||
response = await self.client.chat.completions.create(**follow_up_params)
|
||||
|
||||
# Extract token usage and calculate cost for follow-up call
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
follow_up_input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
||||
follow_up_output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
||||
|
||||
input_tokens += follow_up_input_tokens
|
||||
output_tokens += follow_up_output_tokens
|
||||
|
||||
# Calculate additional cost
|
||||
pricing = MODEL_PRICING.get(model)
|
||||
if pricing:
|
||||
additional_cost = pricing.calculate_cost(follow_up_input_tokens, follow_up_output_tokens)
|
||||
total_cost += additional_cost
|
||||
|
||||
logging.info(f"Follow-up API call - Model: {model}, Input tokens: {follow_up_input_tokens}, Output tokens: {follow_up_output_tokens}, Additional cost: {format_cost(additional_cost)}")
|
||||
|
||||
# Save additional token usage and cost to database
|
||||
await self.db.save_token_usage(user_id, model, follow_up_input_tokens, follow_up_output_tokens, additional_cost)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
# Add image URLs to assistant content if any were found
|
||||
has_images = len(image_urls) > 0
|
||||
@@ -1706,7 +1887,15 @@ print("\\n=== Correlation Analysis ===")
|
||||
})
|
||||
|
||||
# Store the response in history for models that support it
|
||||
if model in ["openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o1-mini", "openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-nano", "openai/gpt-4.1-mini", "openai/o3", "openai/o4-mini", "openai/o1-preview"]:
|
||||
models_with_history = [
|
||||
"openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-5", "openai/gpt-5-nano",
|
||||
"openai/gpt-5-mini", "openai/gpt-5-chat", "openai/o1", "openai/o1-mini",
|
||||
"openai/o3-mini", "openai/gpt-4.1", "openai/gpt-4.1-nano", "openai/gpt-4.1-mini",
|
||||
"openai/o3", "openai/o4-mini", "openai/o1-preview",
|
||||
"anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514",
|
||||
"anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku"
|
||||
]
|
||||
if model in models_with_history:
|
||||
if model in ["openai/o1-mini", "openai/o1-preview"]:
|
||||
# For models without system prompt support, keep track separately
|
||||
if has_images:
|
||||
@@ -1762,7 +1951,7 @@ print("\\n=== Correlation Analysis ===")
|
||||
|
||||
# Log processing time and cost for performance monitoring
|
||||
processing_time = time.time() - start_time
|
||||
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: ${total_cost:.6f})")
|
||||
logging.info(f"Message processed in {processing_time:.2f} seconds (User: {user_id}, Model: {model}, Cost: {format_cost(total_cost)})")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation cleanly
|
||||
@@ -2087,6 +2276,25 @@ print("\\n=== Correlation Analysis ===")
|
||||
async def _image_to_text(self, args: Dict[str, Any]):
|
||||
"""Convert image to text"""
|
||||
try:
|
||||
# Check if model passed "latest_image" - use stored URL
|
||||
image_url = args.get("image_url", "")
|
||||
if image_url == "latest_image" or not image_url:
|
||||
user_id = self._find_user_id_from_current_task()
|
||||
if user_id:
|
||||
# Try in-memory first (from current session), then database
|
||||
if user_id in self.user_latest_image_url:
|
||||
args["image_url"] = self.user_latest_image_url[user_id]
|
||||
logging.info(f"Using in-memory image URL for image_to_text")
|
||||
else:
|
||||
db_url = await self._get_latest_image_url_from_db(user_id)
|
||||
if db_url:
|
||||
args["image_url"] = db_url
|
||||
logging.info(f"Using database image URL for image_to_text")
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
|
||||
result = await self.image_generator.image_to_text(args)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -2096,15 +2304,82 @@ print("\\n=== Correlation Analysis ===")
|
||||
async def _upscale_image(self, args: Dict[str, Any]):
|
||||
"""Upscale an image"""
|
||||
try:
|
||||
# Check if model passed "latest_image" - use stored URL
|
||||
image_url = args.get("image_url", "")
|
||||
if image_url == "latest_image" or not image_url:
|
||||
user_id = self._find_user_id_from_current_task()
|
||||
if user_id:
|
||||
# Try in-memory first (from current session), then database
|
||||
if user_id in self.user_latest_image_url:
|
||||
args["image_url"] = self.user_latest_image_url[user_id]
|
||||
logging.info(f"Using in-memory image URL for upscale")
|
||||
else:
|
||||
db_url = await self._get_latest_image_url_from_db(user_id)
|
||||
if db_url:
|
||||
args["image_url"] = db_url
|
||||
logging.info(f"Using database image URL for upscale")
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
|
||||
result = await self.image_generator.upscale_image(args)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in image upscaling: {str(e)}")
|
||||
return json.dumps({"error": f"Image upscaling failed: {str(e)}"})
|
||||
|
||||
async def _remove_background(self, args: Dict[str, Any]):
|
||||
"""Remove background from an image"""
|
||||
try:
|
||||
# Check if model passed "latest_image" - use stored URL
|
||||
image_url = args.get("image_url", "")
|
||||
if image_url == "latest_image" or not image_url:
|
||||
user_id = self._find_user_id_from_current_task()
|
||||
if user_id:
|
||||
# Try in-memory first (from current session), then database
|
||||
if user_id in self.user_latest_image_url:
|
||||
args["image_url"] = self.user_latest_image_url[user_id]
|
||||
logging.info(f"Using in-memory image URL for background removal")
|
||||
else:
|
||||
db_url = await self._get_latest_image_url_from_db(user_id)
|
||||
if db_url:
|
||||
args["image_url"] = db_url
|
||||
logging.info(f"Using database image URL for background removal")
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
|
||||
result = await self.image_generator.remove_background(args)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in background removal: {str(e)}")
|
||||
return json.dumps({"error": f"Background removal failed: {str(e)}"})
|
||||
|
||||
async def _photo_maker(self, args: Dict[str, Any]):
|
||||
"""Create a photo"""
|
||||
try:
|
||||
# Check if model passed "latest_image" in input_images - use stored URL
|
||||
input_images = args.get("input_images", [])
|
||||
if input_images and "latest_image" in input_images:
|
||||
user_id = self._find_user_id_from_current_task()
|
||||
if user_id:
|
||||
# Try in-memory first (from current session), then database
|
||||
if user_id in self.user_latest_image_url:
|
||||
url = self.user_latest_image_url[user_id]
|
||||
args["input_images"] = [url if img == "latest_image" else img for img in input_images]
|
||||
logging.info(f"Using in-memory image URL for photo_maker")
|
||||
else:
|
||||
db_url = await self._get_latest_image_url_from_db(user_id)
|
||||
if db_url:
|
||||
args["input_images"] = [db_url if img == "latest_image" else img for img in input_images]
|
||||
logging.info(f"Using database image URL for photo_maker")
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
else:
|
||||
return json.dumps({"error": "No image found. Please upload an image first."})
|
||||
|
||||
result = await self.image_generator.photo_maker(args)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
358
src/utils/cache.py
Normal file
358
src/utils/cache.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Simple caching utilities for API responses and frequently accessed data.
|
||||
|
||||
This module provides an in-memory LRU cache with optional TTL (time-to-live)
|
||||
support, designed for caching API responses and reducing redundant calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Callable, TypeVar, Generic
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry(Generic[T]):
|
||||
"""A single cache entry with value and expiration time."""
|
||||
value: T
|
||||
expires_at: float
|
||||
created_at: float = field(default_factory=time.time)
|
||||
hits: int = 0
|
||||
|
||||
|
||||
class LRUCache(Generic[T]):
|
||||
"""
|
||||
Thread-safe LRU (Least Recently Used) cache with TTL support.
|
||||
|
||||
Features:
|
||||
- Configurable max size with automatic eviction
|
||||
- Per-entry TTL (time-to-live)
|
||||
- Automatic cleanup of expired entries
|
||||
- Hit/miss statistics tracking
|
||||
|
||||
Usage:
|
||||
cache = LRUCache(max_size=1000, default_ttl=300) # 5 min TTL
|
||||
cache.set("key", "value")
|
||||
value = cache.get("key") # Returns value or None if expired
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 1000,
|
||||
default_ttl: float = 300.0, # 5 minutes default
|
||||
cleanup_interval: float = 60.0
|
||||
):
|
||||
"""
|
||||
Initialize the LRU cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl: Default TTL in seconds
|
||||
cleanup_interval: How often to run cleanup (seconds)
|
||||
"""
|
||||
self._cache: OrderedDict[str, CacheEntry[T]] = OrderedDict()
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Statistics
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
# Background cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the background cleanup task."""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.debug("Cache cleanup task started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background cleanup task."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
logger.debug("Cache cleanup task stopped")
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background task to periodically clean up expired entries."""
|
||||
while True:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired()
|
||||
|
||||
async def _cleanup_expired(self) -> int:
|
||||
"""Remove expired entries. Returns count of removed entries."""
|
||||
now = time.time()
|
||||
removed = 0
|
||||
|
||||
async with self._lock:
|
||||
keys_to_remove = [
|
||||
key for key, entry in self._cache.items()
|
||||
if entry.expires_at <= now
|
||||
]
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
removed += 1
|
||||
|
||||
if removed > 0:
|
||||
logger.debug(f"Cache cleanup: removed {removed} expired entries")
|
||||
|
||||
return removed
|
||||
|
||||
async def get(self, key: str) -> Optional[T]:
|
||||
"""
|
||||
Get a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
async with self._lock:
|
||||
if key not in self._cache:
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
# Check if expired
|
||||
if entry.expires_at <= time.time():
|
||||
del self._cache[key]
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
entry.hits += 1
|
||||
self._hits += 1
|
||||
|
||||
return entry.value
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: T,
|
||||
ttl: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set a value in the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Optional TTL override (uses default if not provided)
|
||||
"""
|
||||
ttl = ttl if ttl is not None else self._default_ttl
|
||||
expires_at = time.time() + ttl
|
||||
|
||||
async with self._lock:
|
||||
# Remove oldest entries if at capacity
|
||||
while len(self._cache) >= self._max_size:
|
||||
oldest_key = next(iter(self._cache))
|
||||
del self._cache[oldest_key]
|
||||
logger.debug(f"Cache evicted oldest entry: {oldest_key}")
|
||||
|
||||
self._cache[key] = CacheEntry(
|
||||
value=value,
|
||||
expires_at=expires_at
|
||||
)
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if key was found and deleted
|
||||
"""
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all entries from the cache.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
async with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
return count
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
return await self.get(key) is not None
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dict with size, hits, misses, hit_rate
|
||||
"""
|
||||
total = self._hits + self._misses
|
||||
hit_rate = (self._hits / total * 100) if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self._max_size,
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": f"{hit_rate:.2f}%",
|
||||
"default_ttl": self._default_ttl
|
||||
}
|
||||
|
||||
|
||||
# Global cache instances for different purposes
|
||||
_api_response_cache: Optional[LRUCache[Dict[str, Any]]] = None
|
||||
_user_preference_cache: Optional[LRUCache[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
async def get_api_cache() -> LRUCache[Dict[str, Any]]:
|
||||
"""Get or create the API response cache."""
|
||||
global _api_response_cache
|
||||
if _api_response_cache is None:
|
||||
_api_response_cache = LRUCache(
|
||||
max_size=500,
|
||||
default_ttl=300.0 # 5 minutes
|
||||
)
|
||||
await _api_response_cache.start()
|
||||
return _api_response_cache
|
||||
|
||||
|
||||
async def get_user_cache() -> LRUCache[Dict[str, Any]]:
|
||||
"""Get or create the user preference cache."""
|
||||
global _user_preference_cache
|
||||
if _user_preference_cache is None:
|
||||
_user_preference_cache = LRUCache(
|
||||
max_size=1000,
|
||||
default_ttl=600.0 # 10 minutes
|
||||
)
|
||||
await _user_preference_cache.start()
|
||||
return _user_preference_cache
|
||||
|
||||
|
||||
def cached(
|
||||
cache_key_func: Callable[..., str],
|
||||
ttl: Optional[float] = None,
|
||||
cache_getter: Callable = get_api_cache
|
||||
):
|
||||
"""
|
||||
Decorator to cache async function results.
|
||||
|
||||
Args:
|
||||
cache_key_func: Function to generate cache key from args
|
||||
ttl: Optional TTL override
|
||||
cache_getter: Function to get the cache instance
|
||||
|
||||
Usage:
|
||||
@cached(
|
||||
cache_key_func=lambda user_id: f"user:{user_id}",
|
||||
ttl=300
|
||||
)
|
||||
async def get_user_data(user_id: int) -> dict:
|
||||
# Expensive operation
|
||||
return await fetch_from_api(user_id)
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
cache = await cache_getter()
|
||||
key = cache_key_func(*args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_value = await cache.get(key)
|
||||
if cached_value is not None:
|
||||
logger.debug(f"Cache hit for key: {key}")
|
||||
return cached_value
|
||||
|
||||
# Execute function and cache result
|
||||
result = await func(*args, **kwargs)
|
||||
await cache.set(key, result, ttl=ttl)
|
||||
logger.debug(f"Cached result for key: {key}")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def invalidate_on_update(
|
||||
cache_key_func: Callable[..., str],
|
||||
cache_getter: Callable = get_api_cache
|
||||
):
|
||||
"""
|
||||
Decorator to invalidate cache when a function (update operation) is called.
|
||||
|
||||
Args:
|
||||
cache_key_func: Function to generate cache key to invalidate
|
||||
cache_getter: Function to get the cache instance
|
||||
|
||||
Usage:
|
||||
@invalidate_on_update(
|
||||
cache_key_func=lambda user_id, **_: f"user:{user_id}"
|
||||
)
|
||||
async def update_user_data(user_id: int, data: dict) -> None:
|
||||
await save_to_db(user_id, data)
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Invalidate cache after update
|
||||
cache = await cache_getter()
|
||||
key = cache_key_func(*args, **kwargs)
|
||||
await cache.delete(key)
|
||||
logger.debug(f"Invalidated cache for key: {key}")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Convenience functions for common caching patterns
|
||||
|
||||
async def cache_user_model(user_id: int, model: str) -> None:
|
||||
"""Cache user's selected model."""
|
||||
cache = await get_user_cache()
|
||||
await cache.set(f"user_model:{user_id}", {"model": model})
|
||||
|
||||
|
||||
async def get_cached_user_model(user_id: int) -> Optional[str]:
|
||||
"""Get user's cached model selection."""
|
||||
cache = await get_user_cache()
|
||||
result = await cache.get(f"user_model:{user_id}")
|
||||
return result["model"] if result else None
|
||||
|
||||
|
||||
async def invalidate_user_cache(user_id: int) -> None:
|
||||
"""Invalidate all cached data for a user."""
|
||||
cache = await get_user_cache()
|
||||
# Clear known user-related keys
|
||||
await cache.delete(f"user_model:{user_id}")
|
||||
await cache.delete(f"user_history:{user_id}")
|
||||
await cache.delete(f"user_stats:{user_id}")
|
||||
531
src/utils/claude_utils.py
Normal file
531
src/utils/claude_utils.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
Claude (Anthropic) API utility functions.
|
||||
|
||||
This module provides utilities for interacting with Anthropic's Claude models,
|
||||
including message conversion and API calls compatible with the existing bot structure.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
|
||||
def is_claude_model(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is a Claude/Anthropic model.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "anthropic/claude-sonnet-4-20250514")
|
||||
|
||||
Returns:
|
||||
True if it's a Claude model, False otherwise
|
||||
"""
|
||||
return model.startswith("anthropic/")
|
||||
|
||||
|
||||
def get_claude_model_id(model: str) -> str:
|
||||
"""
|
||||
Extract the Claude model ID from the full model name.
|
||||
|
||||
Args:
|
||||
model: Full model name (e.g., "anthropic/claude-sonnet-4-20250514")
|
||||
|
||||
Returns:
|
||||
Claude model ID (e.g., "claude-sonnet-4-20250514")
|
||||
"""
|
||||
if model.startswith("anthropic/"):
|
||||
return model[len("anthropic/"):]
|
||||
return model
|
||||
|
||||
|
||||
def convert_openai_messages_to_claude(messages: List[Dict[str, Any]]) -> Tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Convert OpenAI message format to Claude message format.
|
||||
|
||||
OpenAI uses:
|
||||
- {"role": "system", "content": "..."}
|
||||
- {"role": "user", "content": "..."}
|
||||
- {"role": "assistant", "content": "..."}
|
||||
|
||||
Claude uses:
|
||||
- system parameter (separate from messages)
|
||||
- {"role": "user", "content": "..."}
|
||||
- {"role": "assistant", "content": "..."}
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Tuple of (system_prompt, claude_messages)
|
||||
"""
|
||||
system_prompt = None
|
||||
claude_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
# Skip messages with None content
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if role == "system":
|
||||
# Claude uses a separate system parameter
|
||||
if isinstance(content, str):
|
||||
system_prompt = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from list content
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
system_prompt = " ".join(text_parts)
|
||||
elif role in ["user", "assistant"]:
|
||||
# Convert content format
|
||||
converted_content = convert_content_to_claude(content)
|
||||
if converted_content:
|
||||
claude_messages.append({
|
||||
"role": role,
|
||||
"content": converted_content
|
||||
})
|
||||
elif role == "tool":
|
||||
# Claude handles tool results differently - add as user message with tool result
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
tool_name = msg.get("name", "unknown")
|
||||
claude_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": str(content)
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Claude requires alternating user/assistant messages
|
||||
# Merge consecutive messages of the same role
|
||||
merged_messages = merge_consecutive_messages(claude_messages)
|
||||
|
||||
return system_prompt, merged_messages
|
||||
|
||||
|
||||
def convert_content_to_claude(content: Any) -> Any:
|
||||
"""
|
||||
Convert content from OpenAI format to Claude format.
|
||||
|
||||
Args:
|
||||
content: Content in OpenAI format (string or list)
|
||||
|
||||
Returns:
|
||||
Content in Claude format
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
claude_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "text":
|
||||
claude_content.append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
elif item_type == "image_url":
|
||||
# Convert image_url format to Claude format
|
||||
image_url_data = item.get("image_url", {})
|
||||
if isinstance(image_url_data, dict):
|
||||
url = image_url_data.get("url", "")
|
||||
else:
|
||||
url = str(image_url_data)
|
||||
|
||||
if url:
|
||||
# Claude requires base64 data or URLs
|
||||
if url.startswith("data:"):
|
||||
# Parse base64 data URL
|
||||
try:
|
||||
media_type, base64_data = parse_data_url(url)
|
||||
claude_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse data URL: {e}")
|
||||
else:
|
||||
# Regular URL - Claude supports URLs directly
|
||||
claude_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": url
|
||||
}
|
||||
})
|
||||
else:
|
||||
# Handle other types as text
|
||||
if "text" in item:
|
||||
claude_content.append({
|
||||
"type": "text",
|
||||
"text": str(item.get("text", ""))
|
||||
})
|
||||
elif isinstance(item, str):
|
||||
claude_content.append({
|
||||
"type": "text",
|
||||
"text": item
|
||||
})
|
||||
|
||||
return claude_content if claude_content else None
|
||||
|
||||
return str(content) if content else None
|
||||
|
||||
|
||||
def parse_data_url(data_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a data URL into media type and base64 data.
|
||||
|
||||
Args:
|
||||
data_url: Data URL (e.g., "data:image/png;base64,...")
|
||||
|
||||
Returns:
|
||||
Tuple of (media_type, base64_data)
|
||||
|
||||
Raises:
|
||||
ValueError: If the data URL format is invalid
|
||||
"""
|
||||
if not data_url.startswith("data:"):
|
||||
raise ValueError(f"Not a data URL: expected 'data:' prefix, got '{data_url[:20]}...'")
|
||||
|
||||
# Remove "data:" prefix
|
||||
content = data_url[5:]
|
||||
|
||||
# Split by semicolon and comma
|
||||
parts = content.split(";base64,")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid data URL format: expected ';base64,' separator, got '{content[:50]}...'")
|
||||
|
||||
media_type = parts[0]
|
||||
base64_data = parts[1]
|
||||
|
||||
return media_type, base64_data
|
||||
|
||||
|
||||
def merge_consecutive_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge consecutive messages with the same role.
|
||||
Claude requires alternating user/assistant messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
List of merged messages
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged = []
|
||||
current_role = None
|
||||
current_content = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == current_role:
|
||||
# Same role, merge content
|
||||
if isinstance(content, str):
|
||||
if current_content and isinstance(current_content[-1], dict) and current_content[-1].get("type") == "text":
|
||||
current_content[-1]["text"] += "\n" + content
|
||||
else:
|
||||
current_content.append({"type": "text", "text": content})
|
||||
elif isinstance(content, list):
|
||||
current_content.extend(content)
|
||||
else:
|
||||
# Different role, save previous and start new
|
||||
if current_role is not None and current_content:
|
||||
merged.append({
|
||||
"role": current_role,
|
||||
"content": simplify_content(current_content)
|
||||
})
|
||||
|
||||
current_role = role
|
||||
if isinstance(content, str):
|
||||
current_content = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
current_content = content.copy()
|
||||
else:
|
||||
current_content = []
|
||||
|
||||
# Don't forget the last message
|
||||
if current_role is not None and current_content:
|
||||
merged.append({
|
||||
"role": current_role,
|
||||
"content": simplify_content(current_content)
|
||||
})
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def simplify_content(content: List[Dict[str, Any]]) -> Any:
|
||||
"""
|
||||
Simplify content list to string if it only contains text.
|
||||
|
||||
Args:
|
||||
content: List of content items
|
||||
|
||||
Returns:
|
||||
Simplified content (string or list)
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# If only one text item, return as string
|
||||
if len(content) == 1 and content[0].get("type") == "text":
|
||||
return content[0].get("text", "")
|
||||
|
||||
# If all items are text, merge them
|
||||
if all(item.get("type") == "text" for item in content):
|
||||
texts = [item.get("text", "") for item in content]
|
||||
return "\n".join(texts)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def get_claude_tools() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for Claude API.
|
||||
Claude uses a slightly different tool format than OpenAI.
|
||||
|
||||
Returns:
|
||||
List of tool definitions in Claude format
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"name": "google_search",
|
||||
"description": "Search the web for current information",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The search query"},
|
||||
"num_results": {"type": "integer", "description": "Number of results (max 10)", "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "scrape_webpage",
|
||||
"description": "Extract and read content from a webpage URL",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "The webpage URL to scrape"}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_python_code",
|
||||
"description": "Run Python code. Packages auto-install. Use load_file('file_id') for user files. Output files auto-sent to user.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Python code to execute"},
|
||||
"timeout": {"type": "integer", "description": "Timeout in seconds", "maximum": 300}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "generate_image",
|
||||
"description": "Create/generate images from text. Models: flux (best), flux-dev, sdxl, realistic (photos), anime, dreamshaper.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string", "description": "Detailed description of the image to create"},
|
||||
"model": {"type": "string", "description": "Model to use", "enum": ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"]},
|
||||
"num_images": {"type": "integer", "description": "Number of images (1-4)", "maximum": 4},
|
||||
"aspect_ratio": {"type": "string", "description": "Aspect ratio preset", "enum": ["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "21:9"]}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "set_reminder",
|
||||
"description": "Set a reminder",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "Reminder content"},
|
||||
"time": {"type": "string", "description": "Reminder time"}
|
||||
},
|
||||
"required": ["content", "time"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_reminders",
|
||||
"description": "List all reminders",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "upscale_image",
|
||||
"description": "Enlarge/upscale an image to higher resolution. Pass 'latest_image' to use the user's most recently uploaded image.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"},
|
||||
"scale_factor": {"type": "integer", "description": "Scale factor (2 or 4)", "enum": [2, 4]},
|
||||
"model": {"type": "string", "description": "Upscale model", "enum": ["clarity", "ccsr", "sd-latent", "swinir"]}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "remove_background",
|
||||
"description": "Remove background from an image. Pass 'latest_image' to use the user's most recently uploaded image.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"},
|
||||
"model": {"type": "string", "description": "Background removal model", "enum": ["bria", "rembg", "birefnet-base", "birefnet-general", "birefnet-portrait"]}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "image_to_text",
|
||||
"description": "Generate a text description/caption of an image. Pass 'latest_image' to use the user's most recently uploaded image.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def call_claude_api(
|
||||
anthropic_client,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
max_tokens: int = 4096,
|
||||
use_tools: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call the Claude API with the given messages.
|
||||
|
||||
Args:
|
||||
anthropic_client: Anthropic client instance
|
||||
messages: List of messages in OpenAI format
|
||||
model: Model name (e.g., "anthropic/claude-sonnet-4-20250514")
|
||||
max_tokens: Maximum tokens in response
|
||||
use_tools: Whether to include tools
|
||||
|
||||
Returns:
|
||||
Dict with response data including:
|
||||
- content: Response text
|
||||
- input_tokens: Number of input tokens
|
||||
- output_tokens: Number of output tokens
|
||||
- tool_calls: Any tool calls made
|
||||
- stop_reason: Why the response stopped
|
||||
"""
|
||||
try:
|
||||
# Convert messages
|
||||
system_prompt, claude_messages = convert_openai_messages_to_claude(messages)
|
||||
|
||||
# Get Claude model ID
|
||||
model_id = get_claude_model_id(model)
|
||||
|
||||
# Build API parameters
|
||||
api_params = {
|
||||
"model": model_id,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": claude_messages
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
api_params["system"] = system_prompt
|
||||
|
||||
if use_tools:
|
||||
api_params["tools"] = get_claude_tools()
|
||||
|
||||
# Make API call
|
||||
response = await anthropic_client.messages.create(**api_params)
|
||||
|
||||
# Extract response data
|
||||
result = {
|
||||
"content": "",
|
||||
"input_tokens": response.usage.input_tokens if response.usage else 0,
|
||||
"output_tokens": response.usage.output_tokens if response.usage else 0,
|
||||
"tool_calls": [],
|
||||
"stop_reason": response.stop_reason
|
||||
}
|
||||
|
||||
# Process content blocks
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
result["content"] += block.text
|
||||
elif block.type == "tool_use":
|
||||
result["tool_calls"].append({
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": json.dumps(block.input)
|
||||
}
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error calling Claude API: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_claude_tool_calls_to_openai(tool_calls: List[Dict[str, Any]]) -> List[Any]:
|
||||
"""
|
||||
Convert Claude tool calls to OpenAI format for compatibility with existing code.
|
||||
|
||||
Args:
|
||||
tool_calls: Tool calls from Claude API
|
||||
|
||||
Returns:
|
||||
Tool calls in OpenAI format
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
id: str
|
||||
type: str
|
||||
function: FunctionCall
|
||||
|
||||
result = []
|
||||
for tc in tool_calls:
|
||||
result.append(ToolCall(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"]
|
||||
)
|
||||
))
|
||||
|
||||
return result
|
||||
@@ -71,19 +71,40 @@ APPROVED_PACKAGES = {
|
||||
'more-itertools', 'toolz', 'cytoolz', 'funcy'
|
||||
}
|
||||
|
||||
# Blocked patterns
|
||||
# Blocked patterns - Comprehensive security checks
|
||||
# Note: We allow open() for writing to enable saving plots and outputs
|
||||
# The sandboxed environment restricts file access to safe directories
|
||||
BLOCKED_PATTERNS = [
|
||||
# Dangerous system modules
|
||||
# ==================== DANGEROUS SYSTEM MODULES ====================
|
||||
# OS module (except path)
|
||||
r'import\s+os\b(?!\s*\.path)',
|
||||
r'from\s+os\s+import\s+(?!path)',
|
||||
|
||||
# File system modules
|
||||
r'import\s+shutil\b',
|
||||
r'from\s+shutil\s+import',
|
||||
r'import\s+pathlib\b(?!\s*\.)', # Allow pathlib usage but monitor
|
||||
|
||||
# Subprocess and execution modules
|
||||
r'import\s+subprocess\b',
|
||||
r'from\s+subprocess\s+import',
|
||||
r'import\s+sys\b(?!\s*\.(?:path|version|platform))',
|
||||
r'from\s+sys\s+import',
|
||||
r'import\s+multiprocessing\b',
|
||||
r'from\s+multiprocessing\s+import',
|
||||
r'import\s+threading\b',
|
||||
r'from\s+threading\s+import',
|
||||
r'import\s+concurrent\b',
|
||||
r'from\s+concurrent\s+import',
|
||||
|
||||
# System access modules
|
||||
r'import\s+sys\b(?!\s*\.(?:path|version|platform|stdout|stderr))',
|
||||
r'from\s+sys\s+import\s+(?!path|version|platform|stdout|stderr)',
|
||||
r'import\s+platform\b',
|
||||
r'from\s+platform\s+import',
|
||||
r'import\s+ctypes\b',
|
||||
r'from\s+ctypes\s+import',
|
||||
r'import\s+_[a-z]+', # Block private C modules
|
||||
|
||||
# ==================== NETWORK MODULES ====================
|
||||
r'import\s+socket\b',
|
||||
r'from\s+socket\s+import',
|
||||
r'import\s+urllib\b',
|
||||
@@ -92,19 +113,98 @@ BLOCKED_PATTERNS = [
|
||||
r'from\s+requests\s+import',
|
||||
r'import\s+aiohttp\b',
|
||||
r'from\s+aiohttp\s+import',
|
||||
# Dangerous code execution
|
||||
r'import\s+httpx\b',
|
||||
r'from\s+httpx\s+import',
|
||||
r'import\s+http\.client\b',
|
||||
r'from\s+http\.client\s+import',
|
||||
r'import\s+ftplib\b',
|
||||
r'from\s+ftplib\s+import',
|
||||
r'import\s+smtplib\b',
|
||||
r'from\s+smtplib\s+import',
|
||||
r'import\s+telnetlib\b',
|
||||
r'from\s+telnetlib\s+import',
|
||||
r'import\s+ssl\b',
|
||||
r'from\s+ssl\s+import',
|
||||
r'import\s+paramiko\b',
|
||||
r'from\s+paramiko\s+import',
|
||||
|
||||
# ==================== DANGEROUS CODE EXECUTION ====================
|
||||
r'__import__\s*\(',
|
||||
r'\beval\s*\(',
|
||||
r'\bexec\s*\(',
|
||||
r'\bcompile\s*\(',
|
||||
r'\bglobals\s*\(',
|
||||
r'\blocals\s*\(',
|
||||
# File system operations (dangerous)
|
||||
r'\bgetattr\s*\([^,]+,\s*[\'"]__', # Block getattr for dunder methods
|
||||
r'\bsetattr\s*\([^,]+,\s*[\'"]__', # Block setattr for dunder methods
|
||||
r'\bdelattr\s*\([^,]+,\s*[\'"]__', # Block delattr for dunder methods
|
||||
r'\.\_\_\w+\_\_', # Block dunder method access
|
||||
|
||||
# ==================== FILE SYSTEM OPERATIONS ====================
|
||||
r'\.unlink\s*\(',
|
||||
r'\.rmdir\s*\(',
|
||||
r'\.remove\s*\(',
|
||||
r'\.chmod\s*\(',
|
||||
r'\.chown\s*\(',
|
||||
r'\.rmtree\s*\(',
|
||||
r'\.rename\s*\(',
|
||||
r'\.replace\s*\(',
|
||||
r'\.makedirs\s*\(', # Allow mkdir but block makedirs outside sandbox
|
||||
r'Path\s*\(\s*[\'"]\/(?!tmp)', # Block absolute paths outside /tmp
|
||||
r'open\s*\(\s*[\'"]\/(?!tmp)', # Block file access outside /tmp
|
||||
|
||||
# ==================== PICKLE AND SERIALIZATION ====================
|
||||
r'pickle\.loads?\s*\(',
|
||||
r'cPickle\.loads?\s*\(',
|
||||
r'marshal\.loads?\s*\(',
|
||||
r'shelve\.open\s*\(',
|
||||
|
||||
# ==================== PROCESS MANIPULATION ====================
|
||||
r'os\.system\s*\(',
|
||||
r'os\.popen\s*\(',
|
||||
r'os\.spawn',
|
||||
r'os\.exec',
|
||||
r'os\.fork\s*\(',
|
||||
r'os\.kill\s*\(',
|
||||
r'os\.killpg\s*\(',
|
||||
|
||||
# ==================== ENVIRONMENT ACCESS ====================
|
||||
r'os\.environ',
|
||||
r'os\.getenv\s*\(',
|
||||
r'os\.putenv\s*\(',
|
||||
|
||||
# ==================== DANGEROUS BUILTINS ====================
|
||||
r'__builtins__',
|
||||
r'__loader__',
|
||||
r'__spec__',
|
||||
|
||||
# ==================== CODE OBJECT MANIPULATION ====================
|
||||
r'\.f_code',
|
||||
r'\.f_globals',
|
||||
r'\.f_locals',
|
||||
r'\.gi_frame',
|
||||
r'\.co_code',
|
||||
r'types\.CodeType',
|
||||
r'types\.FunctionType',
|
||||
|
||||
# ==================== IMPORT SYSTEM MANIPULATION ====================
|
||||
r'import\s+importlib\b',
|
||||
r'from\s+importlib\s+import',
|
||||
r'sys\.modules',
|
||||
r'sys\.path\.(?:append|insert|extend)',
|
||||
|
||||
# ==================== MEMORY OPERATIONS ====================
|
||||
r'gc\.',
|
||||
r'sys\.getsizeof',
|
||||
r'sys\.getrefcount',
|
||||
r'\bid\s*\(', # Block id() which can leak memory addresses (\b ensures word boundary)
|
||||
]
|
||||
|
||||
# Additional patterns that log warnings but don't block
|
||||
WARNING_PATTERNS = [
|
||||
(r'while\s+True', "Infinite loop detected - ensure break condition exists"),
|
||||
(r'for\s+\w+\s+in\s+range\s*\(\s*\d{6,}', "Very large loop detected"),
|
||||
(r'recursion', "Recursion detected - ensure base case exists"),
|
||||
]
|
||||
|
||||
|
||||
@@ -772,10 +872,54 @@ class CodeExecutor:
|
||||
logger.warning(f"Cleanup failed: {e}")
|
||||
|
||||
def validate_code_security(self, code: str) -> Tuple[bool, str]:
|
||||
"""Validate code for security threats."""
|
||||
"""
|
||||
Validate code for security threats.
|
||||
|
||||
Performs comprehensive security checks including:
|
||||
- Blocked patterns (dangerous imports, code execution, file ops)
|
||||
- Warning patterns (potential issues that are logged)
|
||||
- Code structure validation
|
||||
|
||||
Args:
|
||||
code: The Python code to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, message)
|
||||
"""
|
||||
# Check for blocked patterns
|
||||
for pattern in BLOCKED_PATTERNS:
|
||||
if re.search(pattern, code, re.IGNORECASE):
|
||||
return False, f"Blocked unsafe operation: {pattern}"
|
||||
logger.warning(f"Blocked code pattern detected: {pattern[:50]}...")
|
||||
return False, f"Security violation: Unsafe operation detected"
|
||||
|
||||
# Check for warning patterns (log but don't block)
|
||||
for pattern, warning_msg in WARNING_PATTERNS:
|
||||
if re.search(pattern, code, re.IGNORECASE):
|
||||
logger.warning(f"Code warning: {warning_msg}")
|
||||
|
||||
# Additional structural checks
|
||||
try:
|
||||
# Parse the AST to check for suspicious constructs
|
||||
tree = ast.parse(code)
|
||||
for node in ast.walk(tree):
|
||||
# Check for suspicious attribute access
|
||||
if isinstance(node, ast.Attribute):
|
||||
if node.attr.startswith('_') and node.attr.startswith('__'):
|
||||
logger.warning(f"Dunder attribute access detected: {node.attr}")
|
||||
return False, "Security violation: Private attribute access not allowed"
|
||||
|
||||
# Check for suspicious function calls
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in ['eval', 'exec', 'compile', '__import__']:
|
||||
return False, f"Security violation: {node.func.id}() is not allowed"
|
||||
|
||||
except SyntaxError:
|
||||
# Syntax errors will be caught during execution
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during AST validation: {e}")
|
||||
|
||||
return True, "Code passed security validation"
|
||||
|
||||
def _extract_imports_from_code(self, code: str) -> List[str]:
|
||||
@@ -906,31 +1050,57 @@ import os
|
||||
|
||||
FILES = {json.dumps(file_paths_map)}
|
||||
|
||||
def get_file_path(file_id):
|
||||
'''
|
||||
Get the actual file path for a given file ID.
|
||||
Use this to get the path for pd.read_csv(), open(), etc.
|
||||
|
||||
Args:
|
||||
file_id: The file ID provided when the file was uploaded
|
||||
|
||||
Returns:
|
||||
str: The actual file path on disk
|
||||
|
||||
Example:
|
||||
path = get_file_path('878573881449906208_1764556246_bdbaecc8')
|
||||
df = pd.read_csv(path)
|
||||
|
||||
Available files: Use list(FILES.keys()) to see available files
|
||||
'''
|
||||
if file_id not in FILES:
|
||||
raise ValueError(f"File '{{file_id}}' not found. Available: {{list(FILES.keys())}}")
|
||||
return FILES[file_id]
|
||||
|
||||
def load_file(file_id):
|
||||
'''
|
||||
Load a file automatically based on its extension.
|
||||
Supports 200+ file types with smart auto-detection.
|
||||
Load a file automatically based on its extension and return the data directly.
|
||||
DO NOT pass the result to pd.read_csv() - it already returns a DataFrame!
|
||||
|
||||
Args:
|
||||
file_id: The file ID provided when the file was uploaded
|
||||
|
||||
Returns:
|
||||
Loaded file data (varies by file type):
|
||||
- CSV/TSV: pandas DataFrame
|
||||
- CSV/TSV: pandas DataFrame (ready to use!)
|
||||
- Excel (.xlsx, .xls): pandas ExcelFile object
|
||||
- JSON: pandas DataFrame or dict
|
||||
- Parquet/Feather: pandas DataFrame
|
||||
- Text files: string content
|
||||
- Images: PIL Image object
|
||||
- And 200+ more formats...
|
||||
|
||||
Excel file usage examples:
|
||||
excel_file = load_file('file_id')
|
||||
sheet_names = excel_file.sheet_names
|
||||
df = excel_file.parse('Sheet1')
|
||||
df2 = pd.read_excel(excel_file, sheet_name='Sheet1')
|
||||
CORRECT usage for CSV:
|
||||
df = load_file('file_id') # Returns DataFrame directly
|
||||
print(df.head())
|
||||
|
||||
Available files: {{', '.join(FILES.keys()) if FILES else 'None'}}
|
||||
WRONG usage (DO NOT DO THIS):
|
||||
file_path = load_file('file_id') # WRONG! This is a DataFrame, not a path
|
||||
df = pd.read_csv(file_path) # This will FAIL!
|
||||
|
||||
If you need the file path instead, use get_file_path():
|
||||
path = get_file_path('file_id')
|
||||
df = pd.read_csv(path)
|
||||
|
||||
Available files: Use list(FILES.keys()) to see available files
|
||||
'''
|
||||
if file_id not in FILES:
|
||||
available_files = list(FILES.keys())
|
||||
|
||||
417
src/utils/discord_utils.py
Normal file
417
src/utils/discord_utils.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Discord response utilities for sending messages with proper handling.
|
||||
|
||||
This module provides utilities for sending messages to Discord with
|
||||
proper length handling, error recovery, and formatting.
|
||||
"""
|
||||
|
||||
import discord
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
from typing import Optional, List, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
MAX_EMBED_DESCRIPTION = 4096
|
||||
MAX_EMBED_FIELD_VALUE = 1024
|
||||
MAX_EMBED_FIELDS = 25
|
||||
MAX_FILE_SIZE = 8 * 1024 * 1024 # 8MB for non-nitro
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageChunk:
|
||||
"""A chunk of a message that fits within Discord limits."""
|
||||
content: str
|
||||
is_code_block: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
def split_message(
|
||||
content: str,
|
||||
max_length: int = MAX_MESSAGE_LENGTH,
|
||||
split_on: List[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split a long message into chunks that fit within Discord limits.
|
||||
|
||||
Args:
|
||||
content: The message content to split
|
||||
max_length: Maximum length per chunk
|
||||
split_on: Preferred split points (default: newlines, spaces)
|
||||
|
||||
Returns:
|
||||
List of message chunks
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
|
||||
if split_on is None:
|
||||
split_on = ['\n\n', '\n', '. ', ' ']
|
||||
|
||||
chunks = []
|
||||
remaining = content
|
||||
|
||||
while remaining:
|
||||
if len(remaining) <= max_length:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
|
||||
# Find the best split point
|
||||
split_index = max_length
|
||||
|
||||
for delimiter in split_on:
|
||||
# Look for delimiter before max_length
|
||||
last_index = remaining.rfind(delimiter, 0, max_length)
|
||||
if last_index > max_length // 2: # Don't split too early
|
||||
split_index = last_index + len(delimiter)
|
||||
break
|
||||
|
||||
# If no good split point, hard cut at max_length
|
||||
if split_index >= max_length:
|
||||
split_index = max_length
|
||||
|
||||
chunks.append(remaining[:split_index])
|
||||
remaining = remaining[split_index:]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def split_code_block(
|
||||
code: str,
|
||||
language: str = "",
|
||||
max_length: int = MAX_MESSAGE_LENGTH
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split code into properly formatted code block chunks.
|
||||
|
||||
Args:
|
||||
code: The code content
|
||||
language: The language for syntax highlighting
|
||||
max_length: Maximum length per chunk
|
||||
|
||||
Returns:
|
||||
List of formatted code block strings
|
||||
"""
|
||||
# Account for code block markers
|
||||
marker_length = len(f"```{language}\n") + len("```")
|
||||
effective_max = max_length - marker_length - 20 # Extra buffer
|
||||
|
||||
lines = code.split('\n')
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for line in lines:
|
||||
line_length = len(line) + 1 # +1 for newline
|
||||
|
||||
if current_length + line_length > effective_max and current_chunk:
|
||||
# Finish current chunk
|
||||
chunk_code = '\n'.join(current_chunk)
|
||||
chunks.append(f"```{language}\n{chunk_code}\n```")
|
||||
current_chunk = [line]
|
||||
current_length = line_length
|
||||
else:
|
||||
current_chunk.append(line)
|
||||
current_length += line_length
|
||||
|
||||
# Add remaining chunk
|
||||
if current_chunk:
|
||||
chunk_code = '\n'.join(current_chunk)
|
||||
chunks.append(f"```{language}\n{chunk_code}\n```")
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def send_long_message(
|
||||
channel: discord.abc.Messageable,
|
||||
content: str,
|
||||
max_length: int = MAX_MESSAGE_LENGTH,
|
||||
delay: float = 0.5
|
||||
) -> List[discord.Message]:
|
||||
"""
|
||||
Send a long message split across multiple Discord messages.
|
||||
|
||||
Args:
|
||||
channel: The channel to send to
|
||||
content: The message content
|
||||
max_length: Maximum length per message
|
||||
delay: Delay between messages to avoid rate limiting
|
||||
|
||||
Returns:
|
||||
List of sent messages
|
||||
"""
|
||||
chunks = split_message(content, max_length)
|
||||
messages = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
msg = await channel.send(chunk)
|
||||
messages.append(msg)
|
||||
|
||||
# Add delay between messages (except for the last one)
|
||||
if i < len(chunks) - 1:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
except discord.HTTPException as e:
|
||||
logging.error(f"Failed to send message chunk {i+1}: {e}")
|
||||
# Try sending as file if message still too long
|
||||
if "too long" in str(e).lower():
|
||||
file = discord.File(
|
||||
io.StringIO(chunk),
|
||||
filename=f"message_part_{i+1}.txt"
|
||||
)
|
||||
msg = await channel.send(file=file)
|
||||
messages.append(msg)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def send_code_response(
|
||||
channel: discord.abc.Messageable,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
title: Optional[str] = None
|
||||
) -> List[discord.Message]:
|
||||
"""
|
||||
Send code with proper formatting, handling long code.
|
||||
|
||||
Args:
|
||||
channel: The channel to send to
|
||||
code: The code content
|
||||
language: Programming language for highlighting
|
||||
title: Optional title to display before code
|
||||
|
||||
Returns:
|
||||
List of sent messages
|
||||
"""
|
||||
messages = []
|
||||
|
||||
if title:
|
||||
msg = await channel.send(title)
|
||||
messages.append(msg)
|
||||
|
||||
# If code is too long for code blocks, send as file
|
||||
if len(code) > MAX_MESSAGE_LENGTH - 100:
|
||||
file = discord.File(
|
||||
io.StringIO(code),
|
||||
filename=f"code.{language}" if language else "code.txt"
|
||||
)
|
||||
msg = await channel.send("📎 Code attached as file:", file=file)
|
||||
messages.append(msg)
|
||||
else:
|
||||
chunks = split_code_block(code, language)
|
||||
for chunk in chunks:
|
||||
msg = await channel.send(chunk)
|
||||
messages.append(msg)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def create_error_embed(
|
||||
title: str,
|
||||
description: str,
|
||||
error_type: str = "Error"
|
||||
) -> discord.Embed:
|
||||
"""
|
||||
Create a standardized error embed.
|
||||
|
||||
Args:
|
||||
title: Error title
|
||||
description: Error description
|
||||
error_type: Type of error for categorization
|
||||
|
||||
Returns:
|
||||
Discord Embed object
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
title=f"❌ {title}",
|
||||
description=description[:MAX_EMBED_DESCRIPTION],
|
||||
color=discord.Color.red()
|
||||
)
|
||||
embed.set_footer(text=f"Error Type: {error_type}")
|
||||
return embed
|
||||
|
||||
|
||||
def create_success_embed(
|
||||
title: str,
|
||||
description: str = ""
|
||||
) -> discord.Embed:
|
||||
"""
|
||||
Create a standardized success embed.
|
||||
|
||||
Args:
|
||||
title: Success title
|
||||
description: Success description
|
||||
|
||||
Returns:
|
||||
Discord Embed object
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
title=f"✅ {title}",
|
||||
description=description[:MAX_EMBED_DESCRIPTION] if description else None,
|
||||
color=discord.Color.green()
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
def create_info_embed(
|
||||
title: str,
|
||||
description: str = "",
|
||||
fields: List[tuple] = None
|
||||
) -> discord.Embed:
|
||||
"""
|
||||
Create a standardized info embed with optional fields.
|
||||
|
||||
Args:
|
||||
title: Info title
|
||||
description: Info description
|
||||
fields: List of (name, value, inline) tuples
|
||||
|
||||
Returns:
|
||||
Discord Embed object
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
title=f"ℹ️ {title}",
|
||||
description=description[:MAX_EMBED_DESCRIPTION] if description else None,
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
|
||||
if fields:
|
||||
for name, value, inline in fields[:MAX_EMBED_FIELDS]:
|
||||
embed.add_field(
|
||||
name=name[:256],
|
||||
value=str(value)[:MAX_EMBED_FIELD_VALUE],
|
||||
inline=inline
|
||||
)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def create_progress_embed(
|
||||
title: str,
|
||||
description: str,
|
||||
progress: float = 0.0
|
||||
) -> discord.Embed:
|
||||
"""
|
||||
Create a progress indicator embed.
|
||||
|
||||
Args:
|
||||
title: Progress title
|
||||
description: Progress description
|
||||
progress: Progress value 0.0 to 1.0
|
||||
|
||||
Returns:
|
||||
Discord Embed object
|
||||
"""
|
||||
# Create progress bar
|
||||
bar_length = 20
|
||||
filled = int(bar_length * progress)
|
||||
bar = "█" * filled + "░" * (bar_length - filled)
|
||||
percentage = int(progress * 100)
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"⏳ {title}",
|
||||
description=f"{description}\n\n`{bar}` {percentage}%",
|
||||
color=discord.Color.orange()
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
async def edit_or_send(
|
||||
message: Optional[discord.Message],
|
||||
channel: discord.abc.Messageable,
|
||||
content: str = None,
|
||||
embed: discord.Embed = None
|
||||
) -> discord.Message:
|
||||
"""
|
||||
Edit an existing message or send a new one if editing fails.
|
||||
|
||||
Args:
|
||||
message: Message to edit (or None to send new)
|
||||
channel: Channel to send to if message is None
|
||||
content: Message content
|
||||
embed: Message embed
|
||||
|
||||
Returns:
|
||||
The edited or new message
|
||||
"""
|
||||
try:
|
||||
if message:
|
||||
await message.edit(content=content, embed=embed)
|
||||
return message
|
||||
else:
|
||||
return await channel.send(content=content, embed=embed)
|
||||
except discord.HTTPException:
|
||||
return await channel.send(content=content, embed=embed)
|
||||
|
||||
|
||||
class ProgressMessage:
|
||||
"""
|
||||
A message that can be updated to show progress.
|
||||
|
||||
Usage:
|
||||
async with ProgressMessage(channel, "Processing") as progress:
|
||||
for i in range(100):
|
||||
await progress.update(i / 100, f"Step {i}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: discord.abc.Messageable,
|
||||
title: str,
|
||||
description: str = "Starting..."
|
||||
):
|
||||
self.channel = channel
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.message: Optional[discord.Message] = None
|
||||
self._last_update = 0.0
|
||||
self._update_interval = 2.0 # Minimum seconds between updates
|
||||
|
||||
async def __aenter__(self):
|
||||
embed = create_progress_embed(self.title, self.description, 0.0)
|
||||
self.message = await self.channel.send(embed=embed)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
# Clean up or finalize
|
||||
pass
|
||||
|
||||
async def update(self, progress: float, description: str = None):
|
||||
"""Update the progress message."""
|
||||
import time
|
||||
|
||||
now = time.monotonic()
|
||||
if now - self._last_update < self._update_interval:
|
||||
return
|
||||
|
||||
self._last_update = now
|
||||
|
||||
if description:
|
||||
self.description = description
|
||||
|
||||
try:
|
||||
embed = create_progress_embed(self.title, self.description, progress)
|
||||
await self.message.edit(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass # Ignore edit failures
|
||||
|
||||
async def complete(self, message: str = "Complete!"):
|
||||
"""Mark the progress as complete."""
|
||||
try:
|
||||
embed = create_success_embed(self.title, message)
|
||||
await self.message.edit(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
async def error(self, message: str):
|
||||
"""Mark the progress as failed."""
|
||||
try:
|
||||
embed = create_error_embed(self.title, message)
|
||||
await self.message.edit(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
446
src/utils/monitoring.py
Normal file
446
src/utils/monitoring.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Monitoring and observability utilities.
|
||||
|
||||
This module provides structured logging, error tracking with Sentry,
|
||||
and performance monitoring for the Discord bot.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
from functools import wraps
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Try to import Sentry
|
||||
try:
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
SENTRY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SENTRY_AVAILABLE = False
|
||||
sentry_sdk = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class MonitoringConfig:
|
||||
"""Configuration for monitoring features."""
|
||||
sentry_dsn: Optional[str] = None
|
||||
environment: str = "development"
|
||||
sample_rate: float = 1.0 # 100% of events
|
||||
traces_sample_rate: float = 0.1 # 10% of transactions
|
||||
log_level: str = "INFO"
|
||||
structured_logging: bool = True
|
||||
|
||||
|
||||
def setup_monitoring(config: Optional[MonitoringConfig] = None) -> None:
|
||||
"""
|
||||
Initialize monitoring with optional Sentry integration.
|
||||
|
||||
Args:
|
||||
config: Monitoring configuration, uses env vars if not provided
|
||||
"""
|
||||
if config is None:
|
||||
config = MonitoringConfig(
|
||||
sentry_dsn=os.environ.get("SENTRY_DSN"),
|
||||
environment=os.environ.get("ENVIRONMENT", "development"),
|
||||
sample_rate=float(os.environ.get("SENTRY_SAMPLE_RATE", "1.0")),
|
||||
traces_sample_rate=float(os.environ.get("SENTRY_TRACES_RATE", "0.1")),
|
||||
log_level=os.environ.get("LOG_LEVEL", "INFO"),
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
setup_structured_logging(
|
||||
level=config.log_level,
|
||||
structured=config.structured_logging
|
||||
)
|
||||
|
||||
# Setup Sentry if available and configured
|
||||
if SENTRY_AVAILABLE and config.sentry_dsn:
|
||||
sentry_sdk.init(
|
||||
dsn=config.sentry_dsn,
|
||||
environment=config.environment,
|
||||
sample_rate=config.sample_rate,
|
||||
traces_sample_rate=config.traces_sample_rate,
|
||||
integrations=[AsyncioIntegration()],
|
||||
before_send=before_send_filter,
|
||||
)
|
||||
logger.info(f"Sentry initialized for environment: {config.environment}")
|
||||
else:
|
||||
if config.sentry_dsn and not SENTRY_AVAILABLE:
|
||||
logger.warning("Sentry DSN provided but sentry_sdk not installed")
|
||||
logger.info("Running without Sentry error tracking")
|
||||
|
||||
|
||||
def before_send_filter(event: Dict, hint: Dict) -> Optional[Dict]:
|
||||
"""Filter events before sending to Sentry."""
|
||||
# Don't send events for expected/handled errors
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, _ = hint["exc_info"]
|
||||
|
||||
# Skip common non-critical errors
|
||||
if exc_type.__name__ in [
|
||||
"NotFound", # Discord 404
|
||||
"Forbidden", # Discord 403
|
||||
"RateLimited", # Discord rate limit
|
||||
]:
|
||||
return None
|
||||
|
||||
return event
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Structured Logging
|
||||
# ============================================================
|
||||
|
||||
class StructuredFormatter(logging.Formatter):
|
||||
"""JSON-like structured log formatter."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record as structured message."""
|
||||
log_entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Add extra fields
|
||||
if hasattr(record, "user_id"):
|
||||
log_entry["user_id"] = record.user_id
|
||||
if hasattr(record, "guild_id"):
|
||||
log_entry["guild_id"] = record.guild_id
|
||||
if hasattr(record, "command"):
|
||||
log_entry["command"] = record.command
|
||||
if hasattr(record, "duration_ms"):
|
||||
log_entry["duration_ms"] = record.duration_ms
|
||||
if hasattr(record, "model"):
|
||||
log_entry["model"] = record.model
|
||||
|
||||
# Add exception info if present
|
||||
if record.exc_info:
|
||||
log_entry["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
# Format as key=value pairs for easy parsing
|
||||
parts = [f"{k}={v!r}" for k, v in log_entry.items()]
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def setup_structured_logging(
|
||||
level: str = "INFO",
|
||||
structured: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging configuration.
|
||||
|
||||
Args:
|
||||
level: Log level (DEBUG, INFO, WARNING, ERROR)
|
||||
structured: Use structured formatting
|
||||
"""
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
|
||||
# Create handler
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(log_level)
|
||||
|
||||
if structured:
|
||||
handler.setFormatter(StructuredFormatter())
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
))
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
root_logger.handlers = [handler]
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the given name."""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Error Tracking
|
||||
# ============================================================
|
||||
|
||||
def capture_exception(
|
||||
exception: Exception,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Capture and report an exception.
|
||||
|
||||
Args:
|
||||
exception: The exception to capture
|
||||
context: Additional context to attach
|
||||
|
||||
Returns:
|
||||
Event ID if sent to Sentry, None otherwise
|
||||
"""
|
||||
logger.exception(f"Captured exception: {exception}")
|
||||
|
||||
if SENTRY_AVAILABLE and sentry_sdk.is_initialized():
|
||||
with sentry_sdk.push_scope() as scope:
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
scope.set_extra(key, value)
|
||||
return sentry_sdk.capture_exception(exception)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def capture_message(
|
||||
message: str,
|
||||
level: str = "info",
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Capture and report a message.
|
||||
|
||||
Args:
|
||||
message: The message to capture
|
||||
level: Severity level (debug, info, warning, error, fatal)
|
||||
context: Additional context to attach
|
||||
|
||||
Returns:
|
||||
Event ID if sent to Sentry, None otherwise
|
||||
"""
|
||||
log_method = getattr(logger, level, logger.info)
|
||||
log_method(message)
|
||||
|
||||
if SENTRY_AVAILABLE and sentry_sdk.is_initialized():
|
||||
with sentry_sdk.push_scope() as scope:
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
scope.set_extra(key, value)
|
||||
return sentry_sdk.capture_message(message, level=level)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def set_user_context(
|
||||
user_id: int,
|
||||
username: Optional[str] = None,
|
||||
guild_id: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set user context for error tracking.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
username: Discord username
|
||||
guild_id: Discord guild ID
|
||||
"""
|
||||
if SENTRY_AVAILABLE and sentry_sdk.is_initialized():
|
||||
sentry_sdk.set_user({
|
||||
"id": str(user_id),
|
||||
"username": username,
|
||||
})
|
||||
if guild_id:
|
||||
sentry_sdk.set_tag("guild_id", str(guild_id))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Performance Monitoring
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class PerformanceMetrics:
|
||||
"""Container for performance metrics."""
|
||||
name: str
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
end_time: Optional[float] = None
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def duration_ms(self) -> float:
|
||||
"""Get duration in milliseconds."""
|
||||
end = self.end_time or time.perf_counter()
|
||||
return (end - self.start_time) * 1000
|
||||
|
||||
def finish(self, success: bool = True, error: Optional[str] = None) -> None:
|
||||
"""Mark the operation as finished."""
|
||||
self.end_time = time.perf_counter()
|
||||
self.success = success
|
||||
self.error = error
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for logging."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"duration_ms": round(self.duration_ms, 2),
|
||||
"success": self.success,
|
||||
"error": self.error,
|
||||
**self.metadata
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def measure_sync(name: str, **metadata):
|
||||
"""
|
||||
Context manager to measure synchronous operation performance.
|
||||
|
||||
Usage:
|
||||
with measure_sync("database_query", table="users"):
|
||||
result = db.query(...)
|
||||
"""
|
||||
metrics = PerformanceMetrics(name=name, metadata=metadata)
|
||||
|
||||
try:
|
||||
yield metrics
|
||||
metrics.finish(success=True)
|
||||
except Exception as e:
|
||||
metrics.finish(success=False, error=str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info(
|
||||
f"Performance: {metrics.name}",
|
||||
extra={"duration_ms": metrics.duration_ms, **metrics.metadata}
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def measure_async(name: str, **metadata):
|
||||
"""
|
||||
Async context manager to measure async operation performance.
|
||||
|
||||
Usage:
|
||||
async with measure_async("api_call", endpoint="chat"):
|
||||
result = await api.call(...)
|
||||
"""
|
||||
metrics = PerformanceMetrics(name=name, metadata=metadata)
|
||||
|
||||
# Start Sentry transaction if available
|
||||
transaction = None
|
||||
if SENTRY_AVAILABLE and sentry_sdk.is_initialized():
|
||||
transaction = sentry_sdk.start_transaction(
|
||||
op="task",
|
||||
name=name
|
||||
)
|
||||
|
||||
try:
|
||||
yield metrics
|
||||
metrics.finish(success=True)
|
||||
except Exception as e:
|
||||
metrics.finish(success=False, error=str(e))
|
||||
raise
|
||||
finally:
|
||||
if transaction:
|
||||
transaction.set_status("ok" if metrics.success else "internal_error")
|
||||
transaction.finish()
|
||||
|
||||
logger.info(
|
||||
f"Performance: {metrics.name}",
|
||||
extra={"duration_ms": metrics.duration_ms, **metrics.metadata}
|
||||
)
|
||||
|
||||
|
||||
def track_performance(name: Optional[str] = None):
|
||||
"""
|
||||
Decorator to track async function performance.
|
||||
|
||||
Args:
|
||||
name: Operation name (defaults to function name)
|
||||
|
||||
Usage:
|
||||
@track_performance("process_message")
|
||||
async def handle_message(message):
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
op_name = name or func.__name__
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with measure_async(op_name):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Health Check
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class HealthStatus:
|
||||
"""Health check status."""
|
||||
healthy: bool
|
||||
checks: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
def add_check(
|
||||
self,
|
||||
name: str,
|
||||
healthy: bool,
|
||||
message: str = "",
|
||||
details: Optional[Dict] = None
|
||||
) -> None:
|
||||
"""Add a health check result."""
|
||||
self.checks[name] = {
|
||||
"healthy": healthy,
|
||||
"message": message,
|
||||
**(details or {})
|
||||
}
|
||||
if not healthy:
|
||||
self.healthy = False
|
||||
|
||||
|
||||
async def check_health(
|
||||
db_handler=None,
|
||||
openai_client=None
|
||||
) -> HealthStatus:
|
||||
"""
|
||||
Perform health checks on bot dependencies.
|
||||
|
||||
Args:
|
||||
db_handler: Database handler to check
|
||||
openai_client: OpenAI client to check
|
||||
|
||||
Returns:
|
||||
HealthStatus with check results
|
||||
"""
|
||||
status = HealthStatus(healthy=True)
|
||||
|
||||
# Check database
|
||||
if db_handler:
|
||||
try:
|
||||
# Simple ping or list operation
|
||||
await asyncio.wait_for(
|
||||
db_handler.client.admin.command('ping'),
|
||||
timeout=5.0
|
||||
)
|
||||
status.add_check("database", True, "MongoDB connected")
|
||||
except Exception as e:
|
||||
status.add_check("database", False, f"MongoDB error: {e}")
|
||||
|
||||
# Check OpenAI
|
||||
if openai_client:
|
||||
try:
|
||||
# List models as a simple check
|
||||
await asyncio.wait_for(
|
||||
openai_client.models.list(),
|
||||
timeout=10.0
|
||||
)
|
||||
status.add_check("openai", True, "OpenAI API accessible")
|
||||
except Exception as e:
|
||||
status.add_check("openai", False, f"OpenAI error: {e}")
|
||||
|
||||
return status
|
||||
@@ -28,12 +28,11 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "edit_image",
|
||||
"description": "Edit images (remove background). Returns URLs.",
|
||||
"description": "Remove background from an image. Requires image_url from user's uploaded image or a web URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string"},
|
||||
"operation": {"type": "string", "enum": ["remove_background"]}
|
||||
"image_url": {"type": "string", "description": "URL of the image to edit"}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
@@ -43,12 +42,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "enhance_prompt",
|
||||
"description": "Create enhanced prompt versions.",
|
||||
"description": "Improve and expand a prompt for better image generation results",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"num_versions": {"type": "integer", "minimum": 1, "maximum": 5}
|
||||
"prompt": {"type": "string", "description": "The prompt to enhance"},
|
||||
"num_versions": {"type": "integer", "maximum": 5, "description": "Number of enhanced versions"}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
@@ -58,10 +57,10 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_to_text",
|
||||
"description": "Convert image to text.",
|
||||
"description": "Generate a text description/caption of an image or extract text via OCR. When user uploads an image, pass 'latest_image' as image_url - the system will use the most recent uploaded image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"image_url": {"type": "string"}},
|
||||
"properties": {"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"}},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
}
|
||||
@@ -70,12 +69,13 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "upscale_image",
|
||||
"description": "Upscale image resolution. Returns URLs.",
|
||||
"description": "Enlarge/upscale an image to higher resolution. When user uploads an image and wants to upscale it, pass 'latest_image' as the image_url - the system will use the most recent uploaded image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string"},
|
||||
"scale_factor": {"type": "integer", "enum": [2, 3, 4]}
|
||||
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"},
|
||||
"scale_factor": {"type": "integer", "enum": [2, 4], "description": "Scale factor (2 or 4)"},
|
||||
"model": {"type": "string", "enum": ["clarity", "ccsr", "sd-latent", "swinir"], "description": "Upscale model to use"}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
@@ -85,14 +85,15 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "photo_maker",
|
||||
"description": "Generate images from reference photos. Returns URLs.",
|
||||
"description": "Generate new images based on reference photos. When user uploads an image and wants to use it as reference, pass ['latest_image'] as input_images - the system will use the most recent uploaded image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"input_images": {"type": "array", "items": {"type": "string"}},
|
||||
"strength": {"type": "integer", "minimum": 1, "maximum": 100},
|
||||
"num_images": {"type": "integer", "minimum": 1, "maximum": 4}
|
||||
"prompt": {"type": "string", "description": "Description of the desired output image"},
|
||||
"input_images": {"type": "array", "items": {"type": "string"}, "description": "Pass ['latest_image'] to use the user's most recently uploaded image"},
|
||||
"style": {"type": "string", "description": "Style to apply (e.g., 'Photographic', 'Cinematic', 'Anime')"},
|
||||
"strength": {"type": "integer", "minimum": 0, "maximum": 100, "description": "Reference image influence (0-100)"},
|
||||
"num_images": {"type": "integer", "maximum": 4, "description": "Number of images to generate"}
|
||||
},
|
||||
"required": ["prompt", "input_images"]
|
||||
}
|
||||
@@ -102,28 +103,44 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image_with_refiner",
|
||||
"description": "Generate high-quality images. Returns URLs.",
|
||||
"description": "Generate high-quality refined images with extra detail using SDXL refiner. Best for detailed artwork.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"num_images": {"type": "integer", "minimum": 1, "maximum": 4},
|
||||
"negative_prompt": {"type": "string"}
|
||||
"prompt": {"type": "string", "description": "Detailed description of the image to generate"},
|
||||
"model": {"type": "string", "enum": ["sdxl", "flux", "realistic"], "description": "Base model to use"},
|
||||
"num_images": {"type": "integer", "maximum": 4, "description": "Number of images to generate"},
|
||||
"negative_prompt": {"type": "string", "description": "Things to avoid in the image"}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "remove_background",
|
||||
"description": "Remove background from an image. When user uploads an image and wants to remove its background, pass 'latest_image' as the image_url - the system will use the most recent uploaded image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {"type": "string", "description": "Pass 'latest_image' to use the user's most recently uploaded image"},
|
||||
"model": {"type": "string", "enum": ["bria", "rembg", "birefnet-base", "birefnet-general", "birefnet-portrait"], "description": "Background removal model"}
|
||||
},
|
||||
"required": ["image_url"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "google_search",
|
||||
"description": "Search web for current information.",
|
||||
"description": "Search the web for current information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"num_results": {"type": "integer", "minimum": 1, "maximum": 10}
|
||||
"num_results": {"type": "integer", "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
@@ -133,10 +150,10 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scrape_webpage",
|
||||
"description": "Extract content from webpage.",
|
||||
"description": "Extract and read content from a webpage URL",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"properties": {"url": {"type": "string", "description": "The webpage URL to scrape"}},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
@@ -145,12 +162,20 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": "Generate images from text. Returns URLs.",
|
||||
"description": "Create/generate images from text. Models: flux (best), flux-dev, sdxl, realistic (photos), anime, dreamshaper. Supports aspect ratios.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"num_images": {"type": "integer", "minimum": 1, "maximum": 4}
|
||||
"prompt": {"type": "string", "description": "Detailed description of the image to create"},
|
||||
"model": {"type": "string", "enum": ["flux", "flux-dev", "sdxl", "realistic", "anime", "dreamshaper"], "description": "Model to use for generation"},
|
||||
"num_images": {"type": "integer", "maximum": 4, "description": "Number of images (1-4)"},
|
||||
"aspect_ratio": {"type": "string", "enum": ["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "21:9"], "description": "Aspect ratio preset"},
|
||||
"width": {"type": "integer", "description": "Custom width (512-2048, divisible by 64)"},
|
||||
"height": {"type": "integer", "description": "Custom height (512-2048, divisible by 64)"},
|
||||
"negative_prompt": {"type": "string", "description": "Things to avoid in the image"},
|
||||
"steps": {"type": "integer", "minimum": 10, "maximum": 50, "description": "Inference steps (more = higher quality)"},
|
||||
"cfg_scale": {"type": "number", "minimum": 1, "maximum": 20, "description": "Guidance scale (higher = more prompt adherence)"},
|
||||
"seed": {"type": "integer", "description": "Random seed for reproducibility"}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
@@ -160,33 +185,12 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute_python_code",
|
||||
"description": """Execute Python with AUTO-INSTALL. Packages (pandas, numpy, matplotlib, seaborn, sklearn, plotly, opencv, etc.) install automatically when imported. Just use 'import' normally. Generated files (CSV, images, JSON) auto-captured and sent to user (stored 48h). Load user files: load_file('file_id'). Example: import pandas as pd; df=load_file('id'); df.to_csv('out.csv')""",
|
||||
"description": "Run Python code. Packages auto-install. Use load_file('file_id') for user files. Output files auto-sent to user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute. Import any approved package - they auto-install!"
|
||||
},
|
||||
"input_data": {
|
||||
"type": "string",
|
||||
"description": "Optional input data (DEPRECATED - use load_file() in code instead)"
|
||||
},
|
||||
"install_packages": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "OPTIONAL: Pre-install packages. Usually not needed as packages auto-install on import."
|
||||
},
|
||||
"enable_visualization": {
|
||||
"type": "boolean",
|
||||
"description": "DEPRECATED: Just use plt.savefig() to create images"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"description": "Execution timeout in seconds (default: 60)"
|
||||
}
|
||||
"code": {"type": "string", "description": "Python code to execute"},
|
||||
"timeout": {"type": "integer", "maximum": 300, "description": "Timeout in seconds"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
@@ -196,7 +200,7 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "set_reminder",
|
||||
"description": "Set user reminder with flexible time formats.",
|
||||
"description": "Set reminder",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -211,7 +215,7 @@ def get_tools_for_model() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_reminders",
|
||||
"description": "Get user reminders list.",
|
||||
"description": "List reminders",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
|
||||
280
src/utils/retry.py
Normal file
280
src/utils/retry.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Retry utilities with exponential backoff for API calls.
|
||||
|
||||
This module provides robust retry logic for external API calls
|
||||
to handle transient failures gracefully.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from typing import TypeVar, Callable, Optional, Any, Type, Tuple
|
||||
from functools import wraps
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_BASE_DELAY = 1.0 # seconds
|
||||
DEFAULT_MAX_DELAY = 60.0 # seconds
|
||||
DEFAULT_EXPONENTIAL_BASE = 2
|
||||
|
||||
|
||||
class RetryError(Exception):
|
||||
"""Raised when all retry attempts have been exhausted."""
|
||||
|
||||
def __init__(self, message: str, last_exception: Optional[Exception] = None):
|
||||
super().__init__(message)
|
||||
self.last_exception = last_exception
|
||||
|
||||
|
||||
async def async_retry_with_backoff(
|
||||
func: Callable,
|
||||
*args,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
base_delay: float = DEFAULT_BASE_DELAY,
|
||||
max_delay: float = DEFAULT_MAX_DELAY,
|
||||
exponential_base: float = DEFAULT_EXPONENTIAL_BASE,
|
||||
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||
jitter: bool = True,
|
||||
on_retry: Optional[Callable[[int, Exception], None]] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Execute an async function with exponential backoff retry.
|
||||
|
||||
Args:
|
||||
func: The async function to execute
|
||||
*args: Positional arguments for the function
|
||||
max_retries: Maximum number of retry attempts
|
||||
base_delay: Initial delay between retries in seconds
|
||||
max_delay: Maximum delay between retries
|
||||
exponential_base: Base for exponential backoff calculation
|
||||
retryable_exceptions: Tuple of exception types that should trigger retry
|
||||
jitter: Whether to add randomness to delay
|
||||
on_retry: Optional callback called on each retry with (attempt, exception)
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The return value of the function
|
||||
|
||||
Raises:
|
||||
RetryError: When all retries are exhausted
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except retryable_exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == max_retries:
|
||||
logging.error(f"All {max_retries} retries exhausted for {func.__name__}: {e}")
|
||||
raise RetryError(
|
||||
f"Failed after {max_retries} retries: {str(e)}",
|
||||
last_exception=e
|
||||
)
|
||||
|
||||
# Calculate delay with exponential backoff
|
||||
delay = min(base_delay * (exponential_base ** attempt), max_delay)
|
||||
|
||||
# Add jitter to prevent thundering herd
|
||||
if jitter:
|
||||
delay = delay * (0.5 + random.random())
|
||||
|
||||
logging.warning(
|
||||
f"Retry {attempt + 1}/{max_retries} for {func.__name__} "
|
||||
f"after {delay:.2f}s delay. Error: {e}"
|
||||
)
|
||||
|
||||
if on_retry:
|
||||
try:
|
||||
on_retry(attempt + 1, e)
|
||||
except Exception as callback_error:
|
||||
logging.warning(f"on_retry callback failed: {callback_error}")
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Should not reach here, but just in case
|
||||
raise RetryError("Unexpected retry loop exit", last_exception=last_exception)
|
||||
|
||||
|
||||
def retry_decorator(
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
base_delay: float = DEFAULT_BASE_DELAY,
|
||||
max_delay: float = DEFAULT_MAX_DELAY,
|
||||
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||
jitter: bool = True
|
||||
):
|
||||
"""
|
||||
Decorator for adding retry logic to async functions.
|
||||
|
||||
Usage:
|
||||
@retry_decorator(max_retries=3, base_delay=1.0)
|
||||
async def my_api_call():
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await async_retry_with_backoff(
|
||||
func,
|
||||
*args,
|
||||
max_retries=max_retries,
|
||||
base_delay=base_delay,
|
||||
max_delay=max_delay,
|
||||
retryable_exceptions=retryable_exceptions,
|
||||
jitter=jitter,
|
||||
**kwargs
|
||||
)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Common exception sets for different APIs
|
||||
OPENAI_RETRYABLE_EXCEPTIONS = (
|
||||
# Add specific OpenAI exceptions as needed
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
)
|
||||
|
||||
DISCORD_RETRYABLE_EXCEPTIONS = (
|
||||
# Add specific Discord exceptions as needed
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
)
|
||||
|
||||
HTTP_RETRYABLE_EXCEPTIONS = (
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Simple rate limiter for API calls.
|
||||
|
||||
Usage:
|
||||
limiter = RateLimiter(calls_per_second=1)
|
||||
async with limiter:
|
||||
await make_api_call()
|
||||
"""
|
||||
|
||||
def __init__(self, calls_per_second: float = 1.0):
|
||||
self.min_interval = 1.0 / calls_per_second
|
||||
self.last_call = 0.0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def __aenter__(self):
|
||||
async with self._lock:
|
||||
import time
|
||||
now = time.monotonic()
|
||||
time_since_last = now - self.last_call
|
||||
|
||||
if time_since_last < self.min_interval:
|
||||
await asyncio.sleep(self.min_interval - time_since_last)
|
||||
|
||||
self.last_call = time.monotonic()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker pattern for preventing cascade failures.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: Too many failures, requests are rejected immediately
|
||||
- HALF_OPEN: Testing if service recovered
|
||||
"""
|
||||
|
||||
CLOSED = "closed"
|
||||
OPEN = "open"
|
||||
HALF_OPEN = "half_open"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
half_open_requests: int = 3
|
||||
):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.half_open_requests = half_open_requests
|
||||
|
||||
self.state = self.CLOSED
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = 0.0
|
||||
self.half_open_successes = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute a function through the circuit breaker.
|
||||
|
||||
Args:
|
||||
func: The async function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
The function result
|
||||
|
||||
Raises:
|
||||
Exception: If circuit is open or function fails
|
||||
"""
|
||||
async with self._lock:
|
||||
await self._check_state()
|
||||
|
||||
if self.state == self.OPEN:
|
||||
raise Exception("Circuit breaker is OPEN - service unavailable")
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await self._on_success()
|
||||
return result
|
||||
except Exception as e:
|
||||
await self._on_failure()
|
||||
raise
|
||||
|
||||
async def _check_state(self):
|
||||
"""Check and potentially update circuit state."""
|
||||
import time
|
||||
|
||||
if self.state == self.OPEN:
|
||||
if time.monotonic() - self.last_failure_time >= self.recovery_timeout:
|
||||
logging.info("Circuit breaker transitioning to HALF_OPEN")
|
||||
self.state = self.HALF_OPEN
|
||||
self.half_open_successes = 0
|
||||
|
||||
async def _on_success(self):
|
||||
"""Handle successful call."""
|
||||
async with self._lock:
|
||||
if self.state == self.HALF_OPEN:
|
||||
self.half_open_successes += 1
|
||||
if self.half_open_successes >= self.half_open_requests:
|
||||
logging.info("Circuit breaker transitioning to CLOSED")
|
||||
self.state = self.CLOSED
|
||||
self.failure_count = 0
|
||||
elif self.state == self.CLOSED:
|
||||
self.failure_count = 0
|
||||
|
||||
async def _on_failure(self):
|
||||
"""Handle failed call."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.monotonic()
|
||||
|
||||
if self.state == self.HALF_OPEN:
|
||||
logging.warning("Circuit breaker transitioning to OPEN (half-open failure)")
|
||||
self.state = self.OPEN
|
||||
elif self.failure_count >= self.failure_threshold:
|
||||
logging.warning(f"Circuit breaker transitioning to OPEN ({self.failure_count} failures)")
|
||||
self.state = self.OPEN
|
||||
@@ -304,8 +304,8 @@ class TokenCounter:
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
# Import here to avoid circular dependency
|
||||
from src.commands.commands import MODEL_PRICING
|
||||
# Import from centralized pricing module
|
||||
from src.config.pricing import MODEL_PRICING
|
||||
|
||||
if model not in MODEL_PRICING:
|
||||
model = "openai/gpt-4o" # Default fallback
|
||||
|
||||
287
src/utils/validators.py
Normal file
287
src/utils/validators.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Input validation utilities for the Discord bot.
|
||||
|
||||
This module provides centralized validation for user inputs,
|
||||
enhancing security and reducing code duplication.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# Maximum allowed lengths for various inputs
|
||||
MAX_MESSAGE_LENGTH = 4000 # Discord's limit is 2000, but we process longer
|
||||
MAX_PROMPT_LENGTH = 32000 # Reasonable limit for AI prompts
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
MAX_FILENAME_LENGTH = 255
|
||||
MAX_URL_LENGTH = 2048
|
||||
MAX_CODE_LENGTH = 100000 # 100KB of code
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of a validation check."""
|
||||
is_valid: bool
|
||||
error_message: Optional[str] = None
|
||||
sanitized_value: Optional[str] = None
|
||||
|
||||
|
||||
def validate_message_content(content: str) -> ValidationResult:
|
||||
"""
|
||||
Validate and sanitize message content.
|
||||
|
||||
Args:
|
||||
content: The message content to validate
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status and sanitized content
|
||||
"""
|
||||
if not content:
|
||||
return ValidationResult(is_valid=True, sanitized_value="")
|
||||
|
||||
if len(content) > MAX_MESSAGE_LENGTH:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Message too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed."
|
||||
)
|
||||
|
||||
# Remove null bytes and other control characters (except newlines/tabs)
|
||||
sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', content)
|
||||
|
||||
return ValidationResult(is_valid=True, sanitized_value=sanitized)
|
||||
|
||||
|
||||
def validate_prompt(prompt: str) -> ValidationResult:
|
||||
"""
|
||||
Validate AI prompt content.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to validate
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
if not prompt or not prompt.strip():
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Prompt cannot be empty."
|
||||
)
|
||||
|
||||
if len(prompt) > MAX_PROMPT_LENGTH:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Prompt too long. Maximum {MAX_PROMPT_LENGTH} characters allowed."
|
||||
)
|
||||
|
||||
# Remove null bytes
|
||||
sanitized = prompt.replace('\x00', '')
|
||||
|
||||
return ValidationResult(is_valid=True, sanitized_value=sanitized)
|
||||
|
||||
|
||||
def validate_url(url: str) -> ValidationResult:
|
||||
"""
|
||||
Validate and sanitize a URL.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
if not url:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="URL cannot be empty."
|
||||
)
|
||||
|
||||
if len(url) > MAX_URL_LENGTH:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"URL too long. Maximum {MAX_URL_LENGTH} characters allowed."
|
||||
)
|
||||
|
||||
# Basic URL pattern check
|
||||
url_pattern = re.compile(
|
||||
r'^https?://' # http:// or https://
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain
|
||||
r'localhost|' # localhost
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # or IP
|
||||
r'(?::\d+)?' # optional port
|
||||
r'(?:/?|[/?]\S+)$', re.IGNORECASE
|
||||
)
|
||||
|
||||
if not url_pattern.match(url):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid URL format."
|
||||
)
|
||||
|
||||
# Check for potentially dangerous URL schemes
|
||||
dangerous_schemes = ['javascript:', 'data:', 'file:', 'vbscript:']
|
||||
url_lower = url.lower()
|
||||
for scheme in dangerous_schemes:
|
||||
if scheme in url_lower:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="URL contains potentially dangerous content."
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True, sanitized_value=url)
|
||||
|
||||
|
||||
def validate_filename(filename: str) -> ValidationResult:
|
||||
"""
|
||||
Validate and sanitize a filename.
|
||||
|
||||
Args:
|
||||
filename: The filename to validate
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status and sanitized filename
|
||||
"""
|
||||
if not filename:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Filename cannot be empty."
|
||||
)
|
||||
|
||||
if len(filename) > MAX_FILENAME_LENGTH:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Filename too long. Maximum {MAX_FILENAME_LENGTH} characters allowed."
|
||||
)
|
||||
|
||||
# Remove path traversal attempts
|
||||
sanitized = filename.replace('..', '').replace('/', '').replace('\\', '')
|
||||
|
||||
# Remove dangerous characters
|
||||
sanitized = re.sub(r'[<>:"|?*\x00-\x1f]', '', sanitized)
|
||||
|
||||
# Ensure it's not empty after sanitization
|
||||
if not sanitized:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Filename contains only invalid characters."
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True, sanitized_value=sanitized)
|
||||
|
||||
|
||||
def validate_file_size(size: int) -> ValidationResult:
|
||||
"""
|
||||
Validate file size.
|
||||
|
||||
Args:
|
||||
size: The file size in bytes
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
if size <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="File size must be greater than 0."
|
||||
)
|
||||
|
||||
if size > MAX_FILE_SIZE:
|
||||
max_mb = MAX_FILE_SIZE / (1024 * 1024)
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"File too large. Maximum {max_mb:.0f}MB allowed."
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
|
||||
def validate_code(code: str) -> ValidationResult:
|
||||
"""
|
||||
Validate code for execution.
|
||||
|
||||
Args:
|
||||
code: The code to validate
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
if not code or not code.strip():
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Code cannot be empty."
|
||||
)
|
||||
|
||||
if len(code) > MAX_CODE_LENGTH:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Code too long. Maximum {MAX_CODE_LENGTH} characters allowed."
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True, sanitized_value=code)
|
||||
|
||||
|
||||
def validate_user_id(user_id) -> ValidationResult:
|
||||
"""
|
||||
Validate a Discord user ID.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to validate
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
try:
|
||||
uid = int(user_id)
|
||||
if uid <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid user ID."
|
||||
)
|
||||
# Discord IDs are 17-19 digits
|
||||
if len(str(uid)) < 17 or len(str(uid)) > 19:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid user ID format."
|
||||
)
|
||||
return ValidationResult(is_valid=True)
|
||||
except (ValueError, TypeError):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="User ID must be a valid integer."
|
||||
)
|
||||
|
||||
|
||||
def sanitize_for_logging(text: str, max_length: int = 200) -> str:
|
||||
"""
|
||||
Sanitize text for safe logging (remove sensitive data, truncate).
|
||||
|
||||
Args:
|
||||
text: The text to sanitize
|
||||
max_length: Maximum length for logged text
|
||||
|
||||
Returns:
|
||||
Sanitized text safe for logging
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Remove potential secrets/tokens (common patterns)
|
||||
patterns = [
|
||||
(r'(sk-[a-zA-Z0-9]{20,})', '[OPENAI_KEY]'),
|
||||
(r'(xoxb-[a-zA-Z0-9-]+)', '[SLACK_TOKEN]'),
|
||||
(r'([A-Za-z0-9_-]{24}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27})', '[DISCORD_TOKEN]'),
|
||||
(r'(mongodb\+srv://[^@]+@)', 'mongodb+srv://[REDACTED]@'),
|
||||
(r'(Bearer\s+[A-Za-z0-9_-]+)', 'Bearer [TOKEN]'),
|
||||
(r'(password["\']?\s*[:=]\s*["\']?)[^"\'\s]+', r'\1[REDACTED]'),
|
||||
]
|
||||
|
||||
sanitized = text
|
||||
for pattern, replacement in patterns:
|
||||
sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE)
|
||||
|
||||
# Truncate if needed
|
||||
if len(sanitized) > max_length:
|
||||
sanitized = sanitized[:max_length] + '...[truncated]'
|
||||
|
||||
return sanitized
|
||||
821
tests/test_comprehensive.py
Normal file
821
tests/test_comprehensive.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""
|
||||
Comprehensive test suite for the ChatGPT Discord Bot.
|
||||
|
||||
This module contains unit tests and integration tests for all major components.
|
||||
Uses pytest with pytest-asyncio for async test support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test Fixtures
|
||||
# ============================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_handler():
|
||||
"""Create a mock database handler."""
|
||||
mock = MagicMock()
|
||||
mock.get_history = AsyncMock(return_value=[])
|
||||
mock.save_history = AsyncMock()
|
||||
mock.get_user_model = AsyncMock(return_value="openai/gpt-4o")
|
||||
mock.save_user_model = AsyncMock()
|
||||
mock.is_admin = AsyncMock(return_value=False)
|
||||
mock.is_user_whitelisted = AsyncMock(return_value=True)
|
||||
mock.is_user_blacklisted = AsyncMock(return_value=False)
|
||||
mock.get_user_tool_display = AsyncMock(return_value=False)
|
||||
mock.get_user_files = AsyncMock(return_value=[])
|
||||
mock.save_token_usage = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
"""Create a mock OpenAI client."""
|
||||
mock = MagicMock()
|
||||
|
||||
# Mock response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 100
|
||||
mock_response.usage.completion_tokens = 50
|
||||
|
||||
mock.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_message():
|
||||
"""Create a mock Discord message."""
|
||||
mock = MagicMock()
|
||||
mock.author.id = 123456789
|
||||
mock.author.name = "TestUser"
|
||||
mock.content = "Hello, bot!"
|
||||
mock.channel.send = AsyncMock()
|
||||
mock.channel.typing = MagicMock(return_value=AsyncMock().__aenter__())
|
||||
mock.attachments = []
|
||||
mock.reference = None
|
||||
mock.guild = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pricing Module Tests
|
||||
# ============================================================
|
||||
|
||||
class TestPricingModule:
|
||||
"""Tests for the pricing configuration module."""
|
||||
|
||||
def test_model_pricing_exists(self):
|
||||
"""Test that all expected models have pricing defined."""
|
||||
from src.config.pricing import MODEL_PRICING
|
||||
|
||||
expected_models = [
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-5",
|
||||
"openai/o1",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
"anthropic/claude-opus-4-20250514",
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"anthropic/claude-3.5-haiku",
|
||||
]
|
||||
|
||||
for model in expected_models:
|
||||
assert model in MODEL_PRICING, f"Missing pricing for {model}"
|
||||
|
||||
def test_calculate_cost(self):
|
||||
"""Test cost calculation for known models."""
|
||||
from src.config.pricing import calculate_cost
|
||||
|
||||
# GPT-4o: $5.00 input, $20.00 output per 1M tokens
|
||||
cost = calculate_cost("openai/gpt-4o", 1_000_000, 1_000_000)
|
||||
assert cost == 25.00 # $5 + $20
|
||||
|
||||
# Test smaller amounts
|
||||
cost = calculate_cost("openai/gpt-4o", 1000, 1000)
|
||||
assert cost == pytest.approx(0.025, rel=1e-6) # $0.005 + $0.020
|
||||
|
||||
# Test Claude model
|
||||
cost = calculate_cost("anthropic/claude-3.5-sonnet", 1_000_000, 1_000_000)
|
||||
assert cost == 18.00 # $3 + $15
|
||||
|
||||
def test_calculate_cost_unknown_model(self):
|
||||
"""Test that unknown models return 0 cost."""
|
||||
from src.config.pricing import calculate_cost
|
||||
|
||||
cost = calculate_cost("unknown/model", 1000, 1000)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_format_cost(self):
|
||||
"""Test cost formatting for display."""
|
||||
from src.config.pricing import format_cost
|
||||
|
||||
assert format_cost(0.000001) == "$0.000001"
|
||||
assert format_cost(0.005) == "$0.005000" # 6 decimal places for small amounts
|
||||
assert format_cost(1.50) == "$1.50"
|
||||
assert format_cost(100.00) == "$100.00"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Validator Module Tests
|
||||
# ============================================================
|
||||
|
||||
class TestValidators:
|
||||
"""Tests for input validation utilities."""
|
||||
|
||||
def test_validate_message_content(self):
|
||||
"""Test message content validation."""
|
||||
from src.utils.validators import validate_message_content
|
||||
|
||||
# Valid content
|
||||
result = validate_message_content("Hello, world!")
|
||||
assert result.is_valid
|
||||
assert result.sanitized_value == "Hello, world!"
|
||||
|
||||
# Empty content is valid
|
||||
result = validate_message_content("")
|
||||
assert result.is_valid
|
||||
|
||||
# Content with null bytes should be sanitized
|
||||
result = validate_message_content("Hello\x00World")
|
||||
assert result.is_valid
|
||||
assert "\x00" not in result.sanitized_value
|
||||
|
||||
def test_validate_message_too_long(self):
|
||||
"""Test that overly long messages are rejected."""
|
||||
from src.utils.validators import validate_message_content, MAX_MESSAGE_LENGTH
|
||||
|
||||
long_message = "x" * (MAX_MESSAGE_LENGTH + 1)
|
||||
result = validate_message_content(long_message)
|
||||
assert not result.is_valid
|
||||
assert "too long" in result.error_message.lower()
|
||||
|
||||
def test_validate_url(self):
|
||||
"""Test URL validation."""
|
||||
from src.utils.validators import validate_url
|
||||
|
||||
# Valid URLs
|
||||
assert validate_url("https://example.com").is_valid
|
||||
assert validate_url("http://localhost:8080/path").is_valid
|
||||
assert validate_url("https://api.example.com/v1/data?q=test").is_valid
|
||||
|
||||
# Invalid URLs
|
||||
assert not validate_url("").is_valid
|
||||
assert not validate_url("not-a-url").is_valid
|
||||
assert not validate_url("javascript:alert(1)").is_valid
|
||||
assert not validate_url("file:///etc/passwd").is_valid
|
||||
|
||||
def test_validate_filename(self):
|
||||
"""Test filename validation and sanitization."""
|
||||
from src.utils.validators import validate_filename
|
||||
|
||||
# Valid filename
|
||||
result = validate_filename("test_file.txt")
|
||||
assert result.is_valid
|
||||
assert result.sanitized_value == "test_file.txt"
|
||||
|
||||
# Path traversal attempt
|
||||
result = validate_filename("../../../etc/passwd")
|
||||
assert result.is_valid # Sanitized, not rejected
|
||||
assert ".." not in result.sanitized_value
|
||||
assert "/" not in result.sanitized_value
|
||||
|
||||
# Empty filename
|
||||
result = validate_filename("")
|
||||
assert not result.is_valid
|
||||
|
||||
def test_sanitize_for_logging(self):
|
||||
"""Test that secrets are properly redacted for logging."""
|
||||
from src.utils.validators import sanitize_for_logging
|
||||
|
||||
# Test OpenAI key redaction
|
||||
text = "API key is sk-abcdefghijklmnopqrstuvwxyz123456"
|
||||
sanitized = sanitize_for_logging(text)
|
||||
assert "sk-" not in sanitized
|
||||
assert "[OPENAI_KEY]" in sanitized
|
||||
|
||||
# Test MongoDB URI redaction
|
||||
text = "mongodb+srv://user:password@cluster.mongodb.net/db"
|
||||
sanitized = sanitize_for_logging(text)
|
||||
assert "password" not in sanitized
|
||||
assert "[REDACTED]" in sanitized
|
||||
|
||||
# Test truncation
|
||||
long_text = "x" * 500
|
||||
sanitized = sanitize_for_logging(long_text, max_length=100)
|
||||
assert len(sanitized) < 150 # Account for truncation marker
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Retry Module Tests
|
||||
# ============================================================
|
||||
|
||||
class TestRetryModule:
|
||||
"""Tests for retry utilities."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_success_first_try(self):
|
||||
"""Test that successful functions don't retry."""
|
||||
from src.utils.retry import async_retry_with_backoff
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def success_func():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return "success"
|
||||
|
||||
result = await async_retry_with_backoff(success_func, max_retries=3)
|
||||
assert result == "success"
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_eventual_success(self):
|
||||
"""Test that functions eventually succeed after retries."""
|
||||
from src.utils.retry import async_retry_with_backoff
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def eventual_success():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ConnectionError("Temporary failure")
|
||||
return "success"
|
||||
|
||||
result = await async_retry_with_backoff(
|
||||
eventual_success,
|
||||
max_retries=5,
|
||||
base_delay=0.01, # Fast for testing
|
||||
retryable_exceptions=(ConnectionError,)
|
||||
)
|
||||
assert result == "success"
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_exhausted(self):
|
||||
"""Test that RetryError is raised when retries are exhausted."""
|
||||
from src.utils.retry import async_retry_with_backoff, RetryError
|
||||
|
||||
async def always_fail():
|
||||
raise ConnectionError("Always fails")
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await async_retry_with_backoff(
|
||||
always_fail,
|
||||
max_retries=2,
|
||||
base_delay=0.01,
|
||||
retryable_exceptions=(ConnectionError,)
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Discord Utils Tests
|
||||
# ============================================================
|
||||
|
||||
class TestDiscordUtils:
|
||||
"""Tests for Discord utility functions."""
|
||||
|
||||
def test_split_message_short(self):
|
||||
"""Test that short messages aren't split."""
|
||||
from src.utils.discord_utils import split_message
|
||||
|
||||
short = "This is a short message."
|
||||
chunks = split_message(short)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == short
|
||||
|
||||
def test_split_message_long(self):
|
||||
"""Test that long messages are properly split."""
|
||||
from src.utils.discord_utils import split_message
|
||||
|
||||
# Create a message longer than 2000 characters
|
||||
long = "Hello world. " * 200
|
||||
chunks = split_message(long, max_length=2000)
|
||||
|
||||
assert len(chunks) > 1
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= 2000
|
||||
|
||||
def test_split_code_block(self):
|
||||
"""Test code block splitting."""
|
||||
from src.utils.discord_utils import split_code_block
|
||||
|
||||
code = "\n".join([f"line {i}" for i in range(100)])
|
||||
chunks = split_code_block(code, "python", max_length=500)
|
||||
|
||||
assert len(chunks) > 1
|
||||
for chunk in chunks:
|
||||
assert chunk.startswith("```python\n")
|
||||
assert chunk.endswith("\n```")
|
||||
assert len(chunk) <= 500
|
||||
|
||||
def test_create_error_embed(self):
|
||||
"""Test error embed creation."""
|
||||
from src.utils.discord_utils import create_error_embed
|
||||
import discord
|
||||
|
||||
embed = create_error_embed("Test Error", "Something went wrong", "ValidationError")
|
||||
|
||||
assert isinstance(embed, discord.Embed)
|
||||
assert "Test Error" in embed.title
|
||||
assert embed.color == discord.Color.red()
|
||||
|
||||
def test_create_success_embed(self):
|
||||
"""Test success embed creation."""
|
||||
from src.utils.discord_utils import create_success_embed
|
||||
import discord
|
||||
|
||||
embed = create_success_embed("Success!", "Operation completed")
|
||||
|
||||
assert isinstance(embed, discord.Embed)
|
||||
assert "Success!" in embed.title
|
||||
assert embed.color == discord.Color.green()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Code Interpreter Security Tests
|
||||
# ============================================================
|
||||
|
||||
class TestCodeInterpreterSecurity:
|
||||
"""Tests for code interpreter security features."""
|
||||
|
||||
def test_blocked_imports(self):
|
||||
"""Test that dangerous imports are blocked."""
|
||||
from src.utils.code_interpreter import BLOCKED_PATTERNS
|
||||
import re
|
||||
|
||||
dangerous_code = [
|
||||
"import os",
|
||||
"import subprocess",
|
||||
"from os import system",
|
||||
"import socket",
|
||||
"import requests",
|
||||
"__import__('os')",
|
||||
"eval('print(1)')",
|
||||
"exec('import os')",
|
||||
]
|
||||
|
||||
for code in dangerous_code:
|
||||
blocked = any(
|
||||
re.search(pattern, code, re.IGNORECASE)
|
||||
for pattern in BLOCKED_PATTERNS
|
||||
)
|
||||
assert blocked, f"Should block: {code}"
|
||||
|
||||
def test_allowed_imports(self):
|
||||
"""Test that safe imports are allowed."""
|
||||
from src.utils.code_interpreter import BLOCKED_PATTERNS
|
||||
import re
|
||||
|
||||
safe_code = [
|
||||
"import pandas as pd",
|
||||
"import numpy as np",
|
||||
"import matplotlib.pyplot as plt",
|
||||
"from sklearn.model_selection import train_test_split",
|
||||
"import os.path", # os.path is allowed
|
||||
]
|
||||
|
||||
for code in safe_code:
|
||||
blocked = any(
|
||||
re.search(pattern, code, re.IGNORECASE)
|
||||
for pattern in BLOCKED_PATTERNS
|
||||
)
|
||||
assert not blocked, f"Should allow: {code}"
|
||||
|
||||
def test_file_type_detection(self):
|
||||
"""Test file type detection for various extensions."""
|
||||
from src.utils.code_interpreter import FileManager
|
||||
|
||||
fm = FileManager()
|
||||
|
||||
assert fm._detect_file_type("data.csv") == "csv"
|
||||
assert fm._detect_file_type("data.xlsx") == "excel"
|
||||
assert fm._detect_file_type("config.json") == "json"
|
||||
assert fm._detect_file_type("image.png") == "image"
|
||||
assert fm._detect_file_type("script.py") == "python"
|
||||
assert fm._detect_file_type("unknown.xyz") == "binary"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Claude Utils Tests
|
||||
# ============================================================
|
||||
|
||||
class TestClaudeUtils:
|
||||
"""Tests for Claude utility functions."""
|
||||
|
||||
def test_is_claude_model(self):
|
||||
"""Test Claude model detection."""
|
||||
from src.utils.claude_utils import is_claude_model
|
||||
|
||||
# Claude models
|
||||
assert is_claude_model("anthropic/claude-sonnet-4-20250514") == True
|
||||
assert is_claude_model("anthropic/claude-opus-4-20250514") == True
|
||||
assert is_claude_model("anthropic/claude-3.5-sonnet") == True
|
||||
assert is_claude_model("anthropic/claude-3.5-haiku") == True
|
||||
|
||||
# Non-Claude models
|
||||
assert is_claude_model("openai/gpt-4o") == False
|
||||
assert is_claude_model("openai/gpt-4o-mini") == False
|
||||
assert is_claude_model("gpt-4") == False
|
||||
|
||||
def test_get_claude_model_id(self):
|
||||
"""Test Claude model ID extraction."""
|
||||
from src.utils.claude_utils import get_claude_model_id
|
||||
|
||||
assert get_claude_model_id("anthropic/claude-sonnet-4-20250514") == "claude-sonnet-4-20250514"
|
||||
assert get_claude_model_id("anthropic/claude-3.5-sonnet") == "claude-3.5-sonnet"
|
||||
assert get_claude_model_id("claude-3.5-sonnet") == "claude-3.5-sonnet"
|
||||
|
||||
def test_convert_openai_messages_to_claude(self):
|
||||
"""Test message conversion from OpenAI to Claude format."""
|
||||
from src.utils.claude_utils import convert_openai_messages_to_claude
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
|
||||
system, claude_messages = convert_openai_messages_to_claude(messages)
|
||||
|
||||
# System should be extracted
|
||||
assert system == "You are a helpful assistant."
|
||||
|
||||
# Messages should not contain system
|
||||
assert all(m.get("role") != "system" for m in claude_messages)
|
||||
|
||||
# Should have user and assistant messages
|
||||
assert len(claude_messages) >= 2
|
||||
|
||||
def test_convert_content_to_claude(self):
|
||||
"""Test content conversion."""
|
||||
from src.utils.claude_utils import convert_content_to_claude
|
||||
|
||||
# String content
|
||||
assert convert_content_to_claude("Hello") == "Hello"
|
||||
|
||||
# List content with text
|
||||
list_content = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"}
|
||||
]
|
||||
result = convert_content_to_claude(list_content)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_merge_consecutive_messages(self):
|
||||
"""Test merging consecutive messages with same role."""
|
||||
from src.utils.claude_utils import merge_consecutive_messages
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
merged = merge_consecutive_messages(messages)
|
||||
|
||||
# Should merge two user messages into one
|
||||
assert len(merged) == 2
|
||||
assert merged[0]["role"] == "user"
|
||||
assert merged[1]["role"] == "assistant"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# OpenAI Utils Tests
|
||||
# ============================================================
|
||||
|
||||
class TestOpenAIUtils:
|
||||
"""Tests for OpenAI utility functions."""
|
||||
|
||||
def test_count_tokens(self):
|
||||
"""Test token counting function."""
|
||||
from src.utils.openai_utils import count_tokens
|
||||
|
||||
text = "Hello, world!"
|
||||
tokens = count_tokens(text)
|
||||
assert tokens > 0
|
||||
assert isinstance(tokens, int)
|
||||
|
||||
def test_trim_content_to_token_limit(self):
|
||||
"""Test content trimming."""
|
||||
from src.utils.openai_utils import trim_content_to_token_limit
|
||||
|
||||
# Short content should not be trimmed
|
||||
short = "Hello, world!"
|
||||
trimmed = trim_content_to_token_limit(short, max_tokens=100)
|
||||
assert trimmed == short
|
||||
|
||||
# Long content should be trimmed
|
||||
long = "Hello " * 10000
|
||||
trimmed = trim_content_to_token_limit(long, max_tokens=100)
|
||||
assert len(trimmed) < len(long)
|
||||
|
||||
def test_prepare_messages_for_api(self):
|
||||
"""Test message preparation for API."""
|
||||
from src.utils.openai_utils import prepare_messages_for_api
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
|
||||
prepared = prepare_messages_for_api(messages)
|
||||
|
||||
assert len(prepared) == 3
|
||||
assert all(m.get("role") in ["user", "assistant", "system"] for m in prepared)
|
||||
|
||||
def test_prepare_messages_filters_none_content(self):
|
||||
"""Test that messages with None content are filtered."""
|
||||
from src.utils.openai_utils import prepare_messages_for_api
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": None},
|
||||
{"role": "user", "content": "World"},
|
||||
]
|
||||
|
||||
prepared = prepare_messages_for_api(messages)
|
||||
|
||||
assert len(prepared) == 2
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Database Handler Tests (with mocking)
|
||||
# ============================================================
|
||||
|
||||
class TestDatabaseHandlerMocked:
|
||||
"""Tests for database handler using mocks."""
|
||||
|
||||
def test_filter_expired_images_no_images(self):
|
||||
"""Test that messages without images pass through unchanged."""
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
|
||||
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
|
||||
handler = DatabaseHandler("mongodb://localhost")
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
filtered = handler._filter_expired_images(history)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0]["content"] == "Hello"
|
||||
|
||||
def test_filter_expired_images_recent_image(self):
|
||||
"""Test that recent images are kept."""
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
|
||||
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
|
||||
handler = DatabaseHandler("mongodb://localhost")
|
||||
|
||||
recent_timestamp = datetime.now().isoformat()
|
||||
history = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "Check this image"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": recent_timestamp}
|
||||
]}
|
||||
]
|
||||
|
||||
filtered = handler._filter_expired_images(history)
|
||||
assert len(filtered) == 1
|
||||
assert len(filtered[0]["content"]) == 2 # Both items kept
|
||||
|
||||
def test_filter_expired_images_old_image(self):
|
||||
"""Test that old images are filtered out."""
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
|
||||
with patch('motor.motor_asyncio.AsyncIOMotorClient'):
|
||||
handler = DatabaseHandler("mongodb://localhost")
|
||||
|
||||
old_timestamp = (datetime.now() - timedelta(hours=24)).isoformat()
|
||||
history = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "Check this image"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}, "timestamp": old_timestamp}
|
||||
]}
|
||||
]
|
||||
|
||||
filtered = handler._filter_expired_images(history)
|
||||
assert len(filtered) == 1
|
||||
assert len(filtered[0]["content"]) == 1 # Only text kept
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ============================================================
|
||||
# Cache Module Tests
|
||||
# ============================================================
|
||||
|
||||
class TestLRUCache:
|
||||
"""Tests for the LRU cache implementation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_and_get(self):
|
||||
"""Test basic cache set and get operations."""
|
||||
from src.utils.cache import LRUCache
|
||||
|
||||
cache = LRUCache(max_size=100, default_ttl=60.0)
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
result = await cache.get("key1")
|
||||
assert result == "value1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
from src.utils.cache import LRUCache
|
||||
|
||||
cache = LRUCache(max_size=100, default_ttl=0.1) # 100ms TTL
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
|
||||
# Should exist immediately
|
||||
assert await cache.get("key1") == "value1"
|
||||
|
||||
# Wait for expiration
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# Should be expired now
|
||||
assert await cache.get("key1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_lru_eviction(self):
|
||||
"""Test that LRU eviction works correctly."""
|
||||
from src.utils.cache import LRUCache
|
||||
|
||||
cache = LRUCache(max_size=3, default_ttl=60.0)
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Access key1 to make it recently used
|
||||
await cache.get("key1")
|
||||
|
||||
# Add new key, should evict key2 (least recently used)
|
||||
await cache.set("key4", "value4")
|
||||
|
||||
assert await cache.get("key1") == "value1" # Should exist
|
||||
assert await cache.get("key2") is None # Should be evicted
|
||||
assert await cache.get("key3") == "value3" # Should exist
|
||||
assert await cache.get("key4") == "value4" # Should exist
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_stats(self):
|
||||
"""Test cache statistics tracking."""
|
||||
from src.utils.cache import LRUCache
|
||||
|
||||
cache = LRUCache(max_size=100, default_ttl=60.0)
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
await cache.get("key1") # Hit
|
||||
await cache.get("key2") # Miss
|
||||
await cache.get("key1") # Hit
|
||||
|
||||
stats = cache.stats()
|
||||
assert stats["hits"] == 2
|
||||
assert stats["misses"] == 1
|
||||
assert stats["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing."""
|
||||
from src.utils.cache import LRUCache
|
||||
|
||||
cache = LRUCache(max_size=100, default_ttl=60.0)
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
|
||||
cleared = await cache.clear()
|
||||
assert cleared == 2
|
||||
|
||||
assert await cache.get("key1") is None
|
||||
assert await cache.get("key2") is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Monitoring Module Tests
|
||||
# ============================================================
|
||||
|
||||
class TestMonitoring:
|
||||
"""Tests for the monitoring utilities."""
|
||||
|
||||
def test_performance_metrics(self):
|
||||
"""Test performance metrics tracking."""
|
||||
from src.utils.monitoring import PerformanceMetrics
|
||||
import time
|
||||
|
||||
metrics = PerformanceMetrics(name="test_operation")
|
||||
time.sleep(0.01) # Small delay
|
||||
metrics.finish(success=True)
|
||||
|
||||
assert metrics.success
|
||||
assert metrics.duration_ms > 0
|
||||
assert metrics.duration_ms < 1000 # Should be fast
|
||||
|
||||
def test_measure_sync_context_manager(self):
|
||||
"""Test synchronous measurement context manager."""
|
||||
from src.utils.monitoring import measure_sync
|
||||
import time
|
||||
|
||||
with measure_sync("test_op", custom_field="value") as metrics:
|
||||
time.sleep(0.01)
|
||||
|
||||
assert metrics.duration_ms > 0
|
||||
assert metrics.metadata["custom_field"] == "value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_measure_async_context_manager(self):
|
||||
"""Test async measurement context manager."""
|
||||
from src.utils.monitoring import measure_async
|
||||
|
||||
async with measure_async("async_op") as metrics:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert metrics.duration_ms > 0
|
||||
assert metrics.success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_performance_decorator(self):
|
||||
"""Test performance tracking decorator."""
|
||||
from src.utils.monitoring import track_performance
|
||||
|
||||
call_count = 0
|
||||
|
||||
@track_performance("tracked_function")
|
||||
async def tracked_func():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return "result"
|
||||
|
||||
result = await tracked_func()
|
||||
assert result == "result"
|
||||
assert call_count == 1
|
||||
|
||||
def test_health_status(self):
|
||||
"""Test health status structure."""
|
||||
from src.utils.monitoring import HealthStatus
|
||||
|
||||
status = HealthStatus(healthy=True)
|
||||
|
||||
status.add_check("database", True, "Connected")
|
||||
status.add_check("api", False, "Timeout")
|
||||
|
||||
assert not status.healthy # Should be unhealthy due to API check
|
||||
assert status.checks["database"]["healthy"]
|
||||
assert not status.checks["api"]["healthy"]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Integration Tests (require environment setup)
|
||||
# ============================================================
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestIntegration:
|
||||
"""Integration tests that require actual services."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_connection(self):
|
||||
"""Test actual database connection (skip if no MongoDB)."""
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
mongodb_uri = os.getenv("MONGODB_URI")
|
||||
if not mongodb_uri:
|
||||
pytest.skip("MONGODB_URI not set")
|
||||
|
||||
from src.database.db_handler import DatabaseHandler
|
||||
handler = DatabaseHandler(mongodb_uri)
|
||||
|
||||
connected = await handler.ensure_connected()
|
||||
assert connected
|
||||
|
||||
await handler.close()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Run tests
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user