Skip to content

Commit

Permalink
simplify compare definitions and other improvements (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
geseq committed Dec 26, 2023
1 parent 2ce3199 commit 7f1e82d
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 86 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
cmake_minimum_required(VERSION 3.20 FATAL_ERROR)

find_program(CCACHE_PROGRAM ccache)
if (CCACHE_PROGRAM)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
endif()

set(CPP_ORDERBOOK orderbook)
project(${CPP_ORDERBOOK} LANGUAGES CXX)

Expand Down
16 changes: 8 additions & 8 deletions include/orderbook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class OrderBook {
OrderBook(NotificationInterface<Notification>& n, size_t price_level_pool_size = 16384, size_t order_pool_size = 16384)
: order_pool_(order_pool_size),
notification_(static_cast<Notification&>(n)),
bids_(PriceLevel<CmpGreater>(PriceType::Bid, price_level_pool_size)),
asks_(PriceLevel<CmpLess>(PriceType::Ask, price_level_pool_size)),
trigger_over_(PriceLevel<CmpGreater>(PriceType::Trigger, price_level_pool_size)),
trigger_under_(PriceLevel<CmpLess>(PriceType::Trigger, price_level_pool_size)),
bids_(PriceLevel<PriceType::Bid>(price_level_pool_size)),
asks_(PriceLevel<PriceType::Ask>(price_level_pool_size)),
trigger_over_(PriceLevel<PriceType::TriggerOver>(price_level_pool_size)),
trigger_under_(PriceLevel<PriceType::TriggerUnder>(price_level_pool_size)),
orders_(OrderMap()),
trig_orders_(OrderMap()){};

Expand All @@ -39,10 +39,10 @@ class OrderBook {
private:
pool::AdaptiveObjectPool<Order> order_pool_;

PriceLevel<CmpGreater> bids_;
PriceLevel<CmpLess> asks_;
PriceLevel<CmpGreater> trigger_over_;
PriceLevel<CmpLess> trigger_under_;
PriceLevel<PriceType::Bid> bids_;
PriceLevel<PriceType::Ask> asks_;
PriceLevel<PriceType::TriggerOver> trigger_over_;
PriceLevel<PriceType::TriggerUnder> trigger_under_;

OrderMap orders_;
OrderMap trig_orders_;
Expand Down
5 changes: 3 additions & 2 deletions include/orderqueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class OrderQueue : public boost::intrusive::set_base_hook<boost::intrusive::opti
Order *head_ = nullptr;
Order *tail_ = nullptr;

Decimal price_;
Decimal total_qty_;
uint64_t size_ = 0;

Expand All @@ -30,11 +31,11 @@ class OrderQueue : public boost::intrusive::set_base_hook<boost::intrusive::opti
void remove(Order *o);
Decimal process(const TradeNotification &tn, const PostOrderFill &postFill, OrderID takerOrderID, Decimal qty);

Decimal price_;

friend bool operator<(const OrderQueue &a, const OrderQueue &b) { return a.price_ < b.price_; }
friend bool operator>(const OrderQueue &a, const OrderQueue &b) { return a.price_ > b.price_; }
friend bool operator==(const OrderQueue &a, const OrderQueue &b) { return a.price_ == b.price_; }

friend class PriceCompare;
};

struct PriceCompare {
Expand Down
44 changes: 11 additions & 33 deletions include/pricelevel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,58 +20,36 @@ class Compare {
using CmpGreater = boost::intrusive::compare<std::greater<>>;
using CmpLess = boost::intrusive::compare<std::less<>>;

template <class CompareType>
template <PriceType P>
class PriceLevel {
pool::AdaptiveObjectPool<OrderQueue> queue_pool_;

using CompareType = std::conditional_t<(P == PriceType::Bid || P == PriceType::TriggerOver), CmpGreater, CmpLess>;
using PriceTree = boost::intrusive::rbtree<OrderQueue, CompareType>;
PriceTree price_tree_;

PriceType price_type_;
PriceType price_type_ = P;
Decimal volume_;
uint64_t num_orders_ = 0;
uint64_t depth_ = 0;

public:
PriceLevel(PriceType price_type, size_t price_level_pool_size) : price_type_(price_type), queue_pool_(price_level_pool_size){};
PriceLevel(size_t price_level_pool_size) : queue_pool_(price_level_pool_size){};
uint64_t len();
uint64_t depth();
Decimal volume();
OrderQueue* getQueue();
[[nodiscard]] OrderQueue* getQueue();
[[nodiscard]] OrderQueue* getNextQueue(const Decimal& price);
[[nodiscard]] OrderQueue* largestLessThan(const Decimal& price);
[[nodiscard]] OrderQueue* smallestGreaterThan(const Decimal& price);

void append(Order* order);
void remove(Order* order);

Decimal processMarketOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID takerOrderID, Decimal qty, Flag flag);
Decimal processLimitOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID& takerOrderID, Decimal& price, Decimal qty, Flag& flag);

PriceTree& price_tree() { return price_tree_; }

OrderQueue* LargestLessThan(const Decimal& price) {
auto it = price_tree_.lower_bound(price, PriceCompare());
if (it != price_tree_.begin()) {
--it;
return &(*it);
}
return nullptr;
}

OrderQueue* SmallestGreaterThan(const Decimal& price) {
auto it = price_tree_.upper_bound(price, PriceCompare());
if (it != price_tree_.end()) {
return &(*it);
}
return nullptr;
}

OrderQueue* GetNextQueue(const Decimal& price) {
switch (price_type_) {
case PriceType::Bid:
return LargestLessThan(price);
case PriceType::Ask:
return SmallestGreaterThan(price);
default:
throw std::runtime_error("invalid call to GetQueue");
}
}
PriceTree& price_tree() { return price_tree_; };
};

} // namespace orderbook
Expand Down
4 changes: 3 additions & 1 deletion include/types.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cstdint>
#include <cwchar>
#include <iostream>

#include "decimal.hpp"
Expand Down Expand Up @@ -44,7 +45,8 @@ std::ostream& operator<<(std::ostream& os, const OrderStatus& status);
enum class PriceType {
Bid,
Ask,
Trigger,
TriggerOver,
TriggerUnder,
};

std::ostream& operator<<(std::ostream& os, const PriceType& priceType);
Expand Down
2 changes: 1 addition & 1 deletion src/order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace orderbook {

Decimal Order::getPrice(PriceType pt) {
if (pt == PriceType::Trigger) {
if (pt == PriceType::TriggerOver || pt == PriceType::TriggerUnder) [[unlikely]] {
return trig_price;
}

Expand Down
74 changes: 53 additions & 21 deletions src/pricelevel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@

namespace orderbook {

template <class CompareType>
uint64_t PriceLevel<CompareType>::len() {
template <PriceType P>
uint64_t PriceLevel<P>::len() {
return num_orders_;
}

template <class CompareType>
uint64_t PriceLevel<CompareType>::depth() {
template <PriceType P>
uint64_t PriceLevel<P>::depth() {
return depth_;
}

template <class CompareType>
Decimal PriceLevel<CompareType>::volume() {
template <PriceType P>
Decimal PriceLevel<P>::volume() {
return volume_;
}

template <class CompareType>
void PriceLevel<CompareType>::append(Order* order) {
template <PriceType P>
void PriceLevel<P>::append(Order* order) {
auto price = order->getPrice(price_type_);

auto it = price_tree_.find(price);
Expand All @@ -45,8 +45,8 @@ void PriceLevel<CompareType>::append(Order* order) {
q->append(order);
}

template <class CompareType>
void PriceLevel<CompareType>::remove(Order* order) {
template <PriceType P>
void PriceLevel<P>::remove(Order* order) {
auto price = order->getPrice(price_type_);

auto q = order->queue;
Expand All @@ -64,8 +64,8 @@ void PriceLevel<CompareType>::remove(Order* order) {
volume_ -= order->qty;
}

template <class CompareType>
OrderQueue* PriceLevel<CompareType>::getQueue() {
template <PriceType P>
OrderQueue* PriceLevel<P>::getQueue() {
auto q = price_tree_.begin();
if (q != price_tree_.end()) {
return &*q;
Expand All @@ -74,8 +74,8 @@ OrderQueue* PriceLevel<CompareType>::getQueue() {
return nullptr;
}

template <class CompareType>
Decimal PriceLevel<CompareType>::processMarketOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID takerOrderID, Decimal qty, Flag flag) {
template <PriceType P>
Decimal PriceLevel<P>::processMarketOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID takerOrderID, Decimal qty, Flag flag) {
// TODO: this won't work as pricelevel volumes aren't accounted for correctly
if ((flag & (AoN | FoK)) != 0 && qty > volume_) {
return uint64_t(0);
Expand All @@ -92,9 +92,8 @@ Decimal PriceLevel<CompareType>::processMarketOrder(const TradeNotification& tn,
return uint64_t(0);
};

template <class CompareType>
Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID& takerOrderID, Decimal& price, Decimal qty,
Flag& flag) {
template <PriceType P>
Decimal PriceLevel<P>::processLimitOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID& takerOrderID, Decimal& price, Decimal qty, Flag& flag) {
Decimal qtyProcessed = {};
auto orderQueue = getQueue();

Expand Down Expand Up @@ -127,7 +126,7 @@ Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn,
break;
}
aQty -= orderQueue->totalQty();
orderQueue = GetNextQueue(orderQueue->price());
orderQueue = getNextQueue(orderQueue->price());
}
} else {
while (orderQueue != nullptr && price > orderQueue->price()) {
Expand All @@ -136,7 +135,7 @@ Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn,
break;
}
aQty -= orderQueue->totalQty();
orderQueue = GetNextQueue(orderQueue->price());
orderQueue = getNextQueue(orderQueue->price());
}
}

Expand All @@ -157,7 +156,40 @@ Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn,
return qtyProcessed;
};

template class PriceLevel<CmpGreater>;
template class PriceLevel<CmpLess>;
template <PriceType P>
OrderQueue* PriceLevel<P>::largestLessThan(const Decimal& price) {
auto it = price_tree_.lower_bound(price, PriceCompare());
if (it != price_tree_.begin()) {
--it;
return &(*it);
}
return nullptr;
}

template <PriceType P>
OrderQueue* PriceLevel<P>::smallestGreaterThan(const Decimal& price) {
auto it = price_tree_.upper_bound(price, PriceCompare());
if (it != price_tree_.end()) {
return &(*it);
}
return nullptr;
}

template <PriceType P>
OrderQueue* PriceLevel<P>::getNextQueue(const Decimal& price) {
switch (price_type_) {
case PriceType::Bid:
return largestLessThan(price);
case PriceType::Ask:
return smallestGreaterThan(price);
default:
throw std::runtime_error("invalid call to GetQueue");
}
}

template class PriceLevel<PriceType::Bid>;
template class PriceLevel<PriceType::Ask>;
template class PriceLevel<PriceType::TriggerOver>;
template class PriceLevel<PriceType::TriggerUnder>;

} // namespace orderbook
6 changes: 4 additions & 2 deletions src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ std::ostream& operator<<(std::ostream& os, const PriceType& priceType) {
return os << "Bid";
case PriceType::Ask:
return os << "Ask";
case PriceType::Trigger:
return os << "Trigger";
case PriceType::TriggerOver:
return os << "TriggerOver";
case PriceType::TriggerUnder:
return os << "TriggerUnder";
}
return os << "Unknown";
}
Expand Down
39 changes: 21 additions & 18 deletions test/pricelevel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class PriceLevelTest : public ::testing::Test {
};

TEST_F(PriceLevelTest, TestPriceLevel) {
PriceLevel<CmpLess> bidLevel(PriceType::Bid, 10);
PriceLevel<PriceType::Bid> bidLevel(10);

auto o1 = std::make_shared<Order>(1, Type::Limit, Side::Buy, Decimal(10, 0), Decimal(10, 0), Decimal(uint64_t(0)), Flag::None);
auto o2 = std::make_shared<Order>(2, Type::Limit, Side::Buy, Decimal(10, 0), Decimal(20, 0), Decimal(uint64_t(0)), Flag::None);
Expand All @@ -43,9 +43,10 @@ TEST_F(PriceLevelTest, TestPriceLevel) {
ASSERT_EQ(bidLevel.depth(), 2);
ASSERT_EQ(bidLevel.len(), 2);

if (tree.begin()->head() != o1.get() || tree.begin()->tail() != o1.get() || tree.rbegin()->head() != o2.get() || tree.rbegin()->tail() != o2.get()) {
FAIL() << "invalid price levels";
}
ASSERT_EQ(tree.begin()->head(), o2.get()) << "Invalid price levels: head of the first element does not match o1";
ASSERT_EQ(tree.begin()->tail(), o2.get()) << "Invalid price levels: tail of the first element does not match o1";
ASSERT_EQ(tree.rbegin()->head(), o1.get()) << "Invalid price levels: head of the last element does not match o2";
ASSERT_EQ(tree.rbegin()->tail(), o1.get()) << "Invalid price levels: tail of the last element does not match o2";

bidLevel.remove(o1.get());

Expand All @@ -62,7 +63,7 @@ TEST_F(PriceLevelTest, TestPriceLevel) {
}

TEST_F(PriceLevelTest, TestPriceFinding) {
PriceLevel<CmpLess> askLevel(PriceType::Ask, 10);
PriceLevel<PriceType::Ask> askLevel(10);

askLevel.append(new Order(1, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(130, 0), Decimal(uint64_t(0)), Flag::None));
askLevel.append(new Order(2, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(170, 0), Decimal(uint64_t(0)), Flag::None));
Expand All @@ -75,17 +76,17 @@ TEST_F(PriceLevelTest, TestPriceFinding) {

ASSERT_EQ(askLevel.volume(), Decimal(40, 0));

ASSERT_EQ(askLevel.LargestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
ASSERT_EQ(askLevel.LargestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(askLevel.LargestLessThan(Decimal(100, 0)), nullptr);
ASSERT_EQ(askLevel.largestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
ASSERT_EQ(askLevel.largestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(askLevel.largestLessThan(Decimal(100, 0)), nullptr);

ASSERT_EQ(askLevel.SmallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(askLevel.SmallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(askLevel.SmallestGreaterThan(Decimal(170, 0)), nullptr);
ASSERT_EQ(askLevel.smallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(askLevel.smallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(askLevel.smallestGreaterThan(Decimal(170, 0)), nullptr);
}

TEST_F(PriceLevelTest, TestStopQueuePriceFinding) {
PriceLevel<CmpLess> trigLevel(PriceType::Trigger, 10);
PriceLevel<PriceType::TriggerUnder> trigLevel(10);

trigLevel.append(new Order(1, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(10, 0), Decimal(130, 0), Flag::None));
trigLevel.append(new Order(2, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(20, 0), Decimal(170, 0), Flag::None));
Expand All @@ -98,13 +99,15 @@ TEST_F(PriceLevelTest, TestStopQueuePriceFinding) {

ASSERT_EQ(trigLevel.volume(), Decimal(40, 0));

ASSERT_EQ(trigLevel.LargestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
ASSERT_EQ(trigLevel.LargestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(trigLevel.LargestLessThan(Decimal(100, 0)), nullptr);
std::cout << 1 << std::endl;
ASSERT_EQ(trigLevel.largestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
std::cout << 2 << std::endl;
ASSERT_EQ(trigLevel.largestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(trigLevel.largestLessThan(Decimal(100, 0)), nullptr);

ASSERT_EQ(trigLevel.SmallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(trigLevel.SmallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(trigLevel.SmallestGreaterThan(Decimal(170, 0)), nullptr);
ASSERT_EQ(trigLevel.smallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(trigLevel.smallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(trigLevel.smallestGreaterThan(Decimal(170, 0)), nullptr);
}

} // namespace test
Expand Down

0 comments on commit 7f1e82d

Please sign in to comment.