Skip to content

Commit

Permalink
I think i just fixed caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Shell1010 committed Jan 8, 2024
1 parent d8a5b1e commit 0f9e063
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 35 deletions.
14 changes: 12 additions & 2 deletions selfcord/api/events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from time import perf_counter
from aioconsole import aprint
import asyncio
from ..models import Guild, Convert, User, Message, Member, MessageAck, MessageReactionAdd, PresenceUpdate, DMChannel
import ujson

Expand Down Expand Up @@ -181,12 +182,21 @@ async def handle_ready_supplemental(self, data: dict):

await self.bot.inbuilt_commands()
await self.bot.emit("ready_supplemental")
await asyncio.sleep(2)

for guild in self.bot.user.guilds:
if guild.member_count >= 1000:
for channel in guild.channels:
if channel.type == 0:
await self.bot.gateway.cache_guild(guild, guild.channels[0])
break

async def handle_message_create(self, data: dict):
message = Message(data, self.bot)
self.bot.cached_messages[message.id] = message
if message.author.id not in self.bot.cached_users:
self.bot.cached_users[message.author.id] = message.author

await self.bot.process_commands(message)
await self.bot.emit("message", message)

Expand Down Expand Up @@ -236,15 +246,15 @@ async def handle_thread_delete(self, data: dict):
async def handle_guild_create(self, data: dict):
guild = Guild(data, self.bot)
self.bot.user.guilds.append(guild)
await self.bot.emit("guild_create")
await self.bot.emit("guild_create", guild)

async def handle_guild_delete(self, data: dict):
guild = self.bot.fetch_guild(data['id'])
await self.bot.emit("guild_delete", guild)
del guild

async def handle_guild_member_list_update(self, data: dict):
print(data)
# print(data)
pass

async def handle_presence_update(self, data: dict):
Expand Down
79 changes: 57 additions & 22 deletions selfcord/api/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
from .events import Handler
import websockets
from aioconsole import aprint
import ujson

if TYPE_CHECKING:
Expand Down Expand Up @@ -40,20 +41,22 @@ def __init__(self, bot: Bot, decompress: bool = True) -> None:
self.last_ack: float = 0
self.last_send: float = 0
self.latency: float = float("inf")
self.ws: Optional[Connect] = None
self.ws: Connect
self.alive = False
self.URL = (
"wss://gateway.discord.gg/?encoding=json&v=9&compress=zlib-stream"
if self.decompress else
"wss://gateway.discord.gg/?encoding=json&v=9"
)
self.resume_url: Optional[str]
self.session_id: Optional[str]


async def send_json(self, payload: dict):
if self.ws:
await self.ws.send(ujson.dumps(payload))
try:
await self.ws.send(ujson.dumps(payload))
except Exception:
await self.close()
await self.connect(self.bot.resume_url)

async def load_async(self, item):
loop = asyncio.get_event_loop()
Expand All @@ -70,14 +73,10 @@ async def recv_json(self):
if len(item) < 4 or item[-4:] != self.zlib_suffix:
return
n = len(item)
try:
item = self.zlib.decompress(item)
self.zlib.flush(n)
# self.zlib = decompressobj(15)
except Exception as e:
# with open("test.txt", "a+") as f:
# f.write(f"{item}\n")
print(e)

item = self.zlib.decompress(item)
self.zlib.flush(n)

item = await self.load_async(item)

# await asyncio.sleep(1)
Expand All @@ -86,6 +85,7 @@ async def recv_json(self):
op = item["op"]
data = item["d"]
event = item["t"]
seq = item["s"]

if op == self.HELLO:
interval = data["heartbeat_interval"] / 1000.0
Expand All @@ -95,35 +95,67 @@ async def recv_json(self):
elif op == self.HEARTBEAT_ACK:
self.heartbeat_ack()

elif op == self.RECONNECT:
await self.close()
await asyncio.sleep(3)
await self.connect(f"{self.bot.resume_url}?encoding=json&v=9&compress=zlib-stream")

await self.send_json({
"op": 6,
"d": {"token": self.token, "session_id": self.bot.session_id, "seq": seq},
})

