Custom tokenizer for rwkv. (#1711)

* Custom tokenizer for rwkv.

* Custom tokenizer.

* Getting the tokenizer to work.
This commit is contained in:
Laurent Mazare
2024-02-14 15:11:38 +01:00
committed by GitHub
parent 121a71e01f
commit 26fe162ab5
2 changed files with 105 additions and 38 deletions

View File

@ -4,23 +4,21 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::rwkv_v5::{Config, Model, State};
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
@ -43,7 +41,7 @@ impl TextGeneration {
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
@ -53,28 +51,15 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokens = self.tokenizer.encode(prompt)?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[[t]], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
print!("{}", self.tokenizer.decode(&[t])?)
}
std::io::stdout().flush()?;
@ -98,22 +83,13 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;
let input = Tensor::new(&[[next_token]], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
@ -192,7 +168,7 @@ struct Args {
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
tokenizer: Option<String>,
#[arg(long)]
weight_files: Option<String>,
@ -244,12 +220,11 @@ fn main() -> Result<()> {
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
let tokenizer = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
// TODO: Use the appropriate tokenizer here.
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
.model("lmz/candle-rwkv".to_string())
.get("rwkv_vocab_v20230424.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
@ -265,7 +240,7 @@ fn main() -> Result<()> {
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = Tokenizer::new(tokenizer)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;