Add the forward pass for the T5 model. (#152)

* Add the forward pass for the T5 model.

* More t5 forward pass.
This commit is contained in:
Laurent Mazare
2023-07-12 22:02:40 +01:00
committed by GitHub
parent 465fc8c0c5
commit 6c75a98ad2

View File

@ -3,7 +3,8 @@
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::Tensor;
use candle::{DType, Tensor, D};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
@ -92,6 +93,18 @@ impl T5LayerNorm {
variance_epsilon: eps,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
}
}
#[derive(Debug)]
@ -114,6 +127,14 @@ impl T5DenseActDense {
act: HiddenAct::Relu,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.wi.forward(xs)?;
let xs = self.act.forward(&xs)?;
let xs = self.dropout.forward(&xs)?;
let xs = self.wo.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug)]
@ -136,6 +157,13 @@ impl T5LayerFF {
dropout,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.layer_norm.forward(xs)?;
let ys = self.dense_relu_dense.forward(&ys)?;
let xs = (xs + self.dropout.forward(&ys)?)?;
Ok(xs)
}
}
#[derive(Debug)]
@ -144,6 +172,8 @@ struct T5Attention {
k: Linear,
v: Linear,
o: Linear,
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
}
@ -169,9 +199,33 @@ impl T5Attention {
k,
v,
o,
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
let q = self.q.forward(xs)?;
let k = self.k.forward(xs)?;
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
// position_bias_masked
let attn_weights = scores.softmax(D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
}
}
#[derive(Debug)]
@ -193,6 +247,13 @@ impl T5LayerSelfAttention {
dropout,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let normed_xs = self.layer_norm.forward(xs)?;
let ys = self.self_attention.forward(&normed_xs)?;
let ys = (xs + ys)?;
Ok(ys)
}
}
#[derive(Debug)]
@ -202,6 +263,10 @@ impl T5LayerCrossAttention {
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
todo!()
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
@ -228,18 +293,31 @@ impl T5Block {
ff,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.self_attn.forward(xs)?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?;
// TODO: clamp for f16?
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok(xs)
}
}
#[derive(Debug)]
struct T5Stack {
// TODO: Add embed_tokens if needed (shared embedding layer).
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5Stack {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
@ -251,22 +329,42 @@ impl T5Stack {
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
block,
shared: shared.clone(),
final_layer_norm,
dropout,
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
let mut hidden_states = self.dropout.forward(&input_embeds)?;
for block in self.block.iter() {
hidden_states = block.forward(&hidden_states)?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?;
Ok(hidden_states)
}
}
#[derive(Debug)]
pub struct T5EncoderModel {
shared: Embedding,
shared: Arc<Embedding>,
encoder: T5Stack,
}
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let encoder = T5Stack::load(vb.pp("encoder"), cfg)?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
Ok(Self { shared, encoder })
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let encoder_outputs = self.encoder.forward(input_ids)?;
Ok(encoder_outputs)
}
}