diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index ad784224..23bf7f0d 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -3,7 +3,8 @@ use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder}; use anyhow::Result; -use candle::Tensor; +use candle::{DType, Tensor, D}; +use std::sync::Arc; #[derive(Debug, Clone, PartialEq)] pub struct Config { @@ -92,6 +93,18 @@ impl T5LayerNorm { variance_epsilon: eps, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let dtype = xs.dtype(); + let xs_f32 = xs.to_dtype(DType::F32)?; + let xs2_f32 = (&xs_f32 * &xs_f32)?; + let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?; + let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; + let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.to_dtype(dtype)?; + let xs = xs.broadcast_mul(&self.weight)?; + Ok(xs) + } } #[derive(Debug)] @@ -114,6 +127,14 @@ impl T5DenseActDense { act: HiddenAct::Relu, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.wi.forward(xs)?; + let xs = self.act.forward(&xs)?; + let xs = self.dropout.forward(&xs)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } } #[derive(Debug)] @@ -136,6 +157,13 @@ impl T5LayerFF { dropout, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let ys = self.layer_norm.forward(xs)?; + let ys = self.dense_relu_dense.forward(&ys)?; + let xs = (xs + self.dropout.forward(&ys)?)?; + Ok(xs) + } } #[derive(Debug)] @@ -144,6 +172,8 @@ struct T5Attention { k: Linear, v: Linear, o: Linear, + n_heads: usize, + d_kv: usize, relative_attention_bias: Option, } @@ -169,9 +199,33 @@ impl T5Attention { k, v, o, + n_heads: cfg.num_heads, + d_kv: cfg.d_kv, relative_attention_bias, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); + let q = self.q.forward(xs)?; + let k = self.k.forward(xs)?; + let v = self.v.forward(xs)?; + let q = q + .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .transpose(1, 2)?; + let scores = q.matmul(&k.t()?)?; + // position_bias_masked + let attn_weights = scores.softmax(D::Minus1)?; + let attn_output = attn_weights.matmul(&v)?; + let attn_output = self.o.forward(&attn_output)?; + Ok(attn_output) + } } #[derive(Debug)] @@ -193,6 +247,13 @@ impl T5LayerSelfAttention { dropout, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let normed_xs = self.layer_norm.forward(xs)?; + let ys = self.self_attention.forward(&normed_xs)?; + let ys = (xs + ys)?; + Ok(ys) + } } #[derive(Debug)] @@ -202,6 +263,10 @@ impl T5LayerCrossAttention { fn load(_vb: VarBuilder, _cfg: &Config) -> Result { todo!() } + + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } } #[derive(Debug)] @@ -228,18 +293,31 @@ impl T5Block { ff, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = self.self_attn.forward(xs)?; + // TODO: clamp for f16? + if let Some(cross_attn) = &self.cross_attn { + xs = cross_attn.forward(&xs)?; + // TODO: clamp for f16? + } + let xs = self.ff.forward(&xs)?; + // TODO: clamp for f16? + Ok(xs) + } } #[derive(Debug)] struct T5Stack { // TODO: Add embed_tokens if needed (shared embedding layer). block: Vec, + shared: Arc, final_layer_norm: T5LayerNorm, dropout: Dropout, } impl T5Stack { - fn load(vb: VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, shared: &Arc, cfg: &Config) -> Result { let block = (0..cfg.num_layers) .map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg)) .collect::>>()?; @@ -251,22 +329,42 @@ impl T5Stack { let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { block, + shared: shared.clone(), final_layer_norm, dropout, }) } + + fn forward(&self, input_ids: &Tensor) -> Result { + let input_embeds = self.shared.as_ref().forward(input_ids)?; + let (_b_sz, _seq_len) = input_embeds.shape().r2()?; + + let mut hidden_states = self.dropout.forward(&input_embeds)?; + for block in self.block.iter() { + hidden_states = block.forward(&hidden_states)? + } + let hidden_states = self.final_layer_norm.forward(&hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + Ok(hidden_states) + } } #[derive(Debug)] pub struct T5EncoderModel { - shared: Embedding, + shared: Arc, encoder: T5Stack, } impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; - let encoder = T5Stack::load(vb.pp("encoder"), cfg)?; + let shared = Arc::new(shared); + let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?; Ok(Self { shared, encoder }) } + + pub fn forward(&self, input_ids: &Tensor) -> Result { + let encoder_outputs = self.encoder.forward(input_ids)?; + Ok(encoder_outputs) + } }