Skip to content

Commit

Permalink
Implement on-the-fly quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Jun 13, 2023
1 parent 69639f2 commit 65cdb51
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 24 deletions.
111 changes: 95 additions & 16 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,27 +363,85 @@ bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer =
return true;
}

bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) {
bool rwkv_should_be_quantized(const ggml_type source_type, const ggml_type target_type, const std::string & name, const uint32_t dim_count) {
// Quantize only 2D tensors, except embedding and head matrices.
// Embedding and head take not too much space, especially in bigger models;
// but they significantly increase perplexity when quantized.
return target_type != GGML_TYPE_COUNT &&
target_type != source_type &&
(source_type == GGML_TYPE_F32 || source_type == GGML_TYPE_F16) &&
dim_count == 2 &&
name != "emb.weight" &&
name != "head.weight";
}

bool rwkv_fread_ggml_tensor_data(
FILE * file,
const struct rwkv_tensor_header & header,
struct ggml_context * ctx,
std::string & name,
struct ggml_tensor *& tensor,
const ggml_type target_type
) {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name");

enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type];
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str());

tensor = header.dim_count == 1
? ggml_new_tensor_1d(ctx, ggml_type, header.width)
: ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height);
if (rwkv_should_be_quantized(ggml_type, target_type, name, header.dim_count)) {
// TODO Remove
fprintf(stderr, "Quantizing %s on the fly\n", name.c_str());

size_t buffer_size_bytes = header.dim_count == 1
? rwkv_tensor_size(ggml_type, header.width)
: rwkv_tensor_size(ggml_type, header.width, header.height);

tensor = header.dim_count == 1
? ggml_new_tensor_1d(ctx, target_type, header.width)
: ggml_new_tensor_2d(ctx, target_type, header.width, header.height);

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
ggml_set_name(tensor, name.c_str());

// TODO Make safer (free on return)
char * buffer = (char *) malloc(buffer_size_bytes);

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer), "Failed to read tensor data from %s", name.c_str());

// Quantization works only with FP32 values
if (header.data_type == TYPE_F16) {
float * float_buffer = (float *) malloc(buffer_size_bytes * 2);

ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer, (float *) float_buffer, ggml_nelements(tensor));

free(buffer);

buffer = (char *) float_buffer;
}

int64_t histogram[16] {};

ggml_quantize_chunk(target_type, (const float *) buffer, tensor->data, 0, ggml_nelements(tensor), histogram);

free(buffer);
} else {
tensor = header.dim_count == 1
? ggml_new_tensor_1d(ctx, ggml_type, header.width)
: ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height);

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
ggml_set_name(tensor, name.c_str());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
ggml_set_name(tensor, name.c_str());

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
}

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
return true;
}

bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) {
bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor, const ggml_type target_type) {
struct rwkv_tensor_header header;
RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header");
return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor);
return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor, target_type);
}

bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) {
Expand Down Expand Up @@ -1115,7 +1173,7 @@ struct rwkv_file {
}
};

bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) {
bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance, const ggml_type target_type) {
struct stat file_stat;
struct rwkv_model model;
struct rwkv_ggml_context ctx;
Expand All @@ -1140,7 +1198,14 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data");

rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header);
if (rwkv_should_be_quantized(rwkv_type_to_ggml[tensor_header.data_type], target_type, name, tensor_header.dim_count)) {
// TODO Remove
fprintf(stderr, "Allocating less bytes for quantized tensor\n");

rwkv_ctx_size_add_tensor(ctx_size, 1, 0, target_type, tensor_header.width, tensor_header.height);
} else {
rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header);
}

if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") {
ffn_key_size = tensor_header.height;
Expand All @@ -1156,7 +1221,7 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst
struct ggml_tensor * tensor;

while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) {
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor), "Failed to read model params");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor, target_type), "Failed to read model params");
parameters[std::move(name)] = tensor;
}
}
Expand Down Expand Up @@ -1258,12 +1323,25 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
return rwkv_ctx.release();
}

struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads, const char * target_format_name) {
global_last_error = RWKV_ERROR_NONE;

enum ggml_type target_type = GGML_TYPE_COUNT;

if (strcmp(target_format_name, "") != 0) {
target_type = rwkv_type_to_ggml[rwkv_type_from_string(target_format_name)];

RWKV_ASSERT_NULL_MSG(
RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE,
ggml_is_quantized(target_type),
"Unsupported target format (%s)",
rwkv_type_to_string[rwkv_type_from_ggml[target_type]]
);
}

std::shared_ptr<struct rwkv_instance> instance(new(std::nothrow) struct rwkv_instance());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance, "Failed to allocate instance");
RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get()));
RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get(), target_type));
return rwkv_new_context_impl(instance, n_threads);
}

Expand Down Expand Up @@ -1438,10 +1516,10 @@ void rwkv_free(struct rwkv_context * ctx) {
std::unique_ptr<struct rwkv_context> rwkv_ctx(ctx);
}

bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) {
bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * format_name) {
global_last_error = RWKV_ERROR_NONE;

enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)];
enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(format_name)];
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, ggml_is_quantized(out_type), "Unsupported output data type (%s)", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]);

RWKV_MSG("Loading model from '%s'\n", in_path);
Expand Down Expand Up @@ -1549,6 +1627,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const
// Quantize only 2D tensors, except embedding and head matrices.
// Embedding and head take not too much space, especially in bigger models;
// but they significantly increase perplexity when quantized.
// TODO Use rwkv_should_be_quantized
if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") {
RWKV_MSG("quantizing... ");

Expand Down
3 changes: 2 additions & 1 deletion rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ extern "C" {

// Loads the model from a file and prepares it for inference.
// Returns NULL on any error.
// TODO Split for compatibility and document
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const char * target_format_name);

// Creates a new context from an existing one.
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
Expand Down
4 changes: 3 additions & 1 deletion rwkv/rwkv_cpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ class RWKVModel:
PyTorch wrapper around rwkv.cpp model.
"""

# TODO Document target format parameter
def __init__(
self,
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
gpu_layers_count: int = 0,
target_format_name: str = ''
):
"""
Loads the model and prepares it for inference.
Expand All @@ -36,7 +38,7 @@ def __init__(

self._library = shared_library

self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, target_format_name)

if gpu_layers_count > 0:
self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count)
Expand Down
7 changes: 4 additions & 3 deletions rwkv/rwkv_cpp_shared_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, shared_library_path: str):

self.library = ctypes.cdll.LoadLibrary(shared_library_path)

self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_char_p]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p

self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
Expand Down Expand Up @@ -70,7 +70,8 @@ def __init__(self, shared_library_path: str):
self.library.rwkv_get_system_info_string.argtypes = []
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p

def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
# TODO Document target format parameter
def rwkv_init_from_file(self, model_file_path: str, thread_count: int, target_format_name: str = '') -> RWKVContext:
"""
Loads the model from a file and prepares it for inference.
Throws an exception in case of any error. Error messages would be printed to stderr.
Expand All @@ -83,7 +84,7 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo
Count of threads to use, must be positive.
"""

ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), target_format_name.encode('utf-8'))

assert ptr is not None, 'rwkv_init_from_file failed, check stderr'

Expand Down
4 changes: 2 additions & 2 deletions tests/test_context_cloning.c
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <rwkv.h>
#include "rwkv.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

int main() {
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2, "");

if (!ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
void test_model(const char * model_path, const float * expected_logits, const float max_diff) {
fprintf(stderr, "Testing %s\n", model_path);

struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS);
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, "");
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
ASSERT(error == 0, "Unexpected error %d", error);
#ifdef GGML_USE_CUBLAS
Expand Down

0 comments on commit 65cdb51

Please sign in to comment.