implemented quantized-gemma3 (#2902)

* implemented quantized-gemma, inference not working

* Fixed a few modeling bugs: outputing the correct tokens for a few iterations then garbage

* lint

* clippy

* quantized-gemma3 example working

* added readme

* clippy
This commit is contained in:
Kyle Birnbaum
2025-04-18 22:46:41 -07:00
committed by GitHub
parent 21055b5697
commit b2904a830b
4 changed files with 781 additions and 0 deletions

View File

@ -0,0 +1,18 @@
# candle-quantized-gemma
Candle implementation of quantized Gemma.
## Running an example
```bash
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
> ```python
> def fibonacci(n):
> """Calculates the nth Fibonacci number using recursion."""
> if n <= 1:
> return n
> else:
> return fibonacci(n-1) + fibonacci(n-2
> ```
```

View File

@ -0,0 +1,344 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_gemma3::ModelWeights;
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "gemma3-4b-it")]
Gemma3_4bIt,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by quantization
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "gemma3-4b-it")]
which: Which,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = "google/gemma-3-4b-it";
println!("DEBUG: Downloading tokenizer from {}", repo);
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
Ok(tokenizer)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::Gemma3_4bIt => (
"google/gemma-3-4b-it-qat-q4_0-gguf",
"gemma-3-4b-it-q4_0.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
"main".to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
#[derive(Debug)]
enum Prompt {
Interactive,
Chat,
One(String),
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
println!(
"DEBUG: Tokenizer vocabulary size: {}",
tos.tokenizer().get_vocab(true).len()
);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
Some(s) => Prompt::One(s.to_string()),
None => Prompt::One(DEFAULT_PROMPT.to_string()),
};
let mut pre_prompt_tokens = vec![];
for _ in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
// Format for Gemma 3 chat/instruction format
format!("<start_of_turn>user\n{prompt}\n<end_of_turn>\n<start_of_turn>model\n")
}
};
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
let to_sample = args.sample_len.saturating_sub(1);
let max_seq_len = 8192; // Gemma 3 context length
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
// For Gemma 3, use the correct end of sequence token
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<end_of_turn>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
match prompt {
Prompt::One(_) => break,
Prompt::Interactive => {}
Prompt::Chat => {
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
}
}
}
Ok(())
}

View File

@ -79,6 +79,7 @@ pub mod phi3;
pub mod pixtral;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_gemma3;
pub mod quantized_llama;
pub mod quantized_llama2_c;
pub mod quantized_metavoice;

View File

@ -0,0 +1,418 @@
//! Gemma 3 model implementation with quantization support.
//!
//! Gemma 3 is a family of multimodal language models developed by Google.
//! This implementation provides quantization for reduced memory usage and faster inference.
//!
//! Key characteristics:
//! - Group-Query Attention (GQA) with specialized key-value heads
//! - RMSNorm for layer normalization
//! - Specialized attention patterns with separate normalization for Q/K/V
//! - Feed-forward network with SwiGLU activation
//! - Support for 2/3/4/8-bit quantization
//!
//! References:
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
//!
use std::collections::HashMap;
use crate::quantized_nn::RmsNorm;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
#[derive(Debug, Clone)]
struct Mlp {
feed_forward_gate: QMatMul, // ffn_gate in GGUF
feed_forward_up: QMatMul, // ffn_up in GGUF
feed_forward_down: QMatMul, // ffn_down in GGUF
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = self.feed_forward_gate.forward(xs)?;
let up = self.feed_forward_up.forward(xs)?;
let silu = candle_nn::ops::silu(&gate)?;
let gated = (silu * up)?;
self.feed_forward_down.forward(&gated)
}
}
#[derive(Debug, Clone)]
pub struct LayerWeights {
// Attention components
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
// Specialized normalization for Q and K
attention_q_norm: RmsNorm,
attention_k_norm: RmsNorm,
// Layer normalization
attention_norm: RmsNorm, // Applied before attention
post_attention_norm: RmsNorm, // Applied after attention
ffn_norm: RmsNorm, // Applied before feedforward
post_ffn_norm: RmsNorm, // Applied after feedforward
// Feed-forward network
mlp: Mlp,
// Attention parameters
n_head: usize, // Number of query heads
n_kv_head: usize, // Number of key-value heads
head_dim: usize, // Dimension of each head
q_dim: usize, // Total dimension for queries
// Rotary embedding
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
// Cache
pub kv_cache: Option<(Tensor, Tensor)>,
// Tracing
span_attn: tracing::Span,
span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}
impl LayerWeights {
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
index_pos: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
fn forward_attn(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
index_pos: usize,
) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, _) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone())); // update cache
// Repeat KV for GQA
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
// Scaled Dot-Product Attention
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(mask) = mask {
let mask = mask.broadcast_as(attn_weights.shape())?;
attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?;
}
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.q_dim))?;
self.attention_wo.forward(&attn_output)
}
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
embedding_length: usize,
pub layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
}
fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("gemma3.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(1000000f32);
// Compute the dimensions for queries, keys, and values
// These are the total dimensions when projected across all heads
let q_dim = head_count * key_length;
// Precompute rotary embeddings
let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
// Load token embeddings and output projection
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = match ct.tensor(reader, "output.weight", device) {
Ok(tensor) => tensor,
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
};
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let attention_q_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_k_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
rms_norm_eps,
)?;
let post_attention_norm = RmsNorm::from_qtensor(
ct.tensor(
reader,
&format!("{prefix}.post_attention_norm.weight"),
device,
)?,
rms_norm_eps,
)?;
let ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
rms_norm_eps,
)?;
let post_ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
rms_norm_eps,
)?;
let feed_forward_gate =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let feed_forward_down =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let mlp = Mlp {
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
};
// Tracing spans
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_q_norm,
attention_k_norm,
attention_norm,
post_attention_norm,
ffn_norm,
post_ffn_norm,
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: key_length,
q_dim,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
embedding_length,
layers,
norm,
output: QMatMul::from_qtensor(output)?,
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = if seq_len == 1 {
None
} else {
Some(self.mask(seq_len, x.device())?)
};
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
for layer in self.layers.iter_mut() {
// Attention block
let residual = &layer_in;
let x = layer.attention_norm.forward(&layer_in)?;
let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
let x = layer.post_attention_norm.forward(&x)?;
let x = (x + residual)?;
// Feed-forward block
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = layer.post_ffn_norm.forward(&x)?;
let x = (x + residual)?;
drop(_enter);
layer_in = x;
}
let _enter = self.span_output.enter();
let x = layer_in.i((.., seq_len - 1, ..))?;
let x = self.norm.forward(&x)?;
let output = self.output.forward(&x)?;
Ok(output)
}
}