mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Encodec forward pass (#153)
* Sketch the forward pass for encodec. * Forward pass for the encodec resnet block. * Encodec decoding.
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
|
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::Tensor;
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
|
||||||
// Encodec Model
|
// Encodec Model
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||||
@ -140,6 +140,11 @@ impl EncodecEuclideanCodebook {
|
|||||||
embed_avg,
|
embed_avg,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
|
let quantize = Tensor::embedding(embed_ind, &self.embed)?;
|
||||||
|
Ok(quantize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -152,6 +157,12 @@ impl EncodecVectorQuantization {
|
|||||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||||
Ok(Self { codebook })
|
Ok(Self { codebook })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
|
let quantize = self.codebook.decode(embed_ind)?;
|
||||||
|
let quantize = quantize.transpose(1, 2)?;
|
||||||
|
Ok(quantize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -167,6 +178,22 @@ impl EncodecResidualVectorQuantizer {
|
|||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self { layers })
|
Ok(Self { layers })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut quantized_out = Tensor::zeros((), DType::F32, &codes.device())?;
|
||||||
|
if codes.dim(0)? != self.layers.len() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"codes shape {:?} does not match the number of quantization layers {}",
|
||||||
|
codes.shape(),
|
||||||
|
self.layers.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for (i, layer) in self.layers.iter().enumerate() {
|
||||||
|
let quantized = layer.decode(&codes.i(i)?)?;
|
||||||
|
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||||
|
}
|
||||||
|
Ok(quantized_out)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||||
@ -188,6 +215,10 @@ impl EncodecLSTM {
|
|||||||
}
|
}
|
||||||
Ok(Self { layers })
|
Ok(Self { layers })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -216,10 +247,15 @@ impl EncodecConvTranspose1d {
|
|||||||
bias,
|
bias,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct EncodecConv1d {
|
struct EncodecConv1d {
|
||||||
|
causal: bool,
|
||||||
conv: Conv1d,
|
conv: Conv1d,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,7 +284,17 @@ impl EncodecConv1d {
|
|||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
Ok(Self { conv })
|
Ok(Self {
|
||||||
|
causal: cfg.use_causal_conv,
|
||||||
|
conv,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
// TODO: padding, depending on causal.
|
||||||
|
let xs = self.conv.forward(xs)?;
|
||||||
|
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
||||||
|
Ok(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,6 +330,19 @@ impl EncodecResnetBlock {
|
|||||||
shortcut,
|
shortcut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let residual = xs.clone();
|
||||||
|
let xs = xs.elu(1.)?;
|
||||||
|
let xs = self.block_conv1.forward(&xs)?;
|
||||||
|
let xs = xs.elu(1.)?;
|
||||||
|
let xs = self.block_conv2.forward(&xs)?;
|
||||||
|
let xs = match &self.shortcut {
|
||||||
|
None => (xs + residual)?,
|
||||||
|
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||||
|
};
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Layer<'a> {
|
struct Layer<'a> {
|
||||||
@ -369,6 +428,10 @@ impl EncodecEncoder {
|
|||||||
final_lstm,
|
final_lstm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -433,6 +496,10 @@ impl EncodecDecoder {
|
|||||||
final_conv,
|
final_conv,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -453,4 +520,8 @@ impl EncodecModel {
|
|||||||
quantizer,
|
quantizer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -206,6 +206,8 @@ impl T5Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
// TODO: Apply the mask(s)?
|
||||||
|
// TODO: kv caching.
|
||||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||||
let q = self.q.forward(xs)?;
|
let q = self.q.forward(xs)?;
|
||||||
let k = self.k.forward(xs)?;
|
let k = self.k.forward(xs)?;
|
||||||
@ -220,7 +222,7 @@ impl T5Attention {
|
|||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
let scores = q.matmul(&k.t()?)?;
|
let scores = q.matmul(&k.t()?)?;
|
||||||
// position_bias_masked
|
// TODO: position_bias_masked
|
||||||
let attn_weights = scores.softmax(D::Minus1)?;
|
let attn_weights = scores.softmax(D::Minus1)?;
|
||||||
let attn_output = attn_weights.matmul(&v)?;
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
let attn_output = self.o.forward(&attn_output)?;
|
let attn_output = self.o.forward(&attn_output)?;
|
||||||
@ -309,7 +311,6 @@ impl T5Block {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct T5Stack {
|
struct T5Stack {
|
||||||
// TODO: Add embed_tokens if needed (shared embedding layer).
|
|
||||||
block: Vec<T5Block>,
|
block: Vec<T5Block>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
final_layer_norm: T5LayerNorm,
|
final_layer_norm: T5LayerNorm,
|
||||||
|
Reference in New Issue
Block a user