Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: native zlib support #10243

Merged
merged 14 commits into from
May 11, 2024
5 changes: 4 additions & 1 deletion packages/ws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ const manager = new WebSocketManager({
intents: 0, // for no intents
rest,
// uncomment if you have zlib-sync installed and want to use compression
// compression: CompressionMethod.ZlibStream,
// compression: CompressionMethod.ZlibSync,

// alternatively, we support compression using node's native `node:zlib` module:
// compression: CompressionMethod.ZlibNative,
});

manager.on(WebSocketShardEvents.Dispatch, (event) => {
Expand Down
9 changes: 8 additions & 1 deletion packages/ws/src/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ export enum Encoding {
* Valid compression methods
*/
export enum CompressionMethod {
ZlibStream = 'zlib-stream',
ZlibNative,
ZlibSync,
}

export const DefaultDeviceProperty = `@discordjs/ws [VI]{{inject}}[/VI]` as `@discordjs/ws ${string}`;

const getDefaultSessionStore = lazy(() => new Collection<number, SessionInfo | null>());

export const CompressionParameterMap = {
[CompressionMethod.ZlibNative]: 'zlib-stream',
[CompressionMethod.ZlibSync]: 'zlib-stream',
} as const satisfies Record<CompressionMethod, string>;
vladfrangu marked this conversation as resolved.
Show resolved Hide resolved

/**
* Default options used by the manager
*/
Expand All @@ -46,6 +52,7 @@ export const DefaultWebSocketManagerOptions = {
version: APIVersion,
encoding: Encoding.JSON,
compression: null,
useIdentifyCompression: false,
retrieveSessionInfo(shardId) {
const store = getDefaultSessionStore();
return store.get(shardId) ?? null;
Expand Down
10 changes: 8 additions & 2 deletions packages/ws/src/ws/WebSocketManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ export interface OptionalWebSocketManagerOptions {
*/
buildStrategy(manager: WebSocketManager): IShardingStrategy;
/**
* The compression method to use
* The transport compression method to use - mutually exclusive with `useIdentifyCompression`
*
* @defaultValue `null` (no compression)
* @defaultValue `null` (no transport compression)
*/
compression: CompressionMethod | null;
/**
Expand Down Expand Up @@ -176,6 +176,12 @@ export interface OptionalWebSocketManagerOptions {
* Function used to store session information for a given shard
*/
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
/**
* Whether to use the `compress` option when identifying
*
* @defaultValue `false`
*/
useIdentifyCompression: boolean;
/**
* The gateway version to use
*
Expand Down
195 changes: 138 additions & 57 deletions packages/ws/src/ws/WebSocketShard.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
/* eslint-disable id-length */
import { Buffer } from 'node:buffer';
import { once } from 'node:events';
import { clearInterval, clearTimeout, setInterval, setTimeout } from 'node:timers';
import { setTimeout as sleep } from 'node:timers/promises';
import { URLSearchParams } from 'node:url';
import { TextDecoder } from 'node:util';
import { inflate } from 'node:zlib';
import type * as nativeZlib from 'node:zlib';
import { Collection } from '@discordjs/collection';
import { lazy, shouldUseGlobalFetchAndWebSocket } from '@discordjs/util';
import { AsyncQueue } from '@sapphire/async-queue';
Expand All @@ -21,13 +20,20 @@ import {
type GatewaySendPayload,
} from 'discord-api-types/v10';
import { WebSocket, type Data } from 'ws';
import type { Inflate } from 'zlib-sync';
import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy.js';
import { ImportantGatewayOpcodes, getInitialSendRateLimitState } from '../utils/constants.js';
import type * as ZlibSync from 'zlib-sync';
import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy';
import {
CompressionMethod,
CompressionParameterMap,
ImportantGatewayOpcodes,
getInitialSendRateLimitState,
} from '../utils/constants.js';
import type { SessionInfo } from './WebSocketManager.js';

// eslint-disable-next-line promise/prefer-await-to-then
/* eslint-disable promise/prefer-await-to-then */
const getZlibSync = lazy(async () => import('zlib-sync').then((mod) => mod.default).catch(() => null));
const getNativeZlib = lazy(async () => import('node:zlib').then((mod) => mod).catch(() => null));
/* eslint-enable promise/prefer-await-to-then */

export enum WebSocketShardEvents {
Closed = 'closed',
Expand Down Expand Up @@ -86,9 +92,9 @@ const WebSocketConstructor: typeof WebSocket = shouldUseGlobalFetchAndWebSocket(
export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private connection: WebSocket | null = null;

private useIdentifyCompress = false;
private nativeInflate: nativeZlib.Inflate | null = null;

private inflate: Inflate | null = null;
private zLibSyncInflate: ZlibSync.Inflate | null = null;

private readonly textDecoder = new TextDecoder();

Expand Down Expand Up @@ -120,6 +126,18 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {

#status: WebSocketShardStatus = WebSocketShardStatus.Idle;

private identifyCompressionEnabled = false;

/**
* @privateRemarks
*
* This is needed because `this.strategy.options.compression` is not an actual reflection of the compression method
* used, but rather the compression method that the user wants to use. This is because the libraries could just be missing.
*/
private get transportCompressionEnabled() {
return this.strategy.options.compression !== null && (this.nativeInflate ?? this.zLibSyncInflate) !== null;
}

public get status(): WebSocketShardStatus {
return this.#status;
}
Expand Down Expand Up @@ -161,21 +179,63 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
throw new Error("Tried to connect a shard that wasn't idle");
}

const { version, encoding, compression } = this.strategy.options;
const { version, encoding, compression, useIdentifyCompression } = this.strategy.options;
this.identifyCompressionEnabled = useIdentifyCompression;

// eslint-disable-next-line id-length
const params = new URLSearchParams({ v: version, encoding });
if (compression) {
const zlib = await getZlibSync();
if (zlib) {
params.append('compress', compression);
this.inflate = new zlib.Inflate({
chunkSize: 65_535,
to: 'string',
});
} else if (!this.useIdentifyCompress) {
this.useIdentifyCompress = true;
console.warn(
'WebSocketShard: Compression is enabled but zlib-sync is not installed, falling back to identify compress',
);
if (compression !== null) {
if (useIdentifyCompression) {
console.warn('WebSocketShard: transport compression is enabled, disabling identify compression');
this.identifyCompressionEnabled = false;
}

params.append('compress', CompressionParameterMap[compression]);

switch (compression) {
case CompressionMethod.ZlibNative: {
const zlib = await getNativeZlib();
if (zlib) {
const inflate = zlib.createInflate({
chunkSize: 65_535,
flush: zlib.constants.Z_SYNC_FLUSH,
});

inflate.on('error', (error) => {
this.emit(WebSocketShardEvents.Error, { error });
});

this.nativeInflate = inflate;
} else {
console.warn('WebSocketShard: Compression is set to native but node:zlib is not available.');
params.delete('compress');
}

break;
}

case CompressionMethod.ZlibSync: {
const zlib = await getZlibSync();
if (zlib) {
this.zLibSyncInflate = new zlib.Inflate({
chunkSize: 65_535,
to: 'string',
});
} else {
console.warn('WebSocketShard: Compression is set to zlib-sync, but it is not installed.');
params.delete('compress');
}

break;
}
}
}

if (this.identifyCompressionEnabled) {
const zlib = await getNativeZlib();
if (!zlib) {
console.warn('WebSocketShard: Identify compression is enabled, but node:zlib is not available.');
this.identifyCompressionEnabled = false;
}
}

Expand Down Expand Up @@ -451,28 +511,29 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
`shard id: ${this.id.toString()}`,
`shard count: ${this.strategy.options.shardCount}`,
`intents: ${this.strategy.options.intents}`,
`compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`,
`compression: ${this.transportCompressionEnabled ? CompressionParameterMap[this.strategy.options.compression!] : this.identifyCompressionEnabled ? 'identify' : 'none'}`,
]);

const d: GatewayIdentifyData = {
const data: GatewayIdentifyData = {
token: this.strategy.options.token,
properties: this.strategy.options.identifyProperties,
intents: this.strategy.options.intents,
compress: this.useIdentifyCompress,
compress: this.identifyCompressionEnabled,
shard: [this.id, this.strategy.options.shardCount],
};

if (this.strategy.options.largeThreshold) {
d.large_threshold = this.strategy.options.largeThreshold;
data.large_threshold = this.strategy.options.largeThreshold;
}

if (this.strategy.options.initialPresence) {
d.presence = this.strategy.options.initialPresence;
data.presence = this.strategy.options.initialPresence;
}

await this.send({
op: GatewayOpcodes.Identify,
d,
// eslint-disable-next-line id-length
d: data,
});

await this.waitForEvent(WebSocketShardEvents.Ready, this.strategy.options.readyTimeout);
Expand All @@ -490,6 +551,7 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
this.replayedEvents = 0;
return this.send({
op: GatewayOpcodes.Resume,
// eslint-disable-next-line id-length
d: {
token: this.strategy.options.token,
seq: session.sequence,
Expand All @@ -507,13 +569,22 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {

await this.send({
op: GatewayOpcodes.Heartbeat,
// eslint-disable-next-line id-length
d: session?.sequence ?? null,
});

this.lastHeartbeatAt = Date.now();
this.isAck = false;
}

private parseInflateResult(result: any): GatewayReceivePayload | null {
if (!result) {
return null;
}

return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload;
}

private async unpackMessage(data: Data, isBinary: boolean): Promise<GatewayReceivePayload | null> {
// Deal with no compression
if (!isBinary) {
Expand All @@ -528,10 +599,12 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
const decompressable = new Uint8Array(data as ArrayBuffer);

// Deal with identify compress
if (this.useIdentifyCompress) {
return new Promise((resolve, reject) => {
if (this.identifyCompressionEnabled) {
// eslint-disable-next-line no-async-promise-executor
return new Promise(async (resolve, reject) => {
const zlib = (await getNativeZlib())!;
// eslint-disable-next-line promise/prefer-await-to-callbacks
inflate(decompressable, { chunkSize: 65_535 }, (err, result) => {
zlib.inflate(decompressable, { chunkSize: 65_535 }, (err, result) => {
if (err) {
reject(err);
return;
Expand All @@ -542,42 +615,50 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
});
}

// Deal with gw wide zlib-stream compression
if (this.inflate) {
const l = decompressable.length;
// Deal with transport compression
if (this.transportCompressionEnabled) {
const flush =
l >= 4 &&
decompressable[l - 4] === 0x00 &&
decompressable[l - 3] === 0x00 &&
decompressable[l - 2] === 0xff &&
decompressable[l - 1] === 0xff;
decompressable.length >= 4 &&
decompressable.at(-4) === 0x00 &&
decompressable.at(-3) === 0x00 &&
decompressable.at(-2) === 0xff &&
decompressable.at(-1) === 0xff;

const zlib = (await getZlibSync())!;
this.inflate.push(Buffer.from(decompressable), flush ? zlib.Z_SYNC_FLUSH : zlib.Z_NO_FLUSH);
if (this.nativeInflate) {
this.nativeInflate.write(decompressable, 'binary');

if (this.inflate.err) {
this.emit(WebSocketShardEvents.Error, {
error: new Error(`${this.inflate.err}${this.inflate.msg ? `: ${this.inflate.msg}` : ''}`),
});
}
if (!flush) {
return null;
}

if (!flush) {
return null;
}
const [result] = await once(this.nativeInflate, 'data');
return this.parseInflateResult(result);
} else if (this.zLibSyncInflate) {
const zLibSync = (await getZlibSync())!;
this.zLibSyncInflate.push(Buffer.from(decompressable), flush ? zLibSync.Z_SYNC_FLUSH : zLibSync.Z_NO_FLUSH);

if (this.zLibSyncInflate.err) {
this.emit(WebSocketShardEvents.Error, {
error: new Error(
`${this.zLibSyncInflate.err}${this.zLibSyncInflate.msg ? `: ${this.zLibSyncInflate.msg}` : ''}`,
),
});
}

const { result } = this.inflate;
if (!result) {
return null;
}
if (!flush) {
return null;
}

return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload;
const { result } = this.zLibSyncInflate;
return this.parseInflateResult(result);
}
}

this.debug([
'Received a message we were unable to decompress',
`isBinary: ${isBinary.toString()}`,
`useIdentifyCompress: ${this.useIdentifyCompress.toString()}`,
`inflate: ${Boolean(this.inflate).toString()}`,
`useIdentifyCompression: ${this.identifyCompressionEnabled.toString()}`,
`inflate: ${this.transportCompressionEnabled ? CompressionMethod[this.strategy.options.compression!] : 'none'}`,
didinele marked this conversation as resolved.
Show resolved Hide resolved
]);

return null;
Expand Down Expand Up @@ -838,7 +919,7 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
messages.length > 1
? `\n${messages
.slice(1)
.map((m) => ` ${m}`)
.map((message) => ` ${message}`)
.join('\n')}`
: ''
}`;
Expand Down