From 15e864414982b9f6de273ab8e984f700f58bbefe Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 27 Feb 2024 23:26:35 +0100 Subject: [PATCH] Apply dilations in the encodec model. (#1772) * Apply dilations in the encodec model. * Add some encoding bits. --- candle-transformers/src/models/encodec.rs | 88 ++++++++++++++++++----- 1 file changed, 69 insertions(+), 19 deletions(-) diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index 68f01d87..d3b26e1e 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -226,6 +226,7 @@ pub struct EuclideanCodebook { cluster_size: Tensor, embed: candle_nn::Embedding, embed_avg: Tensor, + c2: Tensor, } impl EuclideanCodebook { @@ -234,15 +235,36 @@ impl EuclideanCodebook { let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?; let e_shape = (cfg.codebook_size, cfg.codebook_dim()); let embed = vb.get(e_shape, "embed")?; + let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?; let embed_avg = vb.get(e_shape, "embed_avg")?; Ok(Self { inited, cluster_size, embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()), embed_avg, + c2, }) } + pub fn encode_slow(&self, xs: &Tensor) -> Result { + let mut target_shape = xs.dims().to_vec(); + target_shape.pop(); + let xs = xs.flatten_to(D::Minus2)?; + let _ = xs.dims2()?; + let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?; + let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?; + codes.reshape(target_shape) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let mut target_shape = xs.dims().to_vec(); + target_shape.pop(); + let xs = xs.flatten_to(D::Minus2)?; + let _ = xs.dims2()?; + let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?; + codes.reshape(target_shape) + } + pub fn decode(&self, embed_ind: &Tensor) -> Result { let quantize = self.embed.forward(embed_ind)?; Ok(quantize) @@ -260,6 +282,10 @@ impl VectorQuantization { Ok(Self { codebook }) } + pub fn encode(&self, xs: &Tensor) -> Result { + self.codebook.encode_slow(xs) + } + pub fn decode(&self, embed_ind: &Tensor) -> Result { let quantize = self.codebook.decode(embed_ind)?; let quantize = quantize.transpose(1, 2)?; @@ -281,6 +307,18 @@ impl ResidualVectorQuantizer { Ok(Self { layers }) } + pub fn encode(&self, xs: &Tensor) -> Result { + let mut codes = Vec::with_capacity(self.layers.len()); + let mut residual = xs.clone(); + for layer in self.layers.iter() { + let indices = layer.encode(&residual)?; + let quantized = layer.decode(&indices)?; + residual = (residual - quantized)?; + codes.push(indices) + } + Tensor::stack(&codes, 0) + } + pub fn decode(&self, codes: &Tensor) -> Result { let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?; let ncodes = codes.dim(0)?; @@ -380,6 +418,7 @@ impl EncodecConv1d { out_c: usize, kernel_size: usize, stride: usize, + dilation: usize, cfg: &Config, vb: VarBuilder, ) -> Result { @@ -389,10 +428,9 @@ impl EncodecConv1d { out_c, kernel_size, candle_nn::Conv1dConfig { - padding: 0, stride, - groups: 1, - dilation: 1, + dilation, + ..Default::default() }, vb.pp("conv"), )?, @@ -463,20 +501,29 @@ pub struct EncodecResnetBlock { } impl EncodecResnetBlock { - pub fn new(dim: usize, dilations: &[usize], cfg: &Config, vb: VarBuilder) -> Result { + pub fn new( + dim: usize, + (dilation1, dilation2): (usize, usize), + cfg: &Config, + vb: VarBuilder, + ) -> Result { let h = dim / cfg.compress; let mut layer = Layer::new(vb.pp("block")); - if dilations.len() != 2 { - candle::bail!("expected dilations of size 2") - } // TODO: Apply dilations! layer.inc(); - let block_conv1 = - EncodecConv1d::new(dim, h, cfg.residual_kernel_size, 1, cfg, layer.next())?; + let block_conv1 = EncodecConv1d::new( + dim, + h, + cfg.residual_kernel_size, + 1, + dilation1, + cfg, + layer.next(), + )?; layer.inc(); - let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, cfg, layer.next())?; + let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?; let shortcut = if cfg.use_conv_shortcut { - let conv = EncodecConv1d::new(dim, dim, 1, 1, cfg, vb.pp("shortcut"))?; + let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?; Some(conv) } else { None @@ -541,6 +588,7 @@ impl Encoder { cfg.num_filters, cfg.kernel_size, 1, + 1, cfg, layer.next(), )?; @@ -552,7 +600,7 @@ impl Encoder { for j in 0..(cfg.num_residual_layers as u32) { let resnet = EncodecResnetBlock::new( current_scale, - &[cfg.dilation_growth_rate.pow(j), 1], + (cfg.dilation_growth_rate.pow(j), 1), cfg, layer.next(), )?; @@ -564,6 +612,7 @@ impl Encoder { current_scale * 2, ratio * 2, ratio, + 1, cfg, layer.next(), )?; @@ -577,6 +626,7 @@ impl Encoder { cfg.hidden_size, cfg.last_kernel_size, 1, + 1, cfg, layer.next(), )?; @@ -621,6 +671,7 @@ impl Decoder { cfg.num_filters * scaling, cfg.last_kernel_size, 1, + 1, cfg, layer.next(), )?; @@ -641,7 +692,7 @@ impl Decoder { for j in 0..(cfg.num_residual_layers as u32) { let resnet = EncodecResnetBlock::new( current_scale / 2, - &[cfg.dilation_growth_rate.pow(j), 1], + (cfg.dilation_growth_rate.pow(j), 1), cfg, layer.next(), )?; @@ -656,6 +707,7 @@ impl Decoder { cfg.audio_channels, cfg.last_kernel_size, 1, + 1, cfg, layer.next(), )?; @@ -700,12 +752,10 @@ impl Model { }) } - pub fn forward(&self, _xs: &Tensor) -> Result { - todo!() - } - - pub fn encode(&self, _xs: &Tensor) -> Result { - todo!() + pub fn encode(&self, xs: &Tensor) -> Result { + let xs = self.encoder.forward(xs)?; + let codes = self.quantizer.encode(&xs)?; + codes.transpose(0, 1) } pub fn decode(&self, codes: &Tensor) -> Result {