Add some text generation pipeline for falcon. (#98)

This commit is contained in:
Laurent Mazare
2023-07-07 06:34:22 +01:00
committed by GitHub
parent 2b8e8c9f14
commit bac4ef40f3
2 changed files with 93 additions and 16 deletions

View File

@ -1,11 +1,14 @@
#![allow(dead_code)]
// TODO: KV cache.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle::{DType, Device, Tensor, D};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
mod model;
use model::{Config, Falcon, VarBuilder};
@ -15,6 +18,78 @@ const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::BF16;
const TEMPERATURE: Option<f64> = None;
struct TextGeneration {
model: Falcon,
rng: rand::rngs::StdRng,
device: Device,
tokenizer: Tokenizer,
}
impl TextGeneration {
fn new(model: Falcon, tokenizer: Tokenizer, seed: u64, device: &Device) -> Self {
Self {
model,
tokenizer,
rng: rand::rngs::StdRng::seed_from_u64(seed),
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
println!("starting the inference loop");
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let start_gen = std::time::Instant::now();
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?;
let next_token = if let Some(temperature) = TEMPERATURE {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
self.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
sample_len as f64 / dt.as_secs_f64(),
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -29,6 +104,10 @@ struct Args {
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "tiiuae/falcon-7b")]
model_id: String,
@ -38,7 +117,6 @@ struct Args {
fn main() -> Result<()> {
use candle_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
@ -75,16 +153,10 @@ fn main() -> Result<()> {
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
let config = Config::falcon7b();
config.validate()?;
let mut model = Falcon::load(&vb, config)?;
let model = Falcon::load(&vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&tokens)?;
println!("{}", logits);
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -544,8 +544,9 @@ impl FalconDecoderLayer {
#[derive(Debug)]
pub struct Falcon {
word_embeddings: Embedding,
h: Vec<FalconDecoderLayer>,
blocks: Vec<FalconDecoderLayer>,
ln_f: LayerNorm,
lm_head: Linear,
config: Config,
}
@ -572,7 +573,7 @@ impl Falcon {
"transformer.word_embeddings",
vb,
)?;
let h = (0..cfg.num_hidden_layers)
let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = LayerNorm::load(
@ -581,10 +582,12 @@ impl Falcon {
"transformer.ln_f",
vb,
)?;
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
Ok(Self {
word_embeddings,
h,
blocks,
ln_f,
lm_head,
config: cfg,
})
}
@ -593,10 +596,12 @@ impl Falcon {
let (b_sz, seq_len) = input_ids.shape().r2()?;
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
for block in self.h.iter_mut() {
for block in self.blocks.iter_mut() {
hidden_state = block.forward(&hidden_state, &causal_mask)?;
}
let hidden_state = self.ln_f.forward(&hidden_state)?;
Ok(hidden_state)
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
Ok(logits)
}
}