mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Custom tokenizer for rwkv. (#1711)
* Custom tokenizer for rwkv. * Custom tokenizer. * Getting the tokenizer to work.
This commit is contained in:
@ -4,23 +4,21 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State};
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
@ -43,7 +41,7 @@ impl TextGeneration {
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
@ -53,28 +51,15 @@ impl TextGeneration {
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokens = self.tokenizer.encode(prompt)?;
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[[t]], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
print!("{}", self.tokenizer.decode(&[t])?)
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
@ -98,22 +83,13 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let input = Tensor::new(&[[next_token]], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
@ -192,7 +168,7 @@ struct Args {
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
@ -244,12 +220,11 @@ fn main() -> Result<()> {
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
// TODO: Use the appropriate tokenizer here.
|
||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("rwkv_vocab_v20230424.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
@ -265,7 +240,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokenizer = Tokenizer::new(tokenizer)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
64
|
||||
@ -315,3 +316,94 @@ impl Model {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
type Bytes = Vec<u8>;
|
||||
|
||||
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
|
||||
pub struct Tokenizer {
|
||||
table: Vec<Vec<Vec<Bytes>>>,
|
||||
good: Vec<HashSet<u8>>,
|
||||
idx2token: HashMap<u32, Vec<u8>>,
|
||||
token2idx: HashMap<Vec<u8>, u32>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let token2idx: HashMap<String, u32> =
|
||||
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
|
||||
let token2idx = token2idx
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.into_bytes(), value))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let idx2token = token2idx
|
||||
.iter()
|
||||
.map(|(key, value)| (*value, key.to_vec()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let max_idx = token2idx.values().copied().max().unwrap_or(0);
|
||||
|
||||
let mut table = vec![vec![vec![]; 256]; 256];
|
||||
let mut good = vec![HashSet::new(); 256];
|
||||
for idx in (0..(1 + max_idx)).rev() {
|
||||
let s = match idx2token.get(&idx) {
|
||||
None => continue,
|
||||
Some(s) => s,
|
||||
};
|
||||
if s.len() >= 2 {
|
||||
let (s0, s1) = (s[0], s[1]);
|
||||
table[s0 as usize][s1 as usize].push(s.to_vec());
|
||||
good[s0 as usize].insert(s1);
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
table,
|
||||
good,
|
||||
idx2token,
|
||||
token2idx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
|
||||
let mut v = Vec::new();
|
||||
for token_id in tokens.iter() {
|
||||
if let Some(token) = self.idx2token.get(token_id) {
|
||||
v.extend_from_slice(token.as_slice())
|
||||
}
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
let bytes = self.decode_bytes(tokens);
|
||||
String::from_utf8(bytes).map_err(candle::Error::wrap)
|
||||
}
|
||||
|
||||
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let mut s = vec![bytes[i]];
|
||||
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
|
||||
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
|
||||
for table_elem in table.iter() {
|
||||
if bytes[i..].starts_with(table_elem) {
|
||||
s = table_elem.to_vec();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += s.len();
|
||||
let token = match self.token2idx.get(&s) {
|
||||
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
|
||||
Some(token) => *token,
|
||||
};
|
||||
tokens.push(token)
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
|
||||
self.encode_bytes(str.as_bytes())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user