Skip to content

Commit

Permalink
Rename xx to x_prev
Browse files Browse the repository at this point in the history
probably should slip this in now before we forget it's a thing.
  • Loading branch information
LoganDark committed Jun 21, 2023
1 parent e077496 commit 01f2907
Showing 1 changed file with 64 additions and 58 deletions.
122 changes: 64 additions & 58 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,71 +760,77 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) {
return true;
}

void rwkv_future_xx(struct rwkv_future_ctx & ctx,
void rwkv_future_carry_x(struct rwkv_future_ctx & ctx,
const struct rwkv_future_tensor weight,
const struct rwkv_future_tensor bias,
struct rwkv_future_tensor & x,
struct rwkv_future_tensor & xx,
struct rwkv_future_tensor & state
struct rwkv_future_tensor & x_prev,
struct rwkv_future_tensor & carry
) {
if (x.height == 1) {
x = x.layer_norm(ctx, weight, bias);
xx = state;
state = x;
x_prev = carry;
carry = x;
} else {
x = x.layer_norm(ctx, weight.repeat(ctx, x), bias.repeat(ctx, x));

xx = x.dup(ctx)
.set_inplace(ctx, state)
x_prev = x.dup(ctx)
.set_inplace(ctx, carry)
.set_inplace(ctx, x.subview(ctx, x.width, x.height - 1));

state = x.subview(ctx, x.width);
carry = x.subview(ctx, x.width);
}
}

void rwkv_xx(struct ggml_context * ctx, struct ggml_tensor * weight, struct ggml_tensor * bias, struct ggml_tensor *& x, struct ggml_tensor *& xx, struct ggml_tensor *& state) {
size_t n_embed = x->ne[0];
size_t sequence_len = x->ne[1];
void rwkv_carry_x(struct ggml_context * ctx,
struct ggml_tensor * weight,
struct ggml_tensor * bias,
struct ggml_tensor *& x,
struct ggml_tensor *& x_prev,
struct ggml_tensor *& carry
) {
const size_t n_embed = x->ne[0];
const size_t sequence_len = x->ne[1];

if (sequence_len == 1) {
// self.layer_norm(x, self.w.blocks[i].ln2)
x = rwkv_layer_norm(ctx, x, weight, bias);

// xx = state[5*i+0]
xx = state;
x_prev = carry;

// state[5*i+0] = x
state = x;
carry = x;
} else {
// self.layer_norm(x, self.w.blocks[i].ln2)
x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x));

// xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:]))
xx = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len);
xx = ggml_set_1d_inplace(ctx, xx, state, 0);
xx = ggml_set_1d_inplace(ctx, xx, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float));
x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len);
x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0);
x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float));

// state[5*i+0] = x[-1,:]
state = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float));
carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float));
}
}

void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx,
const struct rwkv_future_tensor time_mix_k,
const struct rwkv_future_tensor time_mix_v,
const struct rwkv_future_tensor time_mix_r,
const struct rwkv_future_tensor x0,
const struct rwkv_future_tensor xx,
const struct rwkv_future_tensor x,
const struct rwkv_future_tensor x_prev,
const struct rwkv_future_tensor att_r,
const struct rwkv_future_tensor att_k,
const struct rwkv_future_tensor att_v,
struct rwkv_future_tensor & r,
struct rwkv_future_tensor & k,
struct rwkv_future_tensor & v
) {
const struct rwkv_future_tensor xk = x0.combine(ctx, time_mix_k).consume(ctx, xx.combine(ctx, time_mix_k.fn(ctx)));
const struct rwkv_future_tensor xv = x0.combine(ctx, time_mix_v).consume(ctx, xx.combine(ctx, time_mix_v.fn(ctx)));
const struct rwkv_future_tensor xr = x0.combine(ctx, time_mix_r).consume(ctx, xx.combine(ctx, time_mix_r.fn(ctx)));
const struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx)));
const struct rwkv_future_tensor xv = x.combine(ctx, time_mix_v).consume(ctx, x_prev.combine(ctx, time_mix_v.fn(ctx)));
const struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx)));

r = att_r.mul_mat(ctx, xr).fn(ctx);
k = att_k.mul_mat(ctx, xk);
Expand All @@ -834,28 +840,28 @@ void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx,
void rwkv_att_rkv(
struct ggml_context * ctx,
struct rwkv_layer layer,
struct ggml_tensor * x0,
struct ggml_tensor * xx,
struct ggml_tensor * x,
struct ggml_tensor * x_prev,
struct ggml_tensor *& r,
struct ggml_tensor *& k,
struct ggml_tensor *& v
) {
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
struct ggml_tensor * xk = ggml_add_inplace(ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
ggml_mul(ctx, x, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
);

// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
struct ggml_tensor * xv = ggml_add_inplace(ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
ggml_mul(ctx, x, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
);

// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xr = ggml_add_inplace(ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
ggml_mul(ctx, x, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
);

// r = torch.sigmoid(rw @ xr)
Expand Down Expand Up @@ -954,29 +960,29 @@ struct rwkv_future_tensor rwkv_future_att(struct rwkv_future_ctx & ctx,
const struct rwkv_future_tensor att_k,
const struct rwkv_future_tensor att_v,
const struct rwkv_future_tensor att_output,
const struct rwkv_future_tensor x,
struct rwkv_future_tensor x,
struct rwkv_future_tensor & att_xx,
struct rwkv_future_tensor & att_aa,
struct rwkv_future_tensor & att_bb,
struct rwkv_future_tensor & att_pp
) {
struct rwkv_future_tensor x0 = x, xx;
rwkv_future_xx(ctx, ln1_weight, ln1_bias, x0, xx, att_xx);
struct rwkv_future_tensor x_prev;
rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x, x_prev, att_xx);

struct rwkv_future_tensor r, k, v;
rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x0, xx, att_r, att_k, att_v, r, k, v);
rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x, x_prev, att_r, att_k, att_v, r, k, v);

struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, time_first, time_decay, att_aa, att_bb, att_pp, k, v);

