Apply dilations in the encodec model. (#1772)

* Apply dilations in the encodec model.

* Add some encoding bits.
This commit is contained in:
Laurent Mazare
2024-02-27 23:26:35 +01:00
committed by GitHub
parent 0c49e95dfb
commit 15e8644149

View File

@ -226,6 +226,7 @@ pub struct EuclideanCodebook {
cluster_size: Tensor,
embed: candle_nn::Embedding,
embed_avg: Tensor,
c2: Tensor,
}
impl EuclideanCodebook {
@ -234,15 +235,36 @@ impl EuclideanCodebook {
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
embed_avg,
c2,
})
}
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
codes.reshape(target_shape)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.forward(embed_ind)?;
Ok(quantize)
@ -260,6 +282,10 @@ impl VectorQuantization {
Ok(Self { codebook })
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
self.codebook.encode_slow(xs)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
@ -281,6 +307,18 @@ impl ResidualVectorQuantizer {
Ok(Self { layers })
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut codes = Vec::with_capacity(self.layers.len());
let mut residual = xs.clone();
for layer in self.layers.iter() {
let indices = layer.encode(&residual)?;
let quantized = layer.decode(&indices)?;
residual = (residual - quantized)?;
codes.push(indices)
}
Tensor::stack(&codes, 0)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
let ncodes = codes.dim(0)?;
@ -380,6 +418,7 @@ impl EncodecConv1d {
out_c: usize,
kernel_size: usize,
stride: usize,
dilation: usize,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
@ -389,10 +428,9 @@ impl EncodecConv1d {
out_c,
kernel_size,
candle_nn::Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
dilation,
..Default::default()
},
vb.pp("conv"),
)?,
@ -463,20 +501,29 @@ pub struct EncodecResnetBlock {
}
impl EncodecResnetBlock {
pub fn new(dim: usize, dilations: &[usize], cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(
dim: usize,
(dilation1, dilation2): (usize, usize),
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 =
EncodecConv1d::new(dim, h, cfg.residual_kernel_size, 1, cfg, layer.next())?;
let block_conv1 = EncodecConv1d::new(
dim,
h,
cfg.residual_kernel_size,
1,
dilation1,
cfg,
layer.next(),
)?;
layer.inc();
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, cfg, layer.next())?;
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::new(dim, dim, 1, 1, cfg, vb.pp("shortcut"))?;
let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
Some(conv)
} else {
None
@ -541,6 +588,7 @@ impl Encoder {
cfg.num_filters,
cfg.kernel_size,
1,
1,
cfg,
layer.next(),
)?;
@ -552,7 +600,7 @@ impl Encoder {
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
(cfg.dilation_growth_rate.pow(j), 1),
cfg,
layer.next(),
)?;
@ -564,6 +612,7 @@ impl Encoder {
current_scale * 2,
ratio * 2,
ratio,
1,
cfg,
layer.next(),
)?;
@ -577,6 +626,7 @@ impl Encoder {
cfg.hidden_size,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
@ -621,6 +671,7 @@ impl Decoder {
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
@ -641,7 +692,7 @@ impl Decoder {
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
(cfg.dilation_growth_rate.pow(j), 1),
cfg,
layer.next(),
)?;
@ -656,6 +707,7 @@ impl Decoder {
cfg.audio_channels,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
@ -700,12 +752,10 @@ impl Model {
})
}
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn encode(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
let codes = self.quantizer.encode(&xs)?;
codes.transpose(0, 1)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {