Skip to content

Commit

Permalink
Add "get" to other getters
Browse files Browse the repository at this point in the history
  • Loading branch information
LoganDark committed Jun 12, 2023
1 parent 1c442bd commit 1fdfe09
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
8 changes: 4 additions & 4 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1363,12 +1363,12 @@ bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequen

// Provided for compatibility
extern "C" RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
return rwkv_state_len(ctx);
return rwkv_get_state_len(ctx);
}

// Provided for compatibility
extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
return rwkv_logits_len(ctx);
return rwkv_get_logits_len(ctx);
}

size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
Expand All @@ -1383,12 +1383,12 @@ size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_layer;
}

size_t rwkv_state_len(const struct rwkv_context * ctx) {
size_t rwkv_get_state_len(const struct rwkv_context * ctx) {
const struct rwkv_file_header & header = ctx->instance->model.header;
return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
}

size_t rwkv_logits_len(const struct rwkv_context * ctx) {
size_t rwkv_get_logits_len(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_vocab;
}

Expand Down
18 changes: 9 additions & 9 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ extern "C" {
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_logits_len(). This buffer will be written to if non-NULL.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);

// Evaluates the model for a sequence of tokens.
Expand All @@ -117,9 +117,9 @@ extern "C" {
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error. Error messages would be printed to stderr.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_logits_len(). This buffer will be written to if non-NULL.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);

// Returns the number of tokens in the given model's vocabulary.
Expand All @@ -136,16 +136,16 @@ extern "C" {

// Returns the number of float elements in a complete state for the given model.
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
RWKV_API size_t rwkv_state_len(const struct rwkv_context * ctx);
RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx);

// Returns the number of elements in the logits output of a given model.
// This is currently always identical to n_vocab.
RWKV_API size_t rwkv_logits_len(const struct rwkv_context * ctx);
RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx);

// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
// - state: FP32 buffer of size rwkv_state_len() to initialize
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state);

// Frees all allocated memory and the context.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_context_cloning.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ int main() {
return EXIT_FAILURE;
}

float * state = calloc(rwkv_state_len(ctx), sizeof(float));
float * logits = calloc(rwkv_logits_len(ctx), sizeof(float));
float * state = calloc(rwkv_get_state_len(ctx), sizeof(float));
float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float));

if (!state || !logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
Expand All @@ -31,7 +31,7 @@ int main() {
}

float * expected_logits = logits;
logits = calloc(rwkv_logits_len(ctx), sizeof(float));
logits = calloc(rwkv_get_logits_len(ctx), sizeof(float));

if (!logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
Expand All @@ -46,7 +46,7 @@ int main() {
rwkv_eval(ctx, *token, state, state, logits);
}

if (memcmp(expected_logits, logits, rwkv_logits_len(ctx) * sizeof(float))) {
if (memcmp(expected_logits, logits, rwkv_get_logits_len(ctx) * sizeof(float))) {
fprintf(stderr, "results not identical :(\n");
return EXIT_FAILURE;
} else {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ void test_model(const char * model_path, const float * expected_logits, const fl
ASSERT(rwkv_gpu_offload_layers(model, N_GPU_LAYERS), "Unexpected error %d", rwkv_get_last_error(model));
#endif

uint32_t n_vocab = rwkv_logits_len(model);
uint32_t n_vocab = rwkv_get_logits_len(model);

ASSERT(n_vocab == N_VOCAB, "Unexpected n_vocab in the model");

float * state = malloc(sizeof(float) * rwkv_state_len(model));
float * state = malloc(sizeof(float) * rwkv_get_state_len(model));
float * logits = malloc(sizeof(float) * n_vocab);

char * prompt = "\"in";
Expand Down

0 comments on commit 1fdfe09

Please sign in to comment.