Custom tokenizer for rwkv. (#1711)

* Custom tokenizer for rwkv.

* Custom tokenizer.

* Getting the tokenizer to work.
This commit is contained in:
Laurent Mazare
2024-02-14 15:11:38 +01:00
committed by GitHub
parent 121a71e01f
commit 26fe162ab5
2 changed files with 105 additions and 38 deletions

View File

@ -4,23 +4,21 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::rwkv_v5::{Config, Model, State};
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
@ -43,7 +41,7 @@ impl TextGeneration {
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
@ -53,28 +51,15 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokens = self.tokenizer.encode(prompt)?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[[t]], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
print!("{}", self.tokenizer.decode(&[t])?)
}
std::io::stdout().flush()?;
@ -98,22 +83,13 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;
let input = Tensor::new(&[[next_token]], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
@ -192,7 +168,7 @@ struct Args {
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
tokenizer: Option<String>,
#[arg(long)]
weight_files: Option<String>,
@ -244,12 +220,11 @@ fn main() -> Result<()> {
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
let tokenizer = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
// TODO: Use the appropriate tokenizer here.
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
.model("lmz/candle-rwkv".to_string())
.get("rwkv_vocab_v20230424.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
@ -265,7 +240,7 @@ fn main() -> Result<()> {
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = Tokenizer::new(tokenizer)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;

View File

@ -1,6 +1,7 @@
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::{HashMap, HashSet};
fn default_num_attention_heads() -> usize {
64
@ -315,3 +316,94 @@ impl Model {
Ok(xs)
}
}
type Bytes = Vec<u8>;
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
pub struct Tokenizer {
table: Vec<Vec<Vec<Bytes>>>,
good: Vec<HashSet<u8>>,
idx2token: HashMap<u32, Vec<u8>>,
token2idx: HashMap<Vec<u8>, u32>,
}
impl Tokenizer {
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
let token2idx: HashMap<String, u32> =
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
let token2idx = token2idx
.into_iter()
.map(|(key, value)| (key.into_bytes(), value))
.collect::<HashMap<_, _>>();
let idx2token = token2idx
.iter()
.map(|(key, value)| (*value, key.to_vec()))
.collect::<HashMap<_, _>>();
let max_idx = token2idx.values().copied().max().unwrap_or(0);
let mut table = vec![vec![vec![]; 256]; 256];
let mut good = vec![HashSet::new(); 256];
for idx in (0..(1 + max_idx)).rev() {
let s = match idx2token.get(&idx) {
None => continue,
Some(s) => s,
};
if s.len() >= 2 {
let (s0, s1) = (s[0], s[1]);
table[s0 as usize][s1 as usize].push(s.to_vec());
good[s0 as usize].insert(s1);
}
}
Ok(Self {
table,
good,
idx2token,
token2idx,
})
}
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
let mut v = Vec::new();
for token_id in tokens.iter() {
if let Some(token) = self.idx2token.get(token_id) {
v.extend_from_slice(token.as_slice())
}
}
v
}
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
let bytes = self.decode_bytes(tokens);
String::from_utf8(bytes).map_err(candle::Error::wrap)
}
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
let mut tokens = Vec::new();
let mut i = 0;
while i < bytes.len() {
let mut s = vec![bytes[i]];
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
for table_elem in table.iter() {
if bytes[i..].starts_with(table_elem) {
s = table_elem.to_vec();
break;
}
}
}
i += s.len();
let token = match self.token2idx.get(&s) {
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
Some(token) => *token,
};
tokens.push(token)
}
Ok(tokens)
}
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
self.encode_bytes(str.as_bytes())
}
}