elif op == self.DISPATCH:
if hasattr(self.handler, f"handle_{event.lower()}"):
method = getattr(
self.handler, f"handle_{event.lower()}")
asyncio.create_task(method(data))

async def connect(self):
async def connect(self, url: str):
self.ws = await websockets.connect(
self.URL, origin="https://discord.com", max_size=None,
url, origin="https://discord.com", max_size=None,
extra_headers={"user_agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/118.0.0.0"},
read_limit=1000000, max_queue=100, write_limit=1000000,
)
self.alive = True

async def start(self, token: str):
await self.connect()
await self.connect(self.URL)

self.token = token
while self.alive:
try:
await self.recv_json()
except Exception:
await self.close()
await self.connect(self.bot.resume_url)

async def cache_guild(self, guild: Guild, channel):
payload = {
"op": 14,
"d": {
"guild_id": guild.id,
"typing": True,
"threads": False,
"activities": True,
"members": [],
"channels": {
str(channel.id): [
[
0,
99
]
]
}
}
}
await self.send_json(payload)

async def close(self):
"""This function closes the websocket
"""
self.alive = False
if self.ws:
await self.ws.close()
await self.ws.close()

async def identify(self):
payload = {
Expand Down Expand Up @@ -211,7 +243,7 @@ def chunks(self, lst, n):
break
yield lst[: i + 1]

async def chunk_members(self, guild: Guild):
def correct_channels(self, guild: Guild):
roles = guild.me.roles

channels = []
Expand All @@ -223,7 +255,7 @@ async def chunk_members(self, guild: Guild):
for name, value in permission.items():
if name == "VIEW_CHANNEL":
if value:
print(channel.name, "has view channel permission, original")

channels.append(channel)
break

Expand All @@ -233,13 +265,15 @@ async def chunk_members(self, guild: Guild):
if name == "VIEW_CHANNEL":
break
else:
print(channel.name, "has view channel permission")

channels.append(channel)
break
print(len(channels))

return list(set(channels))


async def chunk_members(self, guild: Guild):
channels = self.correct_channels(guild)
channels = channels[:5]
ranges = []

if guild.member_count is not None:
Expand All @@ -256,6 +290,7 @@ async def chunk_members(self, guild: Guild):
"d": {
"guild_id": guild.id,
"typing": True,
"threads": True
}
}
data = payload['d']
Expand Down
15 changes: 12 additions & 3 deletions selfcord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def decorator(coro):

return decorator



async def load_tokens(self, tokens: list, prefixes: list[str] = ["!"], eval: bool = False):
rmv = []
Expand Down Expand Up @@ -493,27 +492,37 @@ async def process_commands(self, msg):
asyncio.create_task(context.invoke())

def fetch_message(self, message_id: str) -> Optional[Message]:
if message_id is None:
return
return self.cached_messages.get(message_id)

def fetch_user(self, user_id: str) -> Optional[User]:
if user_id is None:
return
return self.cached_users.get(user_id)

def fetch_channel(self, channel_id: str) -> Optional[Messageable]:
if channel_id is None:
return
return self.cached_channels.get(channel_id)

# Cry it's O(N) - max 100 guilds so it's cool
def fetch_guild(self, guild_id: str) -> Optional[Guild]:
if guild_id is None:
return
for guild in self.user.guilds:
if guild.id == guild_id:
return guild
return


def fetch_role(self, role_id: str) -> Optional[Role]:
if role_id is None:
return
for guild in self.user.guilds:
for role in guild.roles:
if role.id == role_id:
return role
return



async def get_user(self, user_id: str) -> Optional[User]:
Expand Down
20 changes: 17 additions & 3 deletions selfcord/models/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, payload: dict, bot: Bot):
def update(self, payload: dict):
self.id: str = payload["id"]
self.type: int = payload["type"]
self.allow: int = Permission(payload["allow"], self.bot)
self.deny: int = Permission(payload["deny"], self.bot)
self.allow: Permission = Permission(payload["allow"], self.bot)
self.deny: Permission = Permission(payload["deny"], self.bot)

