diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a9b1ecf..2cb730e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -135,7 +135,7 @@ jobs: run: | mkdir build cd build - cmake -DRWKV_AVX2=OFF -DRWKV_FMA=OFF .. + cmake -DRWKV_AVX2=OFF -DRWKV_FMA=OFF -DRWKV_SANITIZE_ADDRESS=ON .. cmake --build . --config Release - name: Test diff --git a/rwkv.cpp b/rwkv.cpp index 0d4a548..4fd4083 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #define _FILE_OFFSET_BITS 64 #define RWKV_MAYBE_BREAK @@ -459,35 +460,56 @@ struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, // --- Implementation --- -struct rwkv_layer_state { - struct ggml_tensor * ffn_xx; - struct ggml_tensor * att_xx; - struct ggml_tensor * att_aa; - struct ggml_tensor * att_bb; - struct ggml_tensor * att_pp; -}; - -struct rwkv_graph { - struct ggml_tensor * input_state; - std::unique_ptr input_layers; - std::unique_ptr output_layers; - struct ggml_tensor * token_index; - struct ggml_tensor * logits; - std::unique_ptr cgraph; +// Used to calculate the memory usage of GGML contexts before allocating them. +// Since GGML uses an internal bump allocator that can't be grown at runtime, we need to ensure we have enough space, +// while at the same time, not using more memory than necessary. +struct rwkv_ctx_size { + size_t objects_count = 0; + size_t objects_size = 0; + size_t scratch_size = 0; }; -struct rwkv_ggml_guard { +struct rwkv_ggml_context { + std::unique_ptr scratch; struct ggml_context * ctx; - ~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } } + + rwkv_ggml_context(): ctx(NULL) {} + + rwkv_ggml_context(struct rwkv_ctx_size size): ctx(NULL) { + scratch.reset(new(std::nothrow) uint8_t [size.scratch_size]); + + if (!scratch) { + return; + } + + ctx = ggml_init({ size.objects_count * GGML_OBJECT_SIZE + size.objects_size, NULL, false}); + + if (!ctx) { + return; + } + + ggml_set_scratch(ctx, { 0, size.scratch_size, scratch.get() }); + } + + struct rwkv_ggml_context & operator=(struct rwkv_ggml_context && source) { + scratch.reset(source.scratch.release()); + std::swap(ctx, source.ctx); + return *this; + } + + ~rwkv_ggml_context() { + if (ctx) { + ggml_free(ctx); + } + } }; -// An instance of an RWKV model loaded into memory: +// An instance of an RWKV model loaded into memory. // Contains all the model weights. // Shared by one or more contexts. struct rwkv_instance { + struct rwkv_ggml_context ctx; struct rwkv_model model; - struct rwkv_ggml_guard ctx; - std::unique_ptr scratch; // TODO come up with a better solution to estimate "work tensor" size. // The ggml_cgraph allocates a "work tensor" the first time it is used. @@ -499,15 +521,56 @@ struct rwkv_instance { size_t ffn_key_size; }; +// The hidden state of a single RWKV layer. +// These are mostly used for dividing up the input state, and writing portions of the output state. +// But they're also used in building the computation graphs, to represent the operations used from input->output +// (operating "in place" on a rwkv_layer_state). +struct rwkv_layer_state { + struct ggml_tensor * ffn_xx; + struct ggml_tensor * att_xx; + struct ggml_tensor * att_aa; + struct ggml_tensor * att_bb; + struct ggml_tensor * att_pp; +}; + +// Holds a single computation graph and its GGML context. +// Graphs each have their own context so that they can be individually freed and rebuilt. +// Graphs read hidden state from the rwkv_context and then write it back to the rwkv_context. +// (see rwkv_context.input_layers and rwkv_context.output_layers) +struct rwkv_graph { + struct rwkv_ggml_context ctx; + struct ggml_tensor * tokens; + + // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap + std::unique_ptr cgraph; +}; + // RWKV context for a specific instance. -// Contains the computation graph and is used for inference. +// Contains computation graphs and is used for inference. struct rwkv_context { std::shared_ptr instance; - struct ggml_context * ctx; - std::unique_ptr scratch; - struct rwkv_graph graph; + + // Reused by all graphs. + struct rwkv_ggml_context ctx; + struct ggml_tensor * input_state; + std::unique_ptr input_layers; + struct ggml_tensor * output_state; + std::unique_ptr output_layers; + struct ggml_tensor * logits; + + uint32_t n_threads; + + // The serial graph implements the traditional RNN mode that processes only one token at a time (serial mode). + struct rwkv_graph serial_graph; + + // The sequence graph implements the "sequence mode" (or transformer/GPT mode) that processes multiple tokens at a time. + // This can be an order of magnitude or so faster than serial execution if used properly. + size_t sequence_len; + struct rwkv_graph sequence_graph; + enum rwkv_error_flags last_error; bool print_errors; + size_t gpu_layers; size_t vram_total; }; @@ -580,12 +643,6 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) { return true; } -struct rwkv_ctx_size { - size_t objects_count = 0; - size_t objects_size = 0; - size_t scratch_size = 0; -}; - void rwkv_ctx_size_add_objects(struct rwkv_ctx_size & ctx_size, size_t objects, size_t object_size = sizeof(struct ggml_tensor)) { ctx_size.objects_count += objects; ctx_size.objects_size += ((object_size + 15) & ~15) * objects; @@ -615,25 +672,106 @@ void rwkv_ctx_size_add_tensor(struct rwkv_ctx_size & size, const uint64_t tensor rwkv_ctx_size_add_tensor(size, tensors, views, rwkv_type_to_ggml[header.data_type], header.width, header.height); } -struct rwkv_ctx_size rwkv_single_att_size(const size_t n_embed = 0) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - +struct rwkv_ctx_size rwkv_xx_size(const size_t n_embed = 0, const size_t sequence_len = 1) { struct rwkv_ctx_size ctx_size; - /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + if (sequence_len == 1) { + /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + } else { + /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 4, 1, GGML_TYPE_F32, n_embed, sequence_len); + + /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 1, 2, GGML_TYPE_F32, n_embed, sequence_len); + /* xx */ rwkv_ctx_size_add_objects(ctx_size, 2, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I32, 5)); + /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 1, GGML_TYPE_F32, n_embed * sequence_len - 1); + + /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 1, GGML_TYPE_F32, n_embed); + } + + return ctx_size; +} + +void rwkv_xx(struct ggml_context * ctx, struct ggml_tensor * weight, struct ggml_tensor * bias, struct ggml_tensor *& x, struct ggml_tensor *& xx, struct ggml_tensor *& state) { + size_t n_embed = x->ne[0]; + size_t sequence_len = x->ne[1]; + + if (sequence_len == 1) { + // self.layer_norm(x, self.w.blocks[i].ln2) + x = rwkv_layer_norm(ctx, x, weight, bias); + + // xx = state[5*i+0] + xx = state; + + // state[5*i+0] = x + state = x; + } else { + // self.layer_norm(x, self.w.blocks[i].ln2) + x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); + + // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) + xx = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); + xx = ggml_set_1d_inplace(ctx, xx, state, 0); + xx = ggml_set_1d_inplace(ctx, xx, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); + + // state[5*i+0] = x[-1,:] + state = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); + } +} + +struct rwkv_ctx_size rwkv_att_rkv_size(const size_t n_embed = 0, const size_t sequence_len = 1) { + size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + struct rwkv_ctx_size ctx_size; + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + + /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); + /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* k */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* v */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* k */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); + /* v */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); + + return ctx_size; +} +void rwkv_att_rkv(struct ggml_context * ctx, struct rwkv_layer layer, struct ggml_tensor * x0, struct ggml_tensor * xx, struct ggml_tensor *& r, struct ggml_tensor *& k, struct ggml_tensor *& v) { + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_k), + ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ); + + // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + struct ggml_tensor * xv = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_v), + ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ); + + // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_r), + ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ); + + // r = torch.sigmoid(rw @ xr) + r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + // k = kw @ xk + k = ggml_mul_mat(ctx, layer.att_key, xk); + // v = vw @ xv + v = ggml_mul_mat(ctx, layer.att_value, xv); +} + +struct rwkv_ctx_size rwkv_att_wkv_size(const size_t n_embed = 0) { + size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); + + struct rwkv_ctx_size ctx_size; /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); @@ -653,129 +791,117 @@ struct rwkv_ctx_size rwkv_single_att_size(const size_t n_embed = 0) { /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); /* aa */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); /* bb */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); /* pp */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); /* wkv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); return ctx_size; } -struct ggml_tensor * rwkv_single_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer & layer, struct rwkv_layer_state & state) { - // self.layer_norm(x, self.w.blocks[i].ln1) - struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - - // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - struct ggml_tensor * xk = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, state.att_xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) - ); - - // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) - struct ggml_tensor * xv = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, state.att_xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) - ); - - // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_tensor * xr = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, state.att_xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) - ); - - // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); - // k = kw @ xk - struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); - // v = vw @ xv - struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); - +struct ggml_tensor * rwkv_att_wkv(struct ggml_context * ctx, struct ggml_tensor * att_time_first, struct ggml_tensor * att_time_decay, struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor *& aa, struct ggml_tensor *& bb, struct ggml_tensor *& pp) { // ww = time_first + k - struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); + struct ggml_tensor * ww = ggml_add(ctx, att_time_first, k); // qq = torch.maximum(pp, ww) - struct ggml_tensor * qq = rwkv_max(ctx, state.att_pp, ww); + struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); // e1 = torch.exp(pp - qq) - struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, state.att_pp, qq)); + struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); // e2 = torch.exp(ww - qq) struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); // a = e1 * aa + e2 * v - struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_aa), ggml_mul(ctx, e2, v)); + struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); // b = e1 * bb + e2 - struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_bb), e2); + struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); // ww = pp + time_decay - ww = ggml_add(ctx, state.att_pp, layer.att_time_decay); + ww = ggml_add(ctx, pp, att_time_decay); // qq = torch.maximum(ww, k) qq = rwkv_max(ctx, ww, k); // e1 = torch.exp(ww - qq) e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - // e2 = torch.exp(k - qq) + // e2 = torch.exp(k[t] - qq) e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); - // state[5 * i + 1] = x0 // state[5 * i + 2] = e1 * aa + e2 * v // state[5 * i + 3] = e1 * bb + e2 // state[5 * i + 4] = qq - state.att_xx = x0; - state.att_aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_aa), ggml_mul(ctx, e2, v)); - state.att_bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_bb), e2); - state.att_pp = qq; + aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + pp = qq; // wkv = a / b - struct ggml_tensor * wkv = ggml_div(ctx, a, b); - - // ow @ (r * wkv) - return ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv))); + return ggml_div(ctx, a, b); } -struct rwkv_ctx_size rwkv_single_ffn_size(const size_t n_embed = 0, const size_t ffn_key = 0) { +struct rwkv_ctx_size rwkv_att_size(const size_t n_embed = 0) { size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); struct rwkv_ctx_size ctx_size; + /* xx */ rwkv_ctx_size_add(ctx_size, 1, rwkv_xx_size(n_embed)); + /* rkv */ rwkv_ctx_size_add(ctx_size, 1, rwkv_att_rkv_size(n_embed)); + /* wkv */ rwkv_ctx_size_add(ctx_size, 1, rwkv_att_wkv_size(n_embed)); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + + return ctx_size; +} + +struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x0 = x, * xx; + rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + + struct ggml_tensor * r, * k, * v; + rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + + struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); + + // ow @ (r * xx) + return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); +} + +struct rwkv_ctx_size rwkv_ffn_size(const size_t n_embed = 0, const size_t ffn_key = 0, const size_t sequence_len = 1) { + size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + struct rwkv_ctx_size ctx_size; + /* xx */ rwkv_ctx_size_add(ctx_size, 1, rwkv_xx_size(n_embed, sequence_len)); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* k */ rwkv_ctx_size_add_tensor(ctx_size, 3, 0, GGML_TYPE_F32, ffn_key); + /* k */ rwkv_ctx_size_add_tensor(ctx_size, 3, 0, GGML_TYPE_F32, ffn_key, sequence_len); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); return ctx_size; } -struct ggml_tensor * rwkv_single_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer & layer, struct rwkv_layer_state & state) { - // self.layer_norm(x, self.w.blocks[i].ln2) - struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); +struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x0 = x, * xx; + rwkv_xx(ctx, layer.ln2_weight, layer.ln2_bias, x0, xx, state.ffn_xx); + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, state.ffn_xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ); // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, state.ffn_xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) ); - // state[5 * i + 0] = x - state.ffn_xx = x0; - // r = torch.sigmoid(rw @ xr) struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); @@ -783,106 +909,169 @@ struct ggml_tensor * rwkv_single_ffn(struct ggml_context * ctx, struct ggml_tens struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); // r * (vw @ k) - return ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k))); + return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } -struct rwkv_ctx_size rwkv_single_graph_size(const size_t n_vocab = 0, const size_t n_embed = 0, const size_t n_layer = 0, const size_t ffn_key = 0) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - +struct rwkv_ctx_size rwkv_serial_graph_size(const size_t n_vocab, const size_t n_embed, const size_t n_layer, const size_t ffn_key_size) { struct rwkv_ctx_size ctx_size; - - /* state */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_layer * 5 * n_embed); - /* token */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, 1); /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* ffn_xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* att_xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* att_aa */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* att_bb */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* att_pp */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + /* att */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_att_size(n_embed)); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_ffn_size(n_embed, ffn_key_size)); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* att */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_single_att_size(n_embed)); - /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_single_ffn_size(n_embed, ffn_key)); + /* output */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * 5, GGML_TYPE_F32, n_embed); /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_vocab); + /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_vocab); return ctx_size; } -bool rwkv_single_graph(struct ggml_context * ctx, struct rwkv_model & model, const uint32_t n_threads, struct rwkv_graph & out) { - std::unique_ptr cgraph(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph"); - cgraph->n_threads = n_threads; - +bool rwkv_build_serial_graph( + struct ggml_context * ctx, + struct rwkv_model & model, + struct ggml_tensor * tokens, + struct rwkv_layer_state * inputs, + struct rwkv_layer_state * outputs, + struct ggml_tensor * logits, + struct ggml_cgraph * cgraph +) { size_t n_embed = model.header.n_embed; - size_t n_layer = model.header.n_layer; - - struct ggml_tensor * input_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); - size_t output_part_size = n_embed * sizeof(float); - - // We collect parts of input state here. Each part is (n_embed) vector. - std::unique_ptr input_layers(new(std::nothrow) struct rwkv_layer_state [n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, input_layers.get(), "Failed to allocate input state parts"); - - // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr output_layers(new(std::nothrow) struct rwkv_layer_state [n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, output_layers.get(), "Failed to allocate output state parts"); // x = self.w.emb.weight[token] - struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, token_index); + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); // x = self.layer_norm(x, self.w.blocks[0].ln0) x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); - for (size_t i = 0; i < n_layer; i++) { + for (size_t i = 0; i < model.header.n_layer; i++) { struct rwkv_layer & layer = model.layers[i]; - struct rwkv_layer_state & input_layer = input_layers[i]; - struct rwkv_layer_state & output_layer = output_layers[i]; - - size_t state_index = i * 5; - input_layer.ffn_xx = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 0)); - input_layer.att_xx = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 1)); - input_layer.att_aa = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 2)); - input_layer.att_bb = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 3)); - input_layer.att_pp = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 4)); - output_layer = input_layer; - - x = rwkv_single_att(ctx, x, layer, output_layer); - x = rwkv_single_ffn(ctx, x, layer, output_layer); + + struct rwkv_layer_state state = inputs[i]; + x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + + struct rwkv_layer_state & output = outputs[i]; + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.ffn_xx, output.ffn_xx)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_xx, output.att_xx)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_aa, output.att_aa)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_bb, output.att_bb)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); } - // x = self.layer_norm(x, self.w.ln_out) + // x = self.layer_norm(x[-1,:], self.w.ln_out) x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() - struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); - ggml_build_forward_expand(cgraph.get(), logits); + return true; +} - for (uint32_t i = 0; i < n_layer; i++) { - struct rwkv_layer_state & output_layer = output_layers[i]; - ggml_build_forward_expand(cgraph.get(), output_layer.ffn_xx); - ggml_build_forward_expand(cgraph.get(), output_layer.att_xx); - ggml_build_forward_expand(cgraph.get(), output_layer.att_aa); - ggml_build_forward_expand(cgraph.get(), output_layer.att_bb); - ggml_build_forward_expand(cgraph.get(), output_layer.att_pp); +struct rwkv_ctx_size rwkv_sequence_graph_size(const size_t n_vocab, const size_t n_embed, const size_t n_layer, const size_t ffn_key_size, const size_t sequence_len) { + struct rwkv_ctx_size ctx_size; + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 4, 1, GGML_TYPE_F32, n_embed, sequence_len); + + /* xx */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_xx_size(n_embed, sequence_len)); + /* rkv */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_att_rkv_size(n_embed, sequence_len)); + + /* kt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); + /* vt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); + /* xt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); + /* wkv */ rwkv_ctx_size_add(ctx_size, n_layer * sequence_len, rwkv_att_wkv_size(n_embed)); + /* xt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, n_layer * 2, 0, GGML_TYPE_F32, n_embed, sequence_len); + + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed, sequence_len); + /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_ffn_size(n_embed, ffn_key_size, sequence_len)); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed, sequence_len); + + /* output */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * 5, GGML_TYPE_F32, n_embed); + + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 2, GGML_TYPE_F32, n_embed); + /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_vocab); + + return ctx_size; +} + +bool rwkv_build_sequence_graph( + struct ggml_context * ctx, + struct rwkv_model & model, + struct ggml_tensor * tokens, + struct rwkv_layer_state * inputs, + struct rwkv_layer_state * outputs, + struct ggml_tensor * logits, + struct ggml_cgraph * cgraph +) { + const uint32_t n_embed = model.header.n_embed; + const size_t sequence_len = tokens->ne[0]; + + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); + x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, model.ln0_weight, x), ggml_repeat(ctx, model.ln0_bias, x)); + + for (size_t i = 0; i < model.header.n_layer; i++) { + struct rwkv_layer & layer = model.layers[i]; + struct rwkv_layer_state state = inputs[i]; + + struct ggml_tensor * x0 = x, * xx; + rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + + struct ggml_tensor * r, * k, * v; + rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + + ggml_build_forward_expand(cgraph, r); + + for (uint32_t t = 0; t < sequence_len; t++) { + struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, xx, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, wkv, xt)); + } + + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, xx))); + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + + struct rwkv_layer_state & output = outputs[i]; + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.ffn_xx, output.ffn_xx)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_xx, output.att_xx)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_aa, output.att_aa)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_bb, output.att_bb)); + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); } - out.input_state = input_state; - out.input_layers = std::move(input_layers); - out.output_layers = std::move(output_layers); - out.token_index = token_index; - out.logits = logits; - out.cgraph = std::move(cgraph); + // x = self.layer_norm(x[-1,:], self.w.ln_out) + x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias); + + // x = (self.w.head.weight @ x).float() + ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); + return true; } -struct rwkv_file_guard { +size_t rwkv_estimate_graph_work(const enum ggml_type type, const size_t ffn_key_size, const uint32_t n_threads, const size_t sequence_len = 1) { +#ifdef GGML_USE_CUBLAS + enum ggml_type mul_mat_type = GGML_TYPE_F16; +#else + enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type; +#endif + return rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key_size, sequence_len) * n_threads + 64 * (n_threads - 1)); +} + +struct rwkv_file { FILE * file; - ~rwkv_file_guard() { if (file) { fclose(file); } } + + rwkv_file(FILE * file): file(file) {} + + ~rwkv_file() { + if (file) { + fclose(file); + } + } }; void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { @@ -902,59 +1091,50 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { } bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { - FILE * file = fopen(file_path, "rb"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); - rwkv_file_guard file_guard { file }; - - // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. struct stat file_stat; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); + struct rwkv_model model; + struct rwkv_ggml_context ctx; + size_t ffn_key_size = 0; - struct rwkv_file_header header; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file, header), "Invalid file header"); + std::unordered_map parameters; - size_t tensors_start = ftell(file); - struct rwkv_ctx_size ctx_size; + { + rwkv_file file(fopen(file_path, "rb")); - std::string name; - instance.ffn_key_size = 0; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file.file, "Failed to open file %s", file_path); + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file.file), &file_stat) == 0, "Failed to stat file %s", file_path); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file.file, model.header), "Invalid file header"); - while ((size_t) ftell(file) < (size_t) file_stat.st_size) { - struct rwkv_tensor_header header; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data"); - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header); + struct rwkv_tensor_header tensor_header; + std::string name; + struct rwkv_ctx_size ctx_size; - if (instance.ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { - instance.ffn_key_size = header.height; - } - } + while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file.file, tensor_header), "Invalid tensor header"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, instance.ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file"); + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); - std::unique_ptr scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate model scratch space"); + if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { + ffn_key_size = tensor_header.height; + } + } - struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false}); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context"); - rwkv_ggml_guard ggml_guard { ctx }; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); - std::unordered_map parameters; - ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() }); + ctx = ctx_size; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx.ctx, "Failed to allocate model context"); - while ((size_t) ftell(file) < (size_t) file_stat.st_size) { - std::string name; struct ggml_tensor * tensor; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file, ctx, name, tensor), "Failed to read model params"); - parameters[std::move(name)] = tensor; - } - file = NULL; - file_guard = { NULL }; - - struct rwkv_model model { header }; + while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor), "Failed to read model params"); + parameters[std::move(name)] = tensor; + } + } std::unordered_map & parameters_ref = parameters; RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { @@ -967,15 +1147,12 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst // Verify order of dimensions struct ggml_tensor * emb = model.emb; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model.header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model.header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - // Don't free ggml context now - ggml_guard.ctx = NULL; - // Attach ggml context to instance - instance.ctx.ctx = ctx; + instance.ctx = std::move(ctx); instance.model = std::move(model); - instance.scratch = std::move(scratch); + instance.ffn_key_size = ffn_key_size; return true; } @@ -983,51 +1160,85 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptrmodel.header; + const size_t n_vocab = header.n_vocab; + const size_t n_embed = header.n_embed; + const size_t n_layer = header.n_layer; + + struct rwkv_ctx_size ctx_size; + /* input */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed * 5 * n_layer); + /* output */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed * 5 * n_layer); + /* inputs */ rwkv_ctx_size_add_tensor(ctx_size, 0, 5 * n_layer, GGML_TYPE_F32, n_embed); + /* outputs */ rwkv_ctx_size_add_tensor(ctx_size, 0, 5 * n_layer, GGML_TYPE_F32, n_embed); + /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_vocab); - rwkv_ctx_size ctx_size; - rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, instance->ffn_key_size)); - // And finally to end it all off: the graph work tensor - enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type]; - rwkv_ctx_size_add(ctx_size, 1, rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, instance->ffn_key_size) * n_threads + 64 * (n_threads - 1))); + struct rwkv_ggml_context ctx(ctx_size); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx.ctx, "Failed to allocate model context"); - std::unique_ptr scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate graph scratch space (%d)", ctx_size.scratch_size); + struct ggml_tensor * input = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false}); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context"); - rwkv_ggml_guard ggml_guard { ctx }; + // We collect parts of input state here. Each part is (n_embed) vector. + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); - ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() }); + // We collect parts of output state here. Each part is (n_embed) vector. + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); - // Build graph - struct rwkv_graph graph; - RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, instance->model, n_threads, graph)); + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_layer_state & input_state = inputs[i]; + input_state.ffn_xx = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); + input_state.att_xx = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); + input_state.att_aa = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); + input_state.att_bb = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); + input_state.att_pp = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + + struct rwkv_layer_state & output_state = outputs[i]; + output_state.ffn_xx = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); + output_state.att_xx = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); + output_state.att_aa = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); + output_state.att_bb = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); + output_state.att_pp = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + } - std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); + struct ggml_tensor * logits = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_vocab); - // Don't free ggml context - ggml_guard.ctx = NULL; + struct rwkv_ctx_size graph_ctx_size; + /* token */ rwkv_ctx_size_add_objects(graph_ctx_size, 1, sizeof(struct ggml_tensor) + sizeof(uint32_t)); + /* graph */ rwkv_ctx_size_add(graph_ctx_size, 1, rwkv_serial_graph_size(n_vocab, n_embed, n_layer, instance->ffn_key_size)); + /* work */ rwkv_ctx_size_add(graph_ctx_size, 1, rwkv_estimate_graph_work(rwkv_type_to_ggml[header.data_type], instance->ffn_key_size, n_threads)); + struct rwkv_graph serial_graph; + serial_graph.ctx = graph_ctx_size; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, serial_graph.ctx.ctx, "Failed to allocate serial graph context"); + serial_graph.tokens = ggml_new_i32(serial_graph.ctx.ctx, 0); + serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph"); + serial_graph.cgraph->n_threads = n_threads; + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(serial_graph.ctx.ctx, instance->model, serial_graph.tokens, inputs.get(), outputs.get(), logits, serial_graph.cgraph.get())); + + std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context"); rwkv_ctx->instance = std::move(instance); - rwkv_ctx->ctx = ctx; - rwkv_ctx->scratch = std::move(scratch); - rwkv_ctx->graph = std::move(graph); + rwkv_ctx->ctx = std::move(ctx); + rwkv_ctx->input_state = input; + rwkv_ctx->input_layers = std::move(inputs); + rwkv_ctx->output_state = output; + rwkv_ctx->output_layers = std::move(outputs); + rwkv_ctx->logits = logits; + rwkv_ctx->n_threads = n_threads; + rwkv_ctx->serial_graph = std::move(serial_graph); rwkv_ctx->last_error = RWKV_ERROR_NONE; rwkv_ctx->print_errors = global_print_errors; - rwkv_ctx->gpu_layers = 0; - rwkv_ctx->vram_total = 0; - return rwkv_ctx.release(); } struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { global_last_error = RWKV_ERROR_NONE; - std::shared_ptr instance(new(std::nothrow) struct rwkv_instance); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance.get(), "Failed to allocate instance"); + std::shared_ptr instance(new(std::nothrow) struct rwkv_instance()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance, "Failed to allocate instance"); RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get())); - return rwkv_new_context_impl(instance, n_threads); } @@ -1070,46 +1281,84 @@ bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_g return true; } +void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) { + if (state_in) { + memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state)); + } else { + ggml_set_f32(ctx->input_state, 0.0F); + + for (size_t i = 0; i < ctx->instance->model.header.n_layer; i++) { + ggml_set_f32(ctx->input_layers[i].att_pp, -1e30F); + } + } +} + +void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float * logits_out) { + if (state_out) { + memcpy(state_out, ctx->output_state->data, ggml_nbytes(ctx->output_state)); + } + + if (logits_out) { + memcpy(logits_out, ctx->logits->data, ggml_nbytes(ctx->logits)); + } +} + bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; const struct rwkv_file_header & header = ctx->instance->model.header; - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1); - - const struct rwkv_graph & graph = ctx->graph; - ggml_set_i32_1d(graph.token_index, 0, token); - - if (state_in == NULL) { - for (size_t i = 0; i < header.n_layer; i++) { - struct rwkv_layer_state & layer = graph.input_layers[i]; - ggml_set_f32(layer.ffn_xx, 0.0F); - ggml_set_f32(layer.att_xx, 0.0F); - ggml_set_f32(layer.att_aa, 0.0F); - ggml_set_f32(layer.att_bb, 0.0F); - ggml_set_f32(layer.att_pp, -1e30F); - } - } else { - memcpy(graph.input_state->data, state_in, ggml_nbytes(graph.input_state)); - } + const size_t n_vocab = header.n_vocab; + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token (%" PRId32 ") is out of range (0 ..= %zu)", token, n_vocab - 1); - ggml_graph_compute(ctx->ctx, graph.cgraph.get()); + rwkv_set_inputs(ctx, state_in); + ggml_set_i32(ctx->serial_graph.tokens, token); + ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get()); + rwkv_get_outputs(ctx, state_out, logits_out); - if (state_out) { - size_t part_size = rwkv_tensor_size(GGML_TYPE_F32, header.n_embed); - for (size_t i = 0; i < header.n_layer; i++) { - struct rwkv_layer_state & layer = graph.output_layers[i]; - - float * dest = state_out + i * header.n_embed * 5; - memcpy(dest + header.n_embed * 0, layer.ffn_xx->data, part_size); - memcpy(dest + header.n_embed * 1, layer.att_xx->data, part_size); - memcpy(dest + header.n_embed * 2, layer.att_aa->data, part_size); - memcpy(dest + header.n_embed * 3, layer.att_bb->data, part_size); - memcpy(dest + header.n_embed * 4, layer.att_pp->data, part_size); + return true; +} + +bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequence, const size_t sequence_len, const float * state_in, float * state_out, float * logits_out) { + ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; + + const struct rwkv_file_header & header = ctx->instance->model.header; + const size_t n_vocab = header.n_vocab; + const size_t n_embed = header.n_embed; + const size_t n_layer = header.n_layer; + + if (sequence) { + for (size_t i = 0; i < sequence_len; i++) { + const uint32_t token = sequence[i]; + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Tokens[%zu] (%" PRId32 ") is out of range (0 ..= %zu)", i, token, n_vocab - 1); } } - if (logits_out) { - memcpy(logits_out, graph.logits->data, ggml_nbytes(graph.logits)); + if (ctx->sequence_len != sequence_len) { + // Build new sequence graph + struct rwkv_ctx_size ctx_size; + /* tokens */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, sequence_len); + /* graph */ rwkv_ctx_size_add(ctx_size, 1, rwkv_sequence_graph_size(n_vocab, n_embed, n_layer, ctx->instance->ffn_key_size, sequence_len)); + /* work */ rwkv_ctx_size_add(ctx_size, 1, rwkv_estimate_graph_work(rwkv_type_to_ggml[header.data_type], ctx->instance->ffn_key_size, 1, sequence_len)); + + struct rwkv_graph graph; + graph.ctx = ctx_size; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, graph.ctx.ctx, "Failed to allocate sequence graph context"); + graph.tokens = ggml_new_tensor_1d(graph.ctx.ctx, GGML_TYPE_I32, sequence_len); + graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, graph.cgraph, "Failed to allocate sequence graph"); + graph.cgraph->n_threads = 1; + RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(graph.ctx.ctx, ctx->instance->model, graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, graph.cgraph.get())); + + ((struct rwkv_context *) ctx)->sequence_len = sequence_len; + ((struct rwkv_context *) ctx)->sequence_graph = std::move(graph); + } + + // Allow building the sequence graph without actually evaluating, by specifying sequence = NULL. + if (sequence) { + rwkv_set_inputs(ctx, state_in); + memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); + ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get()); + rwkv_get_outputs(ctx, state_out, logits_out); } return true; @@ -1125,7 +1374,6 @@ uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); - ggml_free(ctx->ctx); } bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { @@ -1137,19 +1385,18 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_MSG("Loading model from '%s'\n", in_path); struct stat in_stat; - FILE * in_file = fopen(in_path, "rb"); - rwkv_file_guard in_guard { in_file }; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, in_file, "Failed to open %s for reading", in_path); - FILE * out_file = fopen(out_path, "wb"); - rwkv_file_guard out_guard { out_file }; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, out_file, "Failed to open %s for writing", out_path); + struct rwkv_file in_file(fopen(in_path, "rb")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, in_file.file, "Failed to open %s for reading", in_path); // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(in_file), &in_stat) == 0, "failed to stat file %s", in_path); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(in_file.file), &in_stat) == 0, "failed to stat file %s", in_path); + + struct rwkv_file out_file(fopen(out_path, "wb")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, out_file.file, "Failed to open %s for writing", out_path); struct rwkv_file_header in_header; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(in_file, in_header), "Invalid file header"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(in_file.file, in_header), "Invalid file header"); enum ggml_type in_type = rwkv_type_to_ggml[in_header.data_type]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, "Unsupported input data type (%s); needs to be f32 or f16", rwkv_type_to_string[rwkv_type_from_ggml[in_type]]); @@ -1157,7 +1404,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const struct rwkv_file_header out_header = in_header; out_header.version = RWKV_FILE_VERSION; out_header.data_type = rwkv_type_from_ggml[out_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fwrite_file_header(out_file, out_header), "Failed to write file header"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fwrite_file_header(out_file.file, out_header), "Failed to write file header"); // Process parameters size_t orig_total_size = 0; @@ -1171,9 +1418,9 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const size_t max_out_size = 0; size_t max_key_length = 0; - while (ftell(in_file) < in_stat.st_size) { + while (ftell(in_file.file) < in_stat.st_size) { struct rwkv_tensor_header header; - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_and_skip(in_file, header)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_and_skip(in_file.file, header)); size_t in_size = rwkv_tensor_size(header); @@ -1205,8 +1452,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const } } - rewind(in_file); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(in_file, sizeof(struct rwkv_file_header), SEEK_CUR) == 0); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(in_file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! int64_t hist_all[16] {}; @@ -1222,16 +1468,16 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const std::string & name = tensor.name; uint8_t *& data = tensor.data; - while (ftell(in_file) < in_stat.st_size) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(in_file, header), "Failed to read tensor header"); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(in_file, header.key_length, name), "Failed to read tensor name"); + while (ftell(in_file.file) < in_stat.st_size) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(in_file.file, header), "Failed to read tensor header"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(in_file.file, header.key_length, name), "Failed to read tensor name"); const char * name_str = name.c_str(); RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); data = header.data_type == TYPE_F16 ? out_buf : in_buf; size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file, orig_size, data), "\nFailed to read tensor data of %s", name_str); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); // Quantize only 2D tensors, except embedding and head matrices. // Embedding and head take not too much space, especially in bigger models; @@ -1262,7 +1508,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_MSG("size = %8.3f MB\n", orig_size / 1024.0 / 1024.0); } - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file, tensor), "Failed to write tensor %s", name_str); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file.file, tensor), "Failed to write tensor %s", name_str); orig_total_size += orig_size; new_total_size += new_size; } diff --git a/rwkv.h b/rwkv.h index b62b40c..38bf906 100644 --- a/rwkv.h +++ b/rwkv.h @@ -106,10 +106,22 @@ extern "C" { // Returns false on any error. Error messages would be printed to stderr. // - token: next token index, in range 0 <= token < n_vocab. // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. - // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. - // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL. RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); + // Evaluates the model for a sequence of tokens. + // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so. + // Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. + // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed. (Useful for initialization.) + // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. + // Returns false on any error. Error messages would be printed to stderr. + // - sequence_len: number of tokens to read from the array. + // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count, or NULL if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL. + RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out); + // Returns count of FP32 elements in state buffer. RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx); diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 286e528..3b8e075 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -41,6 +41,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl float * logits = malloc(sizeof(float) * n_vocab); char * prompt = "\"in"; + uint32_t prompt_seq[] = { '"', 'i', 'n' }; const size_t prompt_length = strlen(prompt); @@ -59,6 +60,19 @@ void test_model(const char * model_path, const float * expected_logits, const fl // When something breaks, difference would be way more than 10 ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + rwkv_eval_sequence(model, prompt_seq, prompt_length, NULL, state, logits); + + diff_sum = 0.0F; + + for (uint32_t i = 0; i < n_vocab; i++) { + diff_sum += logits[i] - expected_logits[i]; + } + + fprintf(stderr, "Sequence difference sum: %f\n", diff_sum); + + // When something breaks, difference would be way more than 10 + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + rwkv_free(model); free(state);