Apply dilations in the encodec model. (#1772)

* Apply dilations in the encodec model.

* Add some encoding bits.
This commit is contained in:
Laurent Mazare
2024-02-27 23:26:35 +01:00
committed by GitHub
parent 0c49e95dfb
commit 15e8644149

View File

@ -226,6 +226,7 @@ pub struct EuclideanCodebook {
cluster_size: Tensor, cluster_size: Tensor,
embed: candle_nn::Embedding, embed: candle_nn::Embedding,
embed_avg: Tensor, embed_avg: Tensor,
c2: Tensor,
} }
impl EuclideanCodebook { impl EuclideanCodebook {
@ -234,15 +235,36 @@ impl EuclideanCodebook {
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?; let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim()); let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?; let embed = vb.get(e_shape, "embed")?;
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
let embed_avg = vb.get(e_shape, "embed_avg")?; let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self { Ok(Self {
inited, inited,
cluster_size, cluster_size,
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()), embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
embed_avg, embed_avg,
c2,
}) })
} }
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
codes.reshape(target_shape)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.forward(embed_ind)?; let quantize = self.embed.forward(embed_ind)?;
Ok(quantize) Ok(quantize)
@ -260,6 +282,10 @@ impl VectorQuantization {
Ok(Self { codebook }) Ok(Self { codebook })
} }
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
self.codebook.encode_slow(xs)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?; let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?; let quantize = quantize.transpose(1, 2)?;
@ -281,6 +307,18 @@ impl ResidualVectorQuantizer {
Ok(Self { layers }) Ok(Self { layers })
} }
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut codes = Vec::with_capacity(self.layers.len());
let mut residual = xs.clone();
for layer in self.layers.iter() {
let indices = layer.encode(&residual)?;
let quantized = layer.decode(&indices)?;
residual = (residual - quantized)?;
codes.push(indices)
}
Tensor::stack(&codes, 0)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> { pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?; let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
let ncodes = codes.dim(0)?; let ncodes = codes.dim(0)?;
@ -380,6 +418,7 @@ impl EncodecConv1d {
out_c: usize, out_c: usize,
kernel_size: usize, kernel_size: usize,
stride: usize, stride: usize,
dilation: usize,
cfg: &Config, cfg: &Config,
vb: VarBuilder, vb: VarBuilder,
) -> Result<Self> { ) -> Result<Self> {
@ -389,10 +428,9 @@ impl EncodecConv1d {
out_c, out_c,
kernel_size, kernel_size,
candle_nn::Conv1dConfig { candle_nn::Conv1dConfig {
padding: 0,
stride, stride,
groups: 1, dilation,
dilation: 1, ..Default::default()
}, },
vb.pp("conv"), vb.pp("conv"),
)?, )?,
@ -463,20 +501,29 @@ pub struct EncodecResnetBlock {
} }
impl EncodecResnetBlock { impl EncodecResnetBlock {
pub fn new(dim: usize, dilations: &[usize], cfg: &Config, vb: VarBuilder) -> Result<Self> { pub fn new(
dim: usize,
(dilation1, dilation2): (usize, usize),
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let h = dim / cfg.compress; let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block")); let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations! // TODO: Apply dilations!
layer.inc(); layer.inc();
let block_conv1 = let block_conv1 = EncodecConv1d::new(
EncodecConv1d::new(dim, h, cfg.residual_kernel_size, 1, cfg, layer.next())?; dim,
h,
cfg.residual_kernel_size,
1,
dilation1,
cfg,
layer.next(),
)?;
layer.inc(); layer.inc();
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, cfg, layer.next())?; let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
let shortcut = if cfg.use_conv_shortcut { let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::new(dim, dim, 1, 1, cfg, vb.pp("shortcut"))?; let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
Some(conv) Some(conv)
} else { } else {
None None
@ -541,6 +588,7 @@ impl Encoder {
cfg.num_filters, cfg.num_filters,
cfg.kernel_size, cfg.kernel_size,
1, 1,
1,
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -552,7 +600,7 @@ impl Encoder {
for j in 0..(cfg.num_residual_layers as u32) { for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new( let resnet = EncodecResnetBlock::new(
current_scale, current_scale,
&[cfg.dilation_growth_rate.pow(j), 1], (cfg.dilation_growth_rate.pow(j), 1),
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -564,6 +612,7 @@ impl Encoder {
current_scale * 2, current_scale * 2,
ratio * 2, ratio * 2,
ratio, ratio,
1,
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -577,6 +626,7 @@ impl Encoder {
cfg.hidden_size, cfg.hidden_size,
cfg.last_kernel_size, cfg.last_kernel_size,
1, 1,
1,
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -621,6 +671,7 @@ impl Decoder {
cfg.num_filters * scaling, cfg.num_filters * scaling,
cfg.last_kernel_size, cfg.last_kernel_size,
1, 1,
1,
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -641,7 +692,7 @@ impl Decoder {
for j in 0..(cfg.num_residual_layers as u32) { for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new( let resnet = EncodecResnetBlock::new(
current_scale / 2, current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1], (cfg.dilation_growth_rate.pow(j), 1),
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -656,6 +707,7 @@ impl Decoder {
cfg.audio_channels, cfg.audio_channels,
cfg.last_kernel_size, cfg.last_kernel_size,
1, 1,
1,
cfg, cfg,
layer.next(), layer.next(),
)?; )?;
@ -700,12 +752,10 @@ impl Model {
}) })
} }
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> { pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
todo!() let xs = self.encoder.forward(xs)?;
} let codes = self.quantizer.encode(&xs)?;
codes.transpose(0, 1)
pub fn encode(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
} }
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> { pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {