Skip to content

Commit

Permalink
Phase out very verbose element_count functions (#95)
Browse files Browse the repository at this point in the history
* Phase out very verbose element_count functions

This could have been done better

* Add "get" to other getters

* Specify "float elements" in rwkv_get_logits_len docs

* Use traditional for-loop for rwkv_init_state writes

* Newline
  • Loading branch information
LoganDark committed Jun 12, 2023
1 parent 43c78f2 commit 7199f5b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 27 deletions.
54 changes: 45 additions & 9 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1286,11 +1286,7 @@ void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) {
if (state_in) {
memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state));
} else {
ggml_set_f32(ctx->input_state, 0.0F);

for (size_t i = 0; i < ctx->instance->model.header.n_layer; i++) {
ggml_set_f32(ctx->input_layers[i].att_pp, -1e30F);
}
rwkv_init_state(ctx, (float *) ctx->input_state->data);
}
}

Expand Down Expand Up @@ -1365,12 +1361,52 @@ bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequen
return true;
}

uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed;
// Provided for compatibility
extern "C" RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
return rwkv_get_state_len(ctx);
}

// Provided for compatibility
extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
return rwkv_get_logits_len(ctx);
}

size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_vocab;
}

uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->instance->model.header.n_vocab;
size_t rwkv_get_n_embed(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_embed;
}

size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_layer;
}

size_t rwkv_get_state_len(const struct rwkv_context * ctx) {
const struct rwkv_file_header & header = ctx->instance->model.header;
return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
}

size_t rwkv_get_logits_len(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_vocab;
}

void rwkv_init_state(const struct rwkv_context * ctx, float * state) {
const struct rwkv_file_header & header = ctx->instance->model.header;
const size_t layer_size = (size_t) header.n_embed * 5;
const size_t layer_zero = (size_t) header.n_embed * 4;
const size_t layers_size = (size_t) header.n_layer * layer_size;

for (size_t start = 0; start < layers_size; start += layer_size) {
for (size_t i = 0; i < layer_zero; i++) {
state[start + i] = 0.0F;
}

for (size_t i = layer_zero; i < layer_size; i++) {
state[start + i] = -1e30F;
}
}
}

void rwkv_free(struct rwkv_context * ctx) {
Expand Down
40 changes: 30 additions & 10 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ extern "C" {
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);

// Evaluates the model for a sequence of tokens.
Expand All @@ -117,16 +117,36 @@ extern "C" {
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error. Error messages would be printed to stderr.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count, or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);

// Returns count of FP32 elements in state buffer.
RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx);
// Returns the number of tokens in the given model's vocabulary.
// Useful for telling legacy RWKV models (n_vocab = 50277) apart from modern World models (n_vocab = 65535).
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);

// Returns count of FP32 elements in logits buffer.
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
// Returns the number of elements in the given model's embed weights.
// Useful for reading individual fields of a model's hidden state, if desired.
RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx);

// Returns the number of layers in the given model.
// Useful for always offloading the entire model to GPU, if desired.
RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx);

// Returns the number of float elements in a complete state for the given model.
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx);

// Returns the number of float elements in the logits output of a given model.
// This is currently always identical to n_vocab.
RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx);

// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state);

// Frees all allocated memory and the context.
// Does not need to be the same thread that created the rwkv_context.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_context_cloning.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ int main() {
return EXIT_FAILURE;
}

float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
float * state = calloc(rwkv_get_state_len(ctx), sizeof(float));
float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float));

if (!state || !logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
Expand All @@ -31,7 +31,7 @@ int main() {
}

float * expected_logits = logits;
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
logits = calloc(rwkv_get_logits_len(ctx), sizeof(float));

if (!logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
Expand All @@ -46,7 +46,7 @@ int main() {
rwkv_eval(ctx, *token, state, state, logits);
}

if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
if (memcmp(expected_logits, logits, rwkv_get_logits_len(ctx) * sizeof(float))) {
fprintf(stderr, "results not identical :(\n");
return EXIT_FAILURE;
} else {
Expand Down
11 changes: 7 additions & 4 deletions tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,22 @@ void test_model(const char * model_path, const float * expected_logits, const fl
ASSERT(rwkv_gpu_offload_layers(model, N_GPU_LAYERS), "Unexpected error %d", rwkv_get_last_error(model));
#endif

uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model);
uint32_t n_vocab = rwkv_get_logits_len(model);

ASSERT(n_vocab == N_VOCAB, "Unexpected n_vocab in the model");

float * state = malloc(sizeof(float) * rwkv_get_state_buffer_element_count(model));
float * state = malloc(sizeof(float) * rwkv_get_state_len(model));
float * logits = malloc(sizeof(float) * n_vocab);

char * prompt = "\"in";
uint32_t prompt_seq[] = { '"', 'i', 'n' };

const size_t prompt_length = strlen(prompt);

rwkv_init_state(model, state);

for (size_t i = 0; i < prompt_length; i++) {
rwkv_eval(model, prompt[i], i == 0 ? NULL : state, state, logits);
rwkv_eval(model, prompt[i], state, state, logits);
}

float diff_sum = 0.0F;
Expand All @@ -60,7 +62,8 @@ void test_model(const char * model_path, const float * expected_logits, const fl
// When something breaks, difference would be way more than 10
ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff);

rwkv_eval_sequence(model, prompt_seq, prompt_length, NULL, state, logits);
rwkv_init_state(model, state);
rwkv_eval_sequence(model, prompt_seq, prompt_length, state, state, logits);

diff_sum = 0.0F;

Expand Down

0 comments on commit 7199f5b

Please sign in to comment.