From bac4ef40f328fb11d99d6653ac4b2bd5c743ea11 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 7 Jul 2023 06:34:22 +0100 Subject: [PATCH] Add some text generation pipeline for falcon. (#98) --- candle-examples/examples/falcon/main.rs | 94 +++++++++++++++++++++--- candle-examples/examples/falcon/model.rs | 15 ++-- 2 files changed, 93 insertions(+), 16 deletions(-) diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 7f1f4c51..be48c63a 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -1,11 +1,14 @@ #![allow(dead_code)] +// TODO: KV cache. #[cfg(feature = "mkl")] extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, Tensor}; +use candle::{DType, Device, Tensor, D}; use clap::Parser; +use rand::{distributions::Distribution, SeedableRng}; +use tokenizers::Tokenizer; mod model; use model::{Config, Falcon, VarBuilder}; @@ -15,6 +18,78 @@ const DTYPE: DType = DType::F32; #[cfg(not(feature = "mkl"))] const DTYPE: DType = DType::BF16; +const TEMPERATURE: Option = None; + +struct TextGeneration { + model: Falcon, + rng: rand::rngs::StdRng, + device: Device, + tokenizer: Tokenizer, +} + +impl TextGeneration { + fn new(model: Falcon, tokenizer: Tokenizer, seed: u64, device: &Device) -> Self { + Self { + model, + tokenizer, + rng: rand::rngs::StdRng::seed_from_u64(seed), + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + println!("starting the inference loop"); + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut new_tokens = vec![]; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let start_gen = std::time::Instant::now(); + let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?; + + let next_token = if let Some(temperature) = TEMPERATURE { + let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let logits_v: Vec = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + tokens.push(next_token); + new_tokens.push(next_token); + println!("> {:?}", start_gen.elapsed()); + println!( + "{} token: {} '{}'", + index + 1, + next_token, + self.tokenizer + .decode(vec![next_token], true) + .map_err(E::msg)? + ); + } + let dt = start_gen.elapsed(); + println!( + "{sample_len} tokens generated ({} token/s)\n----\n{}\n----", + sample_len as f64 / dt.as_secs_f64(), + self.tokenizer.decode(new_tokens, true).map_err(E::msg)? + ); + Ok(()) + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -29,6 +104,10 @@ struct Args { #[arg(long, default_value_t = 299792458)] seed: u64, + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + #[arg(long, default_value = "tiiuae/falcon-7b")] model_id: String, @@ -38,7 +117,6 @@ struct Args { fn main() -> Result<()> { use candle_hub::{api::sync::Api, Repo, RepoType}; - use tokenizers::Tokenizer; let args = Args::parse(); let device = if args.cpu { @@ -75,16 +153,10 @@ fn main() -> Result<()> { let vb = VarBuilder::from_safetensors(weights, DTYPE, &device); let config = Config::falcon7b(); config.validate()?; - let mut model = Falcon::load(&vb, config)?; + let model = Falcon::load(&vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let tokens = tokenizer - .encode(args.prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; - let logits = model.forward(&tokens)?; - println!("{}", logits); + let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, &device); + pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index e367e049..82e89841 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -544,8 +544,9 @@ impl FalconDecoderLayer { #[derive(Debug)] pub struct Falcon { word_embeddings: Embedding, - h: Vec, + blocks: Vec, ln_f: LayerNorm, + lm_head: Linear, config: Config, } @@ -572,7 +573,7 @@ impl Falcon { "transformer.word_embeddings", vb, )?; - let h = (0..cfg.num_hidden_layers) + let blocks = (0..cfg.num_hidden_layers) .map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg)) .collect::>>()?; let ln_f = LayerNorm::load( @@ -581,10 +582,12 @@ impl Falcon { "transformer.ln_f", vb, )?; + let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; Ok(Self { word_embeddings, - h, + blocks, ln_f, + lm_head, config: cfg, }) } @@ -593,10 +596,12 @@ impl Falcon { let (b_sz, seq_len) = input_ids.shape().r2()?; let mut hidden_state = self.word_embeddings.forward(input_ids)?; let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?; - for block in self.h.iter_mut() { + for block in self.blocks.iter_mut() { hidden_state = block.forward(&hidden_state, &causal_mask)?; } let hidden_state = self.ln_f.forward(&hidden_state)?; - Ok(hidden_state) + let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; + let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?; + Ok(logits) } }