From deee7612da7dcda1aa1cfd4237f4858d9f5ed8c7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 30 Sep 2023 19:25:47 +0200 Subject: [PATCH] Quantized version of mistral. (#1009) * Quantized version of mistral. * Integrate the quantized mistral variant. * Use the quantized weight files. * Tweak the quantization command. * Fix the dtype when computing the rotary embeddings. * Update the readme with the quantized version. * Fix the decoding of the remaining tokens. --- candle-core/examples/tensor-tools.rs | 36 +- candle-examples/examples/mistral/README.md | 50 +++ candle-examples/examples/mistral/main.rs | 53 ++- candle-examples/src/token_output_stream.rs | 16 +- candle-transformers/src/models/mistral.rs | 24 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_mistral.rs | 364 ++++++++++++++++++ 7 files changed, 507 insertions(+), 37 deletions(-) create mode 100644 candle-transformers/src/models/quantized_mistral.rs diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 3982f2c3..d06b30d1 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -103,8 +103,10 @@ enum Command { Quantize { /// The input file, in gguf format. - in_file: std::path::PathBuf, + in_file: Vec, + /// The output file, in gguf format. + #[arg(long)] out_file: std::path::PathBuf, /// The quantization schema to apply. @@ -218,12 +220,16 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R } fn run_quantize_safetensors( - in_file: std::path::PathBuf, + in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, q: Quantization, ) -> Result<()> { let mut out_file = std::fs::File::create(out_file)?; - let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + let mut tensors = std::collections::HashMap::new(); + for in_file in in_files.iter() { + let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + tensors.extend(in_tensors) + } println!("tensors: {}", tensors.len()); let quantize_fn = match q { @@ -280,20 +286,32 @@ fn run_quantize_safetensors( } fn run_quantize( - in_file: std::path::PathBuf, + in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, ) -> Result<()> { - if let Some(extension) = in_file.extension() { + if in_files.is_empty() { + candle_core::bail!("no specified input files") + } + if let Some(extension) = out_file.extension() { if extension == "safetensors" { - return run_quantize_safetensors(in_file, out_file, q); + candle_core::bail!("the generated file cannot use the safetensors extension") } } + if let Some(extension) = in_files[0].extension() { + if extension == "safetensors" { + return run_quantize_safetensors(in_files, out_file, q); + } + } + + if in_files.len() != 1 { + candle_core::bail!("only a single in-file can be used when quantizing gguf files") + } // Open the out file early so as to fail directly on missing directories etc. let mut out_file = std::fs::File::create(out_file)?; - let mut in_ = std::fs::File::open(&in_file)?; + let mut in_ = std::fs::File::open(&in_files[0])?; let content = gguf_file::Content::read(&mut in_)?; println!("tensors: {}", content.tensor_infos.len()); @@ -319,7 +337,7 @@ fn run_quantize( .par_iter() .map(|(name, _)| { println!(" quantizing {name}"); - let mut in_file = std::fs::File::open(&in_file)?; + let mut in_file = std::fs::File::open(&in_files[0])?; let tensor = content.tensor(&mut in_file, name)?; let tensor = qmode.quantize(name, tensor, quantize_fn)?; Ok((name, tensor)) @@ -360,7 +378,7 @@ fn main() -> anyhow::Result<()> { out_file, quantization, mode, - } => run_quantize(in_file, out_file, quantization, mode)?, + } => run_quantize(&in_file, out_file, quantization, mode)?, } Ok(()) } diff --git a/candle-examples/examples/mistral/README.md b/candle-examples/examples/mistral/README.md index 6a5a0424..61a6666e 100644 --- a/candle-examples/examples/mistral/README.md +++ b/candle-examples/examples/mistral/README.md @@ -6,6 +6,9 @@ as of 2023-09-28. Weights (and the original Python model code) are released unde - [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release. - [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the HuggingFace Hub. +This example supports the initial model as well as a quantized variant. + +## Running the example ```bash $ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150 @@ -38,3 +41,50 @@ fn main() { This example is released under the terms ``` + +## Running the quantized version of the model + +```bash +$ cargo run --example mistral --features accelerate --release -- \ +$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400 +avx: false, neon: true, simd128: false, f16c: false +temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64 +retrieved the files in 562.292µs +loaded the model in 1.100323667s +Here is a sample quick sort implementation in rust + +``rust +fn quick_sort(arr: &mut [i32]) { + if arr.len() <= 1 { + return; + } + + let pivot = arr[0]; + let mut left = vec![]; + let mut right = vec![]; + + for i in 1..arr.len() { + if arr[i] < pivot { + left.push(arr[i]); + } else { + right.push(arr[i]); + } + } + + quick_sort(&mut left); + quick_sort(&mut right); + + let mut i = 0; + for _ in &left { + arr[i] = left.pop().unwrap(); + i += 1; + } + + for _ in &right { + arr[i] = right.pop().unwrap(); + i += 1; + } +} +`` +226 tokens generated (10.91 token/s) +``` diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 6fe08963..18f18e5d 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::mistral::{Config, Model}; +use candle_transformers::models::mistral::{Config, Model as Mistral}; +use candle_transformers::models::quantized_mistral::Model as QMistral; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -16,6 +17,11 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + Mistral(Mistral), + Quantized(QMistral), +} + struct TextGeneration { model: Model, device: Device, @@ -76,7 +82,10 @@ impl TextGeneration { let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input, start_pos)?; + let logits = match &mut self.model { + Model::Mistral(m) => m.forward(&input, start_pos)?, + Model::Quantized(m) => m.forward(&input, start_pos)?, + }; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { logits @@ -101,8 +110,9 @@ impl TextGeneration { } } let dt = start_gen.elapsed(); - let rest = self.tokenizer.decode_rest().map_err(E::msg)?; - print!("{rest}"); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } std::io::stdout().flush()?; println!( "\n{generated_tokens} tokens generated ({:.2} token/s)", @@ -211,24 +221,39 @@ fn main() -> Result<()> { .split(',') .map(std::path::PathBuf::from) .collect::>(), - None => vec![ - repo.get("pytorch_model-00001-of-00002.safetensors")?, - repo.get("pytorch_model-00002-of-00002.safetensors")?, - ], + None => { + if args.quantized { + vec![repo.get("model-q4k.gguf")?] + } else { + vec![ + repo.get("pytorch_model-00001-of-00002.safetensors")?, + repo.get("pytorch_model-00002-of-00002.safetensors")?, + ] + } + } }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); let config = Config::config_7b_v0_1(args.use_flash_attn); - let device = candle_examples::device(args.cpu)?; - let dtype = if device.is_cuda() { - DType::BF16 + let (model, device) = if args.quantized { + let filename = &filenames[0]; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let model = QMistral::new(&config, vb)?; + (Model::Quantized(model), Device::Cpu) } else { - DType::F32 + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Mistral::new(&config, vb)?; + (Model::Mistral(model), device) }; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( diff --git a/candle-examples/src/token_output_stream.rs b/candle-examples/src/token_output_stream.rs index 3d975d63..907d8ddd 100644 --- a/candle-examples/src/token_output_stream.rs +++ b/candle-examples/src/token_output_stream.rs @@ -50,8 +50,20 @@ impl TokenOutputStream { } } - pub fn decode_rest(&self) -> Result { - self.decode(&self.tokens[self.prev_index..]) + pub fn decode_rest(&self) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } } pub fn decode_all(&self) -> Result { diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index a7b4c21b..e0ecee7b 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -6,18 +6,18 @@ use std::sync::Arc; #[derive(Debug, Clone, PartialEq)] pub struct Config { - vocab_size: usize, - hidden_size: usize, - intermediate_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - num_key_value_heads: usize, - hidden_act: Activation, - max_position_embeddings: usize, - rms_norm_eps: f64, - rope_theta: f64, - sliding_window: usize, - use_flash_attn: bool, + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) rms_norm_eps: f64, + pub(crate) rope_theta: f64, + pub(crate) sliding_window: usize, + pub(crate) use_flash_attn: bool, } impl Config { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 15d884c6..b1544579 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod llama; pub mod mistral; pub mod mixformer; pub mod quantized_llama; +pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_t5; pub mod segment_anything; diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs new file mode 100644 index 00000000..171e7440 --- /dev/null +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -0,0 +1,364 @@ +use crate::models::quantized_t5::Embedding; +use crate::models::with_tracing::QMatMul; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::Activation; +use std::sync::Arc; + +pub use crate::models::mistral::Config; + +#[derive(Debug)] +struct Linear { + weight: QMatMul, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result { + x.apply(&self.weight) + } +} + +fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight }) +} + +#[derive(Debug)] +struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let inner = candle_nn::RmsNorm::new(weight, eps); + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Debug)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +impl RotaryEmbedding { + fn new(cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + kv_cache: None, + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?; + let value_states = self.repeat_kv(value_states)?; + + let attn_output = { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug)] +pub struct Model { + embed_tokens: Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: Linear, + sliding_window: usize, + device: Device, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.sliding_window, + device: vb.device().clone(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + // Sliding window mask? + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + self.sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(DType::F32) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } +}