diff --git a/rwkv.cpp b/rwkv.cpp index b99ac6a..cbcea87 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -363,27 +363,85 @@ bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = return true; } -bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { +bool rwkv_should_be_quantized(const ggml_type source_type, const ggml_type target_type, const std::string & name, const uint32_t dim_count) { + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + return target_type != GGML_TYPE_COUNT && + target_type != source_type && + (source_type == GGML_TYPE_F32 || source_type == GGML_TYPE_F16) && + dim_count == 2 && + name != "emb.weight" && + name != "head.weight"; +} + +bool rwkv_fread_ggml_tensor_data( + FILE * file, + const struct rwkv_tensor_header & header, + struct ggml_context * ctx, + std::string & name, + struct ggml_tensor *& tensor, + const ggml_type target_type +) { RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); - tensor = header.dim_count == 1 - ? ggml_new_tensor_1d(ctx, ggml_type, header.width) - : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + if (rwkv_should_be_quantized(ggml_type, target_type, name, header.dim_count)) { + // TODO Remove + fprintf(stderr, "Quantizing %s on the fly\n", name.c_str()); + + size_t buffer_size_bytes = header.dim_count == 1 + ? rwkv_tensor_size(ggml_type, header.width) + : rwkv_tensor_size(ggml_type, header.width, header.height); + + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, target_type, header.width) + : ggml_new_tensor_2d(ctx, target_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + // TODO Make safer (free on return) + char * buffer = (char *) malloc(buffer_size_bytes); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer), "Failed to read tensor data from %s", name.c_str()); + + // Quantization works only with FP32 values + if (header.data_type == TYPE_F16) { + float * float_buffer = (float *) malloc(buffer_size_bytes * 2); + + ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer, (float *) float_buffer, ggml_nelements(tensor)); + + free(buffer); + + buffer = (char *) float_buffer; + } + + int64_t histogram[16] {}; + + ggml_quantize_chunk(target_type, (const float *) buffer, tensor->data, 0, ggml_nelements(tensor), histogram); + + free(buffer); + } else { + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, ggml_type, header.width) + : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - ggml_set_name(tensor, name.c_str()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + } - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); return true; } -bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { +bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor, const ggml_type target_type) { struct rwkv_tensor_header header; RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); - return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); + return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor, target_type); } bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { @@ -1115,7 +1173,7 @@ struct rwkv_file { } }; -bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { +bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance, const ggml_type target_type) { struct stat file_stat; struct rwkv_model model; struct rwkv_ggml_context ctx; @@ -1140,7 +1198,14 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + if (rwkv_should_be_quantized(rwkv_type_to_ggml[tensor_header.data_type], target_type, name, tensor_header.dim_count)) { + // TODO Remove + fprintf(stderr, "Allocating less bytes for quantized tensor\n"); + + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, target_type, tensor_header.width, tensor_header.height); + } else { + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + } if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { ffn_key_size = tensor_header.height; @@ -1156,7 +1221,7 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst struct ggml_tensor * tensor; while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor), "Failed to read model params"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor, target_type), "Failed to read model params"); parameters[std::move(name)] = tensor; } } @@ -1258,12 +1323,25 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr instance(new(std::nothrow) struct rwkv_instance()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance, "Failed to allocate instance"); - RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get())); + RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get(), target_type)); return rwkv_new_context_impl(instance, n_threads); } @@ -1438,10 +1516,10 @@ void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); } -bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { +bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * format_name) { global_last_error = RWKV_ERROR_NONE; - enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)]; + enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(format_name)]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, ggml_is_quantized(out_type), "Unsupported output data type (%s)", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]); RWKV_MSG("Loading model from '%s'\n", in_path); @@ -1549,6 +1627,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // Quantize only 2D tensors, except embedding and head matrices. // Embedding and head take not too much space, especially in bigger models; // but they significantly increase perplexity when quantized. + // TODO Use rwkv_should_be_quantized if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { RWKV_MSG("quantizing... "); diff --git a/rwkv.h b/rwkv.h index 8327425..dc4c5e2 100644 --- a/rwkv.h +++ b/rwkv.h @@ -85,9 +85,10 @@ extern "C" { // Loads the model from a file and prepares it for inference. // Returns NULL on any error. + // TODO Split for compatibility and document // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const char * target_format_name); // Creates a new context from an existing one. // This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times. diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index b38c7ce..79262ff 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -9,12 +9,14 @@ class RWKVModel: PyTorch wrapper around rwkv.cpp model. """ + # TODO Document target format parameter def __init__( self, shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, model_path: str, thread_count: int = max(1, multiprocessing.cpu_count() // 2), gpu_layers_count: int = 0, + target_format_name: str = '' ): """ Loads the model and prepares it for inference. @@ -36,7 +38,7 @@ def __init__( self._library = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, target_format_name) if gpu_layers_count > 0: self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index a38cbbb..a6b950d 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -37,7 +37,7 @@ def __init__(self, shared_library_path: str): self.library = ctypes.cdll.LoadLibrary(shared_library_path) - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_char_p] self.library.rwkv_init_from_file.restype = ctypes.c_void_p self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] @@ -70,7 +70,8 @@ def __init__(self, shared_library_path: str): self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p - def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: + # TODO Document target format parameter + def rwkv_init_from_file(self, model_file_path: str, thread_count: int, target_format_name: str = '') -> RWKVContext: """ Loads the model from a file and prepares it for inference. Throws an exception in case of any error. Error messages would be printed to stderr. @@ -83,7 +84,7 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo Count of threads to use, must be positive. """ - ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) + ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), target_format_name.encode('utf-8')) assert ptr is not None, 'rwkv_init_from_file failed, check stderr' diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index eb0f7c4..3ff1255 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -1,11 +1,11 @@ -#include +#include "rwkv.h" #include #include #include int main() { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2, ""); if (!ctx) { enum rwkv_error_flags error = rwkv_get_last_error(NULL); diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index adb0de7..8a68787 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -26,7 +26,7 @@ void test_model(const char * model_path, const float * expected_logits, const float max_diff) { fprintf(stderr, "Testing %s\n", model_path); - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); + struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, ""); enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); #ifdef GGML_USE_CUBLAS