class Channel:
def __init__(self, payload: dict, bot: Bot):
Expand All @@ -38,6 +38,10 @@ def update(self, payload):
self.flags = payload.get("flags")
self.last_message_id = payload.get("last_message_id")
self.guild_id = payload.get("guild_id")

@property
def guild(self):
return self.bot.fetch_guild(self.guild_id)

async def delete(self):
await self.http.request(
Expand All @@ -49,6 +53,8 @@ class Callable(Channel):
def __init__(self, payload: dict, bot: Bot):
self.bot: Bot = bot
self.http: HttpClient = bot.http
super().__init__(payload, self.bot)
super().update(payload)
self.update(payload)

def update(self, payload):
Expand Down Expand Up @@ -76,6 +82,8 @@ class Messageable(Channel):
def __init__(self, payload: dict, bot: Bot):
self.bot: Bot = bot
self.http: HttpClient = bot.http
super().__init__(payload, self.bot)
super().update(payload)
self.update(payload)

def __repr__(self):
Expand Down Expand Up @@ -155,7 +163,9 @@ async def history(self, limit: int = 50, bot_user_only: bool = False):
return msgs

for message in json:
message.update({"guild_id": self.guild_id})
msg = Message(message, self.bot)

if bot_user_only:
if msg.author.id == self.bot.user.id:
msgs.append(msg)
Expand All @@ -180,6 +190,7 @@ async def history(self, limit: int = 50, bot_user_only: bool = False):
)

for message in json:
message.update({"guild_id": self.guild_id})
msg = Message(message, self.bot)
if bot_user_only:
if msg.author.id == self.bot.user.id:
Expand Down Expand Up @@ -219,7 +230,10 @@ async def search_base(
if json is not None:
for messages in json['messages']:
for message in messages:
self.bot.cached_messages[message.id] = Message(message, self.bot)
message.update({"guild_id": self.guild_id})
msg = Message(message, self.bot)

self.bot.cached_messages[message.id] = msg
return [Message(message, self.bot) for messages in json['messages'] for message in messages]

async def search(
Expand Down
4 changes: 3 additions & 1 deletion selfcord/models/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def update(self, payload: dict):
)
self.preferred_locale: Optional[str] = properties.get(
"preferred_locale")

self.icon: Optional[Asset] = (
Asset(self.id, properties["icon"]).from_icon()
if properties.get("discovery_splash") is not None
if properties.get("icon") is not None
else None
)
self.latest_onboarding_question_id: Optional[str] = properties.get(
Expand Down Expand Up @@ -165,6 +166,7 @@ def partial_update(self, payload: dict):
else None
))
elif key == "icon":

setattr(self, key, (
Asset(self.id, payload["banner"]).from_avatar()
if payload.get("banner") is not None and self.id is not None
Expand Down
16 changes: 12 additions & 4 deletions selfcord/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,29 @@ def link(self):
else:
link = f"https://discord.com/channels/@me/{self.channel_id}/{self.id}"
return link

@property
def guild(self):
return self.bot.fetch_guild(self.guild_id)

@property
def channel(self):
return self.bot.fetch_channel(self.channel_id)

def update(self, payload: dict):
self.id: Optional[str] = payload.get("id")
self.content: Optional[str] = payload.get("content")
if self.content is not None:
self.content = self.content.replace("\x00", "")
self.type: int = payload.get("type", 0)
self.tts: bool = payload.get("tts", False)
self.timestamp: Optional[int] = payload.get("timestamp")
self.replied_message: Optional[Message] = payload.get("referenced_message")
self.pinned: Optional[bool] = payload.get("pinned")
self.nonce: Optional[int] = payload.get("nonce")
self.mentions: Optional[dict] = payload.get("mentions")
self.channel_id: str = payload.get("channel_id", "")
self.channel: Optional[Messageable] = self.bot.fetch_channel(self.channel_id)
self.guild_id: str = payload.get("guild_id")
self.guild: Optional[Guild] = self.bot.fetch_guild(self.guild_id)
self.channel_id: Optional[str] = payload.get("channel_id", "")
self.guild_id: Optional[str] = payload.get("guild_id")
# if payload.get("author") is None:
# print(payload)
self.author: Optional[User] = (
Expand Down

0 comments on commit 0f9e063

Please sign in to comment.