diff --git a/rwkv.cpp b/rwkv.cpp index d523f96..43d676b 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -1286,11 +1286,7 @@ void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) { if (state_in) { memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state)); } else { - ggml_set_f32(ctx->input_state, 0.0F); - - for (size_t i = 0; i < ctx->instance->model.header.n_layer; i++) { - ggml_set_f32(ctx->input_layers[i].att_pp, -1e30F); - } + rwkv_init_state(ctx, (float *) ctx->input_state->data); } } @@ -1365,12 +1361,52 @@ bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequen return true; } -uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed; +// Provided for compatibility +extern "C" RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { + return rwkv_get_state_len(ctx); +} + +// Provided for compatibility +extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { + return rwkv_get_logits_len(ctx); +} + +size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_vocab; } -uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->instance->model.header.n_vocab; +size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_embed; +} + +size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_layer; +} + +size_t rwkv_get_state_len(const struct rwkv_context * ctx) { + const struct rwkv_file_header & header = ctx->instance->model.header; + return (size_t) header.n_embed * 5 * (size_t) header.n_layer; +} + +size_t rwkv_get_logits_len(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_vocab; +} + +void rwkv_init_state(const struct rwkv_context * ctx, float * state) { + const struct rwkv_file_header & header = ctx->instance->model.header; + const size_t layer_size = (size_t) header.n_embed * 5; + const size_t layer_zero = (size_t) header.n_embed * 4; + const size_t layers_size = (size_t) header.n_layer * layer_size; + + for (size_t start = 0; start < layers_size; start += layer_size) { + for (size_t i = 0; i < layer_zero; i++) { + state[start + i] = 0.0F; + } + + for (size_t i = layer_zero; i < layer_size; i++) { + state[start + i] = -1e30F; + } + } } void rwkv_free(struct rwkv_context * ctx) { diff --git a/rwkv.h b/rwkv.h index 38bf906..0744e37 100644 --- a/rwkv.h +++ b/rwkv.h @@ -105,9 +105,9 @@ extern "C" { // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. Error messages would be printed to stderr. // - token: next token index, in range 0 <= token < n_vocab. - // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. - // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL. - // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL. + // - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); // Evaluates the model for a sequence of tokens. @@ -117,16 +117,36 @@ extern "C" { // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. Error messages would be printed to stderr. // - sequence_len: number of tokens to read from the array. - // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count, or NULL if this is a first pass. - // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL. - // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL. + // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out); - // Returns count of FP32 elements in state buffer. - RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx); + // Returns the number of tokens in the given model's vocabulary. + // Useful for telling legacy RWKV models (n_vocab = 50277) apart from modern World models (n_vocab = 65535). + RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx); - // Returns count of FP32 elements in logits buffer. - RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx); + // Returns the number of elements in the given model's embed weights. + // Useful for reading individual fields of a model's hidden state, if desired. + RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx); + + // Returns the number of layers in the given model. + // Useful for always offloading the entire model to GPU, if desired. + RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx); + + // Returns the number of float elements in a complete state for the given model. + // This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state. + RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx); + + // Returns the number of float elements in the logits output of a given model. + // This is currently always identical to n_vocab. + RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx); + + // Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL. + // Useful in cases where tracking the first call to these functions may be annoying or expensive. + // State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs. + // - state: FP32 buffer of size rwkv_get_state_len() to initialize + RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state); // Frees all allocated memory and the context. // Does not need to be the same thread that created the rwkv_context. diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index 9585f16..1a8f4b1 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -13,8 +13,8 @@ int main() { return EXIT_FAILURE; } - float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float)); - float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); if (!state || !logits) { fprintf(stderr, "Failed to allocate state/logits\n"); @@ -31,7 +31,7 @@ int main() { } float * expected_logits = logits; - logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); + logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); if (!logits) { fprintf(stderr, "Failed to allocate state/logits\n"); @@ -46,7 +46,7 @@ int main() { rwkv_eval(ctx, *token, state, state, logits); } - if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) { + if (memcmp(expected_logits, logits, rwkv_get_logits_len(ctx) * sizeof(float))) { fprintf(stderr, "results not identical :(\n"); return EXIT_FAILURE; } else { diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 3b8e075..fd69135 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -33,11 +33,11 @@ void test_model(const char * model_path, const float * expected_logits, const fl ASSERT(rwkv_gpu_offload_layers(model, N_GPU_LAYERS), "Unexpected error %d", rwkv_get_last_error(model)); #endif - uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model); + uint32_t n_vocab = rwkv_get_logits_len(model); ASSERT(n_vocab == N_VOCAB, "Unexpected n_vocab in the model"); - float * state = malloc(sizeof(float) * rwkv_get_state_buffer_element_count(model)); + float * state = malloc(sizeof(float) * rwkv_get_state_len(model)); float * logits = malloc(sizeof(float) * n_vocab); char * prompt = "\"in"; @@ -45,8 +45,10 @@ void test_model(const char * model_path, const float * expected_logits, const fl const size_t prompt_length = strlen(prompt); + rwkv_init_state(model, state); + for (size_t i = 0; i < prompt_length; i++) { - rwkv_eval(model, prompt[i], i == 0 ? NULL : state, state, logits); + rwkv_eval(model, prompt[i], state, state, logits); } float diff_sum = 0.0F; @@ -60,7 +62,8 @@ void test_model(const char * model_path, const float * expected_logits, const fl // When something breaks, difference would be way more than 10 ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); - rwkv_eval_sequence(model, prompt_seq, prompt_length, NULL, state, logits); + rwkv_init_state(model, state); + rwkv_eval_sequence(model, prompt_seq, prompt_length, state, state, logits); diff_sum = 0.0F;