diff --git a/Cargo.lock b/Cargo.lock index 1908f06..2b6932a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" dependencies = [ "crypto-common", - "generic-array", + "generic-array 0.14.7", ] [[package]] @@ -87,18 +87,17 @@ dependencies = [ [[package]] name = "ai00-core" -version = "0.5.2" +version = "0.5.3" dependencies = [ "anyhow", - "bit-set", - "bnf_sampler", "bytemuck", "cbor4ii", "derivative", "fastrand", "flume", "half", - "itertools 0.13.0", + "itertools", + "kbnf", "log", "memmap2", "qp-trie", @@ -113,7 +112,7 @@ dependencies = [ [[package]] name = "ai00-server" -version = "0.5.2" +version = "0.5.3" dependencies = [ "ai00-core", "anyhow", @@ -121,7 +120,7 @@ dependencies = [ "derivative", "flume", "futures-util", - "itertools 0.13.0", + "itertools", "jsonwebtoken", "log", "memmap2", @@ -385,39 +384,7 @@ version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ - "generic-array", -] - -[[package]] -name = "bnf" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c09ea5795b3dd735ff47c4b8adf64c46e3ce056fa3c4880b865a352e4c40a2" -dependencies = [ - "getrandom", - "nom", - "rand", - "serde", - "serde_json", -] - -[[package]] -name = "bnf_sampler" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f672b04599d545028652564805fdfbf6c3a035b971ed8006b68cb9c647ef67e7" -dependencies = [ - "anyhow", - "bit-set", - "bnf", - "itertools 0.12.1", - "lazy_static", - "memchr", - "mimalloc", - "nohash-hasher", - "qp-trie", - "regex", - "rustc-hash", + "generic-array 0.14.7", ] [[package]] @@ -707,7 +674,7 @@ checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" [[package]] name = "converter" -version = "0.5.2" +version = "0.5.3" dependencies = [ "anyhow", "clap", @@ -795,7 +762,7 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "generic-array", + "generic-array 0.14.7", "rand_core", "typenum", ] @@ -970,6 +937,12 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +[[package]] +name = "fixedbitset-stack" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d498da3487b4ea426e370276db9e93c29624b667855b415ccd01da983dd1237" + [[package]] name = "flate2" version = "1.0.30" @@ -1154,6 +1127,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "generic-array" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe739944a5406424e080edccb6add95685130b9f160d5407c639c7df0c5836b0" +dependencies = [ + "typenum", +] + [[package]] name = "getrandom" version = "0.2.14" @@ -1395,6 +1377,15 @@ dependencies = [ "digest", ] +[[package]] +name = "html-escape" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1ad449764d627e22bfd7cd5e8868264fc9236e07c752972b4080cd351cb476" +dependencies = [ + "utf8-width", +] + [[package]] name = "http" version = "1.1.0" @@ -1567,7 +1558,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" dependencies = [ - "generic-array", + "generic-array 0.14.7", ] [[package]] @@ -1600,15 +1591,6 @@ version = "1.70.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -1624,6 +1606,18 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jaggedarray" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d18d4d488e472f5e805edd8d7e710c255ff19511e3f5fbb0932932bee0e804d5" +dependencies = [ + "generic-array 1.0.0", + "num", + "tinyvec", + "typenum", +] + [[package]] name = "jni-sys" version = "0.3.0" @@ -1663,6 +1657,55 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "kbnf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d1722638274535080c14a131457152942c92534ed8989980d53aab4cf9da8a" +dependencies = [ + "ahash 0.8.11", + "displaydoc", + "fixedbitset-stack", + "jaggedarray", + "kbnf-regex-automata", + "kbnf-syntax", + "nom", + "nonmax", + "num", + "serde", + "string-interner", + "strum", + "thiserror", + "tinyvec", +] + +[[package]] +name = "kbnf-regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "499fe4f128763db90782429ea71d01a5503d177d53118589a36e41611286d8a0" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "kbnf-syntax" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef4c575a961efe76523068d29c9362744f6185e6a8d0de7d4563079fc7e0365b" +dependencies = [ + "kbnf-regex-automata", + "nom", + "parse-hyperlinks", + "regex-syntax", + "rustc-hash", + "serde", + "string-interner", + "thiserror", +] + [[package]] name = "khronos-egl" version = "6.0.0" @@ -1712,16 +1755,6 @@ dependencies = [ "windows-targets 0.52.5", ] -[[package]] -name = "libmimalloc-sys" -version = "0.1.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81eb4061c0582dedea1cbc7aff2240300dd6982e0239d1c99e65c1dbf4a30ba7" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -1789,15 +1822,6 @@ dependencies = [ "paste", ] -[[package]] -name = "mimalloc" -version = "0.1.41" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f41a2280ded0da56c8cf898babb86e8f10651a34adcfff190ae9a1159c6908d" -dependencies = [ - "libmimalloc-sys", -] - [[package]] name = "mime" version = "0.3.17" @@ -1941,12 +1965,6 @@ dependencies = [ "libc", ] -[[package]] -name = "nohash-hasher" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" - [[package]] name = "nom" version = "7.1.3" @@ -1957,17 +1975,45 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonmax" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "610a5acd306ec67f907abe5567859a3c693fb9886eb1f012ab8f2a47bef3db51" + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" dependencies = [ - "autocfg", "num-integer", "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1983,6 +2029,28 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2117,6 +2185,18 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "parse-hyperlinks" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0181d37c4d5ae35cc8be7cf823c1a933005661da6a08bcb2855aa392c9a54b8e" +dependencies = [ + "html-escape", + "nom", + "percent-encoding", + "thiserror", +] + [[package]] name = "password-hash" version = "0.4.2" @@ -2771,6 +2851,12 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.17" @@ -3107,9 +3193,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.200" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" dependencies = [ "serde_derive", ] @@ -3137,9 +3223,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.200" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", @@ -3322,12 +3408,45 @@ version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ceb97b7225c713c2fd4db0153cb6b3cab244eb37900c3f634ed4d43310d8c34" +[[package]] +name = "string-interner" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c6a0d765f5807e98a091107bae0a56ea3799f66a5de47b2c84c94a39c09974e" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "serde", +] + [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.60", +] + [[package]] name = "subtle" version = "2.5.0" @@ -3441,18 +3560,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", @@ -3848,6 +3967,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf8-width" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86bd8d4e895da8537e5315b8254664e6b769c4ff3db18321b297a1e7004392e3" + [[package]] name = "utf8parse" version = "0.2.1" @@ -3988,9 +4113,9 @@ dependencies = [ [[package]] name = "web-rwkv" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8974a3084c612cc10e0d7e1b7f947b9951360f322977f6f9d344e7271e99a97" +checksum = "cf8c1d41f4ca9a792c17e1744a38646282fce55de19936fbaf7017b88691794e" dependencies = [ "ahash 0.8.11", "anyhow", @@ -4002,7 +4127,7 @@ dependencies = [ "gpp", "half", "instant", - "itertools 0.13.0", + "itertools", "log", "regex", "rustc-hash", @@ -4065,9 +4190,9 @@ dependencies = [ [[package]] name = "wgpu" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ff1bfee408e1028e2e3acbf6d32d98b08a5a059ccbf5f33305534453ba5d3e" +checksum = "90e37c7b9921b75dfd26dd973fdcbce36f13dfa6e2dc82aece584e0ed48c355c" dependencies = [ "arrayvec", "cfg-if", @@ -4091,9 +4216,9 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6a86eaa5e763e59c73cf9e97d55fffd4dfda69fd8bda19589fcf851ddfef1f" +checksum = "d59e0d5fc509601c69e4e1fa06c1eb3c4c9f12956a5e30c79b61ef1c1be7daf0" dependencies = [ "arrayvec", "bit-vec", @@ -4118,9 +4243,9 @@ dependencies = [ [[package]] name = "wgpu-hal" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d71c8ae05170583049b65ee562fd839fdc0b3e9ddb84f4e40c9d5f8ea0d4c8c" +checksum = "6aa24c3889f885a3fb9133b454c8418bfcfaadcfe4ed3be96ac80e76703b863b" dependencies = [ "android_system_properties", "arrayvec", diff --git a/Cargo.toml b/Cargo.toml index 68a645a..b2479de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["LLM", "deep-learning", "model", "rwkv"] license = "MIT OR Apache-2.0" repository = "https://github.com/cgisky1980/ai00_rwkv_server" rust-version = "1.76" -version = "0.5.2" +version = "0.5.3" [workspace.dependencies] anyhow = "1" @@ -34,7 +34,7 @@ path = "crates/ai00-core" # path = "../web-rwkv" default-features = false features = ["native"] -version = "0.8.11" +version = "0.8.12" [profile.release] lto = false diff --git a/crates/ai00-core/Cargo.toml b/crates/ai00-core/Cargo.toml index d018495..f3a3c11 100644 --- a/crates/ai00-core/Cargo.toml +++ b/crates/ai00-core/Cargo.toml @@ -12,12 +12,11 @@ rust-version.workspace = true version.workspace = true [dependencies] -bit-set = "0.5.3" -bnf_sampler = "0.3.7" bytemuck = "1" cbor4ii = { version = "0.3.2", features = ["serde1"] } fastrand = "2" half = "2.4" +kbnf = "0.1.2" qp-trie = "0.8" rustc-hash = "1.1.0" uuid = { version = "1.8.0", features = ["serde", "v4"] } diff --git a/crates/ai00-core/src/lib.rs b/crates/ai00-core/src/lib.rs index 5472562..ec4dc5e 100644 --- a/crates/ai00-core/src/lib.rs +++ b/crates/ai00-core/src/lib.rs @@ -5,7 +5,6 @@ use std::{ }; use anyhow::{anyhow, bail, Result}; -use bnf_sampler::{utils::U8ArrayWrapper, vocabulary::Vocabulary}; use derivative::Derivative; use flume::{Receiver, Sender}; use half::f16; @@ -267,24 +266,6 @@ async fn load_tokenizer(path: impl AsRef) -> Result { Ok(Tokenizer::new(&contents)?) } -fn load_vocab(tokenizer: &Tokenizer) -> Vocabulary { - let vocab = tokenizer.bytes_to_token_index(); - let token_to_id = vocab - .iter() - .map(|(k, v)| (U8ArrayWrapper(k.clone().into_boxed_slice()), *v as u32)) - .collect(); - let id_to_token = vocab.iter().map(|(k, v)| (*v as u32, k.clone())).collect(); - let id_to_token_string = vocab - .iter() - .map(|(k, v)| (*v as u32, String::from_utf8_lossy(k).to_string())) - .collect(); - Vocabulary { - token_to_id, - id_to_token, - id_to_token_string, - } -} - async fn load_init_state( context: &Context, info: &ModelInfo, @@ -317,7 +298,6 @@ async fn load_runtime( } = reload.clone(); let tokenizer = load_tokenizer(tokenizer_path).await?; - let vocab = load_vocab(&tokenizer); let file = File::open(model_path).await?; let data = unsafe { Mmap::map(&file) }?; @@ -406,7 +386,7 @@ async fn load_runtime( ($version, $precision) => { let model = Build::<$model>::build(builder).await?; let builder = <$runtime>::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await + Runtime::new(context, builder, reload, states, tokenizer).await } )+ } @@ -441,7 +421,7 @@ async fn load_runtime( let seed: Seed<_, $model> = Seed::new(&context); let model = seed.deserialize(&mut deserializer)?; let builder = <$runtime>::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await + Runtime::new(context, builder, reload, states, tokenizer).await } )+ } diff --git a/crates/ai00-core/src/run.rs b/crates/ai00-core/src/run.rs index 9c06f89..aee164b 100644 --- a/crates/ai00-core/src/run.rs +++ b/crates/ai00-core/src/run.rs @@ -8,7 +8,6 @@ use std::{ }; use anyhow::Result; -use bnf_sampler::{grammar::Grammar, vocabulary::Vocabulary}; use derivative::Derivative; use flume::{Receiver, Sender}; use itertools::Itertools; @@ -36,8 +35,6 @@ use crate::{ const END_OF_LINE_TOKEN: u16 = 261; const MIN_PROMPT_CACHE_TOKENS: usize = 32; const MAX_CACHE_ITEMS: usize = 256; -const SAMPLER_ARENA_CAPACITY: usize = 1048576; -const GRAMMAR_ARENA_CAPACITY: usize = 1024; #[derive(Debug)] pub enum SlotResult { @@ -401,7 +398,6 @@ pub struct Runtime { model: Arc, runtime: JobRuntime, tokenizer: Arc, - vocab: Arc, slots: Mutex>, caches: Mutex, } @@ -413,7 +409,6 @@ impl Runtime { reload: ReloadRequest, states: Vec, tokenizer: Tokenizer, - vocab: Vocabulary, ) -> Self where J: Job, @@ -441,7 +436,7 @@ impl Runtime { let info = builder.info(); let state = Arc::new(builder.state()); let model = Arc::new(Model(builder.model())); - let runtime = JobRuntime::new(builder).await; + let runtime = JobRuntime::::new(builder).await; Self { context, @@ -451,7 +446,6 @@ impl Runtime { model, runtime, tokenizer: Arc::new(tokenizer), - vocab: Arc::new(vocab), slots: Mutex::new(slots), caches: Mutex::new(caches), } @@ -571,16 +565,7 @@ impl Runtime { /// Compile and cache the given schema into a BNF sampler. async fn compile_bnf_schema(&self, schema: String) -> Result { - let grammar = Grammar::new(&schema, self.vocab.clone(), GRAMMAR_ARENA_CAPACITY)?; - let start_nonterminal = self.reload.bnf.start_nonterminal.clone(); - let sampler = bnf_sampler::sampler::Sampler::new( - grammar, - start_nonterminal, - self.vocab.clone(), - SAMPLER_ARENA_CAPACITY, - self.reload.bnf.enable_bytes_cache, - )?; - Ok(BnfSampler::new(sampler)) + BnfSampler::new(&self.tokenizer, &schema) } /// Queue an inference task. @@ -950,10 +935,10 @@ impl Runtime { }; // update the transformer (BNF) state - let mut exhausted = false; + let mut halt = false; for transformer in context.transformers.iter() { let mut transformer = transformer.write().await; - exhausted |= transformer.update(token); + halt |= transformer.update(token); } // here we detect if there is a stop word in our buffer @@ -1045,7 +1030,7 @@ impl Runtime { } let _ = context.sender.send(Token::Choose(perplexities)); done = true; - } else if exhausted || stop_matched { + } else if halt || stop_matched { let output = String::from_utf8_lossy(head); let _ = context.sender.send(Token::Content(output.into())); stop(FinishReason::Stop); diff --git a/crates/ai00-core/src/sampler/bnf.rs b/crates/ai00-core/src/sampler/bnf.rs index df490aa..bb048db 100644 --- a/crates/ai00-core/src/sampler/bnf.rs +++ b/crates/ai00-core/src/sampler/bnf.rs @@ -1,65 +1,46 @@ -use bit_set::BitSet; -use bnf_sampler::sampler::{AcceptTokenResult, PossibleTokensResult, Sampler}; +use anyhow::Result; +use kbnf::{ + engine_like::AcceptTokenError, AcceptTokenResult, Engine, EngineLike, Token, Vocabulary, +}; +use web_rwkv::tokenizer::Tokenizer; use super::Transformer; #[derive(Debug)] -pub struct BnfSampler { - sampler: Sampler, - current_token_ids: BitSet, -} +pub struct BnfSampler(Engine); impl BnfSampler { - pub fn new(mut sampler: Sampler) -> Self { - let current_token_ids = match sampler.all_possible_next_tokens(None) { - Ok(PossibleTokensResult::Continue(tokens)) => tokens.clone(), - _ => BitSet::new(), - }; - Self { - sampler, - current_token_ids, - } - } - - #[inline] - pub fn current_token_ids(&self) -> &BitSet { - &self.current_token_ids - } -} - -impl std::ops::Deref for BnfSampler { - type Target = Sampler; - - fn deref(&self) -> &Self::Target { - &self.sampler - } -} - -impl std::ops::DerefMut for BnfSampler { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.sampler + pub fn new(tokenizer: &Tokenizer, schema: &str) -> Result { + let tokens = tokenizer + .token_index_to_bytes() + .iter() + .enumerate() + .map(|(k, v)| (k as u32, Token(v.clone().into_boxed_slice()))) + .collect(); + let strings = tokenizer + .token_index_to_bytes() + .iter() + .enumerate() + .map(|(k, v)| (k as u32, String::from_utf8_lossy(v).to_string())) + .collect(); + let vocab = Vocabulary::new(tokens, strings)?; + let engine = Engine::new(schema, vocab)?; + Ok(Self(engine)) } } impl Transformer for BnfSampler { fn transform(&self, output: &mut [f32]) { - output - .iter_mut() - .enumerate() - .filter(|&(token, _)| !self.current_token_ids().contains(token)) - .for_each(|(_, logits)| *logits = f32::MIN) + self.0.mask_logits(output).expect("bnf transform error") } fn update(&mut self, token: u16) -> bool { - let token = Some(token as u32); - let accept = self.accept_a_token(token).expect("invalid input token"); - self.current_token_ids = match self.sampler.all_possible_next_tokens(None) { - Ok(PossibleTokensResult::Continue(tokens)) => tokens.clone(), - _ => BitSet::new(), + let halt = match self.0.try_accept_new_token(token as u32) { + Ok(AcceptTokenResult::Finished) | Err(AcceptTokenError::Finished) => true, + Ok(AcceptTokenResult::Ongoing) => false, + Err(_) => self.0.is_finished(), }; - match accept { - AcceptTokenResult::Continue => false, - AcceptTokenResult::End | AcceptTokenResult::Failed => true, - } + self.0.compute_allowed_token_ids(); + halt } }