diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index ed8a66b7..f9b883fe 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -1,6 +1,6 @@ use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder}; use anyhow::Result; -use candle::Tensor; +use candle::{DType, IndexOp, Tensor}; // Encodec Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py @@ -140,6 +140,11 @@ impl EncodecEuclideanCodebook { embed_avg, }) } + + fn decode(&self, embed_ind: &Tensor) -> Result { + let quantize = Tensor::embedding(embed_ind, &self.embed)?; + Ok(quantize) + } } #[derive(Debug)] @@ -152,6 +157,12 @@ impl EncodecVectorQuantization { let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?; Ok(Self { codebook }) } + + fn decode(&self, embed_ind: &Tensor) -> Result { + let quantize = self.codebook.decode(embed_ind)?; + let quantize = quantize.transpose(1, 2)?; + Ok(quantize) + } } #[derive(Debug)] @@ -167,6 +178,22 @@ impl EncodecResidualVectorQuantizer { .collect::>>()?; Ok(Self { layers }) } + + fn decode(&self, codes: &Tensor) -> Result { + let mut quantized_out = Tensor::zeros((), DType::F32, &codes.device())?; + if codes.dim(0)? != self.layers.len() { + anyhow::bail!( + "codes shape {:?} does not match the number of quantization layers {}", + codes.shape(), + self.layers.len() + ) + } + for (i, layer) in self.layers.iter().enumerate() { + let quantized = layer.decode(&codes.i(i)?)?; + quantized_out = quantized.broadcast_add(&quantized_out)?; + } + Ok(quantized_out) + } } // https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226 @@ -188,6 +215,10 @@ impl EncodecLSTM { } Ok(Self { layers }) } + + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } } #[derive(Debug)] @@ -216,10 +247,15 @@ impl EncodecConvTranspose1d { bias, }) } + + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } } #[derive(Debug)] struct EncodecConv1d { + causal: bool, conv: Conv1d, } @@ -248,7 +284,17 @@ impl EncodecConv1d { vb.pp("conv"), )?, }; - Ok(Self { conv }) + Ok(Self { + causal: cfg.use_causal_conv, + conv, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + // TODO: padding, depending on causal. + let xs = self.conv.forward(xs)?; + // If we add support for NormType "time_group_norm", we should add some normalization here. + Ok(xs) } } @@ -284,6 +330,19 @@ impl EncodecResnetBlock { shortcut, }) } + + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs.clone(); + let xs = xs.elu(1.)?; + let xs = self.block_conv1.forward(&xs)?; + let xs = xs.elu(1.)?; + let xs = self.block_conv2.forward(&xs)?; + let xs = match &self.shortcut { + None => (xs + residual)?, + Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?, + }; + Ok(xs) + } } struct Layer<'a> { @@ -369,6 +428,10 @@ impl EncodecEncoder { final_lstm, }) } + + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } } #[derive(Debug)] @@ -433,6 +496,10 @@ impl EncodecDecoder { final_conv, }) } + + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } } #[derive(Debug)] @@ -453,4 +520,8 @@ impl EncodecModel { quantizer, }) } + + pub fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } } diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 23bf7f0d..0444f360 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -206,6 +206,8 @@ impl T5Attention { } fn forward(&self, xs: &Tensor) -> Result { + // TODO: Apply the mask(s)? + // TODO: kv caching. let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); let q = self.q.forward(xs)?; let k = self.k.forward(xs)?; @@ -220,7 +222,7 @@ impl T5Attention { .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? .transpose(1, 2)?; let scores = q.matmul(&k.t()?)?; - // position_bias_masked + // TODO: position_bias_masked let attn_weights = scores.softmax(D::Minus1)?; let attn_output = attn_weights.matmul(&v)?; let attn_output = self.o.forward(&attn_output)?; @@ -309,7 +311,6 @@ impl T5Block { #[derive(Debug)] struct T5Stack { - // TODO: Add embed_tokens if needed (shared embedding layer). block: Vec, shared: Arc, final_layer_norm: T5LayerNorm,