mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Apply dilations in the encodec model. (#1772)
* Apply dilations in the encodec model. * Add some encoding bits.
This commit is contained in:
@ -226,6 +226,7 @@ pub struct EuclideanCodebook {
|
||||
cluster_size: Tensor,
|
||||
embed: candle_nn::Embedding,
|
||||
embed_avg: Tensor,
|
||||
c2: Tensor,
|
||||
}
|
||||
|
||||
impl EuclideanCodebook {
|
||||
@ -234,15 +235,36 @@ impl EuclideanCodebook {
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
|
||||
embed_avg,
|
||||
c2,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
|
||||
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.embed.forward(embed_ind)?;
|
||||
Ok(quantize)
|
||||
@ -260,6 +282,10 @@ impl VectorQuantization {
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.codebook.encode_slow(xs)
|
||||
}
|
||||
|
||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.codebook.decode(embed_ind)?;
|
||||
let quantize = quantize.transpose(1, 2)?;
|
||||
@ -281,6 +307,18 @@ impl ResidualVectorQuantizer {
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut codes = Vec::with_capacity(self.layers.len());
|
||||
let mut residual = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let indices = layer.encode(&residual)?;
|
||||
let quantized = layer.decode(&indices)?;
|
||||
residual = (residual - quantized)?;
|
||||
codes.push(indices)
|
||||
}
|
||||
Tensor::stack(&codes, 0)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
let ncodes = codes.dim(0)?;
|
||||
@ -380,6 +418,7 @@ impl EncodecConv1d {
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
@ -389,10 +428,9 @@ impl EncodecConv1d {
|
||||
out_c,
|
||||
kernel_size,
|
||||
candle_nn::Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
dilation,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
@ -463,20 +501,29 @@ pub struct EncodecResnetBlock {
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
pub fn new(dim: usize, dilations: &[usize], cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
(dilation1, dilation2): (usize, usize),
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
candle::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 =
|
||||
EncodecConv1d::new(dim, h, cfg.residual_kernel_size, 1, cfg, layer.next())?;
|
||||
let block_conv1 = EncodecConv1d::new(
|
||||
dim,
|
||||
h,
|
||||
cfg.residual_kernel_size,
|
||||
1,
|
||||
dilation1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, cfg, layer.next())?;
|
||||
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::new(dim, dim, 1, 1, cfg, vb.pp("shortcut"))?;
|
||||
let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
@ -541,6 +588,7 @@ impl Encoder {
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -552,7 +600,7 @@ impl Encoder {
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::new(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
(cfg.dilation_growth_rate.pow(j), 1),
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -564,6 +612,7 @@ impl Encoder {
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -577,6 +626,7 @@ impl Encoder {
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -621,6 +671,7 @@ impl Decoder {
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -641,7 +692,7 @@ impl Decoder {
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::new(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
(cfg.dilation_growth_rate.pow(j), 1),
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -656,6 +707,7 @@ impl Decoder {
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
@ -700,12 +752,10 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn encode(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
let codes = self.quantizer.encode(&xs)?;
|
||||
codes.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user