From 6b909183398ed118f8fdba4d89e0d108c6efe8f1 Mon Sep 17 00:00:00 2001 From: PavelBrm Date: Wed, 1 Sep 2021 21:34:57 +1200 Subject: [PATCH] Release code --- Makefile | 40 ++ README.md | 111 ++++ client.go | 371 ++++++++++++ encoding.go | 641 +++++++++++++++++++++ go.mod | 6 + go.sum | 21 + internal/echo/echo.go | 101 ++++ internal/echo/echo.pb.go | 219 +++++++ internal/echo/echo.proto | 12 + internal/wire/README.md | 45 ++ internal/wire/wire.pb.go | 346 +++++++++++ internal/wire/wire.proto | 22 + rpcz.go | 1176 ++++++++++++++++++++++++++++++++++++++ rpcz_bench_test.go | 254 ++++++++ rpcz_pkg_test.go | 205 +++++++ rpcz_test.go | 574 +++++++++++++++++++ testdata/localhost.crt | 28 + testdata/localhost.key | 52 ++ 18 files changed, 4224 insertions(+) create mode 100644 Makefile create mode 100644 client.go create mode 100644 encoding.go create mode 100644 go.sum create mode 100644 internal/echo/echo.go create mode 100644 internal/echo/echo.pb.go create mode 100644 internal/echo/echo.proto create mode 100644 internal/wire/README.md create mode 100644 internal/wire/wire.pb.go create mode 100644 internal/wire/wire.proto create mode 100644 rpcz.go create mode 100644 rpcz_bench_test.go create mode 100644 rpcz_pkg_test.go create mode 100644 rpcz_test.go create mode 100644 testdata/localhost.crt create mode 100644 testdata/localhost.key diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ef97971 --- /dev/null +++ b/Makefile @@ -0,0 +1,40 @@ +BASE_PATH := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +MKFILE_PATH := $(BASE_PATH)/Makefile +COVER_OUT := cover.out +BENCH_OUT := results.bench + +.DEFAULT_GOAL := help + +test: ## Run tests + go test ./... -coverprofile=$(COVER_OUT) -v + +bench: ## Run benchmarks + go test -benchmem -benchtime 5s -bench=Benchmark -count 5 -timeout 900s -cpu=8 + +bench-results: ## Run benchmarks and save the output + go test -benchmem -benchtime 5s -bench=Benchmark -count 5 -timeout 600s -cpu=8 | tee $(BENCH_OUT) + +cover: ## Show test coverage + @if [ -f $(COVER_OUT) ]; then \ + go tool cover -func=$(COVER_OUT); \ + rm -f $(COVER_OUT); \ + else \ + echo "$(COVER_OUT) is missing. Please run 'make test'"; \ + fi + +results: ## Show benchmark results + @if [ -f $(BENCH_OUT) ]; then \ + benchstat $(BENCH_OUT); \ + else \ + echo "$(BENCH_OUT) is missing. Please run 'make bench-results'"; \ + fi + +clean: ## Remove binaries + @rm -f $(COVER_OUT) + @rm -f $(BENCH_OUT) + @find $(BASE_PATH) -name ".DS_Store" -depth -exec rm {} \; + +help: ## Show help message + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: test bench cover clean help diff --git a/README.md b/README.md index acfc5a5..04d3a06 100644 --- a/README.md +++ b/README.md @@ -1 +1,112 @@ # rpcz + +RPCz is a library for RPC-style interservice communication over network. It's simple, lightweight yet performant. It provides you with the infrastructure for building services and clients, taking care of the low level details such as connection handling and lifecycle management. + +It's designed to work over plain TCP or Unix sockets, or with TLS on top of it. The primary encoding is Protobuf, with JSON available as an alternative. Adjustable bufferring helps achieve better throughput based on the needs of your application. + +While providing similar functionality to gRPC, Twirp and the standard library's `rpc` package, RPCz differs from each in some way: +- it's lightweight +- it focuses on solving one particular task – performant communication over TCP +- the public API is minimal, and consists primarily of a handful of convenience constructors +- supports only two encodings +- it handles connections gracefully, and supports (even enforces) use of `context.Context`. + +The library code uses only one external dependency, Protobuf. + + +## Examples + +Examples of a server running an RPC service, and a client that interacts with it is available in [this repository](https://github.com/golocron/rpcz-example). + + +## In Other Languages + +The most obvious uses case is communication between services developed in Go. However, one of the main reasons for creating this library was the need to interact between software written in different languages. + +The simple protocol of RPCz and the use of Protobuf allow for such cross-language comminication. A server implementing the RPCz protocol must accept and correctly handle a request from a client implementing the protocol. + +| Language | Server | Client | Status | +|--- | --- | --- | --- | +| Go | ✅ | ✅ | Initial release | +| Rust | ⚪ | ⚪ | Planned | +| Zig | ⚪ | ⚪ | Planned | + +There are currently no plans on implementations for languages other than the listed above. + + +## Benchmarks + +The results below were obtained by running: + +```bash +go test -benchmem -benchtime 5s -bench=Benchmark -count 5 -timeout 600s -cpu=8 | tee results.bench + +benchstat results.bench +``` + +```text +name time/op +Protobuf_1K-8 13.4µs ± 7% +Protobuf_4K-8 19.2µs ± 1% +Protobuf_16K-8 38.0µs ± 3% +Protobuf_32K-8 65.9µs ± 2% +Protobuf_64K-8 91.6µs ±14% +ProtobufTLS_4K-8 19.6µs ± 3% +ProtobufNoBuf_4K-8 25.9µs ± 5% +JSON_1K-8 20.4µs ± 4% +JSON_4K-8 53.6µs ± 5% +JSON_16K-8 185µs ± 2% +JSONTLS_4K-8 55.9µs ± 2% + +name speed +Protobuf_1K-8 153MB/s ± 7% +Protobuf_4K-8 427MB/s ± 1% +Protobuf_16K-8 862MB/s ± 3% +Protobuf_32K-8 1.00GB/s ± 2% +Protobuf_64K-8 1.44GB/s ±13% +ProtobufTLS_4K-8 417MB/s ± 3% +ProtobufNoBuf_4K-8 317MB/s ± 5% +JSON_1K-8 100MB/s ± 4% +JSON_4K-8 153MB/s ± 5% +JSON_16K-8 177MB/s ± 2% +JSONTLS_4K-8 147MB/s ± 2% + +name alloc/op +Protobuf_1K-8 9.42kB ± 0% +Protobuf_4K-8 40.1kB ± 0% +Protobuf_16K-8 155kB ± 0% +Protobuf_32K-8 333kB ± 0% +Protobuf_64K-8 614kB ± 0% +ProtobufTLS_4K-8 37.9kB ± 0% +ProtobufNoBuf_4K-8 47.6kB ± 0% +JSON_1K-8 5.56kB ± 0% +JSON_4K-8 19.2kB ± 0% +JSON_16K-8 72.3kB ± 0% +JSONTLS_4K-8 19.3kB ± 0% + +name allocs/op +Protobuf_1K-8 21.0 ± 0% +Protobuf_4K-8 21.0 ± 0% +Protobuf_16K-8 21.0 ± 0% +Protobuf_32K-8 21.0 ± 0% +Protobuf_64K-8 21.0 ± 0% +ProtobufTLS_4K-8 22.0 ± 0% +ProtobufNoBuf_4K-8 27.0 ± 0% +JSON_1K-8 29.0 ± 0% +JSON_4K-8 29.0 ± 0% +JSON_16K-8 29.0 ± 0% +JSONTLS_4K-8 30.0 ± 0% +``` + +You can also have a look here for benchmarks of other systems, such as gRPC and the `rpc` package from the standard library. + + +## Open-Source, not Open-Contribution + +Similar to [SQLite](https://www.sqlite.org/copyright.html) and [Litestream](https://github.com/benbjohnson/litestream#open-source-not-open-contribution), RPCz is open-source, but is not open to contributions. This helps keep the code base free of confusions with licenses or proprietary changes, and prevent feature bloat. + +In addition, experiences of many open-source projects have shown that maintenance of an open-source code base can be quite resource demanding. Time, mental health and energy – are just a few. + +Taking the above into account, I've made the decision to make the project closed to contributions. + +Thank you for understanding! diff --git a/client.go b/client.go new file mode 100644 index 0000000..946a41d --- /dev/null +++ b/client.go @@ -0,0 +1,371 @@ +package rpcz + +import ( + "context" + "crypto/tls" + "net" + "sync" + "time" + + "github.com/golocron/rpcz/internal/wire" +) + +// ServiceError represents an error returned by a remote method. +type ServiceError string + +// Error returns the string representation of e. +func (e ServiceError) Error() string { + return string(e) +} + +// ClientOptions allows to setup Client. +// +// An empty value of ClientOptions uses safe defaults: +// - Protobuf encoding +// - 8192 bytes per connection for a read buffer +// - 8192 bytes per connection for a write buffer. +type ClientOptions struct { + Encoding Encoding + ReadBufSize int + WriteBufSize int +} + +// Copy returns a full copy of o. +func (o *ClientOptions) Copy() *ClientOptions { + if o == nil { + return &ClientOptions{} + } + + result := &ClientOptions{ + Encoding: o.Encoding, + ReadBufSize: o.ReadBufSize, + WriteBufSize: o.WriteBufSize, + } + + return result +} + +func (o *ClientOptions) check() error { + if o == nil { + return ErrInvalidClientOptions + } + + if !checkEncoding(o.Encoding) { + return ErrInvalidEncoding + } + + o.adjReadBufSize() + o.adjWriteBufSize() + + return nil +} + +func (o *ClientOptions) adjReadBufSize() { + o.ReadBufSize = calcBufSize(o.ReadBufSize) +} + +func (o *ClientOptions) adjWriteBufSize() { + o.WriteBufSize = calcBufSize(o.WriteBufSize) +} + +// Client makes requests to remote methods on services registered on a remote server. +type Client struct { + cfg *ClientOptions + c *conn + enc clientEncoder + + onceStop *sync.Once + onceClose *sync.Once + onceDone *sync.Once + stopping chan struct{} + closed chan struct{} + recvDone chan struct{} + + reqList requestRetainer + + mu *sync.Mutex + id uint64 + results map[uint64]*Result +} + +// NewClient returns a client connected to saddr with Protobuf encoding and default options. +func NewClient(saddr string) (*Client, error) { + return NewClientWithOptions(saddr, &ClientOptions{}) +} + +// NewClientWithOptions returns a client connected to saddr and set up with opts. +func NewClientWithOptions(saddr string, opts *ClientOptions) (*Client, error) { + nc, err := net.Dial("tcp", saddr) + if err != nil { + return nil, err + } + + return NewClientWithConn(opts, nc) +} + +// NewClientTLS returns a client connected to saddr in TLS mode. +func NewClientTLS(saddr string, cfg *tls.Config) (*Client, error) { + return NewClientTLSOptions(saddr, cfg, &ClientOptions{}) +} + +// NewClientTLSOptions creates a client connected to saddr over TLS with opts. +func NewClientTLSOptions(saddr string, cfg *tls.Config, opts *ClientOptions) (*Client, error) { + nc, err := tls.Dial("tcp", saddr, cfg) + if err != nil { + return nil, err + } + + return NewClientWithConn(opts, nc) +} + +// NewClientWithConn creates a client set up with opts connected to nc. +func NewClientWithConn(opts *ClientOptions, nc net.Conn) (*Client, error) { + cfg := opts.Copy() + if err := cfg.check(); err != nil { + return nil, err + } + + c := newConn(nc.RemoteAddr().String(), nc) + result := &Client{ + cfg: cfg, + c: c, + enc: newClientEncoder(cfg.Encoding, newBufReader(c.rwc, cfg.ReadBufSize), newBufWriter(c.rwc, cfg.WriteBufSize)), + + onceStop: &sync.Once{}, + onceClose: &sync.Once{}, + onceDone: &sync.Once{}, + stopping: make(chan struct{}), + closed: make(chan struct{}), + recvDone: make(chan struct{}), + + reqList: newRequestRetainer(retainRequestNum), + + mu: &sync.Mutex{}, + results: make(map[uint64]*Result), + } + + go result.recv(context.Background()) + + return result, nil +} + +// Do asynchronously calls mtd on svc with given args and resp. +func (c *Client) Do(ctx context.Context, svc, mtd string, args, resp interface{}) *Result { + result := newResult(svc, mtd, args, resp) + + c.send(ctx, result) + + return result +} + +// SyncDo calls Do and awaits for the response. +func (c *Client) SyncDo(ctx context.Context, svc, mtd string, args, resp interface{}) error { + result := c.Do(ctx, svc, mtd, args, resp) + + return result.Err() +} + +// Peer returns the address of the server c is connected to. +func (c *Client) Peer() string { + if c.c == nil { + return "" + } + + return c.c.raddr +} + +// Close closes the connection and shuts down the client. +// +// Any unfinished requests are aborted. +func (c *Client) Close() error { + var err error + + c.closeRecv() + + c.onceClose.Do(func() { + closeOrSkip(c.closed) + err = c.c.close() + }) + + for !c.isRecvDone() { + time.Sleep(closeLoopDelay) + } + + return err +} + +func (c *Client) send(ctx context.Context, fut *Result) { + if c.isStopping() { + fut.done(ErrShutdown) + return + } + + if c.isClosed() { + fut.done(ErrClosed) + return + } + + c.mu.Lock() + id := c.id + c.id++ + c.results[id] = fut + c.mu.Unlock() + + req := c.reqList.obtainReq() + req.Kind = int32(c.cfg.Encoding) + req.Id = id + req.Service = fut.svc + req.Method = fut.mtd + + if err := c.enc.writeRequest(ctx, req, fut.args); err != nil { + c.mu.Lock() + fut = c.results[id] + delete(c.results, id) + c.mu.Unlock() + + if fut != nil { + fut.done(err) + } + } + + c.reqList.retainReq(req) +} + +func (c *Client) recv(ctx context.Context) { + var ( + err error + rerr error + resp = &wire.Response{} + ) + + for { + select { + case <-c.recvDone: + return + case <-c.stopping: + c.closeRecv() + return + case <-c.closed: + c.closeRecv() + return + case <-ctx.Done(): + c.closeRecv() + c.stop(ctx.Err()) + + return + default: + *resp = wire.Response{} + if err = c.enc.readResponse(ctx, resp); err != nil { + if c.isClosed() && c.isRecvDone() { + c.stop(ErrClosed) + + return + } + + c.closeRecv() + c.stop(err) + + return + } + + c.mu.Lock() + fut := c.results[resp.Id] + delete(c.results, resp.Id) + c.mu.Unlock() + + switch { + case fut == nil: + // No op. + case resp.Error != "": + rerr = ServiceError(resp.Error) + fut.done(rerr) + default: + if err = unmarshal(Encoding(resp.Kind), resp.Data, fut.resp); err != nil { + fut.done(err) + + continue + } + + fut.done(nil) + } + } + } +} + +func (c *Client) stop(err error) { + c.onceStop.Do(func() { closeOrSkip(c.stopping) }) + + c.mu.Lock() + for k, fut := range c.results { + delete(c.results, k) + fut.done(err) + } + c.mu.Unlock() +} + +func (c *Client) closeRecv() { + c.onceDone.Do(func() { closeOrSkip(c.recvDone) }) +} + +func (c *Client) isStopping() bool { + select { + case <-c.stopping: + return true + default: + return false + } +} + +func (c *Client) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +func (c *Client) isRecvDone() bool { + select { + case <-c.recvDone: + return true + default: + return false + } +} + +// Result represents a result of a request. +type Result struct { + svc string + mtd string + args interface{} + resp interface{} + err chan error +} + +func newResult(svc, mtd string, args, resp interface{}) *Result { + result := &Result{ + svc: svc, + mtd: mtd, + args: args, + resp: resp, + err: make(chan error, 1), + } + + return result +} + +// Err waits for the request to finish and returns nil on success, and an error otherwise. +func (r *Result) Err() error { + return <-r.err +} + +// ErrChan returns a channel that indicates completion of the request by sending a nil or non-nil error. +func (r *Result) ErrChan() <-chan error { + return r.err +} + +func (r *Result) done(err error) { + r.err <- err + close(r.err) +} diff --git a/encoding.go b/encoding.go new file mode 100644 index 0000000..692ebd8 --- /dev/null +++ b/encoding.go @@ -0,0 +1,641 @@ +package rpcz + +import ( + "bufio" + "context" + "encoding/binary" + "encoding/json" + "io" + "net" + "sync" + "time" + + "github.com/golang/protobuf/proto" + + "github.com/golocron/rpcz/internal/wire" +) + +type serverEncoder interface { + readRequest(ctx context.Context, req *wire.Request) error + writeResponse(ctx context.Context, resp *wire.Response, v interface{}) error + + reader() bufReader + writer() bufWriter +} + +type clientEncoder interface { + readResponse(ctx context.Context, resp *wire.Response) error + writeRequest(ctx context.Context, req *wire.Request, v interface{}) error + + reader() bufReader + writer() bufWriter +} + +func newServerEncoder(kind Encoding, r bufReader, w bufWriter) serverEncoder { + if kind == JSON { + return newSrvEncoderJSON(kind, r, w) + } + + return newSrvEncoderPbf(kind, r, w) +} + +func newClientEncoder(kind Encoding, r bufReader, w bufWriter) clientEncoder { + if kind == JSON { + return newClientEncoderJSON(kind, r, w) + } + + return newClientEncoderPbf(kind, r, w) +} + +func newBufReader(rd io.Reader, size int) bufReader { + if size < 0 { + return newNoBufReader(rd) + } + + return bufio.NewReaderSize(rd, size) +} + +func newBufWriter(w io.Writer, size int) bufWriter { + if size < 0 { + return newNoBufWriter(w) + } + + return bufio.NewWriterSize(w, size) +} + +type bufReader interface { + io.Reader + io.ByteReader + peekDiscarder + Buffered() int + Reset(r io.Reader) +} + +type bufWriter interface { + io.Writer + flusher + Reset(w io.Writer) +} + +type peekDiscarder interface { + Peek(n int) ([]byte, error) + Discard(n int) (int, error) +} + +type flusher interface { + Flush() error +} + +// Protobuf encoding. + +type srvEncoderPbf struct { + kind Encoding + rmu *sync.Mutex + wmu *sync.Mutex + d *encDriverPbf +} + +func newSrvEncoderPbf(kind Encoding, r bufReader, w bufWriter) *srvEncoderPbf { + result := &srvEncoderPbf{ + kind: kind, + rmu: &sync.Mutex{}, + wmu: &sync.Mutex{}, + d: newEncDriverPbf(r, w), + } + + return result +} + +func (c *srvEncoderPbf) readRequest(ctx context.Context, req *wire.Request) error { + c.rmu.Lock() + defer c.rmu.Unlock() + + return c.d.decode(ctx, req) +} + +func (c *srvEncoderPbf) writeResponse(ctx context.Context, resp *wire.Response, v interface{}) error { + msg, ok := v.(proto.Message) + if !ok { + return errInvalidProtobufMsg + } + + data, err := proto.Marshal(msg) + if err != nil { + return err + } + + resp.Data = data + raw, err := proto.Marshal(resp) + if err != nil { + return err + } + + c.wmu.Lock() + err = c.d.encode(ctx, raw) + c.wmu.Unlock() + + return err +} + +func (c *srvEncoderPbf) reader() bufReader { + return c.d.r +} + +func (c *srvEncoderPbf) writer() bufWriter { + return c.d.w +} + +type clientEncoderPbf struct { + kind Encoding + rmu *sync.Mutex + wmu *sync.Mutex + d *encDriverPbf +} + +func newClientEncoderPbf(kind Encoding, r bufReader, w bufWriter) *clientEncoderPbf { + result := &clientEncoderPbf{ + kind: kind, + rmu: &sync.Mutex{}, + wmu: &sync.Mutex{}, + d: newEncDriverPbf(r, w), + } + + return result +} + +func (c *clientEncoderPbf) readResponse(ctx context.Context, resp *wire.Response) error { + c.rmu.Lock() + defer c.rmu.Unlock() + + return c.d.decode(ctx, resp) +} + +func (c *clientEncoderPbf) writeRequest(ctx context.Context, req *wire.Request, v interface{}) error { + msg, ok := v.(proto.Message) + if !ok { + return errInvalidProtobufMsg + } + + data, err := proto.Marshal(msg) + if err != nil { + return err + } + + req.Data = data + raw, err := proto.Marshal(req) + if err != nil { + return err + } + + c.wmu.Lock() + err = c.d.encode(ctx, raw) + c.wmu.Unlock() + + return err +} + +func (c *clientEncoderPbf) reader() bufReader { + return c.d.r +} + +func (c *clientEncoderPbf) writer() bufWriter { + return c.d.w +} + +type encDriverPbf struct { + r bufReader + w bufWriter +} + +func newEncDriverPbf(r bufReader, w bufWriter) *encDriverPbf { + result := &encDriverPbf{r: r, w: w} + + return result +} + +func (d *encDriverPbf) decode(ctx context.Context, dst proto.Message) error { + size, err := binary.ReadUvarint(d.r) + if err != nil { + return err + } + + if size == 0 { + return nil + } + + isize := int(size) + + if d.r.Buffered() >= isize { + data, err := d.r.Peek(isize) + if err != nil { + return err + } + + err = proto.Unmarshal(data, dst) + + _, derr := d.r.Discard(isize) + if err == nil { + return derr + } + + return err + } + + data := make([]byte, isize) + if _, err := io.ReadFull(d.r, data); err != nil { + return err + } + + return proto.Unmarshal(data, dst) +} + +func (d *encDriverPbf) encode(ctx context.Context, parts ...[]byte) error { + for _, part := range parts { + if err := d.encodePart(ctx, part); err != nil { + _ = d.w.Flush() + + return err + } + } + + return d.w.Flush() +} + +func (d *encDriverPbf) encodePart(ctx context.Context, data []byte) error { + var bin [binary.MaxVarintLen64]byte + head := bin[:] + + if len(data) == 0 { + hlen := binary.PutUvarint(head, uint64(0)) + if err := d.write(ctx, head[:hlen]); err != nil { + return err + } + + return nil + } + + hlen := binary.PutUvarint(head, uint64(len(data))) + if err := d.write(ctx, head[:hlen]); err != nil { + return err + } + + if err := d.write(ctx, data); err != nil { + return err + } + + return nil +} + +func (d *encDriverPbf) write(ctx context.Context, data []byte) error { + boff := newDefaultBackoff() + + for i, size := 0, len(data); i < size; { + n, err := d.w.Write(data[i:]) + if err != nil { + nerr, ok := err.(net.Error) + if !ok || !nerr.Temporary() { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(boff.next()): + } + } + + if err == nil { + boff.resetd() + } + + i += n + } + + return nil +} + +// JSON encoding. + +type srequestJSON struct { + Kind Encoding `json:"kind"` + ID uint64 `json:"id"` + Service string `json:"service"` + Method string `json:"method"` + Data *json.RawMessage `json:"data"` +} + +type sresponseJSON struct { + Kind Encoding `json:"kind"` + ID uint64 `json:"id"` + Service string `json:"service"` + Method string `json:"method"` + Error string `json:"error"` + Data interface{} `json:"data"` +} + +type srvEncoderJSON struct { + kind Encoding + rmu *sync.Mutex + wmu *sync.Mutex + d *encDriverJSON +} + +func newSrvEncoderJSON(kind Encoding, r bufReader, w bufWriter) *srvEncoderJSON { + result := &srvEncoderJSON{ + kind: kind, + rmu: &sync.Mutex{}, + wmu: &sync.Mutex{}, + d: newEncDriverJSON(r, w), + } + + return result +} + +func (c *srvEncoderJSON) readRequest(ctx context.Context, req *wire.Request) error { + c.rmu.Lock() + defer c.rmu.Unlock() + + jreq := &srequestJSON{} + if err := c.d.decode(ctx, jreq); err != nil { + return err + } + + req.Kind = int32(jreq.Kind) + req.Id = jreq.ID + req.Service = jreq.Service + req.Method = jreq.Method + + if jreq.Data != nil { + req.Data = *jreq.Data + } + + return nil +} + +func (c *srvEncoderJSON) writeResponse(ctx context.Context, resp *wire.Response, v interface{}) error { + jresp := &sresponseJSON{ + Kind: Encoding(resp.Kind), + ID: resp.Id, + Service: resp.Service, + Method: resp.Method, + Error: resp.Error, + Data: v, + } + + c.wmu.Lock() + err := c.d.encode(ctx, jresp) + c.wmu.Unlock() + + return err +} + +func (c *srvEncoderJSON) reader() bufReader { + return c.d.r +} + +func (c *srvEncoderJSON) writer() bufWriter { + return c.d.w +} + +type crequestJSON struct { + Kind Encoding `json:"kind"` + ID uint64 `json:"id"` + Service string `json:"service"` + Method string `json:"method"` + Data interface{} `json:"data"` +} + +type cresponseJSON struct { + Kind Encoding `json:"kind"` + ID uint64 `json:"id"` + Service string `json:"service"` + Method string `json:"method"` + Error string `json:"error"` + Data *json.RawMessage `json:"data"` +} + +type clientEncoderJSON struct { + kind Encoding + rmu *sync.Mutex + wmu *sync.Mutex + d *encDriverJSON +} + +func newClientEncoderJSON(kind Encoding, r bufReader, w bufWriter) *clientEncoderJSON { + result := &clientEncoderJSON{ + kind: kind, + rmu: &sync.Mutex{}, + wmu: &sync.Mutex{}, + d: newEncDriverJSON(r, w), + } + + return result +} + +func (c *clientEncoderJSON) readResponse(ctx context.Context, resp *wire.Response) error { + c.rmu.Lock() + defer c.rmu.Unlock() + + jresp := &cresponseJSON{} + if err := c.d.decode(ctx, jresp); err != nil { + return err + } + + resp.Kind = int32(jresp.Kind) + resp.Id = jresp.ID + resp.Service = jresp.Service + resp.Method = jresp.Method + resp.Error = jresp.Error + + if jresp.Data != nil { + resp.Data = *jresp.Data + } + + return nil +} + +func (c *clientEncoderJSON) writeRequest(ctx context.Context, req *wire.Request, v interface{}) error { + jreq := &crequestJSON{ + Kind: Encoding(req.Kind), + ID: req.Id, + Service: req.Service, + Method: req.Method, + Data: v, + } + + c.wmu.Lock() + err := c.d.encode(ctx, jreq) + c.wmu.Unlock() + + return err +} + +func (c *clientEncoderJSON) reader() bufReader { + return c.d.r +} + +func (c *clientEncoderJSON) writer() bufWriter { + return c.d.w +} + +type encDriverJSON struct { + r bufReader + w bufWriter + + d *json.Decoder + e *json.Encoder +} + +func newEncDriverJSON(r bufReader, w bufWriter) *encDriverJSON { + result := &encDriverJSON{ + r: r, + w: w, + + d: json.NewDecoder(r), + e: json.NewEncoder(w), + } + + return result +} + +func (d *encDriverJSON) decode(ctx context.Context, v interface{}) error { + return d.d.Decode(v) +} + +func (d *encDriverJSON) encode(ctx context.Context, parts ...interface{}) error { + for _, part := range parts { + if err := d.encodePart(ctx, part); err != nil { + _ = d.w.Flush() + + return err + } + } + + return d.w.Flush() +} + +func (d *encDriverJSON) encodePart(ctx context.Context, data interface{}) error { + return d.e.Encode(data) +} + +type noBufReader struct { + r io.Reader +} + +func newNoBufReader(r io.Reader) *noBufReader { + return &noBufReader{r: r} +} + +// Read delegates the call directly to the underlying r. +func (r *noBufReader) Read(p []byte) (int, error) { + return r.r.Read(p) +} + +// ReadByte reads exactly one byte using Read. +func (r *noBufReader) ReadByte() (byte, error) { + var one [1]byte + b := one[:] + + n, err := r.r.Read(b) + if err != nil { + return 0, err + } + + if n == 0 { + return 0, io.EOF + } + + return b[0], nil +} + +// Buffered is a no-op. +// +// It returns -1 to always be less than the size of a message. +func (r *noBufReader) Buffered() int { + return -1 +} + +// Peek returns nil and no error. +// +// Calls to Peek are usually preceeded by calling Buffered(). +// Since Buffered() always returns -1, this method must not be called. +func (r *noBufReader) Peek(n int) ([]byte, error) { + return nil, nil +} + +// Discard returns 0 and no error. +func (r *noBufReader) Discard(n int) (int, error) { + return 0, nil +} + +// Reset sets the underlying r to rd. +func (r *noBufReader) Reset(rd io.Reader) { + r.r = rd +} + +type noBufWriter struct { + w io.Writer +} + +func newNoBufWriter(w io.Writer) *noBufWriter { + return &noBufWriter{w: w} +} + +// Write delegates the call directly to the underlying w. +func (w *noBufWriter) Write(p []byte) (int, error) { + return w.w.Write(p) +} + +// Flush is a no-op. +func (w *noBufWriter) Flush() error { + return nil +} + +// Reset sets the underlying w to wt. +func (w *noBufWriter) Reset(wt io.Writer) { + w.w = wt +} + +func marshal(kind Encoding, v interface{}) ([]byte, error) { + switch kind { + case Protobuf: + return marshalPbf(v) + case JSON: + return json.Marshal(v) + default: + return nil, errInvalidEncoding + } +} + +func unmarshal(kind Encoding, data []byte, v interface{}) error { + switch kind { + case Protobuf: + return unmarshalPbf(data, v) + case JSON: + return json.Unmarshal(data, v) + default: + return errInvalidEncoding + } +} + +func marshalPbf(v interface{}) ([]byte, error) { + if v == nil { + return nil, errInvalidProtobufMsg + } + + m, ok := v.(proto.Message) + if !ok { + return nil, errInvalidProtobufMsg + } + + return proto.Marshal(m) +} + +func unmarshalPbf(data []byte, v interface{}) error { + m, ok := v.(proto.Message) + if !ok { + return errInvalidProtobufMsg + } + + return proto.Unmarshal(data, m) +} diff --git a/go.mod b/go.mod index 571764a..5cdff76 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module github.com/golocron/rpcz go 1.16 + +require ( + github.com/golang/protobuf v1.5.2 + github.com/stretchr/testify v1.7.0 + google.golang.org/protobuf v1.26.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6d02d24 --- /dev/null +++ b/go.sum @@ -0,0 +1,21 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/echo/echo.go b/internal/echo/echo.go new file mode 100644 index 0000000..d33c1c5 --- /dev/null +++ b/internal/echo/echo.go @@ -0,0 +1,101 @@ +// package echo shows a simple RPC service that can be served with rpcz. +package echo + +import ( + "context" + "errors" + "time" +) + +const ( + defaultTimeout = 60 * time.Second +) + +var ( + errInvalidMsg = errors.New("echo: invalid message") +) + +// Echo service replies back with the message it receives. +type Echo struct{} + +// New returns an instance of Echo. +func New() *Echo { return &Echo{} } + +// Echo handles req and fills in resp. +func (s *Echo) Echo(ctx context.Context, req *EchoRequest, resp *EchoResponse) error { + if req.GetMsg() == "" { + return errInvalidMsg + } + + resp.Msg = req.Msg + + return nil +} + +// ExtendedEcho service replies back with the message it receives. +// +// It shows an example of service-side and caller-defined timeouts. +type ExtendedEcho struct { + echo *Echo + timeout time.Duration +} + +// NewExtendedEcho returns a new ExtendedEcho with the specified timeout. +func NewExtendedEcho(timeout time.Duration) *ExtendedEcho { + result := &ExtendedEcho{timeout: timeout} + if result.timeout == 0 { + result.timeout = defaultTimeout + } + + return result +} + +// Echo handles req and fills in resp. +// +// The service may wait for req.Delay, if specified, but no longer than s.timeout. +func (s *ExtendedEcho) Echo(ctx context.Context, req *EchoRequest, resp *EchoResponse) error { + return s.handle(ctx, req, resp) +} + +func (s *ExtendedEcho) handle(ctx context.Context, req *EchoRequest, resp *EchoResponse) error { + lctx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + out := s.do(ctx, req, resp) + + select { + case <-lctx.Done(): + return lctx.Err() + case err := <-out: + return err + } +} + +func (s *ExtendedEcho) do(ctx context.Context, req *EchoRequest, resp *EchoResponse) chan error { + out := make(chan error, 1) + + go doEcho(ctx, out, s.echo, req, resp) + + return out +} + +func doEcho(ctx context.Context, dst chan<- error, s *Echo, req *EchoRequest, resp *EchoResponse) { + defer close(dst) + + if req.Delay <= 0 { + dst <- s.Echo(ctx, req, resp) + return + } + + timer := time.NewTimer(time.Duration(req.Delay)) + defer func() { _ = timer.Stop() }() + + select { + case <-ctx.Done(): + dst <- ctx.Err() + return + case <-timer.C: + } + + dst <- s.Echo(ctx, req, resp) +} diff --git a/internal/echo/echo.pb.go b/internal/echo/echo.pb.go new file mode 100644 index 0000000..29f502d --- /dev/null +++ b/internal/echo/echo.pb.go @@ -0,0 +1,219 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.23.0 +// protoc v3.17.3 +// source: echo.proto + +package echo + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type EchoRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"` + Delay int64 `protobuf:"varint,2,opt,name=delay,proto3" json:"delay,omitempty"` +} + +func (x *EchoRequest) Reset() { + *x = EchoRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_echo_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoRequest) ProtoMessage() {} + +func (x *EchoRequest) ProtoReflect() protoreflect.Message { + mi := &file_echo_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoRequest.ProtoReflect.Descriptor instead. +func (*EchoRequest) Descriptor() ([]byte, []int) { + return file_echo_proto_rawDescGZIP(), []int{0} +} + +func (x *EchoRequest) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + +func (x *EchoRequest) GetDelay() int64 { + if x != nil { + return x.Delay + } + return 0 +} + +type EchoResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"` +} + +func (x *EchoResponse) Reset() { + *x = EchoResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_echo_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoResponse) ProtoMessage() {} + +func (x *EchoResponse) ProtoReflect() protoreflect.Message { + mi := &file_echo_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoResponse.ProtoReflect.Descriptor instead. +func (*EchoResponse) Descriptor() ([]byte, []int) { + return file_echo_proto_rawDescGZIP(), []int{1} +} + +func (x *EchoResponse) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + +var File_echo_proto protoreflect.FileDescriptor + +var file_echo_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x65, 0x63, + 0x68, 0x6f, 0x22, 0x35, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x6d, 0x73, 0x67, 0x12, 0x14, 0x0a, 0x05, 0x64, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x05, 0x64, 0x65, 0x6c, 0x61, 0x79, 0x22, 0x20, 0x0a, 0x0c, 0x45, 0x63, 0x68, + 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x42, 0x28, 0x5a, 0x26, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x6c, 0x6f, 0x63, 0x72, + 0x6f, 0x6e, 0x2f, 0x72, 0x70, 0x63, 0x7a, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, + 0x2f, 0x65, 0x63, 0x68, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_echo_proto_rawDescOnce sync.Once + file_echo_proto_rawDescData = file_echo_proto_rawDesc +) + +func file_echo_proto_rawDescGZIP() []byte { + file_echo_proto_rawDescOnce.Do(func() { + file_echo_proto_rawDescData = protoimpl.X.CompressGZIP(file_echo_proto_rawDescData) + }) + return file_echo_proto_rawDescData +} + +var file_echo_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_echo_proto_goTypes = []interface{}{ + (*EchoRequest)(nil), // 0: echo.EchoRequest + (*EchoResponse)(nil), // 1: echo.EchoResponse +} +var file_echo_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_echo_proto_init() } +func file_echo_proto_init() { + if File_echo_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_echo_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_echo_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_echo_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_echo_proto_goTypes, + DependencyIndexes: file_echo_proto_depIdxs, + MessageInfos: file_echo_proto_msgTypes, + }.Build() + File_echo_proto = out.File + file_echo_proto_rawDesc = nil + file_echo_proto_goTypes = nil + file_echo_proto_depIdxs = nil +} diff --git a/internal/echo/echo.proto b/internal/echo/echo.proto new file mode 100644 index 0000000..cee9765 --- /dev/null +++ b/internal/echo/echo.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; +package echo; +option go_package = "github.com/golocron/rpcz/internal/echo"; + +message EchoRequest { + string msg = 1; + int64 delay = 2; +} + +message EchoResponse { + string msg = 1; +} diff --git a/internal/wire/README.md b/internal/wire/README.md new file mode 100644 index 0000000..041b0e6 --- /dev/null +++ b/internal/wire/README.md @@ -0,0 +1,45 @@ +# Wire + +This is an internal package that defines the communication protocol for `rpcz`. + +The definitions can be generated by running: + +```bash +protoc -I=./ --go_out=./ --go_opt=paths=source_relative ./wire.proto +``` + + +## Request + +A request is represented by the `Request` type. Its fileds have the following characteristics: + +| Name | Type | Meaning | +| --- | --- | --- | +| `kind` | `int32` | Holds the encoding of a request. | +| `id` | `uint64` | Specifies the unique id of a request, set by the client. | +| `service` | `string` | The target service the client is requesting. | +| `method` | `string` | The method on the service the client is calling. | +| `data` | `bytes` | Protbuf-encoded* arguments of the request. | + +This is used to encode any type of arguments. Both client and server encoders take care of properly dealing with this. It's a simplified version of the official method for encoding a value of any type using Protobuf. We just don't send the type information since it's already known by the server from the methods' signatures. + + +## Response + +A response is represented similarly to `Request` with one extra field for an error: + +| Name | Type | Meaning | +| --- | --- | --- | +| `kind` | `int32` | Holds the encoding of a request. | +| `id` | `uint64` | Specifies the unique id of a request, set by the client. | +| `service` | `string` | The target service the client is requesting. | +| `method` | `string` | The method on the service the client is calling. | +| `error` | `string` | Holds the error returned by the method. | +| `data` | `bytes` | Protbuf-encoded* arguments of the request. | + +See the explanation above for the `Request` type. + + +## Invalid Request + +The `InvalidRequest` type is used as a body on a response that has resulted into an error. diff --git a/internal/wire/wire.pb.go b/internal/wire/wire.pb.go new file mode 100644 index 0000000..7d9aa93 --- /dev/null +++ b/internal/wire/wire.pb.go @@ -0,0 +1,346 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.23.0 +// protoc v3.17.3 +// source: wire.proto + +package wire + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type Request struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Kind int32 `protobuf:"varint,1,opt,name=kind,proto3" json:"kind,omitempty"` + Id uint64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + Service string `protobuf:"bytes,3,opt,name=service,proto3" json:"service,omitempty"` + Method string `protobuf:"bytes,4,opt,name=method,proto3" json:"method,omitempty"` + Data []byte `protobuf:"bytes,5,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *Request) Reset() { + *x = Request{} + if protoimpl.UnsafeEnabled { + mi := &file_wire_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Request) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Request) ProtoMessage() {} + +func (x *Request) ProtoReflect() protoreflect.Message { + mi := &file_wire_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Request.ProtoReflect.Descriptor instead. +func (*Request) Descriptor() ([]byte, []int) { + return file_wire_proto_rawDescGZIP(), []int{0} +} + +func (x *Request) GetKind() int32 { + if x != nil { + return x.Kind + } + return 0 +} + +func (x *Request) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *Request) GetService() string { + if x != nil { + return x.Service + } + return "" +} + +func (x *Request) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + +func (x *Request) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type Response struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Kind int32 `protobuf:"varint,1,opt,name=kind,proto3" json:"kind,omitempty"` + Id uint64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + Service string `protobuf:"bytes,3,opt,name=service,proto3" json:"service,omitempty"` + Method string `protobuf:"bytes,4,opt,name=method,proto3" json:"method,omitempty"` + Error string `protobuf:"bytes,5,opt,name=error,proto3" json:"error,omitempty"` + Data []byte `protobuf:"bytes,6,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *Response) Reset() { + *x = Response{} + if protoimpl.UnsafeEnabled { + mi := &file_wire_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Response) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Response) ProtoMessage() {} + +func (x *Response) ProtoReflect() protoreflect.Message { + mi := &file_wire_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Response.ProtoReflect.Descriptor instead. +func (*Response) Descriptor() ([]byte, []int) { + return file_wire_proto_rawDescGZIP(), []int{1} +} + +func (x *Response) GetKind() int32 { + if x != nil { + return x.Kind + } + return 0 +} + +func (x *Response) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *Response) GetService() string { + if x != nil { + return x.Service + } + return "" +} + +func (x *Response) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + +func (x *Response) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *Response) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type InvalidRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *InvalidRequest) Reset() { + *x = InvalidRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_wire_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *InvalidRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InvalidRequest) ProtoMessage() {} + +func (x *InvalidRequest) ProtoReflect() protoreflect.Message { + mi := &file_wire_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InvalidRequest.ProtoReflect.Descriptor instead. +func (*InvalidRequest) Descriptor() ([]byte, []int) { + return file_wire_proto_rawDescGZIP(), []int{2} +} + +var File_wire_proto protoreflect.FileDescriptor + +var file_wire_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x77, 0x69, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x77, 0x69, + 0x72, 0x65, 0x22, 0x73, 0x0a, 0x07, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x6b, 0x69, 0x6e, + 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, + 0x64, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6d, + 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, + 0x68, 0x6f, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x8a, 0x01, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x22, 0x10, 0x0a, 0x0e, 0x49, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x6c, 0x6f, 0x63, 0x72, 0x6f, 0x6e, 0x2f, 0x72, 0x70, + 0x63, 0x7a, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x77, 0x69, 0x72, 0x65, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_wire_proto_rawDescOnce sync.Once + file_wire_proto_rawDescData = file_wire_proto_rawDesc +) + +func file_wire_proto_rawDescGZIP() []byte { + file_wire_proto_rawDescOnce.Do(func() { + file_wire_proto_rawDescData = protoimpl.X.CompressGZIP(file_wire_proto_rawDescData) + }) + return file_wire_proto_rawDescData +} + +var file_wire_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_wire_proto_goTypes = []interface{}{ + (*Request)(nil), // 0: wire.Request + (*Response)(nil), // 1: wire.Response + (*InvalidRequest)(nil), // 2: wire.InvalidRequest +} +var file_wire_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_wire_proto_init() } +func file_wire_proto_init() { + if File_wire_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_wire_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Request); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_wire_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Response); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_wire_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*InvalidRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_wire_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_wire_proto_goTypes, + DependencyIndexes: file_wire_proto_depIdxs, + MessageInfos: file_wire_proto_msgTypes, + }.Build() + File_wire_proto = out.File + file_wire_proto_rawDesc = nil + file_wire_proto_goTypes = nil + file_wire_proto_depIdxs = nil +} diff --git a/internal/wire/wire.proto b/internal/wire/wire.proto new file mode 100644 index 0000000..0f2170a --- /dev/null +++ b/internal/wire/wire.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; +package wire; +option go_package = "github.com/golocron/rpcz/internal/wire"; + +message Request { + int32 kind = 1; + uint64 id = 2; + string service = 3; + string method = 4; + bytes data = 5; +} + +message Response { + int32 kind = 1; + uint64 id = 2; + string service = 3; + string method = 4; + string error = 5; + bytes data = 6; +} + +message InvalidRequest {} diff --git a/rpcz.go b/rpcz.go new file mode 100644 index 0000000..a70d634 --- /dev/null +++ b/rpcz.go @@ -0,0 +1,1176 @@ +/* + Package rpcz allows exposing and accessing methods of a service over network. + + This package supports communication over TCP, Unix sockets, and with TLS on top of it. + The default and recommended encoding is Protobuf, with JSON available as an alternative. + No user-defined encodings (aka codecs) are supported, nor it is planned. + It is as an alternative to rpc from the standard library and gRPC. + + The server allows to register and expose one or more services. A service must be a pointer to a value of an exported type. + Each method meeting the following criteria will be registered and can be called remotely: + + - A method must be exported. + - A method has three arguments. + - The first argument is context.Context. + - The second and third arguments are both pointers and exported. + - A method has only one return parameter of the type error. + + The following line illustrates a valid method's signature: + + func (s *Service) SomeMethod(ctx context.Context, req *SomeMethodRequest, resp *SomeMethodResponse) error + + The request and response types must be marshallable to Protobuf, i.e. implement the proto.Message interface. + + The first argument of a service's method is passed in the server's context. This context is cancelled when + the server is requested to shutdown. The server is not concerned with specifying timeouts for service's methods. + It's the service's responsibility to implement timeouts should they be needed. The primary use of the parent context + is to be notified when shutdown has been requested to gracefully finish any ongoing operations. + + At the moment, the context does not include any data, but it might be exteneded at a later point with useful information + such as a request trace identificator. Reminder: contextes MUST NOT be used for dependency injection. + + The second argument represents a request to the method of a service. The third argument is passed in a pointer to a value + to which the method writes the response. + + If a method returns an error, it's sent over the wire as a string, and when the client returns it as ServiceError. +*/ +package rpcz + +import ( + "context" + "crypto/tls" + "errors" + "go/token" + "io" + "net" + "reflect" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/golocron/rpcz/internal/wire" +) + +const ( + // Supported encodings. + Unknown Encoding = -1 + iota + Protobuf + JSON +) + +const ( + bufSize8k = 8 << 10 + bufSize256k = 256 << 10 + bufSize512k = 512 << 10 + defBufSize = bufSize8k + maxBufSize = bufSize512k + + retainRequestNum = 256 + retainResponseNum = 256 + + shutdownLoopDelay = 10 * time.Millisecond + closeLoopDelay = 5 * time.Millisecond + + msgInvalidReturnNum = "rpcz: method returned invalid number of return values" + msgInvalidReturnType = "rpcz: method returned invalid type" + msgUnexpectedError = "rpcz: unexpected error" +) + +var ( + ErrInvalidEncoding = errors.New("rpcz: invalid encoding") + ErrInvalidServerOptions = errors.New("rpcz: invalid server options: must not be nil") + + ErrInvalidAddr = errors.New("rpcz: invalid server address") + ErrInvalidListener = errors.New("rpcz: invalid server listener") + ErrSrvStarted = errors.New("rpcz: server already started") + ErrSrvClosed = errors.New("rpcz: server already shutdown") + + ErrInvalidClientOptions = errors.New("rpcz: invalid client options: must not be nil") + ErrShutdown = errors.New("rpcz: client is shutdown") + ErrClosed = errors.New("rpcz: client is closed") + + errInvalidEncoding = errors.New("rpcz: invalid encoding") + errInvalidProtobufMsg = errors.New("rpcz: invalid protobuf message") + + errInvalidServiceName = errors.New("rpcz: invalid service name") + errUnexportedService = errors.New("rpcz: service is not exported") + errInvalidServiceType = errors.New("rpcz: service is not a pointer") + errInvalidServiceMethods = errors.New("rpcz: service has no valid methods") + + errInvalidReqServiceName = errors.New("rpcz: invalid service") + errInvalidReqServiceMethod = errors.New("rpcz: invalid service method") + errUnknownReqService = errors.New("rpcz: unknown service") + errUnknownReqServiceMethod = errors.New("rpcz: unknown method") +) + +// Encoding specifies which encoding is supported by a concrete server, client, and transport. +type Encoding int8 + +// ServerOptions allows to setup Server. +// +// An empty value of ServerOptions uses safe defaults: +// - Protobuf encoding +// - 8192 bytes per connection for a read buffer +// - 8192 bytes per connection for a write buffer. +type ServerOptions struct { + Encoding Encoding + ReadBufSize int + WriteBufSize int +} + +// Copy returns a full copy of o. +func (o *ServerOptions) Copy() *ServerOptions { + if o == nil { + return &ServerOptions{} + } + + result := &ServerOptions{ + Encoding: o.Encoding, + ReadBufSize: o.ReadBufSize, + WriteBufSize: o.WriteBufSize, + } + + return result +} + +func (o *ServerOptions) check() error { + if o == nil { + return ErrInvalidServerOptions + } + + if !checkEncoding(o.Encoding) { + return ErrInvalidEncoding + } + + o.adjReadBufSize() + o.adjWriteBufSize() + + return nil +} + +func (o *ServerOptions) adjReadBufSize() { + o.ReadBufSize = calcBufSize(o.ReadBufSize) +} + +func (o *ServerOptions) adjWriteBufSize() { + o.WriteBufSize = calcBufSize(o.WriteBufSize) +} + +// Server handles requests from clients by calling methods on registered services. +type Server struct { + cfg *ServerOptions + addr net.Addr + ln net.Listener + + onceStart *sync.Once + onceStop *sync.Once + onceClose *sync.Once + started chan struct{} + stopping chan struct{} + closed chan struct{} + + rbufPool *sync.Pool + wbufPool *sync.Pool + + reqList requestRetainer + respList responseRetainer + + mu *sync.Mutex + conns map[*conn]struct{} + + smu *sync.Mutex + sset map[string]*service + + wg *sync.WaitGroup +} + +// NewServer returns a server listening on laddr with Protobuf encoding and default options. +func NewServer(laddr string) (*Server, error) { + return NewServerWithOptions(laddr, &ServerOptions{}) +} + +// NewServerWithOptions returns a server listening on laddr and set up with opts. +func NewServerWithOptions(laddr string, opts *ServerOptions) (*Server, error) { + ln, err := net.Listen("tcp", laddr) + if err != nil { + return nil, err + } + + return NewServerWithListener(opts, ln) +} + +// NewServerTLS returns a server listening on laddr in TLS mode. +func NewServerTLS(laddr string, cfg *tls.Config) (*Server, error) { + return NewServerTLSOptions(laddr, cfg, &ServerOptions{}) +} + +// NewServerTLSOptions creates a server listening on laddr in TLS mode with opts. +func NewServerTLSOptions(laddr string, cfg *tls.Config, opts *ServerOptions) (*Server, error) { + ln, err := tls.Listen("tcp", laddr, cfg) + if err != nil { + return nil, err + } + + return NewServerWithListener(opts, ln) +} + +// NewServerWithListener returns a server set up with opts listening on ln. +func NewServerWithListener(opts *ServerOptions, ln net.Listener) (*Server, error) { + cfg := opts.Copy() + if err := cfg.check(); err != nil { + return nil, err + } + + result := newServer(cfg, ln) + + return result, nil +} + +func newServer(cfg *ServerOptions, ln net.Listener) *Server { + result := &Server{ + cfg: cfg, + addr: ln.Addr(), + ln: ln, + + onceStart: &sync.Once{}, + onceStop: &sync.Once{}, + onceClose: &sync.Once{}, + started: make(chan struct{}), + stopping: make(chan struct{}), + closed: make(chan struct{}), + + mu: &sync.Mutex{}, + conns: make(map[*conn]struct{}), + + smu: &sync.Mutex{}, + sset: make(map[string]*service), + + wg: &sync.WaitGroup{}, + rbufPool: &sync.Pool{New: func() interface{} { return newBufReader(nil, cfg.ReadBufSize) }}, + wbufPool: &sync.Pool{New: func() interface{} { return newBufWriter(nil, cfg.WriteBufSize) }}, + + reqList: newRequestRetainer(retainRequestNum), + respList: newResponseRetainer(retainResponseNum), + } + + return result +} + +// Register registers the given svc. +// +// The svc is registered if it's exported and at least one method satisfies the following conditions: +// - a method is exported +// - a method has three arguments +// - the first argument is context.Context +// - the second and third arguments are both pointers to exported values that implement proto.Message +// - a method has only one return parameter of the type error. +// New services can be added while the server is runnning. +func (s *Server) Register(svc interface{}) error { + rsvc, err := newService(svc) + if err != nil { + return err + } + + s.smu.Lock() + s.sset[rsvc.name] = rsvc + s.smu.Unlock() + + return nil +} + +// Start starts accepting connections. +// +// After a successful call to Start, subsequent calls to return ErrServerStarted. +func (s *Server) Start(ctx context.Context) error { + if err := s.preStart(); err != nil { + return err + } + + s.onceStart.Do(func() { closeOrSkip(s.started) }) + + boff := newDefaultBackoff() + + for { + select { + case <-s.stopping: + return nil + case <-s.closed: + return nil + case <-ctx.Done(): + return ctx.Err() + + default: + nc, err := s.ln.Accept() + if err != nil { + nerr, ok := err.(net.Error) + if !ok || !nerr.Temporary() { + if s.isStopping() || s.isClosed() { + return nil + } + + return err + } + + select { + case <-s.stopping: + return ErrSrvClosed + case <-s.closed: + return ErrSrvClosed + case <-ctx.Done(): + return ctx.Err() + case <-time.After(boff.next()): + } + + continue + } + + boff.resetd() + + c := newConn(nc.RemoteAddr().String(), nc) + s.addConn(c) + + enc := newServerEncoder(s.cfg.Encoding, s.obtainBufReader(c.rwc), s.obtainBufWriter(c.rwc)) + + s.wg.Add(1) + go s.handle(ctx, c, enc) + } + } + + s.wg.Wait() + + return nil +} + +// Run is an alias for Start. +func (s *Server) Run(ctx context.Context) error { + return s.Start(ctx) +} + +// Shutdown gracefully shuts down the server. A successful call returns no error. +// +// It iteratively tries to close new and idle connections until no such left, +// unless/until interrupted by cancellation of ctx. +// Subsequent calls return ErrSrvClosed. +func (s *Server) Shutdown(ctx context.Context) error { + if !s.isStarted() || s.ln == nil { + return nil + } + + var ( + err error + didDo bool + ) + + s.onceStop.Do(func() { + closeOrSkip(s.stopping) + + err = s.ln.Close() + didDo = true + }) + + if !didDo && err == nil { + return ErrSrvClosed + } + + tc := time.NewTicker(shutdownLoopDelay) + defer tc.Stop() + + for { + if s.closeIdle() && s.isStopping() { + break + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-tc.C: + } + } + + s.wg.Wait() + + return err +} + +// Started returns true if the server is started. +func (s *Server) Started() bool { + return s.isStarted() +} + +// preStart checks if Server can be started. +// +// It must be called only once inside Start. +// Subsequent calls return ErrSrvStarted. +func (s *Server) preStart() error { + if s.cfg == nil { + return ErrInvalidServerOptions + } + + if s.addr == nil { + return ErrInvalidAddr + } + + if s.ln == nil { + return ErrInvalidListener + } + + if err := s.cfg.check(); err != nil { + return err + } + + if s.isStopping() || s.isClosed() { + return ErrSrvClosed + } + + if s.isStarted() { + return ErrSrvStarted + } + + return nil +} + +func (s *Server) isStarted() bool { + select { + case <-s.started: + return true + default: + return false + } +} + +func (s *Server) isStopping() bool { + select { + case <-s.stopping: + return true + default: + return false + } +} + +func (s *Server) isClosed() bool { + select { + case <-s.closed: + return true + default: + return false + } +} + +func (s *Server) addConn(c *conn) { + s.mu.Lock() + s.conns[c] = struct{}{} + s.mu.Unlock() +} + +func (s *Server) delConn(c *conn) { + s.mu.Lock() + delete(s.conns, c) + s.mu.Unlock() +} + +// closeIdle closes new and idle connections and returns true if no connections left. +// +// A successful call is broadcasted by closing s.closed. +func (s *Server) closeIdle() bool { + s.mu.Lock() + defer s.mu.Unlock() + + empty := true + for c := range s.conns { + if c == nil { + continue + } + + st := c.getState() + if st != stateNew && st != stateIdle { + empty = false + continue + } + + c.close() + delete(s.conns, c) + } + + if empty { + s.onceClose.Do(func() { closeOrSkip(s.closed) }) + } + + return empty +} + +func (s *Server) handle(pctx context.Context, c *conn, enc serverEncoder) { + defer s.wg.Done() + + defer func() { + c.close() + + s.retainBufReader(enc.reader()) + s.retainBufWriter(enc.writer()) + + c.toClosed() + s.delConn(c) + }() + + wg := &sync.WaitGroup{} + + ctx, cancel := context.WithCancel(pctx) + defer cancel() + + defer wg.Wait() + + for { + select { + case <-s.stopping: + return + case <-s.closed: + return + case <-ctx.Done(): + return + default: + var ( + req = s.reqList.obtainReq() + err error + ) + + if err = enc.readRequest(ctx, req); err != nil { + _ = s.shouldHangErr(err) + cancel() + + s.reqList.retainReq(req) + + return + } + + c.toActive() + + resp := s.respList.obtainResp() + + resp.Kind = req.Kind + resp.Id = req.Id + resp.Service = req.Service + resp.Method = req.Method + + svc, err := s.getService(req) + if err != nil { + if err = sendRespErr(ctx, enc, resp, err); err != nil { + if s.shouldHangErr(err) { + cancel() + + s.reqList.retainReq(req) + s.respList.retainResp(resp) + + return + } + } + + s.reqList.retainReq(req) + s.respList.retainResp(resp) + + c.toIdleIfNoReqs() + + continue + } + + c.addReq() + + wg.Add(1) + go func() { + defer wg.Done() + defer c.toIdleIfNoReqs() + defer c.doneReq() + + if err := svc.do(ctx, enc, req, resp); err != nil { + if s.shouldHangErr(err) { + cancel() + + s.reqList.retainReq(req) + s.respList.retainResp(resp) + + return + } + } + + s.reqList.retainReq(req) + s.respList.retainResp(resp) + }() + + c.toIdleIfNoReqs() + } + } +} + +func (s *Server) getService(req *wire.Request) (*service, error) { + svcn := req.GetService() + if svcn == "" { + return nil, errInvalidReqServiceName + } + + s.smu.Lock() + svc, ok := s.sset[svcn] + s.smu.Unlock() + + if !ok { + return nil, errUnknownReqService + } + + return svc, nil +} + +func (s *Server) shouldHangErr(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + return false +} + +func (s *Server) obtainBufReader(r io.Reader) bufReader { + raw := s.rbufPool.Get() + if raw == nil { + return newBufReader(r, s.cfg.ReadBufSize) + } + + result, ok := raw.(bufReader) + if !ok { + return newBufReader(r, s.cfg.ReadBufSize) + } + + result.Reset(r) + + return result +} + +func (s *Server) retainBufReader(br bufReader) { + br.Reset(nil) + + s.rbufPool.Put(br) +} + +func (s *Server) obtainBufWriter(w io.Writer) bufWriter { + raw := s.wbufPool.Get() + if raw == nil { + return newBufWriter(w, s.cfg.WriteBufSize) + } + + result, ok := raw.(bufWriter) + if !ok { + return newBufWriter(w, s.cfg.WriteBufSize) + } + + result.Reset(w) + + return result +} + +func (s *Server) retainBufWriter(bw bufWriter) { + bw.Reset(nil) + + s.wbufPool.Put(bw) +} + +// --- + +const ( + stateNew cstate = iota + stateActive + stateIdle + stateClosed +) + +type cstate uint32 + +type conn struct { + raddr string + state struct{ value uint32 } + + // Protects only nreq. + mu *sync.Mutex + nreq uint64 + + rwc io.ReadWriteCloser +} + +func newConn(raddr string, rwc io.ReadWriteCloser) *conn { + c := &conn{ + raddr: raddr, + mu: &sync.Mutex{}, + rwc: rwc, + } + + return c +} + +func (c *conn) close() error { + return c.rwc.Close() +} + +func (c *conn) toNew() { + c.setState(stateNew) +} + +func (c *conn) toActive() { + c.setState(stateActive) +} + +func (c *conn) toIdle() { + c.setState(stateIdle) +} + +func (c *conn) toIdleIfNoReqs() { + if c.hasNoReqs() { + c.toIdle() + } +} + +func (c *conn) toClosed() { + c.setState(stateClosed) +} + +func (c *conn) getState() cstate { + return cstate(atomic.LoadUint32(&c.state.value)) +} + +func (c *conn) setState(s cstate) { + atomic.StoreUint32(&c.state.value, uint32(s)) +} + +func (c *conn) addReq() { + c.mu.Lock() + c.nreq += 1 + c.mu.Unlock() +} + +func (c *conn) doneReq() { + c.mu.Lock() + if c.nreq > 0 { + c.nreq -= 1 + } + c.mu.Unlock() +} + +func (c *conn) resetReq() { + c.mu.Lock() + c.nreq = 0 + c.mu.Unlock() +} + +func (c *conn) hasNoReqs() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.nreq == 0 +} + +// Service. + +type service struct { + name string + + recvt reflect.Type + recvv reflect.Value + + mu *sync.Mutex + mset map[string]*method +} + +func newService(raw interface{}) (*service, error) { + recvt := reflect.TypeOf(raw) + if recvt.Kind() != reflect.Ptr { + return nil, errInvalidServiceType + } + + recvv := reflect.ValueOf(raw) + name := reflect.Indirect(recvv).Type().Name() + + if name == "" { + return nil, errInvalidServiceName + } + + if !token.IsExported(name) { + return nil, errUnexportedService + } + + mtds := createMethods(recvt, recvv) + if len(mtds) == 0 { + return nil, errInvalidServiceMethods + } + + result := &service{ + name: name, + recvt: recvt, + recvv: recvv, + mu: &sync.Mutex{}, + mset: mtds, + } + + return result, nil +} + +func (s *service) do(ctx context.Context, enc serverEncoder, req *wire.Request, resp *wire.Response) error { + mtd, err := s.getMethod(req) + if err != nil { + if err = sendRespErr(ctx, enc, resp, err); err != nil { + return err + } + + return nil + } + + respval, err := mtd.do(ctx, req, resp) + if err != nil { + if err = sendRespErr(ctx, enc, resp, err); err != nil { + return err + } + + return nil + } + + return enc.writeResponse(ctx, resp, respval) +} + +func (s *service) getMethod(req *wire.Request) (*method, error) { + mtn := req.GetMethod() + if mtn == "" { + return nil, errInvalidReqServiceMethod + } + + s.mu.Lock() + mtd, ok := s.mset[mtn] + s.mu.Unlock() + + if !ok { + return nil, errUnknownReqServiceMethod + } + + return mtd, nil +} + +type method struct { + name string + recvv reflect.Value + + mtd reflect.Method + argt reflect.Type + respt reflect.Type +} + +func (m *method) do(ctx context.Context, req *wire.Request, resp *wire.Response) (interface{}, error) { + argv, err := m.parseArgs(req) + if err != nil { + return nil, err + } + + respv := m.makeResp() + retval := m.mtd.Func.Call([]reflect.Value{ + m.recvv, + reflect.ValueOf(ctx), + argv, + respv, + }) + + m.checkRetval(resp, retval) + + var respval interface{} + if resp.Error != "" { + respval = &wire.InvalidRequest{} + } else { + respval = respv.Interface() + } + + return respval, nil +} + +func (m *method) parseArgs(req *wire.Request) (reflect.Value, error) { + argv := reflect.New(m.argt.Elem()) + if err := unmarshal(Encoding(req.Kind), req.Data, argv.Interface()); err != nil { + return reflect.Value{}, err + } + + return argv, nil +} + +func (m *method) makeResp() reflect.Value { + respv := reflect.New(m.respt.Elem()) + switch m.respt.Elem().Kind() { + case reflect.Map: + respv.Elem().Set(reflect.MakeMap(m.respt.Elem())) + case reflect.Slice: + respv.Elem().Set(reflect.MakeSlice(m.respt.Elem(), 0, 0)) + } + + return respv +} + +func (m *method) checkRetval(resp *wire.Response, retval []reflect.Value) { + switch { + case len(retval) == 1: + v := retval[0].Interface() + if v == nil { + return + } + + verr, ok := v.(error) + if ok { + resp.Error = verr.Error() + } else { + resp.Error = msgInvalidReturnType + } + default: + resp.Error = msgInvalidReturnNum + } +} + +// --- + +// backoff is a simplest backoff calculator. +// +// It is NOT safe for concurrent use. +// The counter calculates delay by multiplying start by factor. +// Once delay exceeds max, max is returned on every call to next(). +// The value of start must not be greater than max. +// The value of delay can be reset by calling resetd(). +type backoff struct { + start time.Duration + max time.Duration + delay time.Duration + factor uint8 +} + +func newBackoff(start, max time.Duration, factor uint8) *backoff { + result := &backoff{ + start: start, + max: max, + factor: factor, + } + + return result +} + +func newDefaultBackoff() *backoff { + return newBackoff(10*time.Millisecond, 1*time.Second, 2) +} + +func (b *backoff) next() time.Duration { + if b == nil { + return 0 + } + + if b.delay == b.max { + return b.delay + } + + if b.delay == 0 { + b.delay = b.start + return b.delay + } + + b.delay *= time.Duration(b.factor) + if b.delay > b.max { + b.delay = b.max + } + + return b.delay +} + +func (b *backoff) resetd() { + if b == nil { + return + } + + b.delay = 0 +} + +type requestRetainer chan *wire.Request + +func newRequestRetainer(size int) requestRetainer { + if size == 0 { + size = runtime.NumCPU() + } + + return make(requestRetainer, size) +} + +func (r requestRetainer) obtainReq() *wire.Request { + select { + case req := <-r: + return req + default: + return &wire.Request{} + } +} + +func (r requestRetainer) retainReq(req *wire.Request) { + req.Reset() + + select { + case r <- req: + default: + } +} + +type responseRetainer chan *wire.Response + +func newResponseRetainer(size int) responseRetainer { + if size == 0 { + size = runtime.NumCPU() + } + + return make(responseRetainer, size) +} + +func (r responseRetainer) obtainResp() *wire.Response { + select { + case resp := <-r: + return resp + default: + return &wire.Response{} + } +} + +func (r responseRetainer) retainResp(resp *wire.Response) { + resp.Reset() + + select { + case r <- resp: + default: + } +} + +func checkEncoding(kind Encoding) bool { + switch kind { + case Protobuf: + return true + case JSON: + return true + default: + return false + } +} + +var ( + tCtxArg = reflect.TypeOf((*context.Context)(nil)).Elem() + terrArg = reflect.TypeOf((*error)(nil)).Elem() +) + +func createMethods(svct reflect.Type, svcv reflect.Value) map[string]*method { + result := make(map[string]*method) + + for n, max := 0, svct.NumMethod(); n < max; n++ { + mtd := svct.Method(n) + + if !token.IsExported(mtd.Name) { + continue + } + + if mtd.Type.NumIn() != 4 { + continue + } + + if mtd.Type.NumOut() != 1 { + continue + } + + if tctx := mtd.Type.In(1); tctx != tCtxArg { + continue + } + + argt := mtd.Type.In(2) + if argt.Kind() != reflect.Ptr || !isUsableType(argt) { + continue + } + + respt := mtd.Type.In(3) + if respt.Kind() != reflect.Ptr || !isUsableType(respt) { + continue + } + + if rtp := mtd.Type.Out(0); rtp != terrArg { + continue + } + + result[mtd.Name] = &method{ + name: mtd.Name, + recvv: svcv, + mtd: mtd, + argt: argt, + respt: respt, + } + } + + return result +} + +func isUsableType(tp reflect.Type) bool { + for tp.Kind() == reflect.Ptr { + tp = tp.Elem() + } + + return token.IsExported(tp.Name()) || tp.PkgPath() == "" +} + +func sendRespErr(ctx context.Context, enc serverEncoder, resp *wire.Response, rerr error) error { + if rerr != nil { + resp.Error = rerr.Error() + } else { + resp.Error = msgUnexpectedError + } + + return enc.writeResponse(ctx, resp, &wire.InvalidRequest{}) +} + +func closeOrSkip(c chan struct{}) { + select { + case <-c: + default: + close(c) + } +} + +func calcBufSize(x int) int { + if x < 0 { + return x + } + + if x >= 0 && x < bufSize8k { + return defBufSize + } + + if x > bufSize256k { + return bufSize512k + } + + return nextPow2(x) +} + +// nextPow2 returns a number which is the next power of 2 larger than x. +// +// The algorithm was found here http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2. +func nextPow2(x int) int { + // Special case. Don't care about negatives. + if x < 0 { + return 1 + } + + if x == 0 { + return 1 + } + + x-- + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + x++ + + return x +} diff --git a/rpcz_bench_test.go b/rpcz_bench_test.go new file mode 100644 index 0000000..4180050 --- /dev/null +++ b/rpcz_bench_test.go @@ -0,0 +1,254 @@ +package rpcz_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "math/rand" + "sync" + "testing" + "time" + + "github.com/golocron/rpcz" + "github.com/golocron/rpcz/internal/echo" +) + +const ( + alpabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + sz1k = 1 << 10 + sz4k = 4 << 10 + sz16k = 16 << 10 + sz32k = 32 << 10 + sz64k = 64 << 10 +) + +func randString(n int) string { + result := make([]byte, n) + abLen := int64(len(alpabet)) + + for i := range result { + if pos := rand.Int63() % abLen; pos < abLen { + result[i] = alpabet[pos] + } + } + + return string(result) +} + +func BenchmarkProtobuf_1K(b *testing.B) { + benchmarkRPCZ(b, sz1k, "tcp", "127.0.0.1:11201", rpcz.Protobuf, sz1k) +} + +func BenchmarkProtobuf_4K(b *testing.B) { + benchmarkRPCZ(b, sz4k, "tcp", "127.0.0.1:11202", rpcz.Protobuf, sz4k) +} + +func BenchmarkProtobuf_16K(b *testing.B) { + benchmarkRPCZ(b, sz16k, "tcp", "127.0.0.1:11203", rpcz.Protobuf, sz16k+1) +} + +func BenchmarkProtobuf_32K(b *testing.B) { + benchmarkRPCZ(b, sz32k, "tcp", "127.0.0.1:11204", rpcz.Protobuf, sz32k+1) +} + +func BenchmarkProtobuf_64K(b *testing.B) { + benchmarkRPCZ(b, sz64k, "tcp", "127.0.0.1:11205", rpcz.Protobuf, sz64k+1) +} + +func BenchmarkProtobufTLS_4K(b *testing.B) { + benchmarkRPCZ(b, sz4k, "tls", "127.0.0.1:11202", rpcz.Protobuf, sz4k) +} + +func BenchmarkProtobufNoBuf_4K(b *testing.B) { + benchmarkRPCZ(b, sz4k, "tcp", "127.0.0.1:11202", rpcz.Protobuf, -1) +} + +func BenchmarkJSON_1K(b *testing.B) { + benchmarkRPCZ(b, sz1k, "tcp", "127.0.0.1:11301", rpcz.JSON, sz1k) +} + +func BenchmarkJSON_4K(b *testing.B) { + benchmarkRPCZ(b, sz4k, "tcp", "127.0.0.1:11302", rpcz.JSON, sz4k) +} + +func BenchmarkJSON_16K(b *testing.B) { + benchmarkRPCZ(b, sz16k, "tcp", "127.0.0.1:11303", rpcz.JSON, sz16k+1) +} + +func BenchmarkJSONTLS_4K(b *testing.B) { + benchmarkRPCZ(b, sz4k, "tls", "127.0.0.1:11302", rpcz.JSON, sz4k) +} + +func benchmarkRPCZ(b *testing.B, tsize int, nw, addr string, kind rpcz.Encoding, bsize int) { + const ( + svc = "Echo" + mtd = "Echo" + ) + + srv, err := newServer(nw, addr, kind, bsize) + if err != nil { + b.Errorf("failed to init server: %v", err) + b.FailNow() + } + + esvc := echo.New() + if err := srv.Register(esvc); err != nil { + b.Errorf("failed to init server: %v", err) + b.FailNow() + } + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + + srv.Start(ctx) + }() + + for !srv.Started() { + time.Sleep(100 * time.Millisecond) + } + + cl, err := newClient(nw, addr, kind, bsize) + if err != nil { + b.Errorf("failed to init client: %v", err) + b.FailNow() + } + + msg := randString(tsize) + reqList := make(reqGen, 128) + respList := make(respGen, 128) + + b.SetBytes(2 * int64(len(msg))) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req := reqList.obtainReq() + req.Msg = msg + + resp := respList.obtainResp() + + if err := cl.SyncDo(ctx, svc, mtd, req, resp); err != nil { + reqList.retainReq(req) + respList.retainResp(resp) + + b.Errorf("failed to make req: %v", err) + return + } + + if m := resp.Msg; m != msg { + reqList.retainReq(req) + respList.retainResp(resp) + + b.Errorf("expected: %q\nactual: %q", msg, m) + return + } + + reqList.retainReq(req) + respList.retainResp(resp) + } + }) + + b.StopTimer() + + cl.Close() + srv.Shutdown(ctx) + wg.Wait() +} + +func newServer(nw, addr string, kind rpcz.Encoding, bsize int) (*rpcz.Server, error) { + cfg := &rpcz.ServerOptions{Encoding: kind, ReadBufSize: bsize, WriteBufSize: bsize} + + switch nw { + case "tls": + return newServerTLS(addr, cfg) + default: + return rpcz.NewServerWithOptions(addr, cfg) + } +} + +func newServerTLS(addr string, cfg *rpcz.ServerOptions) (*rpcz.Server, error) { + cert, err := tls.LoadX509KeyPair("testdata/localhost.crt", "testdata/localhost.key") + if err != nil { + return nil, err + } + + tcfg := &tls.Config{Certificates: []tls.Certificate{cert}} + + return rpcz.NewServerTLSOptions(addr, tcfg, cfg) +} + +func newClient(nw, addr string, kind rpcz.Encoding, bsize int) (*rpcz.Client, error) { + cfg := &rpcz.ClientOptions{Encoding: kind, ReadBufSize: bsize, WriteBufSize: bsize} + + switch nw { + case "tls": + return newClientTLS(addr, cfg) + default: + return rpcz.NewClientWithOptions(addr, cfg) + } +} + +func newClientTLS(addr string, cfg *rpcz.ClientOptions) (*rpcz.Client, error) { + cert, err := tls.LoadX509KeyPair("testdata/localhost.crt", "testdata/localhost.key") + if err != nil { + return nil, err + } + + xcert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, err + + } + + rcert := x509.NewCertPool() + rcert.AddCert(xcert) + + tcfg := &tls.Config{RootCAs: rcert, ServerName: "localhost"} + + return rpcz.NewClientTLSOptions(addr, tcfg, cfg) +} + +type reqGen chan *echo.EchoRequest + +func (r reqGen) obtainReq() *echo.EchoRequest { + select { + case req := <-r: + return req + default: + return &echo.EchoRequest{} + } +} + +func (r reqGen) retainReq(req *echo.EchoRequest) { + req.Reset() + + select { + case r <- req: + default: + } +} + +type respGen chan *echo.EchoResponse + +func (r respGen) obtainResp() *echo.EchoResponse { + select { + case req := <-r: + return req + default: + return &echo.EchoResponse{} + } +} + +func (r respGen) retainResp(resp *echo.EchoResponse) { + resp.Reset() + + select { + case r <- resp: + default: + } +} diff --git a/rpcz_pkg_test.go b/rpcz_pkg_test.go new file mode 100644 index 0000000..d4965f7 --- /dev/null +++ b/rpcz_pkg_test.go @@ -0,0 +1,205 @@ +package rpcz_test + +import ( + "context" + "sync" + "testing" + "time" + + should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" + + "github.com/golocron/rpcz" + "github.com/golocron/rpcz/internal/echo" +) + +func TestServer_Full_Protobuf_0K(t *testing.T) { + testServer_full(t, "tcp", "127.0.0.1:10001", rpcz.Protobuf, -1) +} + +func TestServer_Full_Protobuf_4K(t *testing.T) { + testServer_full(t, "tcp", "127.0.0.1:10001", rpcz.Protobuf, 4<<10) +} + +func TestServer_Full_ProtobufTLS_4K(t *testing.T) { + testServer_full(t, "tls", "127.0.0.1:10001", rpcz.Protobuf, 4<<10) +} + +func TestServer_Full_JSON_0K(t *testing.T) { + testServer_full(t, "tcp", "127.0.0.1:10001", rpcz.JSON, -1) +} + +func TestServer_Full_JSON_4K(t *testing.T) { + testServer_full(t, "tcp", "127.0.0.1:10001", rpcz.JSON, 4<<10) +} + +func TestServer_Full_JSONTLS_4K(t *testing.T) { + testServer_full(t, "tls", "127.0.0.1:10001", rpcz.JSON, 4<<10) +} + +func TestServerClient_Errors(t *testing.T) { + const ( + nw = "tcp" + addr = "127.0.0.1:10003" + kind = rpcz.Protobuf + bsize = 4 << 10 + ) + + t.Run("server_stops_early", func(t *testing.T) { + srv, err1 := newServer(nw, addr, kind, bsize) + must.Equal(t, nil, err1) + + svc := echo.New() + err2 := srv.Register(svc) + must.Equal(t, nil, err2) + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + + if err3 := srv.Start(ctx); err3 != nil { + t.Logf("srv.Start returned unexpected error: %s", err3) + } + }() + + for !srv.Started() { + time.Sleep(100 * time.Millisecond) + } + + cl, err4 := newClient(nw, addr, kind, bsize) + must.Equal(t, nil, err4) + + err5 := srv.Shutdown(ctx) + should.Equal(t, nil, err5) + + req := &echo.EchoRequest{Msg: "This is the wrong Way"} + resp := &echo.EchoResponse{} + + err6 := cl.SyncDo(ctx, "Echo", "Echo", req, resp) + should.Error(t, err6) + + wg.Wait() + + err7 := cl.Close() + should.Equal(t, nil, err7) + }) +} + +func testServer_full(t *testing.T, nw, addr string, kind rpcz.Encoding, bsize int) { + srv, err1 := newServer(nw, addr, kind, bsize) + must.Equal(t, nil, err1) + + svc := echo.New() + err2 := srv.Register(svc) + must.Equal(t, nil, err2) + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + + if err3 := srv.Start(ctx); err3 != nil { + t.Logf("srv.Start returned unexpected error: %s", err3) + } + }() + + for !srv.Started() { + time.Sleep(100 * time.Millisecond) + } + + cl, err4 := newClient(nw, addr, kind, bsize) + must.Equal(t, nil, err4) + + peer := cl.Peer() + should.Equal(t, addr, peer) + + type tcGiven struct { + svc string + mtd string + msg string + } + + type tcExpected struct { + msg string + err error + } + + type testCase struct { + name string + given *tcGiven + exp *tcExpected + } + + tests := []*testCase{ + &testCase{ + name: "invalid_service_requested", + given: &tcGiven{mtd: "Echo"}, + + exp: &tcExpected{err: rpcz.ServiceError("rpcz: invalid service")}, + }, + + &testCase{ + name: "unknown_service_requested", + given: &tcGiven{svc: "sdsds", mtd: "Echo"}, + + exp: &tcExpected{err: rpcz.ServiceError("rpcz: unknown service")}, + }, + + &testCase{ + name: "invalid_method_requested", + given: &tcGiven{svc: "Echo"}, + + exp: &tcExpected{err: rpcz.ServiceError("rpcz: invalid service method")}, + }, + + &testCase{ + name: "unknown_method_requested", + given: &tcGiven{svc: "Echo", mtd: "sdsds"}, + + exp: &tcExpected{err: rpcz.ServiceError("rpcz: unknown method")}, + }, + + &testCase{ + name: "method_returns_error", + given: &tcGiven{svc: "Echo", mtd: "Echo"}, + + exp: &tcExpected{err: rpcz.ServiceError("echo: invalid message")}, + }, + + &testCase{ + name: "successful_call", + given: &tcGiven{svc: "Echo", mtd: "Echo", msg: "This is the Way"}, + + exp: &tcExpected{msg: "This is the Way"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := &echo.EchoRequest{Msg: tc.given.msg} + resp := &echo.EchoResponse{} + + err := cl.SyncDo(ctx, tc.given.svc, tc.given.mtd, req, resp) + should.Equal(t, tc.exp.err, err) + + if tc.exp.err != nil { + return + } + + should.Equal(t, tc.exp.msg, resp.Msg) + }) + } + + err5 := cl.Close() + should.Equal(t, nil, err5) + + err6 := srv.Shutdown(ctx) + should.Equal(t, nil, err6) + + wg.Wait() +} diff --git a/rpcz_test.go b/rpcz_test.go new file mode 100644 index 0000000..731ac3f --- /dev/null +++ b/rpcz_test.go @@ -0,0 +1,574 @@ +package rpcz + +import ( + "context" + "errors" + "io" + "net" + "reflect" + "sync" + "testing" + "time" + + should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" + + "github.com/golocron/rpcz/internal/wire" +) + +func TestServer_preStart(t *testing.T) { + t.Run("invalid_options", func(t *testing.T) { + srv := &Server{} + err := srv.preStart() + should.Equal(t, ErrInvalidServerOptions, err) + }) + + t.Run("invalid_addr", func(t *testing.T) { + srv := &Server{cfg: &ServerOptions{}} + err := srv.preStart() + should.Equal(t, ErrInvalidAddr, err) + }) + + t.Run("invalid_addr", func(t *testing.T) { + addr, err1 := net.ResolveTCPAddr("tcp4", "127.0.0.1:0") + must.Equal(t, nil, err1) + + srv := &Server{cfg: &ServerOptions{}, addr: addr} + err := srv.preStart() + should.Equal(t, ErrInvalidListener, err) + }) + + t.Run("invalid_encoding", func(t *testing.T) { + ln, err1 := net.Listen("tcp4", "127.0.0.1:0") + must.Equal(t, nil, err1) + defer ln.Close() + + srv := &Server{cfg: &ServerOptions{Encoding: Encoding(3)}, addr: ln.Addr(), ln: ln} + err := srv.preStart() + should.Equal(t, ErrInvalidEncoding, err) + }) + + t.Run("invalid_stopping", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + close(srv.stopping) + + err := srv.preStart() + should.Equal(t, ErrSrvClosed, err) + }) + + t.Run("invalid_closed", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + close(srv.closed) + + err := srv.preStart() + should.Equal(t, ErrSrvClosed, err) + }) + + t.Run("invalid_started", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + close(srv.started) + + err := srv.preStart() + should.Equal(t, ErrSrvStarted, err) + }) + + t.Run("successful_call", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + err := srv.preStart() + should.Equal(t, nil, err) + }) +} + +func TestServer_Shutdown(t *testing.T) { + t.Run("not_started", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + act := srv.Shutdown(context.TODO()) + should.Equal(t, nil, act) + }) + + t.Run("ln_nil", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + + ln := srv.ln + defer ln.Close() + + srv.ln = nil + + act := srv.Shutdown(context.TODO()) + should.Equal(t, nil, act) + }) + + t.Run("already_closed", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + + srv.Start(ctx) + }() + + for !srv.Started() { + time.Sleep(100 * time.Millisecond) + } + + act1 := srv.Shutdown(ctx) + should.Equal(t, nil, act1) + + act2 := srv.Shutdown(ctx) + should.Equal(t, ErrSrvClosed, act2) + + wg.Wait() + }) + + t.Run("successful_call", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + + srv.Start(ctx) + }() + + for !srv.Started() { + time.Sleep(100 * time.Millisecond) + } + + c01 := newConn("fake_01", &fakeRWC{}) + srv.addConn(c01) + c01.toActive() + + c02 := newConn("fake_02", &fakeRWC{}) + srv.addConn(c02) + c02.toActive() + + c03 := newConn("fake_03", &fakeRWC{}) + srv.addConn(c03) + + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(20 * time.Millisecond) + c01.toClosed() + srv.delConn(c01) + + time.Sleep(20 * time.Millisecond) + c02.toIdle() + }() + + act := srv.Shutdown(ctx) + should.Equal(t, nil, act) + + wg.Wait() + }) +} + +func TestServer_closeIdle(t *testing.T) { + t.Run("all_active", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + for _, name := range []string{"fake_01", "fake_02"} { + c := newConn(name, &fakeRWC{}) + srv.addConn(c) + c.toActive() + } + + act := srv.closeIdle() + should.Equal(t, false, act) + }) + + t.Run("successful_call", func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + c01 := newConn("fake_01", &fakeRWC{}) + srv.addConn(c01) + c01.toActive() + + c02 := newConn("fake_02", &fakeRWC{}) + srv.addConn(c02) + c02.toActive() + + c03 := newConn("fake_03", &fakeRWC{}) + srv.addConn(c03) + + srv.mu.Lock() + should.Equal(t, len(srv.conns), 3) + srv.mu.Unlock() + + act1 := srv.closeIdle() + should.Equal(t, false, act1) + + srv.mu.Lock() + should.Equal(t, len(srv.conns), 2) + should.Equal(t, struct{}{}, srv.conns[c01]) + should.Equal(t, struct{}{}, srv.conns[c02]) + srv.mu.Unlock() + + c01.toIdle() + act2 := srv.closeIdle() + should.Equal(t, false, act2) + + srv.mu.Lock() + should.Equal(t, len(srv.conns), 1) + should.Equal(t, struct{}{}, srv.conns[c02]) + srv.mu.Unlock() + + c02.toIdle() + act3 := srv.closeIdle() + should.Equal(t, true, act3) + + srv.mu.Lock() + should.Equal(t, len(srv.conns), 0) + srv.mu.Unlock() + }) +} + +func TestServer_shouldHangErr(t *testing.T) { + type testCase struct { + name string + given error + exp bool + } + + tests := []*testCase{ + &testCase{name: "nil_false"}, + + &testCase{name: "EOF_true", given: io.EOF, exp: true}, + + &testCase{name: "UnexpectedEOF_true", given: io.ErrUnexpectedEOF, exp: true}, + + &testCase{name: "CtxCanelled_true", given: context.Canceled, exp: true}, + + &testCase{name: "CtxDeadlineExceeded_true", given: context.DeadlineExceeded, exp: true}, + + &testCase{name: "any_false", given: errors.New("some error")}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + srv, err1 := NewServerWithOptions("127.0.0.1:0", &ServerOptions{}) + must.Equal(t, nil, err1) + defer srv.ln.Close() + + act := srv.shouldHangErr(tc.given) + should.Equal(t, tc.exp, act) + }) + } +} + +func TestNewService(t *testing.T) { + t.Run("not_pointer", func(t *testing.T) { + type tService MockService + + svc := tService{} + + act, err := newService(svc) + should.Equal(t, errInvalidServiceType, err) + should.Equal(t, (*service)(nil), act) + }) + + t.Run("unexported_type", func(t *testing.T) { + type tService MockService + + svc := &tService{} + + act, err := newService(svc) + should.Equal(t, errUnexportedService, err) + should.Equal(t, (*service)(nil), act) + }) + + t.Run("no_methods", func(t *testing.T) { + type Service MockService + + svc := &Service{} + + act, err := newService(svc) + should.Equal(t, errInvalidServiceMethods, err) + should.Equal(t, (*service)(nil), act) + }) + + t.Run("successful_call", func(t *testing.T) { + svc := &MockService{} + + act, err := newService(svc) + must.Equal(t, nil, err) + + should.Equal(t, "MockService", act.name) + should.Equal(t, reflect.TypeOf(svc), act.recvt) + should.Equal(t, reflect.ValueOf(svc), act.recvv) + should.Equal(t, reflect.ValueOf(svc), act.recvv) + + mtd, ok := act.mset["UsableMethod"] + must.Equal(t, true, ok) + + should.Equal(t, "UsableMethod", mtd.name) + should.Equal(t, act.recvv, mtd.recvv) + + should.Equal(t, act.recvt.Method(8).Name, mtd.mtd.Name) + should.Equal(t, act.recvt.Method(8).Type, mtd.mtd.Type) + should.Equal(t, reflect.TypeOf(&UsableTypeArg{}), mtd.argt) + should.Equal(t, reflect.TypeOf(&UsableTypeRepl{}), mtd.respt) + }) +} + +func TestCreateMethods(t *testing.T) { + t.Run("multiple_conditions", func(t *testing.T) { + svc := &MockService{} + + recvt := reflect.TypeOf(svc) + must.Equal(t, recvt.Kind(), reflect.Ptr) + + recvv := reflect.ValueOf(svc) + + mtds := createMethods(recvt, recvv) + must.Equal(t, 1, len(mtds)) + + mtd, ok := mtds["UsableMethod"] + must.Equal(t, true, ok) + + should.Equal(t, "UsableMethod", mtd.name) + should.Equal(t, recvv, mtd.recvv) + should.Equal(t, recvt.Method(8).Name, mtd.mtd.Name) + should.Equal(t, recvt.Method(8).Type, mtd.mtd.Type) + should.Equal(t, reflect.TypeOf(&UsableTypeArg{}), mtd.argt) + should.Equal(t, reflect.TypeOf(&UsableTypeRepl{}), mtd.respt) + }) +} + +func TestBackoff(t *testing.T) { + t.Run("successful_case", func(t *testing.T) { + b := newDefaultBackoff() + + do := func() { + for i := 0; i < 7; i++ { + should.Equal(t, (10<