mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Apply dilations in the encodec model. (#1772)
* Apply dilations in the encodec model. * Add some encoding bits.
This commit is contained in:
@ -226,6 +226,7 @@ pub struct EuclideanCodebook {
|
|||||||
cluster_size: Tensor,
|
cluster_size: Tensor,
|
||||||
embed: candle_nn::Embedding,
|
embed: candle_nn::Embedding,
|
||||||
embed_avg: Tensor,
|
embed_avg: Tensor,
|
||||||
|
c2: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EuclideanCodebook {
|
impl EuclideanCodebook {
|
||||||
@ -234,15 +235,36 @@ impl EuclideanCodebook {
|
|||||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||||
let embed = vb.get(e_shape, "embed")?;
|
let embed = vb.get(e_shape, "embed")?;
|
||||||
|
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
|
||||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inited,
|
inited,
|
||||||
cluster_size,
|
cluster_size,
|
||||||
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
|
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
|
||||||
embed_avg,
|
embed_avg,
|
||||||
|
c2,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut target_shape = xs.dims().to_vec();
|
||||||
|
target_shape.pop();
|
||||||
|
let xs = xs.flatten_to(D::Minus2)?;
|
||||||
|
let _ = xs.dims2()?;
|
||||||
|
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
|
||||||
|
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
|
||||||
|
codes.reshape(target_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut target_shape = xs.dims().to_vec();
|
||||||
|
target_shape.pop();
|
||||||
|
let xs = xs.flatten_to(D::Minus2)?;
|
||||||
|
let _ = xs.dims2()?;
|
||||||
|
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
|
||||||
|
codes.reshape(target_shape)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
let quantize = self.embed.forward(embed_ind)?;
|
let quantize = self.embed.forward(embed_ind)?;
|
||||||
Ok(quantize)
|
Ok(quantize)
|
||||||
@ -260,6 +282,10 @@ impl VectorQuantization {
|
|||||||
Ok(Self { codebook })
|
Ok(Self { codebook })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.codebook.encode_slow(xs)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
let quantize = self.codebook.decode(embed_ind)?;
|
let quantize = self.codebook.decode(embed_ind)?;
|
||||||
let quantize = quantize.transpose(1, 2)?;
|
let quantize = quantize.transpose(1, 2)?;
|
||||||
@ -281,6 +307,18 @@ impl ResidualVectorQuantizer {
|
|||||||
Ok(Self { layers })
|
Ok(Self { layers })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut codes = Vec::with_capacity(self.layers.len());
|
||||||
|
let mut residual = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
let indices = layer.encode(&residual)?;
|
||||||
|
let quantized = layer.decode(&indices)?;
|
||||||
|
residual = (residual - quantized)?;
|
||||||
|
codes.push(indices)
|
||||||
|
}
|
||||||
|
Tensor::stack(&codes, 0)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||||
let ncodes = codes.dim(0)?;
|
let ncodes = codes.dim(0)?;
|
||||||
@ -380,6 +418,7 @@ impl EncodecConv1d {
|
|||||||
out_c: usize,
|
out_c: usize,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
stride: usize,
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
cfg: &Config,
|
cfg: &Config,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
@ -389,10 +428,9 @@ impl EncodecConv1d {
|
|||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
candle_nn::Conv1dConfig {
|
candle_nn::Conv1dConfig {
|
||||||
padding: 0,
|
|
||||||
stride,
|
stride,
|
||||||
groups: 1,
|
dilation,
|
||||||
dilation: 1,
|
..Default::default()
|
||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
@ -463,20 +501,29 @@ pub struct EncodecResnetBlock {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecResnetBlock {
|
impl EncodecResnetBlock {
|
||||||
pub fn new(dim: usize, dilations: &[usize], cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(
|
||||||
|
dim: usize,
|
||||||
|
(dilation1, dilation2): (usize, usize),
|
||||||
|
cfg: &Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
let h = dim / cfg.compress;
|
let h = dim / cfg.compress;
|
||||||
let mut layer = Layer::new(vb.pp("block"));
|
let mut layer = Layer::new(vb.pp("block"));
|
||||||
if dilations.len() != 2 {
|
|
||||||
candle::bail!("expected dilations of size 2")
|
|
||||||
}
|
|
||||||
// TODO: Apply dilations!
|
// TODO: Apply dilations!
|
||||||
layer.inc();
|
layer.inc();
|
||||||
let block_conv1 =
|
let block_conv1 = EncodecConv1d::new(
|
||||||
EncodecConv1d::new(dim, h, cfg.residual_kernel_size, 1, cfg, layer.next())?;
|
dim,
|
||||||
|
h,
|
||||||
|
cfg.residual_kernel_size,
|
||||||
|
1,
|
||||||
|
dilation1,
|
||||||
|
cfg,
|
||||||
|
layer.next(),
|
||||||
|
)?;
|
||||||
layer.inc();
|
layer.inc();
|
||||||
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, cfg, layer.next())?;
|
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
|
||||||
let shortcut = if cfg.use_conv_shortcut {
|
let shortcut = if cfg.use_conv_shortcut {
|
||||||
let conv = EncodecConv1d::new(dim, dim, 1, 1, cfg, vb.pp("shortcut"))?;
|
let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
|
||||||
Some(conv)
|
Some(conv)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -541,6 +588,7 @@ impl Encoder {
|
|||||||
cfg.num_filters,
|
cfg.num_filters,
|
||||||
cfg.kernel_size,
|
cfg.kernel_size,
|
||||||
1,
|
1,
|
||||||
|
1,
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -552,7 +600,7 @@ impl Encoder {
|
|||||||
for j in 0..(cfg.num_residual_layers as u32) {
|
for j in 0..(cfg.num_residual_layers as u32) {
|
||||||
let resnet = EncodecResnetBlock::new(
|
let resnet = EncodecResnetBlock::new(
|
||||||
current_scale,
|
current_scale,
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
(cfg.dilation_growth_rate.pow(j), 1),
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -564,6 +612,7 @@ impl Encoder {
|
|||||||
current_scale * 2,
|
current_scale * 2,
|
||||||
ratio * 2,
|
ratio * 2,
|
||||||
ratio,
|
ratio,
|
||||||
|
1,
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -577,6 +626,7 @@ impl Encoder {
|
|||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
|
1,
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -621,6 +671,7 @@ impl Decoder {
|
|||||||
cfg.num_filters * scaling,
|
cfg.num_filters * scaling,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
|
1,
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -641,7 +692,7 @@ impl Decoder {
|
|||||||
for j in 0..(cfg.num_residual_layers as u32) {
|
for j in 0..(cfg.num_residual_layers as u32) {
|
||||||
let resnet = EncodecResnetBlock::new(
|
let resnet = EncodecResnetBlock::new(
|
||||||
current_scale / 2,
|
current_scale / 2,
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
(cfg.dilation_growth_rate.pow(j), 1),
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -656,6 +707,7 @@ impl Decoder {
|
|||||||
cfg.audio_channels,
|
cfg.audio_channels,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
|
1,
|
||||||
cfg,
|
cfg,
|
||||||
layer.next(),
|
layer.next(),
|
||||||
)?;
|
)?;
|
||||||
@ -700,12 +752,10 @@ impl Model {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
let xs = self.encoder.forward(xs)?;
|
||||||
}
|
let codes = self.quantizer.encode(&xs)?;
|
||||||
|
codes.transpose(0, 1)
|
||||||
pub fn encode(&self, _xs: &Tensor) -> Result<Tensor> {
|
|
||||||
todo!()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user