diff --git a/discord/abc.py b/discord/abc.py index e7a585b9ae..dbe5ffa07b 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1338,6 +1338,9 @@ async def invites(self) -> list[Invite]: for invite in data ] + def search(self, **params): + return self.guild.search(channels=[self], **params) + class Messageable: """An ABC that details the common operations on a model that can send messages. diff --git a/discord/enums.py b/discord/enums.py index beae2474e3..2fca1faabb 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -85,6 +85,9 @@ "SubscriptionStatus", "SeparatorSpacingSize", "SelectDefaultValueType", + "SearchEmbedType", + "SearchSortMode", + "SearchSortOrder", "ApplicationEventWebhookStatus", "InviteTargetUsersJobStatusCode", ) @@ -1136,6 +1139,39 @@ class SelectDefaultValueType(Enum): user = "user" +class SearchEmbedType(Enum): + """The types of media embedded on a message.""" + + image = "image" + video = "video" + gif = "gif" + sound = "sound" + article = "article" + + def __str__(self): + return self.value + + +class SearchSortMode(Enum): + """The sorting algorithm used for message searches.""" + + timestamp = "timestamp" + relevance = "relevance" + + def __str__(self): + return self.value + + +class SearchSortOrder(Enum): + """The order to sort message searches.""" + + asc = "asc" + desc = "desc" + + def __str__(self): + return self.value + + class RoleType(IntEnum): """Represents the type of role. diff --git a/discord/guild.py b/discord/guild.py index 4ecbfe57bd..cd6f1f7cf2 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -64,6 +64,9 @@ RoleType, ScheduledEventLocationType, ScheduledEventPrivacyLevel, + SearchEmbedType, + SearchSortMode, + SearchSortOrder, SortOrder, VerificationLevel, VideoQualityMode, @@ -81,6 +84,7 @@ BanIterator, EntitlementIterator, MemberIterator, + MessageSearchIterator, ) from .member import Member, VoiceState from .mixins import Hashable @@ -98,7 +102,7 @@ from .welcome_screen import WelcomeScreen, WelcomeScreenChannel from .widget import Widget -__all__ = ("BanEntry", "Guild", "GuildRoleCounts") +__all__ = ("BanEntry", "Guild", "GuildRoleCounts", "SearchHas", "SearchAuthors") MISSING = utils.MISSING @@ -152,6 +156,64 @@ class _GuildLimit(NamedTuple): filesize: int +class Parsable: + # idk this kinda sucks lmao + def __init__(self, **values): + self._values = values + for k, v in values.items(): + setattr(self, k, v) + + def parse(self) -> list[str]: + true, false = [], [] + for k, v in self._values.items(): + if v: + true.append(k) + elif v is False: + false.append(f"-{k}") + return true + false + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} resolved={self.parse()!r}>" + + +class SearchAuthors(Parsable): + def __init__( + self, + *, + user: bool | None = None, + bot: bool | None = None, + webhook: bool | None = None, + ): + super().__init__(user=user, bot=bot, webhook=webhook) + + +class SearchHas(Parsable): + def __init__( + self, + *, + image: bool | None = None, + sound: bool | None = None, + video: bool | None = None, + file: bool | None = None, + sticker: bool | None = None, + embed: bool | None = None, + link: bool | None = None, + poll: bool | None = None, + snapshot: bool | None = None, + ): + super().__init__( + image=image, + sound=sound, + video=video, + file=file, + sticker=sticker, + embed=embed, + link=link, + poll=poll, + snapshot=snapshot, + ) + + class GuildRoleCounts(dict[int, int]): """A dictionary subclass that maps role IDs to their member counts. @@ -4724,3 +4786,157 @@ def get_sound(self, sound_id: int) -> SoundboardSound | None: The sound or ``None`` if not found. """ return self._sounds.get(sound_id) + + def search( + self, + *, + limit: int | None = 25, + offset: int | None = None, + after: Snowflake | None = None, + before: Snowflake | None = None, + slop: int | None = 2, + content: str | None = None, + channels: list[Snowflake] | None = None, + author_types: SearchAuthors | None = None, + authors: list[Snowflake] | None = None, + mentions: list[Snowflake] | None = None, + mentions_roles: list[Snowflake] | None = None, + mention_everyone: bool | None = None, + replied_to_users: list[Snowflake] | None = None, + replied_to_messages: list[Snowflake] | None = None, + pinned: bool | None = None, + has: SearchHas | None = None, + embed_types: list[SearchEmbedType] | None = None, + embed_providers: list[str] | None = None, + link_hostnames: list[str] | None = None, + attachment_filenames: list[str] | None = None, + attachment_extensions: list[str] | None = None, + sort_by: SearchSortMode | None = None, + sort_order: SearchSortOrder | None = SearchSortOrder.desc, + include_nsfw: bool | None = False, + ) -> list[Message]: + """etc...""" + + params = {} + + if limit: + if limit <= 0: + raise ValueError("limit must be above 1") + params["limit"] = limit if limit <= 25 else limit + + if offset is not None: + if offset > 9975 or offset < 0: + raise ValueError("offset must be between 0 and 9975") + params["offset"] = offset + + if after: + params["min_id"] = after.id + + if before: + params["max_id"] = before.id + + if slop is not None: + if slop > 100 or slop < 0: + raise ValueError("slop must be between 0 and 100") + params["slop"] = int(slop) + + if content: + if len(content) > 1024: + raise ValueError("content must be under 1024 characters") + params["content"] = content + + if channels: + if len(channels) > 500: + raise ValueError("can only specify up to 500 channels") + params["channel_id"] = [c.id for c in channels] + + if author_types: + params["author_type"] = author_types.parse() + + if authors: + if len(authors) > 100: + raise ValueError("can only specify up to 100 authors") + params["author_id"] = [a.id for a in authors] + + if mentions: + if len(mentions) > 100: + raise ValueError("can only specify up to 100 mentions") + params["mentions"] = [m.id for m in mentions] + + if mentions_roles: + if len(mentions_roles) > 100: + raise ValueError("can only specify up to 100 mentions_roles") + params["mentions_role_id"] = [m.id for m in mentions_roles] + + if mention_everyone is not None: + params["mention_everyone"] = mention_everyone + + if replied_to_users: + if len(replied_to_users) > 100: + raise ValueError("can only specify up to 100 replied_to_users") + params["replied_to_user_id"] = [u.id for u in replied_to_users] + + if replied_to_messages: + if len(replied_to_messages) > 100: + raise ValueError("can only specify up to 100 replied_to_messages") + params["replied_to_message_id"] = [m.id for u in replied_to_messages] + + if pinned is not None: + params["pinned"] = pinned + + if has: + params["has"] = has.parse() + + if embed_types: + params["embed_type"] = [str(t) for t in embed_types] + + if embed_providers: + if len(embed_providers) > 100: + raise ValueError("can only specify up to 100 embed_providers") + for e in embed_providers: + if len(e) > 256: + raise ValueError( + f"embed_provider {e!r} must be up to 256 characters." + ) + params["embed_provider"] = embed_providers + + if link_hostnames: + if len(link_hostnames) > 100: + raise ValueError("can only specify up to 100 link_hostnames") + for l in link_hostnames: + if len(l) > 256: + raise ValueError( + f"link_hostname {l!r} must be up to 256 characters." + ) + params["link_hostname"] = link_hostnames + + if attachment_filenames: + if len(attachment_filenames) > 100: + raise ValueError("can only specify up to 100 attachment_filenames") + for a in attachment_filenames: + if len(a) > 1024: + raise ValueError( + f"attachment_filename {a!r} must be up to 1024 characters." + ) + params["attachment_filename"] = attachment_filenames + + if attachment_extensions: + if len(attachment_extensions) > 100: + raise ValueError("can only specify up to 100 attachment_extensions") + for a in attachment_extensions: + if len(a) > 256: + raise ValueError( + f"attachment_extension {a!r} must be up to 256 characters." + ) + params["attachment_extension"] = attachment_extensions + + if sort_by: + params["sort_by"] = str(sort_by) + + if sort_order: + params["sort_order"] = str(sort_order) + + if include_nsfw is not None: + params["include_nsfw"] = include_nsfw + + return MessageSearchIterator(self, limit, params) diff --git a/discord/http.py b/discord/http.py index 0717feadf5..68b5d733c0 100644 --- a/discord/http.py +++ b/discord/http.py @@ -958,6 +958,70 @@ def legacy_pins_from( Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id) ) + def message_search( + self, + guild_id: Snowflake, + *, + limit: int | None = None, + offset: int | None = None, + min_id: Snowflake | None = None, + max_id: Snowflake | None = None, + slop: int | None = None, + content: str | None = None, + channel_id: SnowflakeList | None = None, + author_type: message.SearchAuthorTypes | None = None, + author_id: SnowflakeList | None = None, + mentions: SnowflakeList | None = None, + mentions_role_id: SnowflakeList | None = None, + mention_everyone: bool | None = None, + replied_to_user_id: SnowflakeList | None = None, + replied_to_message_id: SnowflakeList | None = None, + pinned: bool | None = None, + has: list[message.SearchHasTypes] | None = None, + embed_type: list[message.SearchEmbedTypes] | None = None, + embed_provider: list[str] | None = None, + link_hostname: list[str] | None = None, + attachment_filename: list[str] | None = None, + attachment_extension: list[str] | None = None, + sort_by: message.SearchSortModes | None = None, + sort_order: message.SearchSortOrders | None = None, + include_nsfw: bool | None = None, + ) -> Response[message.MessageSearchResults]: + + p = { + "limit": limit, + "offset": offset, + "min_id": min_id, + "max_id": max_id, + "slop": slop, + "content": content, + "channel_id": channel_id, + "author_type": author_type, + "author_id": author_id, + "mentions": mentions, + "mentions_role_id": mentions_role_id, + "mention_everyone": ( + int(mention_everyone) if mention_everyone is not None else None + ), + "replied_to_user_id": replied_to_user_id, + "replied_to_message_id": replied_to_message_id, + "pinned": int(pinned) if pinned is not None else None, + "has": has, + "embed_type": embed_type, + "embed_provider": embed_provider, + "link_hostname": link_hostname, + "attachment_filename": attachment_filename, + "attachment_extension": attachment_extension, + "sort_by": sort_by, + "sort_order": sort_order, + "include_nsfw": int(include_nsfw) if include_nsfw is not None else None, + } + params = {k: v for k, v in p.items() if v is not None} + return self.request( + Route("GET", "/guilds/{guild_id}/messages/search", guild_id=guild_id), + params=params, + ) + # Member management def kick( diff --git a/discord/iterators.py b/discord/iterators.py index b074aefdc4..2959dd407b 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -50,6 +50,7 @@ "AuditLogIterator", "GuildIterator", "MemberIterator", + "MessageSearchIterator", "ScheduledEventSubscribersIterator", "EntitlementIterator", "SubscriptionIterator", @@ -68,6 +69,7 @@ from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload from .types.message import MessagePin as MessagePinPayload + from .types.message import MessageSearch as MessageSearchPayload from .types.monetization import Entitlement as EntitlementPayload from .types.monetization import Subscription as SubscriptionPayload from .types.threads import Thread as ThreadPayload @@ -1283,3 +1285,74 @@ def __await__(self) -> Generator[Any, Any, MessagePin]: reference="The documentation of pins()", ) return self.retrieve_inner().__await__() + + +class MessageSearchIterator(_AsyncIterator["Message"]): + """Iterator for receiving a guild's search results.""" + + def __init__( + self, + guild, + limit, + params, + ): + self.guild = guild + self.limit = limit + self.params = params + if "limit" in params: + if params["limit"] is None or params["limit"] > 25: + params["limit"] = 25 + + self.state = self.guild._state + self.search = self.state.http.message_search + self.messages = asyncio.Queue() + self.message_ids = [] + + async def next(self) -> Message: + if self.messages.empty(): + await self.fill_messages() + + try: + return self.messages.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + def _get_retrieve(self) -> bool: + l = self.limit + if l is None or l > 25: + r = 25 + else: + r = l + self.retrieve = r + return r > 0 + + async def fill_messages(self): + + if self._get_retrieve(): + data = await self._retrieve_messages(self.retrieve) + if not data["messages"]: + # "Clients should not rely on the length of the `messages` array to paginate results" + self.limit = 0 # terminate the infinite loop + + data.get("threads", []) + members = data.get("members", []) # do something here + + for element in data["messages"]: + message = element[0] + if int(message["id"]) not in self.message_ids: + ch = self.guild.get_channel(int(message["channel_id"])) + channel = await ch._get_channel() + await self.messages.put( + self.state.create_message(channel=channel, data=message) + ) + self.message_ids.append(int(message["id"])) + + async def _retrieve_messages(self, retrieve: int) -> list[MessagePayload]: + data: list[MessageSearchPayload] = await self.search( + self.guild.id, **self.params + ) + self.params["offset"] = self.params.get("offset", 0) + retrieve + if data["messages"]: + if self.limit is not None: + self.limit -= retrieve + return data diff --git a/discord/member.py b/discord/member.py index 0353472c8d..04db0ed535 100644 --- a/discord/member.py +++ b/discord/member.py @@ -1315,3 +1315,12 @@ def get_role(self, role_id: int, /) -> Role | None: The role or ``None`` if not found in the member's roles. """ return self.guild.get_role(role_id) if self._roles.has(role_id) else None + + def search_messages(self, **params): + return self.guild.search(authors=[self], **params) + + def search_replies(self, **params): + return self.guild.search(replied_to_users=[self], **params) + + def search_mentions(self, **params): + return self.guild.search(mentions=[self], **params) diff --git a/discord/message.py b/discord/message.py index 736397ee4e..7b200dee20 100644 --- a/discord/message.py +++ b/discord/message.py @@ -2275,6 +2275,9 @@ async def end_poll(self) -> Message: return message + def search_replies(self, **params): + return self.guild.search(replied_to_messages=[self], **params) + def to_reference( self, *, fail_if_not_exists: bool = True, type: MessageReferenceType = None ) -> MessageReference: diff --git a/discord/role.py b/discord/role.py index 0508a87671..8dcf6b54c9 100644 --- a/discord/role.py +++ b/discord/role.py @@ -917,3 +917,6 @@ async def delete(self, *, reason: str | None = None) -> None: """ await self._state.http.delete_role(self.guild.id, self.id, reason=reason) + + def search_mentions(self, **params): + return self.guild.search(mentions_roles=[self], **params) diff --git a/discord/types/message.py b/discord/types/message.py index c6a48881c7..33365fbaef 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -35,7 +35,7 @@ from .poll import Poll from .snowflake import Snowflake, SnowflakeList from .sticker import StickerItem -from .threads import Thread +from .threads import Thread, ThreadMember from .user import User if TYPE_CHECKING: @@ -194,3 +194,69 @@ class AllowedMentions(TypedDict): roles: SnowflakeList users: SnowflakeList replied_user: bool + + +SearchAuthorTypes = Literal["user", "bot", "webhook", "-user", "-bot", "-webhook"] +SearchHasTypes = Literal[ + "image", + "sound", + "video", + "file", + "sticker", + "embed", + "link", + "poll", + "snapshot", + "-image", + "-sound", + "-video", + "-file", + "-sticker", + "-embed", + "-link", + "-poll", + "-snapshot", +] +SearchEmbedTypes = Literal["image", "video", "gif", "sound", "article"] +SearchSortModes = Literal["relevance", "timestamp"] +SearchSortOrders = Literal["asc", "desc"] + + +class MessageSearch(TypedDict): + limit: NotRequired[int] + offset: NotRequired[int] + max_id: NotRequired[Snowflake] + min_id: NotRequired[Snowflake] + slop: NotRequired[int] + content: NotRequired[str] + channel_id: NotRequired[SnowflakeList] + author_type: NotRequired[list[SearchAuthorTypes]] + author_id: NotRequired[SnowflakeList] + mentions: NotRequired[SnowflakeList] + mentions_role_id: NotRequired[SnowflakeList] + mention_everyone: NotRequired[bool] + replied_to_user_id: NotRequired[SnowflakeList] + replied_to_message_id: NotRequired[SnowflakeList] + pinned: NotRequired[bool] + has: NotRequired[list[SearchHasTypes]] + embed_type: NotRequired[list[SearchEmbedTypes]] + embed_provider: NotRequired[list[str]] + link_hostname: NotRequired[list[str]] + attachment_filename: NotRequired[list[str]] + attachment_extension: NotRequired[list[str]] + sort_by: NotRequired[SearchSortModes] + sort_order: NotRequired[SearchSortOrders] + include_nsfw: NotRequired[bool] + cursor: NotRequired[dict] + command_id: NotRequired[Snowflake] + command_name: NotRequired[str] + contents: NotRequired[list[str]] + + +class MessageSearchResults(TypedDict): + doing_deep_historical_index: bool + documents_indexed: NotRequired[int] + total_results: int + messages: list[list[Message]] # ????? + threads: NotRequired[list[Thread]] + members: NotRequired[list[ThreadMember]]