Skip to content

Commit

Permalink
Add WebsocketCompressionNegotiator
Browse files Browse the repository at this point in the history
  • Loading branch information
trowski committed Oct 21, 2023
1 parent 904902c commit 5bf7962
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 15 deletions.
7 changes: 3 additions & 4 deletions src/Rfc6455ClientFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use Amp\Http\Server\Request;
use Amp\Http\Server\Response;
use Amp\Socket\Socket;
use Amp\Websocket\Compression\WebsocketCompressionContextFactory;
use Amp\Websocket\Compression\WebsocketCompressionContext;
use Amp\Websocket\ConstantRateLimit;
use Amp\Websocket\Parser\Rfc6455ParserFactory;
use Amp\Websocket\Parser\WebsocketParserFactory;
Expand All @@ -24,13 +24,10 @@ final class Rfc6455ClientFactory implements WebsocketClientFactory
use ForbidSerialization;

/**
* @param WebsocketCompressionContextFactory|null $compressionContextFactory Deprecated. This argument is unused.
* Compression is not supported in v3.x but will be in v4.x.
* @param WebsocketHeartbeatQueue|null $heartbeatQueue Use null to disable automatic heartbeats (pings).
* @param WebsocketRateLimit|null $rateLimit Use null to disable client rate limits.
*/
public function __construct(
private readonly ?WebsocketCompressionContextFactory $compressionContextFactory = null,
private readonly ?WebsocketHeartbeatQueue $heartbeatQueue = new PeriodicHeartbeatQueue(),
private readonly ?WebsocketRateLimit $rateLimit = new ConstantRateLimit(),
private readonly WebsocketParserFactory $parserFactory = new Rfc6455ParserFactory(),
Expand All @@ -43,6 +40,7 @@ public function createClient(
Request $request,
Response $response,
Socket $socket,
?WebsocketCompressionContext $compressionContext,
): WebsocketClient {
if ($socket instanceof ResourceStream) {
$socketResource = $socket->getResource();
Expand All @@ -69,6 +67,7 @@ public function createClient(
socket: $socket,
masked: false,
parserFactory: $this->parserFactory,
compressionContext: $compressionContext,
heartbeatQueue: $this->heartbeatQueue,
rateLimit: $this->rateLimit,
frameSplitThreshold: $this->frameSplitThreshold,
Expand Down
34 changes: 34 additions & 0 deletions src/Rfc7692CompressionNegotiator.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<?php declare(strict_types=1);

namespace Amp\Websocket\Server;

use Amp\Http;
use Amp\Http\Server\Request;
use Amp\Http\Server\Response;
use Amp\Websocket\Compression\Rfc7692CompressionFactory;
use Amp\Websocket\Compression\WebsocketCompressionContext;

final class Rfc7692CompressionNegotiator implements WebsocketCompressionNegotiator
{
private readonly Rfc7692CompressionFactory $compressionContextFactory;

public function __construct()
{
$this->compressionContextFactory = new Rfc7692CompressionFactory();
}

public function negotiateCompression(Request $request, Response $response): ?WebsocketCompressionContext
{
$extensions = Http\splitHeader($request, 'sec-websocket-extensions') ?? [];
foreach ($extensions as $extension) {
if ($compressionContext = $this->compressionContextFactory->fromClientHeader($extension, $headerLine)) {
/** @psalm-suppress PossiblyNullArgument */
$response->setHeader('sec-websocket-extensions', $headerLine);

return $compressionContext;
}
}

return null;
}
}
23 changes: 19 additions & 4 deletions src/Websocket.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use Amp\Http\Server\Request;
use Amp\Http\Server\RequestHandler;
use Amp\Http\Server\Response;
use Amp\Websocket\Compression\WebsocketCompressionContext;
use Amp\Websocket\WebsocketClient;
use Amp\Websocket\WebsocketCloseCode;
use Amp\Websocket\WebsocketClosedException;
Expand All @@ -32,6 +33,7 @@ public function __construct(
private readonly PsrLogger $logger,
private readonly WebsocketAcceptor $acceptor,
private readonly WebsocketClientHandler $clientHandler,
private readonly ?WebsocketCompressionNegotiator $compressionNegotiator = null,
private readonly WebsocketClientFactory $clientFactory = new Rfc6455ClientFactory(),
) {
/** @psalm-suppress PropertyTypeCoercion */
Expand All @@ -45,27 +47,40 @@ public function handleRequest(Request $request): Response
$response = $this->acceptor->handleHandshake($request);

if ($response->getStatus() !== HttpStatus::SWITCHING_PROTOCOLS) {
$response->removeHeader('sec-websocket-accept');
$response->setHeader('connection', 'close');
return $this->modifyNonUpgradeResponse($response);
}

$compressionContext = $this->compressionNegotiator?->negotiateCompression($request, $response);

return $response;
if ($response->getStatus() !== HttpStatus::SWITCHING_PROTOCOLS) {
return $this->modifyNonUpgradeResponse($response);
}

$response->upgrade(fn (UpgradedSocket $socket) => $this->reapClient(
socket: $socket,
request: $request,
response: $response,
compressionContext: $compressionContext,
));

return $response;
}

private function modifyNonUpgradeResponse(Response $response): Response
{
$response->removeHeader('sec-websocket-accept');
$response->setHeader('connection', 'close');

return $response;
}

private function reapClient(
UpgradedSocket $socket,
Request $request,
Response $response,
?WebsocketCompressionContext $compressionContext,
): void {
$client = $this->clientFactory->createClient($request, $response, $socket);
$client = $this->clientFactory->createClient($request, $response, $socket, $compressionContext);

/** @psalm-suppress RedundantCondition */
\assert($this->logger->debug(\sprintf(
Expand Down
10 changes: 8 additions & 2 deletions src/WebsocketClientFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
use Amp\Http\Server\Request;
use Amp\Http\Server\Response;
use Amp\Socket\Socket;
use Amp\Websocket\Compression\WebsocketCompressionContext;
use Amp\Websocket\WebsocketClient;

interface WebsocketClientFactory
{
/**
* Creates a Client object based on the given Request.
* Creates a {@see WebsocketClient} object based on the given Request.
*/
public function createClient(Request $request, Response $response, Socket $socket): WebsocketClient;
public function createClient(
Request $request,
Response $response,
Socket $socket,
?WebsocketCompressionContext $compressionContext,
): WebsocketClient;
}
17 changes: 17 additions & 0 deletions src/WebsocketCompressionNegotiator.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?php declare(strict_types=1);

namespace Amp\Websocket\Server;

use Amp\Http\Server\Request;
use Amp\Http\Server\Response;
use Amp\Websocket\Compression\WebsocketCompressionContext;

interface WebsocketCompressionNegotiator
{
/**
* Examine the given {@see Request} and {@see Response} returned from a {@see WebsocketAcceptor} to determine
* if compression should be enabled for the client. If so, return an instance of {@see WebsocketCompressionContext}
* and modify the {@see Response} object accordingly.
*/
public function negotiateCompression(Request $request, Response $response): ?WebsocketCompressionContext;
}
38 changes: 33 additions & 5 deletions test/WebsocketTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ protected function execute(\Closure $onConnect, WebsocketClient $client): void
{
\assert($client instanceof MockObject);

$factory = $this->createMock(WebsocketClientFactory::class);
$factory->method('createClient')
$clientFactory = $this->createMock(WebsocketClientFactory::class);
$clientFactory->method('createClient')
->willReturn($client);

$deferred = new DeferredFuture;

$webserver = $this->createWebsocketServer(
$factory,
$clientFactory,
function (WebsocketGateway $gateway, WebsocketClient $client) use ($onConnect, $deferred): void {
$deferred->complete($onConnect($gateway, $client));
}
Expand Down Expand Up @@ -66,7 +66,7 @@ function (WebsocketGateway $gateway, WebsocketClient $client) use ($onConnect, $
* @param \Closure(WebsocketGateway, WebsocketClient):void $clientHandler
*/
protected function createWebsocketServer(
WebsocketClientFactory $factory,
WebsocketClientFactory $clientFactory,
\Closure $clientHandler,
WebsocketGateway $gateway = new WebsocketClientGateway(),
): SocketHttpServer {
Expand All @@ -90,7 +90,7 @@ public function handleClient(WebsocketClient $client, Request $request, Response
($this->clientHandler)($this->gateway, $client);
}
},
clientFactory: $factory,
clientFactory: $clientFactory,
);

$httpServer->expose(new Socket\InternetAddress('127.0.0.1', 0));
Expand Down Expand Up @@ -124,6 +124,7 @@ public function testHandshake(Request $request, int $status, array $expectedHead
logger: $logger,
acceptor: $acceptor,
clientHandler: $this->createMock(WebsocketClientHandler::class),
compressionNegotiator: new Rfc7692CompressionNegotiator(),
);
$server->start($websocket, $this->createMock(ErrorHandler::class));

Expand Down Expand Up @@ -218,6 +219,33 @@ public function provideHandshakes(): iterable
HttpStatus::BAD_REQUEST,
["sec-websocket-version" => ["13"]],
];

// 8 ----- compression: valid header ------------------------------------------------------>
$request = $this->createRequest();
$request->setHeader("sec-websocket-extensions", "permessage-deflate; client_max_window_bits");
yield 'With Valid Compression' => [
$request,
HttpStatus::SWITCHING_PROTOCOLS,
[
"upgrade" => ["websocket"],
"connection" => ["upgrade"],
"sec-websocket-accept" => ["HSmrc0sMlYUkAGmm5OPpG2HaGWk="],
"sec-websocket-extensions" => ["permessage-deflate; client_max_window_bits=15"],
],
];

// 9 ----- compression: invalid header ---------------------------------------------------->
$request = $this->createRequest();
$request->setHeader("sec-websocket-extensions", "permessage-deflate; client_max_window_bits=8;");
yield 'With Invalid Compression' => [
$request,
HttpStatus::SWITCHING_PROTOCOLS,
[
"upgrade" => ["websocket"],
"connection" => ["upgrade"],
"sec-websocket-accept" => ["HSmrc0sMlYUkAGmm5OPpG2HaGWk="],
],
];
}

public function testBroadcast(): void
Expand Down

0 comments on commit 5bf7962

Please sign in to comment.