Skip to content

Commit

Permalink
feat: add support for beast websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
Nerixyz committed Feb 21, 2024
1 parent 2b44292 commit 32649bf
Show file tree
Hide file tree
Showing 11 changed files with 434 additions and 24 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
name: unittest
on: [push, pull_request]

# stop in-progress builds on push
concurrency:
group: unittest-${{ github.ref }}
cancel-in-progress: true

jobs:
unittest-boost-asio:
name: "${{matrix.generator}} ${{matrix.toolset}} Boost ${{matrix.boost_version}} ${{matrix.build_type}} C++${{matrix.standard}} ${{matrix.name_args}}"
Expand Down Expand Up @@ -121,6 +126,7 @@ jobs:
"${GITHUB_WORKSPACE}"
env:
BOOST_ROOT: ${{env.BOOST_INSTALL_PATH}}/boost
OPENSSL_ROOT: ${{matrix.generator == 'MinGW Makefiles' && 'C:/Program Files/OpenSSL' || null}}

- name: Build
working-directory: build
Expand Down
79 changes: 79 additions & 0 deletions include/wintls/beast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//
// Copyright (c) 2023 Kasper Laudrup (laudrup at stacktrace dot dk)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//

#ifndef BOOST_WINTLS_BEAST_HPP
#define BOOST_WINTLS_BEAST_HPP

#if __has_include("boost/beast.hpp")

#include <boost/version.hpp>
#include <boost/beast/websocket.hpp>
#include <wintls/context.hpp>
#include <wintls/stream.hpp>

namespace boost {
namespace beast {

namespace detail {

template<class AsyncStream>
struct wintls_shutdown_op : boost::asio::coroutine {
wintls_shutdown_op(wintls::stream<AsyncStream>& s, role_type role)
: s_(s)
, role_(role) {
}

template<class Self>
void operator()(Self& self, error_code ec = {}, std::size_t = 0) {
BOOST_ASIO_CORO_REENTER(*this) {
#if (BOOST_VERSION / 100 % 1000) >= 77
self.reset_cancellation_state(net::enable_total_cancellation());
#endif

BOOST_ASIO_CORO_YIELD
s_.async_shutdown(std::move(self));
ec_ = ec;

using boost::beast::websocket::async_teardown;
BOOST_ASIO_CORO_YIELD
async_teardown(role_, s_.next_layer(), std::move(self));
if (!ec_) {
ec_ = ec;
}

self.complete(ec_);
}
}

private:
wintls::stream<AsyncStream>& s_;
role_type role_;
error_code ec_;
};

} // namespace detail

template<class AsyncStream, class TeardownHandler>
void async_teardown(role_type role, wintls::stream<AsyncStream>& stream, TeardownHandler&& handler) {
return boost::asio::async_compose<TeardownHandler, void(error_code)>(
detail::wintls_shutdown_op<AsyncStream>(stream, role), handler, stream);
}

template<class AsyncStream>
void teardown(boost::beast::role_type role, wintls::stream<AsyncStream>& stream, boost::system::error_code& ec) {
stream.shutdown(ec);
using boost::beast::websocket::teardown;
boost::system::error_code ec2;
teardown(role, stream.next_layer(), ec ? ec2 : ec);
}

} // namespace beast
} // namespace boost

#endif

#endif
16 changes: 16 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ if(MSVC)
target_compile_options(unittest PRIVATE "-bigobj")
endif()

if(MINGW)
target_compile_options(unittest PRIVATE "-Wa,-mbig-obj")
endif()

target_compile_definitions(unittest PRIVATE
TEST_CERTIFICATES_PATH="${CMAKE_CURRENT_LIST_DIR}/test_certificates/gen/"
)
Expand Down Expand Up @@ -83,6 +87,18 @@ if(${CMAKE_CXX_STANDARD} LESS 17 AND ENABLE_WINTLS_STANDALONE_ASIO)
)
endif()

if(NOT ENABLE_WINTLS_STANDALONE_ASIO)
if(MSVC AND ${Boost_VERSION} VERSION_LESS "1.76")
# Unreferenced formal parameter in boost/beast/websocket/impl/ssl.hpp
target_compile_options(unittest PRIVATE /wd4100)
endif()