return att_output.mul_mat(ctx, r.combine(ctx, wkv));
}

struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
struct ggml_tensor * x0 = x, * xx;
rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx);
struct ggml_tensor * x_prev;
rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);

struct ggml_tensor * r, * k, * v;
rwkv_att_rkv(ctx, layer, x0, xx, r, k, v);
rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v);

struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp);

Expand All @@ -992,14 +998,14 @@ struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx,
const struct rwkv_future_tensor ffn_k,
const struct rwkv_future_tensor ffn_v,
const struct rwkv_future_tensor ffn_r,
const struct rwkv_future_tensor x,
struct rwkv_future_tensor x,
struct rwkv_future_tensor & ffn_xx
) {
struct rwkv_future_tensor x0 = x, xx;
rwkv_future_xx(ctx, ln2_weight, ln2_bias, x0, xx, ffn_xx);
struct rwkv_future_tensor x_prev;
rwkv_future_carry_x(ctx, ln2_weight, ln2_bias, x, x_prev, ffn_xx);

struct rwkv_future_tensor xk = x0.combine(ctx, time_mix_k).consume(ctx, xx.combine(ctx, time_mix_k.fn(ctx)));
struct rwkv_future_tensor xr = x0.combine(ctx, time_mix_r).consume(ctx, xx.combine(ctx, time_mix_r.fn(ctx)));
struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx)));
struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx)));

struct rwkv_future_tensor r = ffn_r.mul_mat(ctx, xr).fn(ctx);
struct rwkv_future_tensor k = ffn_k.mul_mat(ctx, xk).view(ctx).view(ctx);
Expand All @@ -1008,22 +1014,22 @@ struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx,
}

struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
struct ggml_tensor * x0 = x, * xx;
rwkv_xx(ctx, layer.ln2_weight, layer.ln2_bias, x0, xx, state.ffn_xx);
struct ggml_tensor * x_prev;
rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx);

// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
ggml_mul(ctx, x, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
);

// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
ggml_mul(ctx, x, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
);

// r = torch.sigmoid(rw @ xr)
Expand Down Expand Up @@ -1195,21 +1201,21 @@ struct rwkv_future_tensor rwkv_future_sequence_graph(struct rwkv_future_ctx & ct
x = x.layer_norm(ctx, ln0_weight.repeat(ctx, x), ln0_bias.repeat(ctx, x));

for (size_t i = 0; i < n_layer; i++) {
struct rwkv_future_tensor x0 = x, xx;
rwkv_future_xx(ctx, ln1_weight, ln1_bias, x0, xx, att_xx);
struct rwkv_future_tensor x0 = x, x_prev;
rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x0, x_prev, att_xx);

struct rwkv_future_tensor r, k, v;
rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, xx, att_r, att_k, att_v, r, k, v);
rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, x_prev, att_r, att_k, att_v, r, k, v);

for (size_t i = 0; i < tokens.width; i++) {
struct rwkv_future_tensor kt = k.subview(ctx, emb.width);
struct rwkv_future_tensor vt = v.subview(ctx, emb.width);
struct rwkv_future_tensor xt = xx.subview(ctx, emb.width);
struct rwkv_future_tensor xt = x_prev.subview(ctx, emb.width);
struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, att_time_first, att_time_decay, att_aa, att_bb, att_pp, k, v);
wkv.view(ctx);
}

x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, xx)));
x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, x_prev)));
x = x.consume(ctx, rwkv_future_ffn(ctx, ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx));

ffn_xx.view(ctx);
Expand Down Expand Up @@ -1245,23 +1251,23 @@ bool rwkv_build_sequence_graph(
struct rwkv_layer & layer = model.layers[i];
struct rwkv_layer_state state = inputs[i];

struct ggml_tensor * x0 = x, * xx;
rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx);
struct ggml_tensor * x0 = x, * x_prev;
rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx);

struct ggml_tensor * r, * k, * v;
rwkv_att_rkv(ctx, layer, x0, xx, r, k, v);
rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v);

ggml_build_forward_expand(cgraph, r);

for (uint32_t t = 0; t < sequence_len; t++) {
struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t);
struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t);
struct ggml_tensor * xt = ggml_view_1d(ctx, xx, n_embed, n_embed * sizeof(float) * t);
struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t);
struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp);
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, wkv, xt));
}

x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, xx)));
x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev)));
x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state));

struct rwkv_layer_state & output = outputs[i];
Expand Down

0 comments on commit 01f2907

Please sign in to comment.