diff --git a/candle-examples/examples/quantized-gemma/README.md b/candle-examples/examples/quantized-gemma/README.md new file mode 100644 index 00000000..aa65d978 --- /dev/null +++ b/candle-examples/examples/quantized-gemma/README.md @@ -0,0 +1,18 @@ +# candle-quantized-gemma + +Candle implementation of quantized Gemma. + +## Running an example + +```bash +$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. " + +> ```python +> def fibonacci(n): +> """Calculates the nth Fibonacci number using recursion.""" +> if n <= 1: +> return n +> else: +> return fibonacci(n-1) + fibonacci(n-2 +> ``` +``` \ No newline at end of file diff --git a/candle-examples/examples/quantized-gemma/main.rs b/candle-examples/examples/quantized-gemma/main.rs new file mode 100644 index 00000000..543acde5 --- /dev/null +++ b/candle-examples/examples/quantized-gemma/main.rs @@ -0,0 +1,344 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::quantized::gguf_file; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_gemma3::ModelWeights; + +const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num"; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "gemma3-4b-it")] + Gemma3_4bIt, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGUF file to load, typically a .gguf file generated by quantization + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "gemma3-4b-it")] + which: Which, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = "google/gemma-3-4b-it"; + println!("DEBUG: Downloading tokenizer from {}", repo); + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path); + let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?; + + Ok(tokenizer) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename) = match self.which { + Which::Gemma3_4bIt => ( + "google/gemma-3-4b-it-qat-q4_0-gguf", + "gemma-3-4b-it-q4_0.gguf", + ), + }; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + "main".to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{}B", size_in_bytes) + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +#[derive(Debug)] +enum Prompt { + Interactive, + Chat, + One(String), +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + ModelWeights::from_gguf(model, &mut file, &device)? + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + + let mut tos = TokenOutputStream::new(tokenizer); + println!( + "DEBUG: Tokenizer vocabulary size: {}", + tos.tokenizer().get_vocab(true).len() + ); + + let prompt = match args.prompt.as_deref() { + Some("chat") => Prompt::Chat, + Some("interactive") => Prompt::Interactive, + Some(s) => Prompt::One(s.to_string()), + None => Prompt::One(DEFAULT_PROMPT.to_string()), + }; + + let mut pre_prompt_tokens = vec![]; + for _ in 0.. { + let prompt_str = match &prompt { + Prompt::One(prompt) => prompt.clone(), + Prompt::Interactive | Prompt::Chat => { + print!("> "); + std::io::stdout().flush()?; + let mut prompt = String::new(); + std::io::stdin().read_line(&mut prompt)?; + if prompt.ends_with('\n') { + prompt.pop(); + if prompt.ends_with('\r') { + prompt.pop(); + } + } + // Format for Gemma 3 chat/instruction format + format!("user\n{prompt}\n\nmodel\n") + } + }; + print!("{}", &prompt_str); + + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat(); + + let to_sample = args.sample_len.saturating_sub(1); + let max_seq_len = 8192; // Gemma 3 context length + let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 { + let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len; + prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec() + } else { + prompt_tokens + }; + let mut all_tokens = vec![]; + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in prompt_tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + // For Gemma 3, use the correct end of sequence token + let eos_token = *tos + .tokenizer() + .get_vocab(true) + .get("") + .unwrap(); + + let start_post_prompt = std::time::Instant::now(); + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + prompt_tokens.len(), + prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + + match prompt { + Prompt::One(_) => break, + Prompt::Interactive => {} + Prompt::Chat => { + pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat() + } + } + } + + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index bdb8d267..1ac75e33 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -79,6 +79,7 @@ pub mod phi3; pub mod pixtral; pub mod quantized_blip; pub mod quantized_blip_text; +pub mod quantized_gemma3; pub mod quantized_llama; pub mod quantized_llama2_c; pub mod quantized_metavoice; diff --git a/candle-transformers/src/models/quantized_gemma3.rs b/candle-transformers/src/models/quantized_gemma3.rs new file mode 100644 index 00000000..b5cbdf89 --- /dev/null +++ b/candle-transformers/src/models/quantized_gemma3.rs @@ -0,0 +1,418 @@ +//! Gemma 3 model implementation with quantization support. +//! +//! Gemma 3 is a family of multimodal language models developed by Google. +//! This implementation provides quantization for reduced memory usage and faster inference. +//! +//! Key characteristics: +//! - Group-Query Attention (GQA) with specialized key-value heads +//! - RMSNorm for layer normalization +//! - Specialized attention patterns with separate normalization for Q/K/V +//! - Feed-forward network with SwiGLU activation +//! - Support for 2/3/4/8-bit quantization +//! +//! References: +//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/) +//! + +use std::collections::HashMap; + +use crate::quantized_nn::RmsNorm; +use candle::quantized::gguf_file; +use candle::quantized::QTensor; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_nn::{Embedding, Module}; + +pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window + +#[derive(Debug, Clone)] +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Result { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + feed_forward_gate: QMatMul, // ffn_gate in GGUF + feed_forward_up: QMatMul, // ffn_up in GGUF + feed_forward_down: QMatMul, // ffn_down in GGUF +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let gate = self.feed_forward_gate.forward(xs)?; + let up = self.feed_forward_up.forward(xs)?; + let silu = candle_nn::ops::silu(&gate)?; + let gated = (silu * up)?; + self.feed_forward_down.forward(&gated) + } +} + +#[derive(Debug, Clone)] +pub struct LayerWeights { + // Attention components + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_wo: QMatMul, + + // Specialized normalization for Q and K + attention_q_norm: RmsNorm, + attention_k_norm: RmsNorm, + + // Layer normalization + attention_norm: RmsNorm, // Applied before attention + post_attention_norm: RmsNorm, // Applied after attention + ffn_norm: RmsNorm, // Applied before feedforward + post_ffn_norm: RmsNorm, // Applied after feedforward + + // Feed-forward network + mlp: Mlp, + + // Attention parameters + n_head: usize, // Number of query heads + n_kv_head: usize, // Number of key-value heads + head_dim: usize, // Dimension of each head + q_dim: usize, // Total dimension for queries + + // Rotary embedding + cos: Tensor, + sin: Tensor, + neg_inf: Tensor, + + // Cache + pub kv_cache: Option<(Tensor, Tensor)>, + + // Tracing + span_attn: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + index_pos: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, _) = x.dims3()?; + + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.attention_q_norm.forward(&q.contiguous()?)?; + let k = self.attention_k_norm.forward(&k.contiguous()?)?; + + let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); // update cache + + // Repeat KV for GQA + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; + + // Scaled Dot-Product Attention + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + if let Some(mask) = mask { + let mask = mask.broadcast_as(attn_weights.shape())?; + attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?; + } + + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.q_dim))?; + + self.attention_wo.forward(&attn_output) + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + tok_embeddings: Embedding, + embedding_length: usize, + pub layers: Vec, + norm: RmsNorm, + output: QMatMul, + masks: HashMap, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + ) -> Result { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("gemma3.block_count")?.to_u32()? as usize; + let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize; + let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize; + let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize; + let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + + let rope_freq_base = md_get("gemma3.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(1000000f32); + + // Compute the dimensions for queries, keys, and values + // These are the total dimensions when projected across all heads + let q_dim = head_count * key_length; + + // Precompute rotary embeddings + let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + // Load token embeddings and output projection + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let norm = RmsNorm::from_qtensor( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = match ct.tensor(reader, "output.weight", device) { + Ok(tensor) => tensor, + Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist + }; + + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + + let attention_q_norm = RmsNorm::from_qtensor( + ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?, + rms_norm_eps, + )?; + + let attention_k_norm = RmsNorm::from_qtensor( + ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?, + rms_norm_eps, + )?; + + let attention_norm = RmsNorm::from_qtensor( + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + rms_norm_eps, + )?; + + let post_attention_norm = RmsNorm::from_qtensor( + ct.tensor( + reader, + &format!("{prefix}.post_attention_norm.weight"), + device, + )?, + rms_norm_eps, + )?; + + let ffn_norm = RmsNorm::from_qtensor( + ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + rms_norm_eps, + )?; + + let post_ffn_norm = RmsNorm::from_qtensor( + ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?, + rms_norm_eps, + )?; + + let feed_forward_gate = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_down = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + + let mlp = Mlp { + feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?, + feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?, + feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?, + }; + + // Tracing spans + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq)?, + attention_wk: QMatMul::from_qtensor(attention_wk)?, + attention_wv: QMatMul::from_qtensor(attention_wv)?, + attention_wo: QMatMul::from_qtensor(attention_wo)?, + attention_q_norm, + attention_k_norm, + attention_norm, + post_attention_norm, + ffn_norm, + post_ffn_norm, + mlp, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: key_length, + q_dim, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_mlp, + }) + } + + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + embedding_length, + layers, + norm, + output: QMatMul::from_qtensor(output)?, + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, x.device())?) + }; + let _enter = self.span.enter(); + + let mut layer_in = self.tok_embeddings.forward(x)?; + layer_in = (layer_in * (self.embedding_length as f64).sqrt())?; + + for layer in self.layers.iter_mut() { + // Attention block + let residual = &layer_in; + let x = layer.attention_norm.forward(&layer_in)?; + let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?; + let x = layer.post_attention_norm.forward(&x)?; + let x = (x + residual)?; + + // Feed-forward block + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let x = layer.mlp.forward(&x)?; + let x = layer.post_ffn_norm.forward(&x)?; + let x = (x + residual)?; + drop(_enter); + + layer_in = x; + } + + let _enter = self.span_output.enter(); + + let x = layer_in.i((.., seq_len - 1, ..))?; + let x = self.norm.forward(&x)?; + let output = self.output.forward(&x)?; + + Ok(output) + } +}