if(MSVC AND ${Boost_VERSION} VERSION_LESS "1.85")
# Unreachable code in boost/beast/core/impl/buffers_cat.hpp
target_compile_options(unittest PRIVATE /wd4702)
endif()
endif()

include(CTest)
include(Catch)
catch_discover_tests(unittest TEST_SPEC "*")
73 changes: 73 additions & 0 deletions test/async_ws_echo_client.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//
// Copyright (c) 2023 Kasper Laudrup (laudrup at stacktrace dot dk)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//

#ifndef BOOST_WINTLS_TEST_ASYNC_ECHO_CLIENT_HPP
#define BOOST_WINTLS_TEST_ASYNC_ECHO_CLIENT_HPP

#include "unittest.hpp"

template<typename Stream>
struct async_ws_echo_client : public Stream {
public:
using Stream::stream;

async_ws_echo_client(net::io_context& context, const std::string& message)
: Stream(context)
, message_(message)
, ws_(stream) {
}

void run() {
do_tls_handshake();
}

std::string received_message() const {
return std::string(net::buffers_begin(recv_buffer_.data()),
net::buffers_begin(recv_buffer_.data()) + static_cast<std::ptrdiff_t>(recv_buffer_.size()));
}

private:
void do_tls_handshake() {
ws_.next_layer().async_handshake(Stream::handshake_type::client, [this](const boost::system::error_code& ec) {
REQUIRE_FALSE(ec);
do_ws_handshake();
});
}

void do_ws_handshake() {
ws_.async_handshake("localhost", "/", [this](const boost::system::error_code& ec) {
REQUIRE_FALSE(ec);
do_write();
});
}

void do_write() {
ws_.async_write(net::buffer(message_), [this](const boost::system::error_code& ec, std::size_t) {
REQUIRE_FALSE(ec);
do_read();
});
}

void do_read() {
ws_.async_read(recv_buffer_, [this](const boost::system::error_code& ec, std::size_t) {
REQUIRE_FALSE(ec);
do_shutdown();
});
}

void do_shutdown() {
ws_.async_close(websocket::close_code::normal, [](const boost::system::error_code& ec) {
REQUIRE_FALSE(ec);
});
}

std::string message_;
beast::flat_buffer recv_buffer_;
websocket::stream<decltype(stream)&> ws_;
};

#endif // BOOST_WINTLS_TEST_ASYNC_ECHO_CLIENT_HPP
67 changes: 67 additions & 0 deletions test/async_ws_echo_server.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//
// Copyright (c) 2023 Kasper Laudrup (laudrup at stacktrace dot dk)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//

#ifndef BOOST_WINTLS_TEST_ASYNC_ECHO_SERVER_HPP
#define BOOST_WINTLS_TEST_ASYNC_ECHO_SERVER_HPP

#include "unittest.hpp"

template<typename Stream>
class async_ws_echo_server : public Stream {
public:
using Stream::stream;

async_ws_echo_server(net::io_context& context)
: Stream(context)
, ws_(stream) {
}

void run() {
do_tls_handshake();
}

private:
void do_tls_handshake() {
ws_.next_layer().async_handshake(Stream::handshake_type::server, [this](const boost::system::error_code& ec) {
REQUIRE_FALSE(ec);
do_ws_handshake();
});
}

void do_ws_handshake() {
ws_.async_accept([this](const boost::system::error_code& ec) {
REQUIRE_FALSE(ec);
do_read();
});
}

void do_read() {
ws_.async_read(recv_buffer_, [this](const boost::system::error_code& ec, std::size_t) {
REQUIRE_FALSE(ec);
ws_.text(ws_.got_text());
do_write();
});
}

void do_write() {
ws_.async_write(recv_buffer_.data(), [this](const boost::system::error_code& ec, std::size_t) {
REQUIRE_FALSE(ec);
do_shutdown();
});
}

void do_shutdown() {
ws_.async_close(websocket::close_code::normal, [](const boost::system::error_code& ec) {
REQUIRE_FALSE(ec);
});
}

beast::flat_buffer recv_buffer_;
websocket::stream<decltype(stream)&> ws_;
};

#endif // BOOST_WINTLS_TEST_ASYNC_ECHO_SERVER_HPP
83 changes: 64 additions & 19 deletions test/echo_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//

#ifndef WINTLS_USE_STANDALONE_ASIO
#include "ws_echo_server.hpp"
#include "ws_echo_client.hpp"
#include "async_ws_echo_server.hpp"
#include "async_ws_echo_client.hpp"
#endif // !WINTLS_USE_STANDALONE_ASIO

#include "echo_server.hpp"
#include "echo_client.hpp"
#include "async_echo_server.hpp"
Expand All @@ -30,28 +37,66 @@ std::string generate_data(std::size_t size) {
return ret;
}

}

using TestTypes = std::tuple<std::tuple<asio_ssl_client_stream, asio_ssl_server_stream>,
std::tuple<wintls_client_stream, asio_ssl_server_stream>,
std::tuple<asio_ssl_client_stream, wintls_server_stream>,
std::tuple<wintls_client_stream, wintls_server_stream>>;

TEMPLATE_LIST_TEST_CASE("echo test", "", TestTypes) {
using ClientStream = typename std::tuple_element<0, TestType>::type;
using ServerStream = typename std::tuple_element<1, TestType>::type;

auto test_data_size = GENERATE(0x100, 0x100 - 1, 0x100 + 1,
0x1000, 0x1000 - 1, 0x1000 + 1,
0x10000, 0x10000 - 1, 0x10000 + 1,
0x100000, 0x100000 - 1, 0x100000 + 1);
} // namespace

#ifdef WINTLS_USE_STANDALONE_ASIO

#define WINTLS_TEST_TYPES (Tls)

#else // WINTLS_USE_STANDALONE_ASIO

template<typename S>
struct WebSocket {
using ClientStream = typename std::tuple_element<0, S>::type;
using ServerStream = typename std::tuple_element<1, S>::type;

using Server = ws_echo_server<ServerStream>;
using AsyncServer = async_ws_echo_server<ServerStream>;
using Client = ws_echo_client<ClientStream>;
using AsyncClient = async_ws_echo_client<ClientStream>;
};

#define WINTLS_TEST_TYPES (WebSocket, Tls)

#endif // !WINTLS_USE_STANDALONE_ASIO

template<typename S>
struct Tls {
using ClientStream = typename std::tuple_element<0, S>::type;
using ServerStream = typename std::tuple_element<1, S>::type;

using Server = echo_server<ServerStream>;
using AsyncServer = async_echo_server<ServerStream>;
using Client = echo_client<ClientStream>;
using AsyncClient = async_echo_client<ClientStream>;
};

TEMPLATE_PRODUCT_TEST_CASE("echo test",
"",
WINTLS_TEST_TYPES,
((std::tuple<asio_ssl_client_stream, asio_ssl_server_stream>),
(std::tuple<wintls_client_stream, asio_ssl_server_stream>),
(std::tuple<asio_ssl_client_stream, wintls_server_stream>),
(std::tuple<wintls_client_stream, wintls_server_stream>))) {
auto test_data_size = GENERATE(0x100,
0x100 - 1,
0x100 + 1,
0x1000,
0x1000 - 1,
0x1000 + 1,
0x10000,
0x10000 - 1,
0x10000 + 1,
0x100000,
0x100000 - 1,
0x100000 + 1);
const std::string test_data = generate_data(static_cast<std::size_t>(test_data_size));

net::io_context io_context;

SECTION("sync test") {
echo_client<ClientStream> client(io_context);
echo_server<ServerStream> server(io_context);
typename TestType::Client client(io_context);
typename TestType::Server server(io_context);

client.stream.next_layer().connect(server.stream.next_layer());

Expand All @@ -72,8 +117,8 @@ TEMPLATE_LIST_TEST_CASE("echo test", "", TestTypes) {
}

SECTION("async test") {
async_echo_server<ServerStream> server(io_context);
async_echo_client<ClientStream> client(io_context, test_data);
typename TestType::AsyncServer server(io_context);
typename TestType::AsyncClient client(io_context, test_data);
client.stream.next_layer().connect(server.stream.next_layer());
server.run();
client.run();
Expand Down
Loading

0 comments on commit 32649bf

Please sign in to comment.