From 26fe162ab5e64faca39cf6ccaa6a8513ce1240f0 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 14 Feb 2024 15:11:38 +0100 Subject: [PATCH] Custom tokenizer for rwkv. (#1711) * Custom tokenizer for rwkv. * Custom tokenizer. * Getting the tokenizer to work. --- candle-examples/examples/rwkv/main.rs | 51 ++++--------- candle-transformers/src/models/rwkv_v5.rs | 92 +++++++++++++++++++++++ 2 files changed, 105 insertions(+), 38 deletions(-) diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index 7fd1f76c..0ccf2ec3 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -4,23 +4,21 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use anyhow::{Error as E, Result}; +use anyhow::Result; use clap::{Parser, ValueEnum}; -use candle_transformers::models::rwkv_v5::{Config, Model, State}; +use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer}; use candle::{DType, Device, Tensor}; -use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; -use tokenizers::Tokenizer; struct TextGeneration { model: Model, config: Config, device: Device, - tokenizer: TokenOutputStream, + tokenizer: Tokenizer, logits_processor: LogitsProcessor, repeat_penalty: f32, repeat_last_n: usize, @@ -43,7 +41,7 @@ impl TextGeneration { Self { model, config, - tokenizer: TokenOutputStream::new(tokenizer), + tokenizer, logits_processor, repeat_penalty, repeat_last_n, @@ -53,28 +51,15 @@ impl TextGeneration { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; - self.tokenizer.clear(); - let mut tokens = self - .tokenizer - .tokenizer() - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); + let mut tokens = self.tokenizer.encode(prompt)?; let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_token("<|endoftext|>") { - Some(token) => token, - None => anyhow::bail!("cannot find the token"), - }; let mut state = State::new(1, &self.config, &self.device)?; let mut next_logits = None; for &t in tokens.iter() { let input = Tensor::new(&[[t]], &self.device)?; let logits = self.model.forward(&input, &mut state)?; next_logits = Some(logits); - if let Some(t) = self.tokenizer.next_token(t)? { - print!("{t}") - } + print!("{}", self.tokenizer.decode(&[t])?) } std::io::stdout().flush()?; @@ -98,22 +83,13 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); generated_tokens += 1; - if next_token == eos_token { - break; - } - if let Some(t) = self.tokenizer.next_token(next_token)? { - print!("{t}"); - std::io::stdout().flush()?; - } + print!("{}", self.tokenizer.decode(&[next_token])?); + std::io::stdout().flush()?; let input = Tensor::new(&[[next_token]], &self.device)?; next_logits = Some(self.model.forward(&input, &mut state)?) } let dt = start_gen.elapsed(); - if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { - print!("{rest}"); - } - std::io::stdout().flush()?; println!( "\n{generated_tokens} tokens generated ({:.2} token/s)", generated_tokens as f64 / dt.as_secs_f64(), @@ -192,7 +168,7 @@ struct Args { revision: Option, #[arg(long)] - tokenizer_file: Option, + tokenizer: Option, #[arg(long)] weight_files: Option, @@ -244,12 +220,11 @@ fn main() -> Result<()> { args.revision .unwrap_or_else(|| args.which.revision().to_string()), )); - let tokenizer_filename = match args.tokenizer_file { + let tokenizer = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => api - // TODO: Use the appropriate tokenizer here. - .model("EleutherAI/gpt-neox-20b".to_string()) - .get("tokenizer.json")?, + .model("lmz/candle-rwkv".to_string()) + .get("rwkv_vocab_v20230424.json")?, }; let config_filename = match args.config_file { Some(file) => std::path::PathBuf::from(file), @@ -265,7 +240,7 @@ fn main() -> Result<()> { } }; println!("retrieved the files in {:?}", start.elapsed()); - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let tokenizer = Tokenizer::new(tokenizer)?; let start = std::time::Instant::now(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index d8ea7b20..04dbfc45 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -1,6 +1,7 @@ use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use std::collections::{HashMap, HashSet}; fn default_num_attention_heads() -> usize { 64 @@ -315,3 +316,94 @@ impl Model { Ok(xs) } } + +type Bytes = Vec; + +// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14 +pub struct Tokenizer { + table: Vec>>, + good: Vec>, + idx2token: HashMap>, + token2idx: HashMap, u32>, +} + +impl Tokenizer { + pub fn new>(p: P) -> Result { + let file = std::fs::File::open(p)?; + let token2idx: HashMap = + serde_json::from_reader(file).map_err(candle::Error::wrap)?; + let token2idx = token2idx + .into_iter() + .map(|(key, value)| (key.into_bytes(), value)) + .collect::>(); + let idx2token = token2idx + .iter() + .map(|(key, value)| (*value, key.to_vec())) + .collect::>(); + + let max_idx = token2idx.values().copied().max().unwrap_or(0); + + let mut table = vec![vec![vec![]; 256]; 256]; + let mut good = vec![HashSet::new(); 256]; + for idx in (0..(1 + max_idx)).rev() { + let s = match idx2token.get(&idx) { + None => continue, + Some(s) => s, + }; + if s.len() >= 2 { + let (s0, s1) = (s[0], s[1]); + table[s0 as usize][s1 as usize].push(s.to_vec()); + good[s0 as usize].insert(s1); + } + } + Ok(Self { + table, + good, + idx2token, + token2idx, + }) + } + + pub fn decode_bytes(&self, tokens: &[u32]) -> Vec { + let mut v = Vec::new(); + for token_id in tokens.iter() { + if let Some(token) = self.idx2token.get(token_id) { + v.extend_from_slice(token.as_slice()) + } + } + v + } + + pub fn decode(&self, tokens: &[u32]) -> Result { + let bytes = self.decode_bytes(tokens); + String::from_utf8(bytes).map_err(candle::Error::wrap) + } + + pub fn encode_bytes(&self, bytes: &[u8]) -> Result> { + let mut tokens = Vec::new(); + let mut i = 0; + while i < bytes.len() { + let mut s = vec![bytes[i]]; + if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) { + let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize]; + for table_elem in table.iter() { + if bytes[i..].starts_with(table_elem) { + s = table_elem.to_vec(); + break; + } + } + } + i += s.len(); + let token = match self.token2idx.get(&s) { + None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)), + Some(token) => *token, + }; + tokens.push(token) + } + Ok(tokens) + } + + pub fn encode(&self, str: &str) -> Result> { + self.encode_bytes(str.as_bytes()) + } +}