mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Support more modes in the encodec example. (#1777)
* Support more modes in the encodec example. * Remove the old encodec model from the musicgen bits.
This commit is contained in:
@ -95,3 +95,9 @@ required-features = ["candle-datasets"]
|
||||
[[example]]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
|
||||
|
20
candle-examples/examples/encodec/README.md
Normal file
20
candle-examples/examples/encodec/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-endocec
|
||||
|
||||
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
|
||||
compression model using an encoder/decoder architecture with residual vector
|
||||
quantization.
|
||||
|
||||
## Running one example
|
||||
|
||||
```bash
|
||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||
jfk.wav
|
||||
```
|
||||
|
||||
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
||||
an output wav file containing the audio data. Instead of `code-to-audio` one
|
||||
can use:
|
||||
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
||||
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
||||
containing EnCodec tokens for the input audio file.
|
@ -5,15 +5,85 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp};
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::encodec::{Config, Model};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some encodec tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
@ -21,18 +91,6 @@ struct Args {
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Input file as a safetensors containing the encodec tokens.
|
||||
#[arg(long)]
|
||||
code_file: String,
|
||||
|
||||
/// Output file that will be generated in wav format.
|
||||
#[arg(long)]
|
||||
out: String,
|
||||
|
||||
/// Do another step of encoding the PCM data and and decoding the resulting codes.
|
||||
#[arg(long)]
|
||||
roundtrip: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -48,25 +106,36 @@ fn main() -> Result<()> {
|
||||
let config = Config::default();
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
let codes = candle::safetensors::load(args.code_file, &device)?;
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("pcm shape: {:?}", pcm.shape());
|
||||
|
||||
let pcm = if args.roundtrip {
|
||||
let codes = model.encode(&pcm)?;
|
||||
println!("second step codes shape: {:?}", pcm.shape());
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("second step pcm shape: {:?}", pcm.shape());
|
||||
pcm
|
||||
} else {
|
||||
pcm
|
||||
codes
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
|
||||
}
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out)?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,580 +0,0 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
target_bandwidths: Vec<f64>,
|
||||
sampling_rate: usize,
|
||||
audio_channels: usize,
|
||||
normalize: bool,
|
||||
chunk_length_s: Option<usize>,
|
||||
overlap: Option<usize>,
|
||||
hidden_size: usize,
|
||||
num_filters: usize,
|
||||
num_residual_layers: usize,
|
||||
upsampling_ratios: Vec<usize>,
|
||||
norm_type: NormType,
|
||||
kernel_size: usize,
|
||||
last_kernel_size: usize,
|
||||
residual_kernel_size: usize,
|
||||
dilation_growth_rate: usize,
|
||||
use_causal_conv: bool,
|
||||
pad_mode: &'static str,
|
||||
compress: usize,
|
||||
num_lstm_layers: usize,
|
||||
trim_right_ratio: f64,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
use_conv_shortcut: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||
sampling_rate: 24_000,
|
||||
audio_channels: 1,
|
||||
normalize: false,
|
||||
chunk_length_s: None,
|
||||
overlap: None,
|
||||
hidden_size: 128,
|
||||
num_filters: 32,
|
||||
num_residual_layers: 1,
|
||||
upsampling_ratios: vec![8, 5, 4, 2],
|
||||
norm_type: NormType::WeightNorm,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
dilation_growth_rate: 2,
|
||||
use_causal_conv: true,
|
||||
pad_mode: "reflect",
|
||||
compress: 2,
|
||||
num_lstm_layers: 2,
|
||||
trim_right_ratio: 1.0,
|
||||
codebook_size: 1024,
|
||||
codebook_dim: None,
|
||||
use_conv_shortcut: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
pad_mode: "reflect",
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn codebook_dim(&self) -> usize {
|
||||
self.codebook_dim.unwrap_or(self.codebook_size)
|
||||
}
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
let num = 1000f64
|
||||
* self
|
||||
.target_bandwidths
|
||||
.last()
|
||||
.expect("empty target_bandwidths");
|
||||
(num as usize) / (self.frame_rate() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||
#[derive(Debug)]
|
||||
struct EncodecEuclideanCodebook {
|
||||
inited: Tensor,
|
||||
cluster_size: Tensor,
|
||||
embed: Tensor,
|
||||
embed_avg: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecEuclideanCodebook {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, "inited")?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed,
|
||||
embed_avg,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.embed.embedding(embed_ind)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecVectorQuantization {
|
||||
codebook: EncodecEuclideanCodebook,
|
||||
}
|
||||
|
||||
impl EncodecVectorQuantization {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.codebook.decode(embed_ind)?;
|
||||
let quantize = quantize.transpose(1, 2)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResidualVectorQuantizer {
|
||||
layers: Vec<EncodecVectorQuantization>,
|
||||
}
|
||||
|
||||
impl EncodecResidualVectorQuantizer {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
candle::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
codes.shape(),
|
||||
self.layers.len()
|
||||
)
|
||||
}
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let quantized = layer.decode(&codes.i(i)?)?;
|
||||
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||
}
|
||||
Ok(quantized_out)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConvTranspose1d {
|
||||
weight_g: Tensor,
|
||||
weight_v: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecConvTranspose1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
_stride: usize,
|
||||
vb: VarBuilder,
|
||||
_cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Self {
|
||||
weight_g,
|
||||
weight_v,
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => conv1d_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResnetBlock {
|
||||
block_conv1: EncodecConv1d,
|
||||
block_conv2: EncodecConv1d,
|
||||
shortcut: Option<EncodecConv1d>,
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
candle::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 =
|
||||
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
block_conv1,
|
||||
block_conv2,
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv1.forward(&xs)?;
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv2.forward(&xs)?;
|
||||
let xs = match &self.shortcut {
|
||||
None => (xs + residual)?,
|
||||
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Layer<'a> {
|
||||
vb: VarBuilder<'a>,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl<'a> Layer<'a> {
|
||||
fn new(vb: VarBuilder<'a>) -> Self {
|
||||
Self { vb, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecEncoder {
|
||||
init_conv: EncodecConv1d,
|
||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||
final_lstm: EncodecLSTM,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
let mut scaling = 1;
|
||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConv1d::load(
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
final_lstm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecDecoder {
|
||||
init_conv: EncodecConv1d,
|
||||
init_lstm: EncodecLSTM,
|
||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConvTranspose1d::load(
|
||||
current_scale,
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
sampling_layers.push((conv1d, resnets));
|
||||
scaling /= 2;
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters,
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
init_lstm,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EncodecModel {
|
||||
encoder: EncodecEncoder,
|
||||
decoder: EncodecDecoder,
|
||||
quantizer: EncodecResidualVectorQuantizer,
|
||||
}
|
||||
|
||||
impl EncodecModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
@ -10,9 +10,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
use crate::encodec_model;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
use candle_transformers::models::t5;
|
||||
use candle_transformers::models::{encodec, t5};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -372,7 +371,7 @@ impl MusicgenForCausalLM {
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub audio_encoder: encodec::Model,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
@ -381,15 +380,42 @@ pub struct MusicgenForConditionalGeneration {
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: t5::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
encodec: encodec::Config,
|
||||
}
|
||||
|
||||
impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
let encodec = encodec::Config {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: encodec::NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
// This should be Reflect and not Replicate but Reflect does not work yet.
|
||||
pad_mode: encodec::PadMode::Replicate,
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
};
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5::Config::musicgen_small(),
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
encodec,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -401,8 +427,7 @@ impl MusicgenForConditionalGeneration {
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
Ok(Self {
|
||||
text_encoder,
|
||||
|
@ -1,20 +0,0 @@
|
||||
use candle::Result;
|
||||
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
Reference in New Issue
Block a user