diff --git a/discord/cog.py b/discord/cog.py index 7fc26f35..18670162 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -1,838 +1,837 @@ """ The MIT License (MIT) Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -import inspect import importlib +import inspect import sys -import discord.utils import types -from . import errors -from .commands import SlashCommand, UserCommand, MessageCommand, ApplicationCommand, SlashCommandGroup +from typing import Any, Callable, Mapping, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, \ + Type -from typing import Any, Callable, Mapping, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type - -from .commands.commands import _BaseCommand +import discord.utils +from . import errors +from .commands.core import _BaseCommand if TYPE_CHECKING: from .commands import ApplicationContext, ApplicationCommand __all__ = ( 'CogMeta', 'Cog', 'CogMixin', ) CogT = TypeVar('CogT', bound='Cog') FuncT = TypeVar('FuncT', bound=Callable[..., Any]) MISSING: Any = discord.utils.MISSING def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") class CogMeta(type): """A metaclass for defining a cog. Note that you should probably not use this directly. It is exposed purely for documentation purposes along with making custom metaclasses to intermix with other metaclasses such as the :class:`abc.ABCMeta` metaclass. For example, to create an abstract cog mixin class, the following would be done. .. code-block:: python3 import abc class CogABCMeta(commands.CogMeta, abc.ABCMeta): pass class SomeMixin(metaclass=abc.ABCMeta): pass class SomeCogMixin(SomeMixin, commands.Cog, metaclass=CogABCMeta): pass .. note:: When passing an attribute of a metaclass that is documented below, note that you must pass it as a keyword-only argument to the class creation like the following example: .. code-block:: python3 class MyCog(commands.Cog, name='My Cog'): pass Attributes ----------- name: :class:`str` The cog name. By default, it is the name of the class with no modification. description: :class:`str` The cog description. By default, it is the cleaned docstring of the class. .. versionadded:: 1.6 command_attrs: :class:`dict` A list of attributes to apply to every command inside this cog. The dictionary is passed into the :class:`Command` options at ``__init__``. If you specify attributes inside the command attribute in the class, it will override the one specified inside this attribute. For example: .. code-block:: python3 class MyCog(commands.Cog, command_attrs=dict(hidden=True)): @commands.command() async def foo(self, ctx): pass # hidden -> True @commands.command(hidden=False) async def bar(self, ctx): pass # hidden -> False """ __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[ApplicationCommand] __cog_listeners__: List[Tuple[str, str]] def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: name, bases, attrs = args attrs['__cog_name__'] = kwargs.pop('name', name) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) description = kwargs.pop('description', None) if description is None: description = inspect.cleandoc(attrs.get('__doc__', '')) attrs['__cog_description__'] = description commands = {} listeners = {} no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' new_cls = super().__new__(cls, name, bases, attrs, **kwargs) valid_commands = [(c for i, c in j.__dict__.items() if isinstance(c, _BaseCommand)) for j in reversed(new_cls.__mro__)] if any(isinstance(i, ApplicationCommand) for i in valid_commands) and any(not isinstance(i, _BaseCommand) for i in valid_commands): _filter = ApplicationCommand else: _filter = _BaseCommand for base in reversed(new_cls.__mro__): for elem, value in base.__dict__.items(): if elem in commands: del commands[elem] if elem in listeners: del listeners[elem] try: if getattr(value, "parent") is not None and isinstance(value, ApplicationCommand): # Skip commands if they are a part of a group continue except AttributeError: pass is_static_method = isinstance(value, staticmethod) if is_static_method: value = value.__func__ if isinstance(value, _filter): if is_static_method: raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.') if elem.startswith(('cog_', 'bot_')): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value elif inspect.iscoroutinefunction(value): try: getattr(value, '__cog_listener__') except AttributeError: continue else: if elem.startswith(('cog_', 'bot_')): raise TypeError(no_bot_cog.format(base, elem)) listeners[elem] = value new_cls.__cog_commands__ = list(commands.values()) listeners_as_list = [] for listener in listeners.values(): for listener_name in listener.__cog_listener_names__: # I use __name__ instead of just storing the value so I can inject # the self attribute when the time comes to add them to the bot listeners_as_list.append((listener_name, listener.__name__)) new_cls.__cog_listeners__ = listeners_as_list cmd_attrs = new_cls.__cog_settings__ # Either update the command with the cog provided defaults or copy it. # r.e type ignore, type-checker complains about overriding a ClassVar new_cls.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in new_cls.__cog_commands__) # type: ignore lookup = { cmd.qualified_name: cmd for cmd in new_cls.__cog_commands__ } # Update the Command instances dynamically as well for command in new_cls.__cog_commands__: if not isinstance(command, ApplicationCommand): setattr(new_cls, command.callback.__name__, command) parent = command.parent if parent is not None: # Get the latest parent reference parent = lookup[parent.qualified_name] # type: ignore # Update our parent's reference to our self parent.remove_command(command.name) # type: ignore parent.add_command(command) # type: ignore return new_cls def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args) @classmethod def qualified_name(cls) -> str: return cls.__cog_name__ def _cog_special_method(func: FuncT) -> FuncT: func.__cog_special_method__ = None return func class Cog(metaclass=CogMeta): """The base class that all cogs must inherit from. A cog is a collection of commands, listeners, and optional state to help group commands together. More information on them can be found on the :ref:`ext_commands_cogs` page. When inheriting from this class, the options shown in :class:`CogMeta` are equally valid here. """ __cog_name__: ClassVar[str] __cog_settings__: ClassVar[Dict[str, Any]] __cog_commands__: ClassVar[List[ApplicationCommand]] __cog_listeners__: ClassVar[List[Tuple[str, str]]] def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT: # For issue 426, we need to store a copy of the command objects # since we modify them to inject `self` to them. # To do this, we need to interfere with the Cog creation process. self = super().__new__(cls) return self def get_commands(self) -> List[ApplicationCommand]: r""" Returns -------- List[:class:`.ApplicationCommand`] A :class:`list` of :class:`.ApplicationCommand`\s that are defined inside this cog. .. note:: This does not include subcommands. """ return [c for c in self.__cog_commands__ if isinstance(c, ApplicationCommand) and c.parent is None] @property def qualified_name(self) -> str: """:class:`str`: Returns the cog's specified name, not the class name.""" return self.__cog_name__ @property def description(self) -> str: """:class:`str`: Returns the cog's description, typically the cleaned docstring.""" return self.__cog_description__ @description.setter def description(self, description: str) -> None: self.__cog_description__ = description def walk_commands(self) -> Generator[ApplicationCommand, None, None]: """An iterator that recursively walks through this cog's commands and subcommands. Yields ------ Union[:class:`.Command`, :class:`.Group`] A command or group from the cog. """ for command in self.__cog_commands__: if command.parent is None: yield command def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]: """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog. Returns -------- List[Tuple[:class:`str`, :ref:`coroutine `]] The listeners defined in this cog. """ return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] @classmethod def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: """Return None if the method is not overridden. Otherwise returns the overridden method.""" return getattr(getattr(method, "__func__", method), '__cog_special_method__', method) @classmethod def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: """A decorator that marks a function as a listener. This is the cog equivalent of :meth:`.Bot.listen`. Parameters ------------ name: :class:`str` The name of the event being listened to. If not provided, it defaults to the function's name. Raises -------- TypeError The function is not a coroutine function or a string was not passed as the name. """ if name is not MISSING and not isinstance(name, str): raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') def decorator(func: FuncT) -> FuncT: actual = func if isinstance(actual, staticmethod): actual = actual.__func__ if not inspect.iscoroutinefunction(actual): raise TypeError('Listener function must be a coroutine function.') actual.__cog_listener__ = True to_assign = name or actual.__name__ try: actual.__cog_listener_names__.append(to_assign) except AttributeError: actual.__cog_listener_names__ = [to_assign] # we have to return `func` instead of `actual` because # we need the type to be `staticmethod` for the metaclass # to pick it up but the metaclass unfurls the function and # thus the assignments need to be on the actual function return func return decorator def has_error_handler(self) -> bool: """:class:`bool`: Checks whether the cog has an error handler. .. versionadded:: 1.7 """ return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') @_cog_special_method def cog_unload(self) -> None: """A special method that is called when the cog gets removed. This function **cannot** be a coroutine. It must be a regular function. Subclasses must replace this if they want special unloading behaviour. """ pass @_cog_special_method def bot_check_once(self, ctx: ApplicationContext) -> bool: """A special method that registers as a :meth:`.Bot.check_once` check. This function **can** be a coroutine and must take a sole parameter, ``ctx``, to represent the :class:`.Context`. """ return True @_cog_special_method def bot_check(self, ctx: ApplicationContext) -> bool: """A special method that registers as a :meth:`.Bot.check` check. This function **can** be a coroutine and must take a sole parameter, ``ctx``, to represent the :class:`.Context`. """ return True @_cog_special_method def cog_check(self, ctx: ApplicationContext) -> bool: """A special method that registers as a :func:`~discord.ext.commands.check` for every command and subcommand in this cog. This function **can** be a coroutine and must take a sole parameter, ``ctx``, to represent the :class:`.Context`. """ return True @_cog_special_method async def cog_command_error(self, ctx: ApplicationContext, error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. This is similar to :func:`.on_command_error` except only applying to the commands inside this cog. This **must** be a coroutine. Parameters ----------- ctx: :class:`.Context` The invocation context where the error happened. error: :class:`CommandError` The error that happened. """ pass @_cog_special_method async def cog_before_invoke(self, ctx: ApplicationContext) -> None: """A special method that acts as a cog local pre-invoke hook. This is similar to :meth:`.Command.before_invoke`. This **must** be a coroutine. Parameters ----------- ctx: :class:`.Context` The invocation context. """ pass @_cog_special_method async def cog_after_invoke(self, ctx: ApplicationContext) -> None: """A special method that acts as a cog local post-invoke hook. This is similar to :meth:`.Command.after_invoke`. This **must** be a coroutine. Parameters ----------- ctx: :class:`.Context` The invocation context. """ pass def _inject(self: CogT, bot) -> CogT: cls = self.__class__ # realistically, the only thing that can cause loading errors # is essentially just the command loading, which raises if there are # duplicates. When this condition is met, we want to undo all what # we've added so far for some form of atomic loading. for index, command in enumerate(self.__cog_commands__): command._set_cog(self) if not isinstance(command, ApplicationCommand): if command.parent is None: try: bot.add_command(command) except Exception as e: # undo our additions for to_undo in self.__cog_commands__[:index]: if to_undo.parent is None: bot.remove_command(to_undo.name) raise e else: bot.add_application_command(command) # check if we're overriding the default if cls.bot_check is not Cog.bot_check: bot.add_check(self.bot_check) if cls.bot_check_once is not Cog.bot_check_once: bot.add_check(self.bot_check_once, call_once=True) # while Bot.add_listener can raise if it's not a coroutine, # this precondition is already met by the listener decorator # already, thus this should never raise. # Outside of, memory errors and the like... for name, method_name in self.__cog_listeners__: bot.add_listener(getattr(self, method_name), name) return self def _eject(self, bot) -> None: cls = self.__class__ try: for command in self.__cog_commands__: if isinstance(command, ApplicationCommand): bot.remove_application_command(command) else: if command.parent is None: bot.remove_command(command.name) for _, method_name in self.__cog_listeners__: bot.remove_listener(getattr(self, method_name)) if cls.bot_check is not Cog.bot_check: bot.remove_check(self.bot_check) if cls.bot_check_once is not Cog.bot_check_once: bot.remove_check(self.bot_check_once, call_once=True) finally: try: self.cog_unload() except Exception: pass class CogMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} def add_cog(self, cog: Cog, *, override: bool = False) -> None: """Adds a "cog" to the bot. A cog is a class that has its own event listeners and commands. .. versionchanged:: 2.0 :exc:`.ClientException` is raised when a cog with the same name is already loaded. Parameters ----------- cog: :class:`.Cog` The cog to register to the bot. override: :class:`bool` If a previously loaded cog with the same name should be ejected instead of raising an error. .. versionadded:: 2.0 Raises ------- TypeError The cog does not inherit from :class:`.Cog`. CommandError An error happened during loading. .ClientException A cog with the same name is already loaded. """ if not isinstance(cog, Cog): raise TypeError('cogs must derive from Cog') cog_name = cog.__cog_name__ existing = self.__cogs.get(cog_name) if existing is not None: if not override: raise discord.ClientException(f'Cog named {cog_name!r} already loaded') self.remove_cog(cog_name) cog = cog._inject(self) self.__cogs[cog_name] = cog def get_cog(self, name: str) -> Optional[Cog]: """Gets the cog instance requested. If the cog is not found, ``None`` is returned instead. Parameters ----------- name: :class:`str` The name of the cog you are requesting. This is equivalent to the name passed via keyword argument in class creation or the class name if unspecified. Returns -------- Optional[:class:`Cog`] The cog that was requested. If not found, returns ``None``. """ return self.__cogs.get(name) def remove_cog(self, name: str) -> Optional[Cog]: """Removes a cog from the bot and returns it. All registered commands and event listeners that the cog has registered will be removed as well. If no cog is found then this method has no effect. Parameters ----------- name: :class:`str` The name of the cog to remove. Returns ------- Optional[:class:`.Cog`] The cog that was removed. ``None`` if not found. """ cog = self.__cogs.pop(name, None) if cog is None: return if hasattr(self, "_help_command"): help_command = self._help_command if help_command and help_command.cog is cog: help_command.cog = None cog._eject(self) return cog @property def cogs(self) -> Mapping[str, Cog]: """Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" return types.MappingProxyType(self.__cogs) # extensions def _remove_module_references(self, name: str) -> None: # find all references to the module # remove the cogs registered from the module for cogname, cog in self.__cogs.copy().items(): if _is_submodule(name, cog.__module__): self.remove_cog(cogname) # remove all the commands from the module for cmd in self.all_commands.copy().values(): if cmd.module is not None and _is_submodule(name, cmd.module): # if isinstance(cmd, GroupMixin): # cmd.recursively_remove_all_commands() self.remove_command(cmd.name) # remove all the listeners from the module for event_list in self.extra_events.copy().values(): remove = [] for index, event in enumerate(event_list): if event.__module__ is not None and _is_submodule(name, event.__module__): remove.append(index) for index in reversed(remove): del event_list[index] def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: func = getattr(lib, 'teardown') except AttributeError: pass else: try: func(self) except Exception: pass finally: self.__extensions.pop(key, None) sys.modules.pop(key, None) name = lib.__name__ for module in list(sys.modules.keys()): if _is_submodule(name, module): del sys.modules[module] def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: # precondition: key not in self.__extensions lib = importlib.util.module_from_spec(spec) sys.modules[key] = lib try: spec.loader.exec_module(lib) # type: ignore except Exception as e: del sys.modules[key] raise errors.ExtensionFailed(key, e) from e try: setup = getattr(lib, 'setup') except AttributeError: del sys.modules[key] raise errors.NoEntryPointError(key) try: setup(self) except Exception as e: del sys.modules[key] self._remove_module_references(lib.__name__) self._call_module_finalizers(lib, key) raise errors.ExtensionFailed(key, e) from e else: self.__extensions[key] = lib def _resolve_name(self, name: str, package: Optional[str]) -> str: try: return importlib.util.resolve_name(name, package) except ImportError: raise errors.ExtensionNotFound(name) def load_extension(self, name: str, *, package: Optional[str] = None) -> None: """Loads an extension. An extension is a python module that contains commands, cogs, or listeners. An extension must have a global function, ``setup`` defined as the entry point on what to do when the extension is loaded. This entry point must have a single argument, the ``bot``. Parameters ------------ name: :class:`str` The extension name to load. It must be dot separated like regular Python imports if accessing a sub-module. e.g. ``foo.test`` if you want to import ``foo/test.py``. package: Optional[:class:`str`] The package name to resolve relative imports with. This is required when loading an extension using a relative path, e.g ``.foo.test``. Defaults to ``None``. .. versionadded:: 1.7 Raises -------- ExtensionNotFound The extension could not be imported. This is also raised if the name of the extension could not be resolved using the provided ``package`` parameter. ExtensionAlreadyLoaded The extension is already loaded. NoEntryPointError The extension does not have a setup function. ExtensionFailed The extension or its setup function had an execution error. """ name = self._resolve_name(name, package) if name in self.__extensions: raise errors.ExtensionAlreadyLoaded(name) spec = importlib.util.find_spec(name) if spec is None: raise errors.ExtensionNotFound(name) self._load_from_module_spec(spec, name) def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: """Unloads an extension. When the extension is unloaded, all commands, listeners, and cogs are removed from the bot and the module is un-imported. The extension can provide an optional global function, ``teardown``, to do miscellaneous clean-up if necessary. This function takes a single parameter, the ``bot``, similar to ``setup`` from :meth:`~.Bot.load_extension`. Parameters ------------ name: :class:`str` The extension name to unload. It must be dot separated like regular Python imports if accessing a sub-module. e.g. ``foo.test`` if you want to import ``foo/test.py``. package: Optional[:class:`str`] The package name to resolve relative imports with. This is required when unloading an extension using a relative path, e.g ``.foo.test``. Defaults to ``None``. .. versionadded:: 1.7 Raises ------- ExtensionNotFound The name of the extension could not be resolved using the provided ``package`` parameter. ExtensionNotLoaded The extension was not loaded. """ name = self._resolve_name(name, package) lib = self.__extensions.get(name) if lib is None: raise errors.ExtensionNotLoaded(name) self._remove_module_references(lib.__name__) self._call_module_finalizers(lib, name) def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: """Atomically reloads an extension. This replaces the extension with the same extension, only refreshed. This is equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension` except done in an atomic way. That is, if an operation fails mid-reload then the bot will roll-back to the prior working state. Parameters ------------ name: :class:`str` The extension name to reload. It must be dot separated like regular Python imports if accessing a sub-module. e.g. ``foo.test`` if you want to import ``foo/test.py``. package: Optional[:class:`str`] The package name to resolve relative imports with. This is required when reloading an extension using a relative path, e.g ``.foo.test``. Defaults to ``None``. .. versionadded:: 1.7 Raises ------- ExtensionNotLoaded The extension was not loaded. ExtensionNotFound The extension could not be imported. This is also raised if the name of the extension could not be resolved using the provided ``package`` parameter. NoEntryPointError The extension does not have a setup function. ExtensionFailed The extension setup function had an execution error. """ name = self._resolve_name(name, package) lib = self.__extensions.get(name) if lib is None: raise errors.ExtensionNotLoaded(name) # get the previous module states from sys modules modules = { name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name) } try: # Unload and then load the module... self._remove_module_references(lib.__name__) self._call_module_finalizers(lib, name) self.load_extension(name) except Exception: # if the load failed, the remnants should have been # cleaned from the load_extension function call # so let's load it from our old compiled library. lib.setup(self) # type: ignore self.__extensions[name] = lib # revert sys.modules back to normal and raise back to caller sys.modules.update(modules) raise @property def extensions(self) -> Mapping[str, types.ModuleType]: """Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" return types.MappingProxyType(self.__extensions) diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py index 66628133..ef15ffd7 100644 --- a/discord/commands/__init__.py +++ b/discord/commands/__init__.py @@ -1,29 +1,29 @@ """ The MIT License (MIT) Copyright (c) 2015-2021 Rapptz Copyright (c) 2021-present Pycord Development Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from .commands import * from .context import * +from .core import * from .errors import * from .permissions import * diff --git a/discord/commands/context.py b/discord/commands/context.py index b1fb8c96..2d057ab8 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -1,268 +1,267 @@ """ The MIT License (MIT) Copyright (c) 2015-2021 Rapptz Copyright (c) 2021-present Pycord Development Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Callable, TYPE_CHECKING, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Optional, TypeVar, Union import discord.abc if TYPE_CHECKING: from typing_extensions import ParamSpec import discord from discord import Bot from discord.state import ConnectionState - from .commands import ApplicationCommand, Option + from .core import ApplicationCommand, Option from ..cog import Cog from ..webhook import WebhookMessage from typing import Callable from ..guild import Guild from ..interactions import Interaction, InteractionResponse from ..member import Member from ..message import Message from ..user import User from ..utils import cached_property -from ..webhook import Webhook T = TypeVar('T') CogT = TypeVar('CogT', bound="Cog") if TYPE_CHECKING: P = ParamSpec('P') else: P = TypeVar('P') __all__ = ("ApplicationContext", "AutocompleteContext") class ApplicationContext(discord.abc.Messageable): """Represents a Discord application command interaction context. This class is not created manually and is instead passed to application commands as the first parameter. .. versionadded:: 2.0 Attributes ----------- bot: :class:`.Bot` The bot that the command belongs to. interaction: :class:`.Interaction` The interaction object that invoked the command. command: :class:`.ApplicationCommand` The command that this context belongs to. """ def __init__(self, bot: Bot, interaction: Interaction): self.bot = bot self.interaction = interaction # below attributes will be set after initialization self.command: ApplicationCommand = None # type: ignore self.focused: Option = None # type: ignore self.value: str = None # type: ignore self.options: dict = None # type: ignore self._state: ConnectionState = self.interaction._state async def _get_channel(self) -> discord.abc.Messageable: return self.channel async def invoke(self, command: ApplicationCommand[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| Calls a command with the arguments given. This is useful if you want to just call the callback that a :class:`.ApplicationCommand` holds internally. .. note:: This does not handle converters, checks, cooldowns, pre-invoke, or after-invoke hooks in any matter. It calls the internal callback directly as-if it was a regular function. You must take care in passing the proper arguments when using this function. Parameters ----------- command: :class:`.ApplicationCommand` The command that is going to be called. \*args The arguments to use. \*\*kwargs The keyword arguments to use. Raises ------- TypeError The command argument to invoke is missing. """ return await command(self, *args, **kwargs) @cached_property def channel(self): return self.interaction.channel @cached_property def channel_id(self) -> Optional[int]: return self.interaction.channel_id @cached_property def guild(self) -> Optional[Guild]: return self.interaction.guild @cached_property def guild_id(self) -> Optional[int]: return self.interaction.guild_id @cached_property def locale(self) -> Optional[str]: return self.interaction.locale @cached_property def guild_locale(self) -> Optional[str]: return self.interaction.guild_locale @cached_property def me(self) -> Union[Member, User]: return self.guild.me if self.guild is not None else self.bot.user @cached_property def message(self) -> Optional[Message]: return self.interaction.message @cached_property def user(self) -> Optional[Union[Member, User]]: return self.interaction.user @cached_property def author(self) -> Optional[Union[Member, User]]: return self.user @property def voice_client(self): if self.guild is None: return None return self.guild.voice_client @cached_property def response(self) -> InteractionResponse: return self.interaction.response @property def respond(self) -> Callable[..., Union[Interaction, WebhookMessage]]: """Callable[..., Union[:class:`~.Interaction`, :class:`~.Webhook`]]: Sends either a response or a followup response depending if the interaction has been responded to yet or not.""" if not self.response.is_done(): return self.interaction.response.send_message # self.response else: return self.followup.send # self.send_followup @property def send_response(self): if not self.response.is_done(): return self.interaction.response.send_message else: raise RuntimeError( f"Interaction was already issued a response. Try using {type(self).__name__}.send_followup() instead." ) @property def send_followup(self): if self.response.is_done(): return self.followup.send else: raise RuntimeError( f"Interaction was not yet issued a response. Try using {type(self).__name__}.respond() first." ) @property def defer(self): return self.interaction.response.defer @property def followup(self): return self.interaction.followup async def delete(self): """Calls :attr:`~discord.commands.ApplicationContext.respond`. If the response is done, then calls :attr:`~discord.commands.ApplicationContext.respond` first.""" if not self.response.is_done(): await self.defer() return await self.interaction.delete_original_message() @property def edit(self): return self.interaction.edit_original_message @property def cog(self) -> Optional[Cog]: """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. ``None`` if it does not exist.""" if self.command is None: return None return self.command.cog class AutocompleteContext: """Represents context for a slash command's option autocomplete. This class is not created manually and is instead passed to an Option's autocomplete callback. .. versionadded:: 2.0 Attributes ----------- bot: :class:`.Bot` The bot that the command belongs to. interaction: :class:`.Interaction` The interaction object that invoked the autocomplete. command: :class:`.ApplicationCommand` The command that this context belongs to. focused: :class:`.Option` The option the user is currently typing. value: :class:`.str` The content of the focused option. options :class:`.dict` A name to value mapping of the options that the user has selected before this option. """ __slots__ = ("bot", "interaction", "command", "focused", "value", "options") def __init__(self, bot: Bot, interaction: Interaction) -> None: self.bot = bot self.interaction = interaction self.command: ApplicationCommand = None # type: ignore self.focused: Option = None # type: ignore self.value: str = None # type: ignore self.options: dict = None # type: ignore @property def cog(self) -> Optional[Cog]: """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. ``None`` if it does not exist.""" if self.command is None: return None return self.command.cog diff --git a/discord/commands/commands.py b/discord/commands/core.py similarity index 99% rename from discord/commands/commands.py rename to discord/commands/core.py index 34da68d3..10200da3 100644 --- a/discord/commands/commands.py +++ b/discord/commands/core.py @@ -1,1453 +1,1452 @@ """ The MIT License (MIT) Copyright (c) 2015-2021 Rapptz Copyright (c) 2021-present Pycord Development Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations import asyncio import datetime import functools import inspect import re import types from collections import OrderedDict from typing import Any, Callable, Dict, Generator, Generic, List, Literal, Optional, Type, TypeVar, Union, TYPE_CHECKING from .context import ApplicationContext, AutocompleteContext from .errors import ApplicationCommandError, CheckFailure, ApplicationCommandInvokeError from .permissions import CommandPermission from ..enums import SlashCommandOptionType, ChannelType from ..errors import ValidationError, ClientException from ..member import Member from ..message import Message from ..user import User from ..utils import find, get_or_fetch, async_all __all__ = ( "_BaseCommand", "ApplicationCommand", "SlashCommand", "Option", "OptionChoice", "option", "slash_command", "application_command", "user_command", "message_command", "command", "SlashCommandGroup", "ContextMenuCommand", "UserCommand", "MessageCommand", ) if TYPE_CHECKING: from typing_extensions import ParamSpec from ..cog import Cog - from ..interactions import Interaction T = TypeVar('T') CogT = TypeVar('CogT', bound='Cog') if TYPE_CHECKING: P = ParamSpec('P') else: P = TypeVar('P') def wrap_callback(coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): try: ret = await coro(*args, **kwargs) except ApplicationCommandError: raise except asyncio.CancelledError: return except Exception as exc: raise ApplicationCommandInvokeError(exc) from exc return ret return wrapped def hooked_wrapped_callback(command, ctx, coro): @functools.wraps(coro) async def wrapped(arg): try: ret = await coro(arg) except ApplicationCommandError: raise except asyncio.CancelledError: return except Exception as exc: raise ApplicationCommandInvokeError(exc) from exc finally: if hasattr(command, '_max_concurrency') and command._max_concurrency is not None: await command._max_concurrency.release(ctx) await command.call_after_hooks(ctx) return ret return wrapped class _BaseCommand: __slots__ = () class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]): cog = None def __init__(self, func: Callable, **kwargs) -> None: from ..ext.commands.cooldowns import CooldownMapping, BucketType, MaxConcurrency try: cooldown = func.__commands_cooldown__ except AttributeError: cooldown = kwargs.get('cooldown') if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self._buckets: CooldownMapping = buckets try: max_concurrency = func.__commands_max_concurrency__ except AttributeError: max_concurrency = kwargs.get('max_concurrency') self._max_concurrency: Optional[MaxConcurrency] = max_concurrency def __repr__(self): return f"" def __eq__(self, other) -> bool: if hasattr(self, "id") and hasattr(other, "id"): check = self.id == other.id else: check = ( self.name == other.name and self.guild_ids == self.guild_ids ) return ( isinstance(other, self.__class__) and self.parent == other.parent and check ) async def __call__(self, ctx, *args, **kwargs): """|coro| Calls the command's callback. This method bypasses all checks that a command has and does not convert the arguments beforehand, so take care to pass the correct arguments in. """ return await self.callback(ctx, *args, **kwargs) def _prepare_cooldowns(self, ctx: ApplicationContext): if self._buckets.valid: current = datetime.datetime.now().timestamp() bucket = self._buckets.get_bucket(ctx, current) # type: ignore (ctx instead of non-existent message) if bucket is not None: retry_after = bucket.update_rate_limit(current) if retry_after: from ..ext.commands.errors import CommandOnCooldown raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore async def prepare(self, ctx: ApplicationContext) -> None: # This should be same across all 3 types ctx.command = self if not await self.can_run(ctx): raise CheckFailure(f'The check functions for the command {self.name} failed') if hasattr(self, "_max_concurrency"): if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message await self._max_concurrency.acquire(ctx) # type: ignore (ctx instead of non-existent message) try: self._prepare_cooldowns(ctx) await self.call_before_hooks(ctx) except: if self._max_concurrency is not None: await self._max_concurrency.release(ctx) # type: ignore (ctx instead of non-existent message) raise def reset_cooldown(self, ctx: ApplicationContext) -> None: """Resets the cooldown on this command. Parameters ----------- ctx: :class:`.ApplicationContext` The invocation context to reset the cooldown under. """ if self._buckets.valid: bucket = self._buckets.get_bucket(ctx) # type: ignore (ctx instead of non-existent message) bucket.reset() async def invoke(self, ctx: ApplicationContext) -> None: await self.prepare(ctx) injected = hooked_wrapped_callback(self, ctx, self._invoke) await injected(ctx) async def can_run(self, ctx: ApplicationContext) -> bool: if not await ctx.bot.can_run(ctx): raise CheckFailure(f'The global check functions for command {self.name} failed.') predicates = self.checks if not predicates: # since we have no checks, then we just return True. return True return await async_all(predicate(ctx) for predicate in predicates) # type: ignore async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> None: ctx.command_failed = True cog = self.cog try: coro = self.on_error except AttributeError: pass else: injected = wrap_callback(coro) if cog is not None: await injected(cog, ctx, error) else: await injected(ctx, error) try: if cog is not None: local = cog.__class__._get_overridden_method(cog.cog_command_error) if local is not None: wrapped = wrap_callback(local) await wrapped(ctx, error) finally: ctx.bot.dispatch('application_command_error', ctx, error) def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) def error(self, coro): """A decorator that registers a coroutine as a local error handler. A local error handler is an :func:`.on_command_error` event limited to a single command. However, the :func:`.on_command_error` is still invoked afterwards as the catch-all. Parameters ----------- coro: :ref:`coroutine ` The coroutine to register as the local error handler. Raises ------- TypeError The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): raise TypeError('The error handler must be a coroutine.') self.on_error = coro return coro def has_error_handler(self) -> bool: """:class:`bool`: Checks whether the command has an error handler registered. """ return hasattr(self, 'on_error') def before_invoke(self, coro): """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is called. This makes it a useful function to set up database connections or any type of set up required. This pre-invoke hook takes a sole parameter, a :class:`.Context`. See :meth:`.Bot.before_invoke` for more info. Parameters ----------- coro: :ref:`coroutine ` The coroutine to register as the pre-invoke hook. Raises ------- TypeError The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): raise TypeError('The pre-invoke hook must be a coroutine.') self._before_invoke = coro return coro def after_invoke(self, coro): """A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is called. This makes it a useful function to clean-up database connections or any type of clean up required. This post-invoke hook takes a sole parameter, a :class:`.Context`. See :meth:`.Bot.after_invoke` for more info. Parameters ----------- coro: :ref:`coroutine ` The coroutine to register as the post-invoke hook. Raises ------- TypeError The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): raise TypeError('The post-invoke hook must be a coroutine.') self._after_invoke = coro return coro async def call_before_hooks(self, ctx: ApplicationContext) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog if self._before_invoke is not None: # should be cog if @commands.before_invoke is used instance = getattr(self._before_invoke, '__self__', cog) # __self__ only exists for methods, not functions # however, if @command.before_invoke is used, it will be a function if instance: await self._before_invoke(instance, ctx) # type: ignore else: await self._before_invoke(ctx) # type: ignore # call the cog local hook if applicable: if cog is not None: hook = cog.__class__._get_overridden_method(cog.cog_before_invoke) if hook is not None: await hook(ctx) # call the bot global hook if necessary hook = ctx.bot._before_invoke if hook is not None: await hook(ctx) async def call_after_hooks(self, ctx: ApplicationContext) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, '__self__', cog) if instance: await self._after_invoke(instance, ctx) # type: ignore else: await self._after_invoke(ctx) # type: ignore # call the cog local hook if applicable: if cog is not None: hook = cog.__class__._get_overridden_method(cog.cog_after_invoke) if hook is not None: await hook(ctx) hook = ctx.bot._after_invoke if hook is not None: await hook(ctx) @property def cooldown(self): return self._buckets._cooldown @property def full_parent_name(self) -> str: """:class:`str`: Retrieves the fully qualified parent command name. This the base command name required to execute it. For example, in ``/one two three`` the parent name would be ``one two``. """ entries = [] command = self while command.parent is not None and hasattr(command.parent, "name"): command = command.parent entries.append(command.name) return ' '.join(reversed(entries)) @property def qualified_name(self) -> str: """:class:`str`: Retrieves the fully qualified command name. This is the full parent name with the command name as well. For example, in ``/one two three`` the qualified name would be ``one two three``. """ parent = self.full_parent_name if parent: return parent + ' ' + self.name else: return self.name def _set_cog(self, cog): self.cog = cog class SlashCommand(ApplicationCommand): r"""A class that implements the protocol for a slash command. These are not created manually, instead they are created via the decorator or functional interface. Attributes ----------- name: :class:`str` The name of the command. callback: :ref:`coroutine ` The coroutine that is executed when the command is called. description: Optional[:class:`str`] The description for the command. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. options: List[:class:`Option`] The parameters for this command. parent: Optional[:class:`SlashCommandGroup`] The parent group that this command belongs to. ``None`` if there isn't one. default_permission: :class:`bool` Whether the command is enabled by default when it is added to a guild. permissions: List[:class:`CommandPermission`] The permissions for this command. .. note:: If this is not empty then default_permissions will be set to False. cog: Optional[:class:`Cog`] The cog that this command belongs to. ``None`` if there isn't one. checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] A list of predicates that verifies if the command could be executed with the given :class:`.ApplicationContext` as the sole parameter. If an exception is necessary to be thrown to signal failure, then one inherited from :exc:`.CommandError` should be used. Note that if the checks fail then :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` event. cooldown: Optional[:class:`~discord.ext.commands.Cooldown`] The cooldown applied when the command is invoked. ``None`` if the command doesn't have a cooldown. .. versionadded:: 2.0 """ type = 1 def __new__(cls, *args, **kwargs) -> SlashCommand: self = super().__new__(cls) self.__original_kwargs__ = kwargs.copy() return self def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) if not asyncio.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") self.callback = func self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None) name = kwargs.get("name") or func.__name__ validate_chat_input_name(name) self.name: str = name self.id = None description = kwargs.get("description") or ( inspect.cleandoc(func.__doc__).splitlines()[0] if func.__doc__ is not None else "No description provided" ) validate_chat_input_description(description) self.description: str = description self.parent = kwargs.get('parent') self.attached_to_group: bool = False self.cog = None params = self._get_signature_parameters() if (kwop := kwargs.get('options', None)): self.options: List[Option] = self._match_option_param_names(params, kwop) else: self.options: List[Option] = self._parse_options(params) try: checks = func.__commands_checks__ checks.reverse() except AttributeError: checks = kwargs.get('checks', []) self.checks = checks self._before_invoke = None self._after_invoke = None # Permissions self.default_permission = kwargs.get("default_permission", True) self.permissions: List[CommandPermission] = getattr(func, "__app_cmd_perms__", []) + kwargs.get("permissions", []) if self.permissions and self.default_permission: self.default_permission = False def _parse_options(self, params) -> List[Option]: final_options = [] if list(params.items())[0][0] == "self": temp = list(params.items()) temp.pop(0) params = dict(temp) params = iter(params.items()) # next we have the 'ctx' as the next parameter try: next(params) except StopIteration: raise ClientException( f'Callback for {self.name} command is missing "ctx" parameter.' ) final_options = [] for p_name, p_obj in params: option = p_obj.annotation if option == inspect.Parameter.empty: option = str if self._is_typing_union(option): if self._is_typing_optional(option): option = Option( option.__args__[0], "No description provided", required=False ) else: option = Option( option.__args__, "No description provided" ) if not isinstance(option, Option): option = Option(option, "No description provided") if p_obj.default != inspect.Parameter.empty: option.required = False option.default = option.default if option.default is not None else p_obj.default if option.default == inspect.Parameter.empty: option.default = None if option.name is None: option.name = p_name option._parameter_name = p_name validate_chat_input_name(option.name) validate_chat_input_description(option.description) final_options.append(option) return final_options def _match_option_param_names(self, params, options): if list(params.items())[0][0] == "self": temp = list(params.items()) temp.pop(0) params = dict(temp) params = iter(params.items()) # next we have the 'ctx' as the next parameter try: next(params) except StopIteration: raise ClientException( f'Callback for {self.name} command is missing "ctx" parameter.' ) check_annotations = [ lambda o, a: o.input_type == SlashCommandOptionType.string and o.converter is not None, # pass on converters lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # union types lambda o, a: self._is_typing_optional(a) and not o.required and o._raw_type in a.__args__, # optional lambda o, a: inspect.isclass(a) and issubclass(a, o._raw_type) # 'normal' types ] for o in options: validate_chat_input_name(o.name) validate_chat_input_description(o.description) try: p_name, p_obj = next(params) except StopIteration: # not enough params for all the options raise ClientException( f"Too many arguments passed to the options kwarg." ) p_obj = p_obj.annotation if not any(c(o, p_obj) for c in check_annotations): raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.") o._parameter_name = p_name left_out_params = OrderedDict() left_out_params[''] = '' # bypass first iter (ctx) for k, v in params: left_out_params[k] = v options.extend(self._parse_options(left_out_params)) return options def _is_typing_union(self, annotation): return ( getattr(annotation, '__origin__', None) is Union or type(annotation) is getattr(types, "UnionType", Union) ) # type: ignore def _is_typing_optional(self, annotation): return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore @property def is_subcommand(self) -> bool: return self.parent is not None def to_dict(self) -> Dict: as_dict = { "name": self.name, "description": self.description, "options": [o.to_dict() for o in self.options], "default_permission": self.default_permission, } if self.is_subcommand: as_dict["type"] = SlashCommandOptionType.sub_command.value return as_dict async def _invoke(self, ctx: ApplicationContext) -> None: # TODO: Parse the args better kwargs = {} for arg in ctx.interaction.data.get("options", []): op = find(lambda x: x.name == arg["name"], self.options) arg = arg["value"] # Checks if input_type is user, role or channel if ( SlashCommandOptionType.user.value <= op.input_type.value <= SlashCommandOptionType.role.value ): if ctx.guild is None and op.input_type.name == "user": _data = ctx.interaction.data["resolved"]["users"][arg] _data["id"] = int(arg) arg = User(state=ctx.interaction._state, data=_data) else: name = "member" if op.input_type.name == "user" else op.input_type.name arg = await get_or_fetch(ctx.guild, name, int(arg), default=int(arg)) elif op.input_type == SlashCommandOptionType.mentionable: arg_id = int(arg) arg = await get_or_fetch(ctx.guild, "member", arg_id) if arg is None: arg = ctx.guild.get_role(arg_id) or arg_id elif op.input_type == SlashCommandOptionType.string and (converter := op.converter) is not None: arg = await converter.convert(converter, ctx, arg) kwargs[op._parameter_name] = arg for o in self.options: if o._parameter_name not in kwargs: kwargs[o._parameter_name] = o.default if self.cog is not None: await self.callback(self.cog, ctx, **kwargs) elif self.parent is not None and self.attached_to_group is True: await self.callback(self.parent, ctx, **kwargs) else: await self.callback(ctx, **kwargs) async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): values = { i.name: i.default for i in self.options } for op in ctx.interaction.data.get("options", []): if op.get("focused", False): option = find(lambda o: o.name == op["name"], self.options) values.update({ i["name"]:i["value"] for i in ctx.interaction.data["options"] }) ctx.command = self ctx.focused = option ctx.value = op.get("value") ctx.options = values if len(inspect.signature(option.autocomplete).parameters) == 2: instance = getattr(option.autocomplete, "__self__", ctx.cog) result = option.autocomplete(instance, ctx) else: result = option.autocomplete(ctx) if asyncio.iscoroutinefunction(option.autocomplete): result = await result choices = [ o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result ][:25] return await ctx.interaction.response.send_autocomplete_result(choices=choices) def copy(self): """Creates a copy of this command. Returns -------- :class:`SlashCommand` A new instance of this command. """ ret = self.__class__(self.callback, **self.__original_kwargs__) return self._ensure_assignment_on_copy(ret) def _ensure_assignment_on_copy(self, other): other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke if self.checks != other.checks: other.checks = self.checks.copy() #if self._buckets.valid and not other._buckets.valid: # other._buckets = self._buckets.copy() #if self._max_concurrency != other._max_concurrency: # # _max_concurrency won't be None at this point # other._max_concurrency = self._max_concurrency.copy() # type: ignore try: other.on_error = self.on_error except AttributeError: pass return other def _update_copy(self, kwargs: Dict[str, Any]): if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) copy = self.__class__(self.callback, **kw) return self._ensure_assignment_on_copy(copy) else: return self.copy() channel_type_map = { 'TextChannel': ChannelType.text, 'VoiceChannel': ChannelType.voice, 'StageChannel': ChannelType.stage_voice, 'CategoryChannel': ChannelType.category, 'Thread': ChannelType.public_thread } class ThreadOption: def __init__(self, thread_type: Literal["public", "private", "news"]): type_map = { "public": ChannelType.public_thread, "private": ChannelType.private_thread, "news": ChannelType.news_thread, } self._type = type_map[thread_type] @property def __name__(self): return 'ThreadOption' class Option: def __init__( self, input_type: Any, /, description: str = None, **kwargs ) -> None: self.name: Optional[str] = kwargs.pop("name", None) self.description = description or "No description provided" self.converter = None self._raw_type = input_type self.channel_types: List[SlashCommandOptionType] = kwargs.pop("channel_types", []) if not isinstance(input_type, SlashCommandOptionType): if hasattr(input_type, "convert"): self.converter = input_type input_type = SlashCommandOptionType.string else: _type = SlashCommandOptionType.from_datatype(input_type) if _type == SlashCommandOptionType.channel: if not isinstance(input_type, tuple): input_type = (input_type,) for i in input_type: if i.__name__ == 'GuildChannel': continue if isinstance(i, ThreadOption): self.channel_types.append(i._type) continue channel_type = channel_type_map[i.__name__] self.channel_types.append(channel_type) input_type = _type self.input_type = input_type self.default = kwargs.pop("default", None) self.required: bool = kwargs.pop("required", True) if self.default is None else False self.choices: List[OptionChoice] = [ o if isinstance(o, OptionChoice) else OptionChoice(o) for o in kwargs.pop("choices", list()) ] if self.input_type == SlashCommandOptionType.integer: minmax_types = (int, type(None)) elif self.input_type == SlashCommandOptionType.number: minmax_types = (int, float, type(None)) else: minmax_types = (type(None),) minmax_typehint = Optional[Union[minmax_types]] # type: ignore self.min_value: minmax_typehint = kwargs.pop("min_value", None) self.max_value: minmax_typehint = kwargs.pop("max_value", None) if not (isinstance(self.min_value, minmax_types) or self.min_value is None): raise TypeError(f"Expected {minmax_typehint} for min_value, got \"{type(self.min_value).__name__}\"") if not (isinstance(self.max_value, minmax_types) or self.min_value is None): raise TypeError(f"Expected {minmax_typehint} for max_value, got \"{type(self.max_value).__name__}\"") self.autocomplete = kwargs.pop("autocomplete", None) def to_dict(self) -> Dict: as_dict = { "name": self.name, "description": self.description, "type": self.input_type.value, "required": self.required, "choices": [c.to_dict() for c in self.choices], "autocomplete": bool(self.autocomplete) } if self.channel_types: as_dict["channel_types"] = [t.value for t in self.channel_types] if self.min_value is not None: as_dict["min_value"] = self.min_value if self.max_value is not None: as_dict["max_value"] = self.max_value return as_dict def __repr__(self): return f"" class OptionChoice: def __init__(self, name: str, value: Optional[Union[str, int, float]] = None): self.name = name self.value = value if value is not None else name def to_dict(self) -> Dict[str, Union[str, int, float]]: return {"name": self.name, "value": self.value} def option(name, type=None, **kwargs): """A decorator that can be used instead of typehinting Option""" def decor(func): nonlocal type type = type or func.__annotations__.get(name, str) func.__annotations__[name] = Option(type, **kwargs) return func return decor class SlashCommandGroup(ApplicationCommand): r"""A class that implements the protocol for a slash command group. These can be created manually, but they should be created via the decorator or functional interface. Attributes ----------- name: :class:`str` The name of the command. description: Optional[:class:`str`] The description for the command. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. parent: Optional[:class:`SlashCommandGroup`] The parent group that this group belongs to. ``None`` if there isn't one. subcommands: List[Union[:class:`SlashCommand`, :class:`SlashCommandGroup`]] The list of all subcommands under this group. cog: Optional[:class:`Cog`] The cog that this command belongs to. ``None`` if there isn't one. checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] A list of predicates that verifies if the command could be executed with the given :class:`.ApplicationContext` as the sole parameter. If an exception is necessary to be thrown to signal failure, then one inherited from :exc:`.CommandError` should be used. Note that if the checks fail then :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` event. """ type = 1 def __new__(cls, *args, **kwargs) -> SlashCommandGroup: self = super().__new__(cls) self.__original_kwargs__ = kwargs.copy() self.__initial_commands__ = [] for i, c in cls.__dict__.items(): if isinstance(c, type) and SlashCommandGroup in c.__bases__: c = c( c.__name__, ( inspect.cleandoc(cls.__doc__).splitlines()[0] if cls.__doc__ is not None else "No description provided" ) ) if isinstance(c, (SlashCommand, SlashCommandGroup)): c.parent = self c.attached_to_group = True self.__initial_commands__.append(c) return self def __init__( self, name: str, description: str, guild_ids: Optional[List[int]] = None, parent: Optional[SlashCommandGroup] = None, **kwargs ) -> None: validate_chat_input_name(name) validate_chat_input_description(description) self.name = name self.description = description self.input_type = SlashCommandOptionType.sub_command_group self.subcommands: List[Union[SlashCommand, SlashCommandGroup]] = self.__initial_commands__ self.guild_ids = guild_ids self.parent = parent self.checks = [] self._before_invoke = None self._after_invoke = None self.cog = None # Permissions self.default_permission = kwargs.get("default_permission", True) self.permissions: List[CommandPermission] = kwargs.get("permissions", []) if self.permissions and self.default_permission: self.default_permission = False def to_dict(self) -> Dict: as_dict = { "name": self.name, "description": self.description, "options": [c.to_dict() for c in self.subcommands], "default_permission": self.default_permission, } if self.parent is not None: as_dict["type"] = self.input_type.value return as_dict def command(self, **kwargs) -> SlashCommand: def wrap(func) -> SlashCommand: command = SlashCommand(func, parent=self, **kwargs) self.subcommands.append(command) return command return wrap def create_subgroup(self, name, description) -> SlashCommandGroup: if self.parent is not None: # TODO: Improve this error message raise Exception("Subcommands can only be nested once") sub_command_group = SlashCommandGroup(name, description, parent=self) self.subcommands.append(sub_command_group) return sub_command_group def subgroup( self, name: Optional[str] = None, description: Optional[str] = None, guild_ids: Optional[List[int]] = None, ) -> Callable[[Type[SlashCommandGroup]], SlashCommandGroup]: """A shortcut decorator that initializes the provided subclass of :class:`.SlashCommandGroup` as a subgroup. .. versionadded:: 2.0 Parameters ---------- name: Optional[:class:`str`] The name of the group to create. This will resolve to the name of the decorated class if ``None`` is passed. description: Optional[:class:`str`] The description of the group to create. guild_ids: Optional[List[:class:`int`]] A list of the IDs of each guild this group should be added to, making it a guild command. This will be a global command if ``None`` is passed. Returns -------- Callable[[Type[SlashCommandGroup]], SlashCommandGroup] The slash command group that was created. """ def inner(cls: Type[SlashCommandGroup]) -> SlashCommandGroup: group = cls( name or cls.__name__, description or ( inspect.cleandoc(cls.__doc__).splitlines()[0] if cls.__doc__ is not None else "No description provided" ), guild_ids=guild_ids, parent=self, ) self.subcommands.append(group) return group return inner async def _invoke(self, ctx: ApplicationContext) -> None: option = ctx.interaction.data["options"][0] command = find(lambda x: x.name == option["name"], self.subcommands) ctx.interaction.data = option await command.invoke(ctx) async def invoke_autocomplete_callback(self, ctx: AutocompleteContext) -> None: option = ctx.interaction.data["options"][0] command = find(lambda x: x.name == option["name"], self.subcommands) ctx.interaction.data = option await command.invoke_autocomplete_callback(ctx) def walk_commands(self) -> Generator[SlashCommand, None, None]: """An iterator that recursively walks through all slash commands in this group. Yields ------ :class:`.SlashCommand` A slash command from the group. """ for command in self.subcommands: if isinstance(command, SlashCommandGroup): yield from command.walk_commands() yield command def copy(self): """Creates a copy of this command group. Returns -------- :class:`SlashCommandGroup` A new instance of this command group. """ ret = self.__class__( name=self.name, description=self.description, **self.__original_kwargs__, ) return self._ensure_assignment_on_copy(ret) def _ensure_assignment_on_copy(self, other): other.parent = self.parent other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke if self.subcommands != other.subcommands: other.subcommands = self.subcommands.copy() if self.checks != other.checks: other.checks = self.checks.copy() return other def _update_copy(self, kwargs: Dict[str, Any]): if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) copy = self.__class__(self.callback, **kw) return self._ensure_assignment_on_copy(copy) else: return self.copy() def _set_cog(self, cog): self.cog = cog for subcommand in self.subcommands: subcommand._set_cog(cog) class ContextMenuCommand(ApplicationCommand): r"""A class that implements the protocol for context menu commands. These are not created manually, instead they are created via the decorator or functional interface. Attributes ----------- name: :class:`str` The name of the command. callback: :ref:`coroutine ` The coroutine that is executed when the command is called. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. default_permission: :class:`bool` Whether the command is enabled by default when it is added to a guild. permissions: List[:class:`.CommandPermission`] The permissions for this command. .. note:: If this is not empty then default_permissions will be set to ``False``. cog: Optional[:class:`Cog`] The cog that this command belongs to. ``None`` if there isn't one. checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] A list of predicates that verifies if the command could be executed with the given :class:`.ApplicationContext` as the sole parameter. If an exception is necessary to be thrown to signal failure, then one inherited from :exc:`.CommandError` should be used. Note that if the checks fail then :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` event. cooldown: Optional[:class:`~discord.ext.commands.Cooldown`] The cooldown applied when the command is invoked. ``None`` if the command doesn't have a cooldown. .. versionadded:: 2.0 """ def __new__(cls, *args, **kwargs) -> ContextMenuCommand: self = super().__new__(cls) self.__original_kwargs__ = kwargs.copy() return self def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) if not asyncio.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") self.callback = func self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None) # Discord API doesn't support setting descriptions for context menu commands # so it must be empty self.description = "" self.name: str = kwargs.pop("name", func.__name__) if not isinstance(self.name, str): raise TypeError("Name of a command must be a string.") self.cog = None try: checks = func.__commands_checks__ checks.reverse() except AttributeError: checks = kwargs.get('checks', []) self.checks = checks self._before_invoke = None self._after_invoke = None self.validate_parameters() self.default_permission = kwargs.get("default_permission", True) self.permissions: List[CommandPermission] = getattr(func, "__app_cmd_perms__", []) + kwargs.get("permissions", []) if self.permissions and self.default_permission: self.default_permission = False # Context Menu commands can't have parents self.parent = None def validate_parameters(self): params = self._get_signature_parameters() if list(params.items())[0][0] == "self": temp = list(params.items()) temp.pop(0) params = dict(temp) params = iter(params) # next we have the 'ctx' as the next parameter try: next(params) except StopIteration: raise ClientException( f'Callback for {self.name} command is missing "ctx" parameter.' ) # next we have the 'user/message' as the next parameter try: next(params) except StopIteration: cmd = "user" if type(self) == UserCommand else "message" raise ClientException( f'Callback for {self.name} command is missing "{cmd}" parameter.' ) # next there should be no more parameters try: next(params) raise ClientException( f"Callback for {self.name} command has too many parameters." ) except StopIteration: pass @property def qualified_name(self): return self.name def to_dict(self) -> Dict[str, Union[str, int]]: return {"name": self.name, "description": self.description, "type": self.type, "default_permission": self.default_permission} class UserCommand(ContextMenuCommand): type = 2 def __new__(cls, *args, **kwargs) -> UserCommand: self = super().__new__(cls) self.__original_kwargs__ = kwargs.copy() return self async def _invoke(self, ctx: ApplicationContext) -> None: if "members" not in ctx.interaction.data["resolved"]: _data = ctx.interaction.data["resolved"]["users"] for i, v in _data.items(): v["id"] = int(i) user = v target = User(state=ctx.interaction._state, data=user) else: _data = ctx.interaction.data["resolved"]["members"] for i, v in _data.items(): v["id"] = int(i) member = v _data = ctx.interaction.data["resolved"]["users"] for i, v in _data.items(): v["id"] = int(i) user = v member["user"] = user target = Member( data=member, guild=ctx.interaction._state._get_guild(ctx.interaction.guild_id), state=ctx.interaction._state, ) if self.cog is not None: await self.callback(self.cog, ctx, target) else: await self.callback(ctx, target) def copy(self): """Creates a copy of this command. Returns -------- :class:`UserCommand` A new instance of this command. """ ret = self.__class__(self.callback, **self.__original_kwargs__) return self._ensure_assignment_on_copy(ret) def _ensure_assignment_on_copy(self, other): other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke if self.checks != other.checks: other.checks = self.checks.copy() #if self._buckets.valid and not other._buckets.valid: # other._buckets = self._buckets.copy() #if self._max_concurrency != other._max_concurrency: # # _max_concurrency won't be None at this point # other._max_concurrency = self._max_concurrency.copy() # type: ignore try: other.on_error = self.on_error except AttributeError: pass return other def _update_copy(self, kwargs: Dict[str, Any]): if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) copy = self.__class__(self.callback, **kw) return self._ensure_assignment_on_copy(copy) else: return self.copy() class MessageCommand(ContextMenuCommand): type = 3 def __new__(cls, *args, **kwargs) -> MessageCommand: self = super().__new__(cls) self.__original_kwargs__ = kwargs.copy() return self async def _invoke(self, ctx: ApplicationContext): _data = ctx.interaction.data["resolved"]["messages"] for i, v in _data.items(): v["id"] = int(i) message = v channel = ctx.interaction._state.get_channel(int(message["channel_id"])) if channel is None: data = await ctx.interaction._state.http.start_private_message( int(message["author"]["id"]) ) channel = ctx.interaction._state.add_dm_channel(data) target = Message(state=ctx.interaction._state, channel=channel, data=message) if self.cog is not None: await self.callback(self.cog, ctx, target) else: await self.callback(ctx, target) def copy(self): """Creates a copy of this command. Returns -------- :class:`MessageCommand` A new instance of this command. """ ret = self.__class__(self.callback, **self.__original_kwargs__) return self._ensure_assignment_on_copy(ret) def _ensure_assignment_on_copy(self, other): other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke if self.checks != other.checks: other.checks = self.checks.copy() #if self._buckets.valid and not other._buckets.valid: # other._buckets = self._buckets.copy() #if self._max_concurrency != other._max_concurrency: # # _max_concurrency won't be None at this point # other._max_concurrency = self._max_concurrency.copy() # type: ignore try: other.on_error = self.on_error except AttributeError: pass return other def _update_copy(self, kwargs: Dict[str, Any]): if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) copy = self.__class__(self.callback, **kw) return self._ensure_assignment_on_copy(copy) else: return self.copy() def slash_command(**kwargs): """Decorator for slash commands that invokes :func:`application_command`. .. versionadded:: 2.0 Returns -------- Callable[..., :class:`SlashCommand`] A decorator that converts the provided method into a :class:`.SlashCommand`. """ return application_command(cls=SlashCommand, **kwargs) def user_command(**kwargs): """Decorator for user commands that invokes :func:`application_command`. .. versionadded:: 2.0 Returns -------- Callable[..., :class:`UserCommand`] A decorator that converts the provided method into a :class:`.UserCommand`. """ return application_command(cls=UserCommand, **kwargs) def message_command(**kwargs): """Decorator for message commands that invokes :func:`application_command`. .. versionadded:: 2.0 Returns -------- Callable[..., :class:`MessageCommand`] A decorator that converts the provided method into a :class:`.MessageCommand`. """ return application_command(cls=MessageCommand, **kwargs) def application_command(cls=SlashCommand, **attrs): """A decorator that transforms a function into an :class:`.ApplicationCommand`. More specifically, usually one of :class:`.SlashCommand`, :class:`.UserCommand`, or :class:`.MessageCommand`. The exact class depends on the ``cls`` parameter. By default the ``description`` attribute is received automatically from the docstring of the function and is cleaned up with the use of ``inspect.cleandoc``. If the docstring is ``bytes``, then it is decoded into :class:`str` using utf-8 encoding. The ``name`` attribute also defaults to the function name unchanged. .. versionadded:: 2.0 Parameters ----------- cls: :class:`.ApplicationCommand` The class to construct with. By default this is :class:`.SlashCommand`. You usually do not change this. attrs Keyword arguments to pass into the construction of the class denoted by ``cls``. Raises ------- TypeError If the function is not a coroutine or is already a command. """ def decorator(func: Callable) -> cls: if isinstance(func, ApplicationCommand): func = func.callback elif not callable(func): raise TypeError( "func needs to be a callable or a subclass of ApplicationCommand." ) return cls(func, **attrs) return decorator def command(**kwargs): """There is an alias for :meth:`application_command`. .. note:: This decorator is overridden by :func:`commands.command`. .. versionadded:: 2.0 Returns -------- Callable[..., :class:`ApplicationCommand`] A decorator that converts the provided method into an :class:`.ApplicationCommand`. """ return application_command(**kwargs) docs = "https://discord.com/developers/docs" # Validation def validate_chat_input_name(name: Any): # Must meet the regex ^[\w-]{1,32}$ if not isinstance(name, str): raise TypeError(f"Chat input command names and options must be of type str. Received {name}") if not re.match(r"^[\w-]{1,32}$", name): raise ValidationError( r'Chat input command names and options must follow the regex "^[\w-]{1,32}$". For more information, see ' f"{docs}/interactions/application-commands#application-command-object-application-command-naming. Received " f"{name}" ) if not 1 <= len(name) <= 32: raise ValidationError( f"Chat input command names and options must be 1-32 characters long. Received {name}" ) if not name.lower() == name: # Can't use islower() as it fails if none of the chars can be lower. See #512. raise ValidationError(f"Chat input command names and options must be lowercase. Received {name}") def validate_chat_input_description(description: Any): if not isinstance(description, str): raise TypeError(f"Command description must be of type str. Received {description}") if not 1 <= len(description) <= 100: raise ValidationError( f"Command description must be 1-100 characters long. Received {description}" ) diff --git a/examples/views/button_roles.py b/examples/views/button_roles.py index d30726af..6104fbf8 100644 --- a/examples/views/button_roles.py +++ b/examples/views/button_roles.py @@ -1,109 +1,109 @@ import discord -from discord.commands.commands import slash_command +from discord.commands.core import slash_command from discord.ext import commands """ Let users assign themselves roles by clicking on Buttons. The view made is persistent, so it will work even when the bot restarts. See this example for more information about persistent views https://github.com/Pycord-Development/pycord/blob/master/examples/views/persistent.py Make sure to load this cog when your bot starts! """ # this is the list of role IDs that will be added as buttons. role_ids = [...] class RoleButton(discord.ui.Button): def __init__(self, role: discord.Role): """ A button for one role. `custom_id` is needed for persistent views. """ super().__init__( label=role.name, style=discord.enums.ButtonStyle.primary, custom_id=str(role.id), ) async def callback(self, interaction: discord.Interaction): """This function will be called any time a user clicks on this button Parameters ---------- interaction : discord.Interaction The interaction object that was created when the user clicked on the button """ # figure out who clicked the button user = interaction.user # get the role this button is for (stored in the custom ID) role = interaction.guild.get_role(int(self.custom_id)) if role is None: # if this role doesn't exist, ignore # you can do some error handling here return # passed all checks # add the role and send a response to the uesr ephemerally (hidden to other users) if role not in user.roles: # give the user the role if they don't already have it await user.add_roles(role) await interaction.response.send_message( f"🎉 You have been given the role {role.mention}", ephemeral=True ) else: # else, take the role from the user await user.remove_roles(role) await interaction.response.send_message( f"❌ The {role.mention} role has been taken from you", ephemeral=True ) class ButtonRoleCog(commands.Cog): """A cog with a slash command for posting the message with buttons and to initialize the view again when the bot is restarted """ def __init__(self, bot): self.bot = bot # make sure to set the guild ID here to whatever server you want the buttons in @slash_command(guild_ids=[...], description="Post the button role message") async def post(self, ctx: commands.Context): """A slash command to post a new view with a button for each role""" # timeout is None because we want this view to be persistent view = discord.ui.View(timeout=None) # loop through the list of roles and add a new button to the view for each role for role_id in role_ids: # get the role the guild by ID role = ctx.guild.get_role(role_id) view.add_item(RoleButton(role)) await ctx.respond("Click a button to assign yourself a role", view=view) @commands.Cog.listener() async def on_ready(self): """This function is called every time the bot restarts. If a view was already created before (with the same custom IDs for buttons) it will be loaded and the bot will start watching for button clicks again. """ # we recreate the view as we did in the /post command view = discord.ui.View(timeout=None) # make sure to set the guild ID here to whatever server you want the buttons in guild = self.bot.get_guild(...) for role_id in role_ids: role = guild.get_role(role_id) view.add_item(RoleButton(role)) # add the view to the bot so it will watch for button interactions self.bot.add_view(view) def setup(bot): # load the cog bot.add_cog(ButtonRoleCog(bot))