From 2619c4307fe02db031d3a41cfbed91b12b97df31 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 21 Sep 2023 11:13:39 +0100 Subject: [PATCH] Add a quantized version of the t5 model. (#921) --- .gitignore | 1 + candle-core/src/quantized/mod.rs | 2 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_t5.rs | 869 ++++++++++++++++++ candle-transformers/src/models/t5.rs | 2 +- 5 files changed, 873 insertions(+), 2 deletions(-) create mode 100644 candle-transformers/src/models/quantized_t5.rs diff --git a/.gitignore b/.gitignore index 2748d37e..d0a8c320 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ flamegraph.svg *.dylib *.so *.swp +*.swo trace-*.json candle-wasm-examples/*/build diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 5c2bb2b2..f627f0f6 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -229,7 +229,7 @@ impl QTensor { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct QMatMul(std::sync::Arc); impl QMatMul { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a20254d9..d783a2c6 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -5,6 +5,7 @@ pub mod efficientnet; pub mod falcon; pub mod llama; pub mod quantized_llama; +pub mod quantized_t5; pub mod segment_anything; pub mod stable_diffusion; pub mod t5; diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs new file mode 100644 index 00000000..c14500ba --- /dev/null +++ b/candle-transformers/src/models/quantized_t5.rs @@ -0,0 +1,869 @@ +// T5 Text Model, quantized version +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + +use candle::quantized::QTensor; +use candle::{DType, Device, Module, Result, Shape, Tensor, D}; +use candle_nn::Activation; +use serde::Deserialize; +use std::sync::Arc; + +// VarBuilder specialized for QTensors +pub struct VarBuilder { + data: Arc>>, + path: Vec, + device: Device, +} + +impl VarBuilder { + fn pp(&self, s: S) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + device: self.device.clone(), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } + + fn get>(&self, s: S, name: &str) -> Result> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => { + let shape = s.into(); + if qtensor.shape() != &shape { + candle::bail!( + "shape mismatch for {name}, got {:?}, expected {shape:?}", + qtensor.shape() + ) + } + Ok(qtensor.clone()) + } + } + } +} + +#[derive(Debug)] +struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?; + let inner = candle_nn::Embedding::new(embeddings, d2); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result { + let ws = vb.get((out_dim, in_dim), "weight")?; + let inner = candle::quantized::QMatMul::from_arc(ws); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } +} + +impl Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +impl std::fmt::Debug for QMatMul { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QMatMul") + } +} + +fn default_relative_attention_max_distance() -> usize { + 128 +} + +fn default_is_decoder() -> bool { + false +} + +fn default_use_cache() -> bool { + true +} + +fn default_tie_word_embeddings() -> bool { + true +} + +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + d_model: usize, + d_kv: usize, + d_ff: usize, + num_layers: usize, + num_decoder_layers: Option, + num_heads: usize, + relative_attention_num_buckets: usize, + #[serde(default = "default_relative_attention_max_distance")] + relative_attention_max_distance: usize, + dropout_rate: f64, + layer_norm_epsilon: f64, + initializer_factor: f64, + #[serde(default)] + feed_forward_proj: Activation, + #[serde(default = "default_tie_word_embeddings")] + tie_word_embeddings: bool, + #[serde(default = "default_is_decoder")] + is_decoder: bool, + is_encoder_decoder: bool, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 32128, + d_model: 512, + d_kv: 64, + d_ff: 2048, + num_layers: 6, + num_decoder_layers: None, + num_heads: 8, + relative_attention_num_buckets: 32, + relative_attention_max_distance: 128, + dropout_rate: 0.1, + layer_norm_epsilon: 1e-6, + initializer_factor: 1.0, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + is_decoder: false, + is_encoder_decoder: true, + use_cache: true, + pad_token_id: 0, + eos_token_id: 1, + } + } +} + +#[derive(Debug)] +struct T5LayerNorm { + weight: Tensor, + variance_epsilon: f64, + span: tracing::Span, +} + +impl T5LayerNorm { + fn load(h: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(h, "weight")?.dequantize(&vb.device)?; + Ok(Self { + weight, + variance_epsilon: eps, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), + }) + } +} + +impl Module for T5LayerNorm { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let dtype = xs.dtype(); + let xs_f32 = xs.to_dtype(DType::F32)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.to_dtype(dtype)?; + let xs = xs.broadcast_mul(&self.weight)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseActDense { + wi: QMatMul, + wo: QMatMul, + act: Activation, + span: tracing::Span, +} + +impl T5DenseActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi, + wo, + act: Activation::Relu, + span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"), + }) + } +} + +impl Module for T5DenseActDense { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let xs = self.wi.forward(xs)?; + let xs = self.act.forward(&xs)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseGatedActDense { + wi_0: QMatMul, + wi_1: QMatMul, + wo: QMatMul, + act: Activation, + span: tracing::Span, +} + +impl T5DenseGatedActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi_0, + wi_1, + wo, + act: Activation::NewGelu, + span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), + }) + } +} + +impl Module for T5DenseGatedActDense { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; + let hidden_linear = self.wi_1.forward(xs)?; + let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5LayerFF { + dense_act: Option, + gated_dense_act: Option, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerFF { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu { + ( + None, + Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), + ) + } else { + ( + Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?), + None, + ) + }; + Ok(Self { + dense_act, + gated_dense_act, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer-ff"), + }) + } +} + +impl Module for T5LayerFF { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let ys = self.layer_norm.forward(xs)?; + let ys = match &self.dense_act { + Some(dense_act) => dense_act.forward(&ys)?, + None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?, + }; + let xs = (xs + ys)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5Attention { + q: QMatMul, + k: QMatMul, + v: QMatMul, + o: QMatMul, + n_heads: usize, + d_kv: usize, + relative_attention_bias: Option, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_cache: tracing::Span, + span_mm: tracing::Span, + span_sm: tracing::Span, +} + +impl T5Attention { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result { + let inner_dim = cfg.num_heads * cfg.d_kv; + let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?; + let relative_attention_bias = if has_relative_attention_bias { + let emb = Embedding::new( + cfg.relative_attention_num_buckets, + cfg.num_heads, + vb.pp("relative_attention_bias"), + )?; + Some(emb) + } else { + None + }; + Ok(Self { + q, + k, + v, + o, + n_heads: cfg.num_heads, + d_kv: cfg.d_kv, + relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), + span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"), + span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"), + span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option)> { + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). + let _enter = self.span.enter(); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; + let q = self.q.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; + let q = q + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut k = k + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + + if self.use_cache { + let _enter = self.span_cache.enter(); + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; + // TODO: Use flash_attn. + let scores = { + let _enter = self.span_mm.enter(); + q.matmul(&k.t()?)? + }; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; + + let (scores, position_bias) = match position_bias { + Some(position_bias) => ( + scores.broadcast_add(position_bias)?, + Some(position_bias.clone()), + ), + None => match &self.relative_attention_bias { + None => (scores, None), + Some(relative_attention_bias) => { + // This only handles the bidirectional case. + let kv_len = k.dim(2)?; + let (q_start, q_end) = match self.use_cache { + true => ((kv_len - q_len) as u32, kv_len as u32), + false => (0_u32, kv_len as u32), + }; + let num_buckets = self.relative_attention_num_buckets as u32 / 2; + let max_exact = num_buckets / 2; + let relative_position = (q_start..q_end) + .map(|i| { + (0..kv_len as u32) + .map(|j| { + if i < j { + if j - i < max_exact { + j - i + num_buckets + } else { + let b = f32::log( + (j - i) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + u32::min( + max_exact + num_buckets + b as u32, + self.relative_attention_num_buckets as u32 - 1, + ) + } + } else if i - j < max_exact { + i - j + } else { + let b = f32::log( + (i - j) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + max_exact + b as u32 + } + }) + .collect::>() + }) + .collect::>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores.broadcast_add(&position_bias)?, Some(position_bias)) + // TODO: position_bias_masked? + } + }, + }; + + let attn_weights = { + let _enter = self.span_sm.enter(); + candle_nn::ops::softmax(&scores, D::Minus1)? + }; + let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.inner_dim))?; + let attn_output = self.o.forward(&attn_output)?; + Ok((attn_output, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug)] +struct T5LayerSelfAttention { + self_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerSelfAttention { + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + self_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option)> { + let _enter = self.span.enter(); + let normed_xs = self.layer_norm.forward(xs)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; + let ys = (xs + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerCrossAttention { + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "cross-attn"), + }) + } + + fn forward( + &mut self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option)> { + let _enter = self.span.enter(); + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5Block { + self_attn: T5LayerSelfAttention, + cross_attn: Option, + ff: T5LayerFF, + span: tracing::Span, +} + +impl T5Block { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result { + let vb = vb.pp("layer"); + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; + let cross_attn = if cfg.is_decoder { + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) + } else { + None + }; + let ff_i = if cross_attn.is_some() { 2 } else { 1 }; + let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?; + Ok(Self { + self_attn, + cross_attn, + ff, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Option)> { + let _enter = self.span.enter(); + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; + // TODO: clamp for f16? + if let Some(cross_attn) = &mut self.cross_attn { + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; + // TODO: clamp for f16? + } + let xs = self.ff.forward(&xs)?; + // TODO: clamp for f16? + Ok((xs, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } +} + +#[derive(Debug)] +struct T5Stack { + block: Vec, + shared: Arc, + final_layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5Stack { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc, cfg: &Config) -> Result { + let block = (0..cfg.num_layers) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg)) + .collect::>>()?; + let final_layer_norm = T5LayerNorm::load( + cfg.d_model, + cfg.layer_norm_epsilon, + vb.pp("final_layer_norm"), + )?; + Ok(Self { + block, + shared: shared.clone(), + final_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "stack"), + }) + } + + fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result { + let _enter = self.span.enter(); + let input_embeds = self.shared.as_ref().forward(input_ids)?; + let mut hidden_states = input_embeds; + let mut position_bias = None; + for block in self.block.iter_mut() { + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? + } + self.final_layer_norm.forward(&hidden_states) + } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} + +#[derive(Debug)] +pub struct T5EncoderModel { + encoder: T5Stack, + device: Device, + span: tracing::Span, +} + +impl T5EncoderModel { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; + Ok(Self { + encoder, + device: vb.device.clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + d_model: usize, + tie_word_embeddings: bool, + lm_head: Option, + shared: Arc, + device: Device, + span_decode: tracing::Span, + span_decode_head: tracing::Span, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + assert!(cfg.is_encoder_decoder); + let d_model = cfg.d_model; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; + + let tie_word_embeddings = cfg.tie_word_embeddings; + let lm_head = if tie_word_embeddings { + None + } else { + Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) + }; + + Ok(Self { + encoder, + decoder, + d_model, + tie_word_embeddings, + lm_head, + shared, + device: vb.device.clone(), + span_decode: tracing::span!(tracing::Level::TRACE, "decode"), + span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), + }) + } + + pub fn encode(&mut self, input_ids: &Tensor) -> Result { + self.encoder.forward(input_ids, None) + } + + pub fn decode( + &mut self, + decoder_input_ids: &Tensor, + encoder_output: &Tensor, + ) -> Result { + let _enter = self.span_decode.enter(); + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(encoder_output))?; + + let scaling_factor = if self.tie_word_embeddings { + // Rescale output before projecting on vocab + // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + (self.d_model as f64).sqrt() + } else { + 1.0 + }; + let sequence_output = ((decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?) + * scaling_factor)?; + let output = { + let _enter = self.span_decode_head.enter(); + match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + } + }; + + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) + } + + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result { + let encoder_output = self.encode(input_ids)?; + self.decode(decoder_input_ids, &encoder_output) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } +} diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 94cf5233..539ae89b 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,4 +1,4 @@ -// T5 Text Encoder +// T5 Text Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py use candle::{DType, Device, Module, Result, Tensor, D};