From 01f2907c7eb2dbadc0a36d7c5233d454bb6370a4 Mon Sep 17 00:00:00 2001 From: LoganDark Date: Wed, 21 Jun 2023 15:23:28 -0700 Subject: [PATCH] Rename xx to x_prev probably should slip this in now before we forget it's a thing. --- rwkv.cpp | 122 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index 13559fa..c7a5950 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -760,52 +760,58 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) { return true; } -void rwkv_future_xx(struct rwkv_future_ctx & ctx, +void rwkv_future_carry_x(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor weight, const struct rwkv_future_tensor bias, struct rwkv_future_tensor & x, - struct rwkv_future_tensor & xx, - struct rwkv_future_tensor & state + struct rwkv_future_tensor & x_prev, + struct rwkv_future_tensor & carry ) { if (x.height == 1) { x = x.layer_norm(ctx, weight, bias); - xx = state; - state = x; + x_prev = carry; + carry = x; } else { x = x.layer_norm(ctx, weight.repeat(ctx, x), bias.repeat(ctx, x)); - xx = x.dup(ctx) - .set_inplace(ctx, state) + x_prev = x.dup(ctx) + .set_inplace(ctx, carry) .set_inplace(ctx, x.subview(ctx, x.width, x.height - 1)); - state = x.subview(ctx, x.width); + carry = x.subview(ctx, x.width); } } -void rwkv_xx(struct ggml_context * ctx, struct ggml_tensor * weight, struct ggml_tensor * bias, struct ggml_tensor *& x, struct ggml_tensor *& xx, struct ggml_tensor *& state) { - size_t n_embed = x->ne[0]; - size_t sequence_len = x->ne[1]; +void rwkv_carry_x(struct ggml_context * ctx, + struct ggml_tensor * weight, + struct ggml_tensor * bias, + struct ggml_tensor *& x, + struct ggml_tensor *& x_prev, + struct ggml_tensor *& carry +) { + const size_t n_embed = x->ne[0]; + const size_t sequence_len = x->ne[1]; if (sequence_len == 1) { // self.layer_norm(x, self.w.blocks[i].ln2) x = rwkv_layer_norm(ctx, x, weight, bias); // xx = state[5*i+0] - xx = state; + x_prev = carry; // state[5*i+0] = x - state = x; + carry = x; } else { // self.layer_norm(x, self.w.blocks[i].ln2) x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) - xx = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); - xx = ggml_set_1d_inplace(ctx, xx, state, 0); - xx = ggml_set_1d_inplace(ctx, xx, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); + x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0); + x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); // state[5*i+0] = x[-1,:] - state = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); + carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); } } @@ -813,8 +819,8 @@ void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor time_mix_k, const struct rwkv_future_tensor time_mix_v, const struct rwkv_future_tensor time_mix_r, - const struct rwkv_future_tensor x0, - const struct rwkv_future_tensor xx, + const struct rwkv_future_tensor x, + const struct rwkv_future_tensor x_prev, const struct rwkv_future_tensor att_r, const struct rwkv_future_tensor att_k, const struct rwkv_future_tensor att_v, @@ -822,9 +828,9 @@ void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx, struct rwkv_future_tensor & k, struct rwkv_future_tensor & v ) { - const struct rwkv_future_tensor xk = x0.combine(ctx, time_mix_k).consume(ctx, xx.combine(ctx, time_mix_k.fn(ctx))); - const struct rwkv_future_tensor xv = x0.combine(ctx, time_mix_v).consume(ctx, xx.combine(ctx, time_mix_v.fn(ctx))); - const struct rwkv_future_tensor xr = x0.combine(ctx, time_mix_r).consume(ctx, xx.combine(ctx, time_mix_r.fn(ctx))); + const struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); + const struct rwkv_future_tensor xv = x.combine(ctx, time_mix_v).consume(ctx, x_prev.combine(ctx, time_mix_v.fn(ctx))); + const struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); r = att_r.mul_mat(ctx, xr).fn(ctx); k = att_k.mul_mat(ctx, xk); @@ -834,28 +840,28 @@ void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx, void rwkv_att_rkv( struct ggml_context * ctx, struct rwkv_layer layer, - struct ggml_tensor * x0, - struct ggml_tensor * xx, + struct ggml_tensor * x, + struct ggml_tensor * x_prev, struct ggml_tensor *& r, struct ggml_tensor *& k, struct ggml_tensor *& v ) { // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ggml_mul(ctx, x, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) ); // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) struct ggml_tensor * xv = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ggml_mul(ctx, x, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) ); // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ggml_mul(ctx, x, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) @@ -954,17 +960,17 @@ struct rwkv_future_tensor rwkv_future_att(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor att_k, const struct rwkv_future_tensor att_v, const struct rwkv_future_tensor att_output, - const struct rwkv_future_tensor x, + struct rwkv_future_tensor x, struct rwkv_future_tensor & att_xx, struct rwkv_future_tensor & att_aa, struct rwkv_future_tensor & att_bb, struct rwkv_future_tensor & att_pp ) { - struct rwkv_future_tensor x0 = x, xx; - rwkv_future_xx(ctx, ln1_weight, ln1_bias, x0, xx, att_xx); + struct rwkv_future_tensor x_prev; + rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x, x_prev, att_xx); struct rwkv_future_tensor r, k, v; - rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x0, xx, att_r, att_k, att_v, r, k, v); + rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x, x_prev, att_r, att_k, att_v, r, k, v); struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, time_first, time_decay, att_aa, att_bb, att_pp, k, v); @@ -972,11 +978,11 @@ struct rwkv_future_tensor rwkv_future_att(struct rwkv_future_ctx & ctx, } struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); @@ -992,14 +998,14 @@ struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor ffn_k, const struct rwkv_future_tensor ffn_v, const struct rwkv_future_tensor ffn_r, - const struct rwkv_future_tensor x, + struct rwkv_future_tensor x, struct rwkv_future_tensor & ffn_xx ) { - struct rwkv_future_tensor x0 = x, xx; - rwkv_future_xx(ctx, ln2_weight, ln2_bias, x0, xx, ffn_xx); + struct rwkv_future_tensor x_prev; + rwkv_future_carry_x(ctx, ln2_weight, ln2_bias, x, x_prev, ffn_xx); - struct rwkv_future_tensor xk = x0.combine(ctx, time_mix_k).consume(ctx, xx.combine(ctx, time_mix_k.fn(ctx))); - struct rwkv_future_tensor xr = x0.combine(ctx, time_mix_r).consume(ctx, xx.combine(ctx, time_mix_r.fn(ctx))); + struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); + struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); struct rwkv_future_tensor r = ffn_r.mul_mat(ctx, xr).fn(ctx); struct rwkv_future_tensor k = ffn_k.mul_mat(ctx, xk).view(ctx).view(ctx); @@ -1008,22 +1014,22 @@ struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx, } struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln2_weight, layer.ln2_bias, x0, xx, state.ffn_xx); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace( ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_mul(ctx, x, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ); // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace( ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_mul(ctx, x, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) @@ -1195,21 +1201,21 @@ struct rwkv_future_tensor rwkv_future_sequence_graph(struct rwkv_future_ctx & ct x = x.layer_norm(ctx, ln0_weight.repeat(ctx, x), ln0_bias.repeat(ctx, x)); for (size_t i = 0; i < n_layer; i++) { - struct rwkv_future_tensor x0 = x, xx; - rwkv_future_xx(ctx, ln1_weight, ln1_bias, x0, xx, att_xx); + struct rwkv_future_tensor x0 = x, x_prev; + rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x0, x_prev, att_xx); struct rwkv_future_tensor r, k, v; - rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, xx, att_r, att_k, att_v, r, k, v); + rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, x_prev, att_r, att_k, att_v, r, k, v); for (size_t i = 0; i < tokens.width; i++) { struct rwkv_future_tensor kt = k.subview(ctx, emb.width); struct rwkv_future_tensor vt = v.subview(ctx, emb.width); - struct rwkv_future_tensor xt = xx.subview(ctx, emb.width); + struct rwkv_future_tensor xt = x_prev.subview(ctx, emb.width); struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, att_time_first, att_time_decay, att_aa, att_bb, att_pp, k, v); wkv.view(ctx); } - x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, xx))); + x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, x_prev))); x = x.consume(ctx, rwkv_future_ffn(ctx, ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); ffn_xx.view(ctx); @@ -1245,23 +1251,23 @@ bool rwkv_build_sequence_graph( struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + struct ggml_tensor * x0 = x, * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); ggml_build_forward_expand(cgraph, r); for (uint32_t t = 0; t < sequence_len; t++) { struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * xt = ggml_view_1d(ctx, xx, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); ggml_build_forward_expand(cgraph, ggml_cpy(ctx, wkv, xt)); } - x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, xx))); + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); struct rwkv_layer_state & output = outputs[i];