Encodec forward pass (#153)

* Sketch the forward pass for encodec.

* Forward pass for the encodec resnet block.

* Encodec decoding.
This commit is contained in:
Laurent Mazare
2023-07-13 08:18:39 +01:00
committed by GitHub
parent 6c75a98ad2
commit a3663ce2f2
2 changed files with 76 additions and 4 deletions

View File

@ -206,6 +206,8 @@ impl T5Attention {
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: Apply the mask(s)?
// TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
let q = self.q.forward(xs)?;
let k = self.k.forward(xs)?;
@ -220,7 +222,7 @@ impl T5Attention {
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
// position_bias_masked
// TODO: position_bias_masked
let attn_weights = scores.softmax(D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
@ -309,7 +311,6 @@ impl T5Block {
#[derive(Debug)]
struct T5Stack {
// TODO: Add embed_tokens if needed (shared embedding layer).
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,