From a72b50e2c0d15b14ae7d94a478b7162eacb80cfb Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 17 Oct 2023 20:41:37 +0100 Subject: [PATCH] Build alibi bias. (#1115) * Build alibi bias. * Apply the alibi attention bias. * Add the replit-code example. --- candle-examples/examples/replit-code/main.rs | 234 +++++++++++++++++++ candle-transformers/src/models/mpt.rs | 100 +++++++- 2 files changed, 328 insertions(+), 6 deletions(-) create mode 100644 candle-examples/examples/replit-code/main.rs diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs new file mode 100644 index 00000000..862f9993 --- /dev/null +++ b/candle-examples/examples/replit-code/main.rs @@ -0,0 +1,234 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::mpt::{Config, Model}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + repeat_penalty, + repeat_last_n, + verbose_prompt, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + println!("starting the inference loop"); + let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?; + if tokens.is_empty() { + anyhow::bail!("Empty prompts are not supported in the phi model.") + } + if self.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + let mut tokens = tokens.get_ids().to_vec(); + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => anyhow::bail!("cannot find the endoftext token"), + }; + print!("{prompt}"); + std::io::stdout().flush()?; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; + print!("{token}"); + std::io::stdout().flush()?; + } + let dt = start_gen.elapsed(); + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 100)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + weight_file: Option, + + #[arg(long)] + tokenizer: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => "lmz/candle-replit-code".to_string(), + }; + let revision = match args.revision { + Some(rev) => rev.to_string(), + None => "main".to_string(), + }; + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let tokenizer_filename = match args.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filename = match args.weight_file { + Some(weight_file) => std::path::PathBuf::from(weight_file), + None => repo.get("model.safetensors")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config = Config::replit_code_v1_5_3b(); + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index e11a9a75..b26caa81 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -15,7 +15,9 @@ pub struct Config { pub(crate) max_seq_len: usize, pub(crate) vocab_size: usize, pub(crate) kv_n_heads: usize, - // pub(crate) attn_config: AttnConfig, + pub(crate) attn_prefix_lm: bool, + pub(crate) attn_alibi: bool, + pub(crate) attn_alibi_bias_max: usize, } impl Config { @@ -28,8 +30,15 @@ impl Config { max_seq_len: 4096, vocab_size: 32768, kv_n_heads: 8, + attn_prefix_lm: false, + attn_alibi: true, + attn_alibi_bias_max: 8, } } + + pub fn is_causal(&self) -> bool { + !self.attn_prefix_lm + } } #[derive(Debug)] @@ -42,6 +51,7 @@ struct GroupedQueryAttention { d_model: usize, n_heads: usize, kv_n_heads: usize, + attn_bias: Tensor, span: tracing::Span, } @@ -52,6 +62,7 @@ impl GroupedQueryAttention { let head_dim = cfg.d_model / cfg.n_heads; let softmax_scale = 1f64 / (head_dim as f64).sqrt(); let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?; + let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?; Ok(Self { wqkv, out_proj, @@ -61,6 +72,7 @@ impl GroupedQueryAttention { d_model: cfg.d_model, n_heads: cfg.n_heads, kv_n_heads: cfg.kv_n_heads, + attn_bias, span: tracing::span!(tracing::Level::TRACE, "gqa"), }) } @@ -94,7 +106,23 @@ impl GroupedQueryAttention { let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?; let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?; let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; - // TODO: attn_bias, alibi + let attn_bias = { + let s_q = query.dim(D::Minus2)?; + let s_k = key.dim(D::Minus1)?; + let (_, _, a_q, a_k) = self.attn_bias.dims4()?; + self.attn_bias + .narrow(2, a_q - s_q, s_q)? + .narrow(3, a_k - s_k, s_k)? + }; + let attn_weights = (attn_weights + attn_bias)?; + let attn_weights = match mask { + None => attn_weights, + Some(mask) => masked_fill( + &attn_weights, + &mask.broadcast_left(b_size * self.n_heads)?, + f32::NEG_INFINITY, + )?, + }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_weights .matmul(&value)? @@ -172,15 +200,49 @@ impl MPTBlock { } } +fn build_alibi_bias(cfg: &Config) -> Result { + let full = !cfg.is_causal(); + let seq_len = cfg.max_seq_len; + let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?; + let alibi_bias = if full { + let a1 = alibi_bias.reshape((1, 1, 1, seq_len))?; + let a2 = alibi_bias.reshape((1, 1, seq_len, 1))?; + a1.broadcast_sub(&a2)?.abs()?.neg()? + } else { + alibi_bias.reshape((1, 1, 1, seq_len))? + }; + let mut n_heads2 = 1; + while 2 * n_heads2 <= cfg.n_heads { + n_heads2 *= 2 + } + let slopes = (1..=n_heads2) + .map(|v| 1f32 / 2f32.powf((v * cfg.attn_alibi_bias_max) as f32 / n_heads2 as f32)) + .collect::>(); + let slopes = if n_heads2 == cfg.n_heads { + slopes + } else { + slopes + .iter() + .skip(1) + .step_by(2) + .chain(slopes.iter().step_by(2)) + .take(cfg.n_heads) + .cloned() + .collect::>() + }; + let slopes = Tensor::new(slopes, &Device::Cpu)?; + alibi_bias.broadcast_mul(&slopes) +} + #[derive(Debug)] -struct Model { +pub struct Model { wte: candle_nn::Embedding, blocks: Vec, norm_f: LayerNorm, } impl Model { - fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let wte = candle_nn::embedding(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?; let vb_b = vb.pp("blocks"); let mut blocks = Vec::with_capacity(cfg.n_layers); @@ -196,7 +258,33 @@ impl Model { }) } - fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { - todo!() + pub fn forward(&mut self, xs: &Tensor) -> Result { + let (_b_size, seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.wte)?; + let mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; + for block in self.blocks.iter_mut() { + xs = block.forward(&xs, mask.as_ref())? + } + xs.narrow(1, seq_len - 1, 1)? + .matmul(&self.wte.embeddings().t()?)? + .squeeze(1) } } + +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +}