Add the SNAC audio tokenizer. (#2869)

* Add the SNAC audio tokenizer.

* More snac.

* Again more snac.

* Add some example code for snac.

* Get the weights to load.

* Add to the snac model.

* Fixes.

* Get round-tripping to work.

* Save/load code files.

* Clippy fix.

* Fmt fix.
This commit is contained in:
Laurent Mazare
2025-04-06 22:15:36 +02:00
committed by GitHub
parent 338f6a102e
commit e3370c6316
7 changed files with 1265 additions and 0 deletions

View File

@ -330,6 +330,7 @@ impl ResidualVectorQuantizer {
Ok(Self { quantizers })
}
#[allow(clippy::wrong_self_convention)]
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
let mut sum = None;
for (idx, quantizer) in self.quantizers.iter().enumerate() {

View File

@ -141,6 +141,20 @@ pub fn conv1d_weight_norm(
Ok(Conv1d::new(weight, Some(bias), config))
}
pub fn conv1d_weight_norm_no_bias(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
Ok(Conv1d::new(weight, None, config))
}
pub fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,

View File

@ -104,6 +104,7 @@ pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod siglip;
pub mod snac;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod starcoder2;

View File

@ -0,0 +1,814 @@
#![allow(unused)]
//! Implementation of the Multi-Scale Neural Audio Codec (SNAC)
//!
//! See: [SNAC](https://github.com/hubertsiuzdak/snac)
//!
/// Multi-Scale Neural Audio Codec (SNAC) compresses audio into discrete codes at a low bitrate.
/// For more information, read the paper: https://arxiv.org/abs/2410.14411
///
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{
linear_b, Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, LayerNorm, Linear,
VarBuilder,
};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub sampling_rate: usize,
pub encoder_dim: usize,
pub encoder_rates: Vec<usize>,
pub decoder_dim: usize,
pub decoder_rates: Vec<usize>,
pub attn_window_size: Option<usize>,
pub codebook_size: usize,
pub codebook_dim: usize,
pub vq_strides: Vec<usize>,
pub noise: bool,
pub depthwise: bool,
}
// Equivalent to torch.repeat_interleave
pub fn repeat_interleave<D: candle::shape::Dim>(
img: &Tensor,
repeats: usize,
dim: D,
) -> Result<Tensor> {
if repeats == 1 {
return Ok(img.clone());
}
let dim = dim.to_index(img.shape(), "chunk")?;
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
let weight_v = {
let name = "parametrizations.weight.original1";
match vb.get((out_c, in_c, kernel_size), name) {
Ok(v) => v,
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
}
};
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
pub fn conv1d_weight_norm_no_bias(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
let weight_v = {
let name = "parametrizations.weight.original1";
match vb.get((out_c, in_c, kernel_size), name) {
Ok(v) => v,
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
}
};
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
Ok(Conv1d::new(weight, None, config))
}
pub fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::ConvTranspose1dConfig,
vb: VarBuilder,
) -> Result<ConvTranspose1d> {
let weight_g = vb.get((in_c, 1, 1), "parametrizations.weight.original0")?;
let weight_v = vb.get(
(in_c, out_c, kernel_size),
"parametrizations.weight.original1",
)?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(ConvTranspose1d::new(weight, bias, config))
}
// https://github.com/hubertsiuzdak/snac/blob/main/snac/attention.py
#[allow(unused)]
#[derive(Debug, Clone)]
struct SinusoidalEmbeddings {
inv_freq: Tensor,
scale: Tensor,
scale_base: f32,
use_xpos: bool,
}
impl SinusoidalEmbeddings {
fn new(dim: usize, scale_base: f32, use_xpos: bool, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10_000f32.powf(i as f32 / dim as f32))
.collect();
let len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, len, dev)?.to_dtype(DType::F32)?;
let scale: Vec<_> = (0..dim)
.step_by(2)
.map(|i| (i as f32 + 0.4 * dim as f32) / (1.4 * dim as f32))
.collect();
let scale = Tensor::from_vec(scale, len, dev)?.to_dtype(DType::F32)?;
Ok(Self {
inv_freq,
scale,
scale_base,
use_xpos,
})
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
struct LocalMHA {
norm: LayerNorm,
to_qkv: Linear,
to_out: Linear,
num_heads: usize,
head_dim: usize,
rel_pos: Option<SinusoidalEmbeddings>,
}
impl LocalMHA {
fn new(
dim: usize,
window_size: usize,
dim_head: usize,
use_rotary_pos_emb: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let to_qkv = linear_b(dim, dim * 3, false, vb.pp("to_qkv"))?;
let to_out = linear_b(dim, dim, false, vb.pp("to_out"))?;
let rel_pos = if use_rotary_pos_emb {
let rel_pos =
SinusoidalEmbeddings::new(dim_head, window_size as f32 / 2.0, false, vb.device())?;
Some(rel_pos)
} else {
None
};
Ok(Self {
norm,
to_qkv,
to_out,
rel_pos,
num_heads: dim / dim_head,
head_dim: dim_head,
})
}
}
impl Module for LocalMHA {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, c, t) = xs.dims3()?;
let residual = xs.clone();
let xs = xs.transpose(1, 2)?.apply(&self.norm)?;
let qkv = xs.apply(&self.to_qkv)?;
let q = qkv.narrow(D::Minus1, 0, c)?;
let k = qkv.narrow(D::Minus1, c, c)?;
let v = qkv.narrow(D::Minus1, 2 * c, c)?;
let q = q
.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let (q, k) = match self.rel_pos {
Some(_) => todo!(),
None => (q, k),
};
let out = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
// Non-causal attention
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&v)?
};
let out = out
.transpose(1, 2)?
.reshape((b, t, self.num_heads * self.head_dim))?
.apply(&self.to_out)?;
out.transpose(1, 2)? + residual
}
}
#[derive(Debug, Clone)]
struct Snake1d {
alpha: Tensor,
}
impl Snake1d {
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
let alpha = vb.get((1, channels, 1), "alpha")?;
Ok(Self { alpha })
}
}
impl Module for Snake1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs_shape = xs.shape();
let xs = xs.flatten_from(2)?;
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
let sin = (&sin * &sin)?;
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
}
}
#[derive(Debug, Clone)]
struct ResidualUnit {
snake1: Snake1d,
conv1: Conv1d,
snake2: Snake1d,
conv2: Conv1d,
}
impl ResidualUnit {
fn new(
dim: usize,
dilation: usize,
kernel: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let pad = ((kernel - 1) * dilation) / 2;
let vb = vb.pp("block");
let snake1 = Snake1d::new(dim, vb.pp(0))?;
let cfg1 = Conv1dConfig {
dilation,
padding: pad,
groups,
..Default::default()
};
let conv1 = conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
let snake2 = Snake1d::new(dim, vb.pp(2))?;
let conv2 = conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
Ok(Self {
snake1,
conv1,
snake2,
conv2,
})
}
}
impl Module for ResidualUnit {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = xs
.apply(&self.snake1)?
.apply(&self.conv1)?
.apply(&self.snake2)?
.apply(&self.conv2)?;
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
if pad > 0 {
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
} else {
ys + xs
}
}
}
#[derive(Debug, Clone)]
struct NoiseBlock {
linear: Conv1d,
}
impl NoiseBlock {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let linear = conv1d_weight_norm_no_bias(dim, dim, 1, Default::default(), vb.pp("linear"))?;
Ok(Self { linear })
}
}
impl Module for NoiseBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, _c, t) = xs.dims3()?;
let noise = Tensor::randn(0f32, 1f32, (b, 1, t), xs.device())?;
let h = xs.apply(&self.linear)?;
let n = noise.broadcast_mul(&h)?;
let xs = (xs + n)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct DecoderBlock {
snake1: Snake1d,
conv_tr1: ConvTranspose1d,
noise: Option<NoiseBlock>,
res1: ResidualUnit,
res2: ResidualUnit,
res3: ResidualUnit,
}
impl DecoderBlock {
fn new(
in_dim: usize,
out_dim: usize,
stride: usize,
noise: bool,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
let cfg = ConvTranspose1dConfig {
stride,
padding: stride.div_ceil(2),
output_padding: stride % 2,
..Default::default()
};
let conv_tr1 =
conv_transpose1d_weight_norm(in_dim, out_dim, 2 * stride, true, cfg, vb.pp(1))?;
let (n, noise) = if noise {
let noise = NoiseBlock::new(out_dim, vb.pp(2))?;
(1, Some(noise))
} else {
(0, None)
};
let res1 = ResidualUnit::new(out_dim, 1, 7, groups, vb.pp(2 + n))?;
let res2 = ResidualUnit::new(out_dim, 3, 7, groups, vb.pp(3 + n))?;
let res3 = ResidualUnit::new(out_dim, 9, 7, groups, vb.pp(4 + n))?;
Ok(Self {
snake1,
conv_tr1,
noise,
res1,
res2,
res3,
})
}
}
impl Module for DecoderBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.snake1)?
.apply(&self.conv_tr1)?
.apply(&self.noise.as_ref())?
.apply(&self.res1)?
.apply(&self.res2)?
.apply(&self.res3)
}
}
#[derive(Debug, Clone)]
struct EncoderBlock {
res1: ResidualUnit,
res2: ResidualUnit,
res3: ResidualUnit,
snake1: Snake1d,
conv1: Conv1d,
}
impl EncoderBlock {
fn new(
out_dim: usize,
in_dim: Option<usize>,
stride: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let in_dim = in_dim.unwrap_or(out_dim / 2);
let res1 = ResidualUnit::new(in_dim, 1, 7, groups, vb.pp(0))?;
let res2 = ResidualUnit::new(in_dim, 3, 7, groups, vb.pp(1))?;
let res3 = ResidualUnit::new(in_dim, 9, 7, groups, vb.pp(2))?;
let snake1 = Snake1d::new(in_dim, vb.pp(3))?;
let cfg1 = Conv1dConfig {
stride,
padding: stride.div_ceil(2),
..Default::default()
};
let conv1 = conv1d_weight_norm(in_dim, out_dim, 2 * stride, cfg1, vb.pp(4))?;
Ok(Self {
res1,
res2,
res3,
snake1,
conv1,
})
}
}
impl candle::Module for EncoderBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.res1)?
.apply(&self.res2)?
.apply(&self.res3)?
.apply(&self.snake1)?
.apply(&self.conv1)
}
}
#[derive(Debug, Clone)]
pub struct Encoder {
conv1: Conv1d,
blocks: Vec<EncoderBlock>,
local_mha: Option<LocalMHA>,
conv2: Conv1d,
}
impl candle::Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.conv1)?;
for block in self.blocks.iter() {
xs = xs.apply(block)?
}
xs.apply(&self.conv2)
}
}
impl Encoder {
fn new(
mut d_model: usize,
strides: &[usize],
depthwise: bool,
attn_window_size: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let mut idx = 0;
let cfg1 = Conv1dConfig {
padding: 3,
..Default::default()
};
let conv1 = conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(idx))?;
idx += 1;
let mut blocks = Vec::with_capacity(strides.len());
for &stride in strides.iter() {
d_model *= 2;
let groups = if depthwise { d_model / 2 } else { 1 };
let block = EncoderBlock::new(d_model, None, stride, groups, vb.pp(idx))?;
idx += 1;
blocks.push(block)
}
let local_mha = match attn_window_size {
Some(w) => {
let mha = LocalMHA::new(d_model, w, 64, true, vb.pp(idx))?;
idx += 1;
Some(mha)
}
None => None,
};
let groups = if depthwise { d_model } else { 1 };
let cfg2 = Conv1dConfig {
padding: 3,
groups,
..Default::default()
};
let conv2 = conv1d_weight_norm(d_model, d_model, 7, cfg2, vb.pp(idx))?;
idx += 1;
Ok(Self {
conv1,
blocks,
local_mha,
conv2,
})
}
}
#[derive(Debug, Clone)]
enum ConvInit {
Depthwise(Conv1d, Conv1d),
Standard(Conv1d),
}
#[derive(Debug, Clone)]
pub struct Decoder {
conv1: ConvInit,
local_mha: Option<LocalMHA>,
blocks: Vec<DecoderBlock>,
snake1: Snake1d,
conv2: Conv1d,
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
in_c: usize,
mut channels: usize,
rates: &[usize],
noise: bool,
depthwise: bool,
attn_window_size: Option<usize>,
d_out: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("model");
let mut idx = 0;
let pad3 = Conv1dConfig {
padding: 3,
..Default::default()
};
let conv1 = if depthwise {
let cfg1 = Conv1dConfig {
padding: 3,
groups: in_c,
..Default::default()
};
let conv1 = conv1d_weight_norm(in_c, in_c, 7, cfg1, vb.pp(idx))?;
idx += 1;
let conv2 = conv1d_weight_norm(in_c, channels, 1, Default::default(), vb.pp(idx))?;
idx += 1;
ConvInit::Depthwise(conv1, conv2)
} else {
let conv1 = conv1d_weight_norm(in_c, channels, 7, pad3, vb.pp(idx))?;
idx += 1;
ConvInit::Standard(conv1)
};
let mut blocks = Vec::with_capacity(rates.len());
let local_mha = match attn_window_size {
Some(w) => {
let mha = LocalMHA::new(channels, w, 64, true, vb.pp(idx))?;
idx += 1;
Some(mha)
}
None => None,
};
for stride in rates.iter() {
let groups = if depthwise { channels / 2 } else { 1 };
let block =
DecoderBlock::new(channels, channels / 2, *stride, noise, groups, vb.pp(idx))?;
idx += 1;
channels /= 2;
blocks.push(block)
}
let snake1 = Snake1d::new(channels, vb.pp(idx))?;
idx += 1;
let conv2 = conv1d_weight_norm(channels, d_out, 7, pad3, vb.pp(idx))?;
idx += 1;
Ok(Self {
conv1,
local_mha,
blocks,
snake1,
conv2,
})
}
}
impl candle::Module for Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = match &self.conv1 {
ConvInit::Standard(c) => xs.apply(c)?,
ConvInit::Depthwise(c1, c2) => xs.apply(c1)?.apply(c2)?,
};
for block in self.blocks.iter() {
xs = xs.apply(block)?
}
xs.apply(&self.snake1)?.apply(&self.conv2)
}
}
fn normalize(v: &Tensor) -> Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}
// https://github.com/hubertsiuzdak/snac/blob/main/snac/vq.py
#[allow(unused)]
#[derive(Clone, Debug)]
struct VectorQuantizer {
in_proj: Conv1d,
out_proj: Conv1d,
codebook: candle_nn::Embedding,
stride: usize,
}
impl VectorQuantizer {
fn new(
in_dim: usize,
cb_size: usize,
cb_dim: usize,
stride: usize,
vb: VarBuilder,
) -> Result<Self> {
let in_proj = conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
let out_proj =
conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
Ok(Self {
in_proj,
out_proj,
codebook,
stride,
})
}
fn decode_latents(&self, latents: &Tensor) -> Result<(Tensor, Tensor)> {
let (b, d, t) = latents.dims3()?;
let encodings = latents.transpose(1, 2)?.reshape((b * t, d))?;
let encodings = normalize(&encodings)?;
let codebook = normalize(self.codebook.embeddings())?;
let dist = (encodings
.sqr()?
.sum_keepdim(1)?
.broadcast_sub(&encodings.matmul(&codebook.t()?)?)?
* 2.0)?
.broadcast_add(&codebook.sqr()?.sum_keepdim(1)?.t()?)?;
let indices = dist.argmin(1)?.reshape((b, ()))?;
let z_q = self.decode_code(&indices)?;
Ok((z_q, indices))
}
fn encode(&self, z: &Tensor) -> Result<(Tensor, Tensor)> {
let z = if self.stride > 1 {
let (b, c, t) = z.dims3()?;
z.reshape((b, c, 1, t))?
.avg_pool2d((1, self.stride))?
.squeeze(2)?
} else {
z.clone()
};
let z_e = z.apply(&self.in_proj)?;
let (z_q, indices) = self.decode_latents(&z_e)?;
let z_q = z_q.apply(&self.out_proj)?;
let z_q = if self.stride > 1 {
repeat_interleave(&z_q, self.stride, D::Minus1)?
} else {
z_q
};
Ok((z_q, indices))
}
fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
embed_id.apply(&self.codebook)
}
fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
self.embed_code(embed_id)?.transpose(1, 2)
}
}
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
quantizers: Vec<VectorQuantizer>,
}
impl ResidualVectorQuantizer {
fn new(
input_dim: usize,
cb_size: usize,
cb_dim: usize,
vq_strides: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let vb = &vb.pp("quantizers");
let quantizers = vq_strides
.iter()
.enumerate()
.map(|(i, stride)| VectorQuantizer::new(input_dim, cb_size, cb_dim, *stride, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self { quantizers })
}
fn encode(&self, z: &Tensor) -> Result<(Tensor, Vec<Tensor>)> {
let mut residual = z.clone();
let mut z_q = z.zeros_like()?;
let mut codes = Vec::with_capacity(self.quantizers.len());
for quantizer in self.quantizers.iter() {
let (z_q_i, indices_i) = quantizer.encode(&residual)?;
z_q = (z_q + &z_q_i)?;
residual = (residual - &z_q_i)?;
codes.push(indices_i)
}
Ok((z_q, codes))
}
#[allow(clippy::wrong_self_convention)]
fn from_codes(&self, codes: &[&Tensor]) -> Result<Tensor> {
let mut sum = None;
for (quantizer, codes) in self.quantizers.iter().zip(codes.iter()) {
let z_p_i = quantizer.decode_code(codes)?;
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
let z_q_i = repeat_interleave(&z_q_i, quantizer.stride, D::Minus1)?;
let s = match sum {
None => z_q_i,
Some(s) => (s + z_q_i)?,
};
sum = Some(s)
}
match sum {
Some(s) => Ok(s),
None => candle::bail!("empty codebooks"),
}
}
}
fn gcd(mut a: usize, mut b: usize) -> usize {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
fn lcm(a: usize, b: usize) -> usize {
a / gcd(a, b) * b
}
// https://github.com/hubertsiuzdak/snac/blob/main/snac/snac.py
#[derive(Debug, Clone)]
pub struct Model {
pub encoder: Encoder,
pub quantizer: ResidualVectorQuantizer,
pub decoder: Decoder,
pub hop_length: usize,
pub config: Config,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = Encoder::new(
cfg.encoder_dim,
&cfg.encoder_rates,
cfg.depthwise,
cfg.attn_window_size,
vb.pp("encoder"),
)?;
let latent_dim = cfg.encoder_dim * 2usize.pow(cfg.encoder_rates.len() as u32);
let quantizer = ResidualVectorQuantizer::new(
latent_dim,
cfg.codebook_size,
cfg.codebook_dim,
&cfg.vq_strides,
vb.pp("quantizer"),
)?;
let decoder = Decoder::new(
latent_dim,
cfg.decoder_dim,
&cfg.decoder_rates,
cfg.noise,
cfg.depthwise,
cfg.attn_window_size,
/* d_out */ 1,
vb.pp("decoder"),
)?;
let hop_length = cfg.encoder_rates.iter().product::<usize>();
Ok(Self {
encoder,
decoder,
quantizer,
config: cfg.clone(),
hop_length,
})
}
fn preprocess(&self, audio_data: &Tensor) -> Result<Tensor> {
let len = audio_data.dim(D::Minus1)?;
let lcm = lcm(
self.config.vq_strides[0],
self.config.attn_window_size.unwrap_or(1),
);
let pad_to = self.hop_length * lcm;
let right_pad = len.div_ceil(pad_to) * pad_to - len;
let audio_data = audio_data.pad_with_zeros(D::Minus1, 0, right_pad)?;
Ok(audio_data)
}
pub fn encode(&self, audio_data: &Tensor) -> Result<Vec<Tensor>> {
let audio_data = self.preprocess(audio_data)?;
let z = self.encoder.forward(&audio_data)?;
let (_, codes) = self.quantizer.encode(&z)?;
Ok(codes)
}
pub fn decode(&self, audio_codes: &[&Tensor]) -> Result<Tensor> {
let audio_values = self.quantizer.from_codes(audio_codes)?;
audio_values.apply(&self.decoder)
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn num_codebooks(&self) -> usize {
self.quantizer.quantizers.len()
}
}