use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder}; use anyhow::Result; use candle::Tensor; // Encodec Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py #[derive(Debug, Clone, PartialEq)] enum NormType { WeightNorm, None, } #[derive(Debug, Clone, PartialEq)] pub struct Config { target_bandwidths: Vec, sampling_rate: usize, audio_channels: usize, normalize: bool, chunk_length_s: Option, overlap: Option, hidden_size: usize, num_filters: usize, num_residual_layers: usize, upsampling_ratios: Vec, norm_type: NormType, kernel_size: usize, last_kernel_size: usize, residual_kernel_size: usize, dilation_growth_rate: usize, use_causal_conv: bool, pad_mode: &'static str, compress: usize, num_lstm_layers: usize, trim_right_ratio: f64, codebook_size: usize, codebook_dim: Option, use_conv_shortcut: bool, } impl Default for Config { fn default() -> Self { Self { target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0], sampling_rate: 24_000, audio_channels: 1, normalize: false, chunk_length_s: None, overlap: None, hidden_size: 128, num_filters: 32, num_residual_layers: 1, upsampling_ratios: vec![8, 5, 4, 2], norm_type: NormType::WeightNorm, kernel_size: 7, last_kernel_size: 7, residual_kernel_size: 3, dilation_growth_rate: 2, use_causal_conv: true, pad_mode: "reflect", compress: 2, num_lstm_layers: 2, trim_right_ratio: 1.0, codebook_size: 1024, codebook_dim: None, use_conv_shortcut: true, } } } impl Config { // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6 pub fn musicgen_small() -> Self { Self { audio_channels: 1, chunk_length_s: None, codebook_dim: Some(128), codebook_size: 2048, compress: 2, dilation_growth_rate: 2, hidden_size: 128, kernel_size: 7, last_kernel_size: 7, norm_type: NormType::WeightNorm, normalize: false, num_filters: 64, num_lstm_layers: 2, num_residual_layers: 1, overlap: None, pad_mode: "reflect", residual_kernel_size: 3, sampling_rate: 32_000, target_bandwidths: vec![2.2], trim_right_ratio: 1.0, upsampling_ratios: vec![8, 5, 4, 4], use_causal_conv: false, use_conv_shortcut: false, } } fn codebook_dim(&self) -> usize { self.codebook_dim.unwrap_or(self.codebook_size) } fn frame_rate(&self) -> usize { let hop_length: usize = self.upsampling_ratios.iter().product(); (self.sampling_rate + hop_length - 1) / hop_length } fn num_quantizers(&self) -> usize { let num = 1000f64 * self .target_bandwidths .last() .expect("empty target_bandwidths"); (num as usize) / (self.frame_rate() * 10) } } // https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340 #[derive(Debug)] struct EncodecEuclideanCodebook { inited: Tensor, cluster_size: Tensor, embed: Tensor, embed_avg: Tensor, } impl EncodecEuclideanCodebook { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let inited = vb.get(1, &format!("{p}.inited"))?; let cluster_size = vb.get(cfg.codebook_size, &format!("{p}.cluster_size"))?; let e_shape = (cfg.codebook_size, cfg.codebook_dim()); let embed = vb.get(e_shape, &format!("{p}.embed"))?; let embed_avg = vb.get(e_shape, &format!("{p}.embed_avg"))?; Ok(Self { inited, cluster_size, embed, embed_avg, }) } } #[derive(Debug)] struct EncodecVectorQuantization { codebook: EncodecEuclideanCodebook, } impl EncodecVectorQuantization { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let codebook = EncodecEuclideanCodebook::load(&format!("{p}.codebook"), vb, cfg)?; Ok(Self { codebook }) } } #[derive(Debug)] struct EncodecResidualVectorQuantizer { layers: Vec, } impl EncodecResidualVectorQuantizer { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let p = format!("{p}.layers"); let layers = (0..cfg.num_quantizers()) .map(|i| EncodecVectorQuantization::load(&format!("{p}.{i}"), vb, cfg)) .collect::>>()?; Ok(Self { layers }) } } // https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226 #[derive(Debug)] struct EncodecLSTM { layers: Vec<(Tensor, Tensor, Tensor, Tensor)>, } impl EncodecLSTM { fn load(dim: usize, p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let p = format!("{p}.lstm"); let mut layers = vec![]; for i in 0..cfg.num_lstm_layers { let w_hh = vb.get((4 * dim, dim), &format!("{p}.weight_hh_l{i}"))?; let w_ih = vb.get((4 * dim, dim), &format!("{p}.weight_ih_l{i}"))?; let b_hh = vb.get(4 * dim, &format!("{p}.bias_hh_l{i}"))?; let b_ih = vb.get(4 * dim, &format!("{p}.bias_ih_l{i}"))?; layers.push((w_hh, w_ih, b_hh, b_ih)) } Ok(Self { layers }) } } #[derive(Debug)] struct EncodecConvTranspose1d { weight_g: Tensor, weight_v: Tensor, bias: Tensor, } impl EncodecConvTranspose1d { fn load( in_c: usize, out_c: usize, k: usize, _stride: usize, p: &str, vb: &VarBuilder, _cfg: &Config, ) -> Result { let p = format!("{p}.conv"); let weight_g = vb.get((in_c, 1, 1), &format!("{p}.weight_g"))?; let weight_v = vb.get((in_c, out_c, k), &format!("{p}.weight_v"))?; let bias = vb.get(out_c, &format!("{p}.bias"))?; Ok(Self { weight_g, weight_v, bias, }) } } #[derive(Debug)] struct EncodecConv1d { conv: Conv1d, } impl EncodecConv1d { fn load( in_c: usize, out_c: usize, kernel_size: usize, stride: usize, p: &str, vb: &VarBuilder, cfg: &Config, ) -> Result { let conv = match cfg.norm_type { NormType::WeightNorm => conv1d_weight_norm( in_c, out_c, kernel_size, Conv1dConfig { padding: 0, stride }, &format!("{p}.conv"), vb, )?, NormType::None => conv1d( in_c, out_c, kernel_size, Conv1dConfig { padding: 0, stride }, &format!("{p}.conv"), vb, )?, }; Ok(Self { conv }) } } #[derive(Debug)] struct EncodecResnetBlock { block_conv1: EncodecConv1d, block_conv2: EncodecConv1d, shortcut: Option, } impl EncodecResnetBlock { fn load( dim: usize, dilations: &[usize], p: &str, vb: &VarBuilder, cfg: &Config, ) -> Result { let h = dim / cfg.compress; let mut layer = Layer::new(format!("{p}.block")); if dilations.len() != 2 { anyhow::bail!("expected dilations of size 2") } // TODO: Apply dilations! layer.inc(); let block_conv1 = EncodecConv1d::load( dim, h, cfg.residual_kernel_size, 1, &layer.next_name(), vb, cfg, )?; layer.inc(); let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, &layer.next_name(), vb, cfg)?; let shortcut = if cfg.use_conv_shortcut { let conv = EncodecConv1d::load(dim, dim, 1, 1, &format!("{p}.shortcut"), vb, cfg)?; Some(conv) } else { None }; Ok(Self { block_conv1, block_conv2, shortcut, }) } } #[derive(Debug)] struct Layer { prefix: String, cnt: usize, } impl Layer { fn new(prefix: String) -> Self { Self { prefix, cnt: 0 } } fn inc(&mut self) { self.cnt += 1; } fn next_name(&mut self) -> String { let name = format!("{}.{}", self.prefix, self.cnt); self.cnt += 1; name } } #[derive(Debug)] struct EncodecEncoder { init_conv: EncodecConv1d, sampling_layers: Vec<(Vec, EncodecConv1d)>, final_lstm: EncodecLSTM, final_conv: EncodecConv1d, } impl EncodecEncoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let mut layer = Layer::new(format!("{p}.layers")); let init_conv = EncodecConv1d::load( cfg.audio_channels, cfg.num_filters, cfg.kernel_size, 1, &layer.next_name(), vb, cfg, )?; let mut sampling_layers = vec![]; let mut scaling = 1; for &ratio in cfg.upsampling_ratios.iter().rev() { let current_scale = scaling * cfg.num_filters; let mut resnets = vec![]; for j in 0..(cfg.num_residual_layers as u32) { let resnet = EncodecResnetBlock::load( current_scale, &[cfg.dilation_growth_rate.pow(j), 1], &layer.next_name(), vb, cfg, )?; resnets.push(resnet) } layer.inc(); // ELU let conv1d = EncodecConv1d::load( current_scale, current_scale * 2, ratio * 2, ratio, &layer.next_name(), vb, cfg, )?; sampling_layers.push((resnets, conv1d)); scaling *= 2; } let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?; layer.inc(); // ELU let final_conv = EncodecConv1d::load( cfg.num_filters * scaling, cfg.hidden_size, cfg.last_kernel_size, 1, &layer.next_name(), vb, cfg, )?; Ok(Self { init_conv, sampling_layers, final_conv, final_lstm, }) } } #[derive(Debug)] struct EncodecDecoder { init_conv: EncodecConv1d, init_lstm: EncodecLSTM, sampling_layers: Vec<(EncodecConvTranspose1d, Vec)>, final_conv: EncodecConv1d, } impl EncodecDecoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let mut layer = Layer::new(format!("{p}.layers")); let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32); let init_conv = EncodecConv1d::load( cfg.hidden_size, cfg.num_filters * scaling, cfg.last_kernel_size, 1, &layer.next_name(), vb, cfg, )?; let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?; let mut sampling_layers = vec![]; for &ratio in cfg.upsampling_ratios.iter() { let current_scale = scaling * cfg.num_filters; layer.inc(); // ELU let conv1d = EncodecConvTranspose1d::load( current_scale, current_scale / 2, ratio * 2, ratio, &layer.next_name(), vb, cfg, )?; let mut resnets = vec![]; for j in 0..(cfg.num_residual_layers as u32) { let resnet = EncodecResnetBlock::load( current_scale / 2, &[cfg.dilation_growth_rate.pow(j), 1], &layer.next_name(), vb, cfg, )?; resnets.push(resnet) } sampling_layers.push((conv1d, resnets)); scaling /= 2; } layer.inc(); // ELU let final_conv = EncodecConv1d::load( cfg.num_filters, cfg.audio_channels, cfg.last_kernel_size, 1, &layer.next_name(), vb, cfg, )?; Ok(Self { init_conv, init_lstm, sampling_layers, final_conv, }) } } #[derive(Debug)] pub struct EncodecModel { encoder: EncodecEncoder, decoder: EncodecDecoder, quantizer: EncodecResidualVectorQuantizer, } impl EncodecModel { pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let encoder = EncodecEncoder::load(&format!("{p}.encoder"), vb, cfg)?; let decoder = EncodecDecoder::load(&format!("{p}.decoder"), vb, cfg)?; let quantizer = EncodecResidualVectorQuantizer::load(&format!("{p}.quantizer"), vb, cfg)?; Ok(Self { encoder, decoder, quantizer, }) } }