mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Encodec forward pass (#153)
* Sketch the forward pass for encodec. * Forward pass for the encodec resnet block. * Encodec decoding.
This commit is contained in:
@ -206,6 +206,8 @@ impl T5Attention {
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Apply the mask(s)?
|
||||
// TODO: kv caching.
|
||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
let q = self.q.forward(xs)?;
|
||||
let k = self.k.forward(xs)?;
|
||||
@ -220,7 +222,7 @@ impl T5Attention {
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
// position_bias_masked
|
||||
// TODO: position_bias_masked
|
||||
let attn_weights = scores.softmax(D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
@ -309,7 +311,6 @@ impl T5Block {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Stack {
|
||||
// TODO: Add embed_tokens if needed (shared embedding layer).
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
final_layer_norm: T5LayerNorm,
|
||||
|
Reference in New Issue
Block a user