mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Support more modes in the encodec example. (#1777)
* Support more modes in the encodec example. * Remove the old encodec model from the musicgen bits.
This commit is contained in:
@ -95,3 +95,9 @@ required-features = ["candle-datasets"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "llama2-c"
|
name = "llama2-c"
|
||||||
required-features = ["candle-datasets"]
|
required-features = ["candle-datasets"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "encodec"
|
||||||
|
required-features = ["symphonia"]
|
||||||
|
|
||||||
|
|
||||||
|
20
candle-examples/examples/encodec/README.md
Normal file
20
candle-examples/examples/encodec/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-endocec
|
||||||
|
|
||||||
|
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
|
||||||
|
compression model using an encoder/decoder architecture with residual vector
|
||||||
|
quantization.
|
||||||
|
|
||||||
|
## Running one example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||||
|
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||||
|
jfk.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
||||||
|
an output wav file containing the audio data. Instead of `code-to-audio` one
|
||||||
|
can use:
|
||||||
|
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
||||||
|
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
||||||
|
containing EnCodec tokens for the input audio file.
|
@ -5,15 +5,85 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{DType, IndexOp};
|
use candle::{DType, IndexOp, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::encodec::{Config, Model};
|
use candle_transformers::models::encodec::{Config, Model};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
|
|
||||||
|
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||||
|
where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
use symphonia::core::audio::Signal;
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
|
||||||
|
let src = std::fs::File::open(path)?;
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||||
|
let mut format = probed.format;
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||||
|
.expect("no supported audio tracks");
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &Default::default())
|
||||||
|
.expect("unsupported codec");
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Action {
|
||||||
|
AudioToAudio,
|
||||||
|
AudioToCode,
|
||||||
|
CodeToAudio,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
/// The action to be performed, specifies the format for the input and output data.
|
||||||
|
action: Action,
|
||||||
|
|
||||||
|
/// The input file, either an audio file or some encodec tokens stored as safetensors.
|
||||||
|
in_file: String,
|
||||||
|
|
||||||
|
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
@ -21,18 +91,6 @@ struct Args {
|
|||||||
/// The model weight file, in safetensor format.
|
/// The model weight file, in safetensor format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
/// Input file as a safetensors containing the encodec tokens.
|
|
||||||
#[arg(long)]
|
|
||||||
code_file: String,
|
|
||||||
|
|
||||||
/// Output file that will be generated in wav format.
|
|
||||||
#[arg(long)]
|
|
||||||
out: String,
|
|
||||||
|
|
||||||
/// Do another step of encoding the PCM data and and decoding the resulting codes.
|
|
||||||
#[arg(long)]
|
|
||||||
roundtrip: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -48,25 +106,36 @@ fn main() -> Result<()> {
|
|||||||
let config = Config::default();
|
let config = Config::default();
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
let codes = candle::safetensors::load(args.code_file, &device)?;
|
let codes = match args.action {
|
||||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
Action::CodeToAudio => {
|
||||||
println!("codes shape: {:?}", codes.shape());
|
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||||
let pcm = model.decode(&codes)?;
|
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||||
println!("pcm shape: {:?}", pcm.shape());
|
codes
|
||||||
|
}
|
||||||
let pcm = if args.roundtrip {
|
Action::AudioToCode | Action::AudioToAudio => {
|
||||||
let codes = model.encode(&pcm)?;
|
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||||
println!("second step codes shape: {:?}", pcm.shape());
|
if sample_rate != 24_000 {
|
||||||
let pcm = model.decode(&codes)?;
|
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
|
||||||
println!("second step pcm shape: {:?}", pcm.shape());
|
}
|
||||||
pcm
|
let pcm_len = pcm.len();
|
||||||
} else {
|
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||||
pcm
|
println!("input pcm shape: {:?}", pcm.shape());
|
||||||
|
model.encode(&pcm)?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
println!("codes shape: {:?}", codes.shape());
|
||||||
|
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
match args.action {
|
||||||
let mut output = std::fs::File::create(&args.out)?;
|
Action::AudioToCode => {
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
codes.save_safetensors("codes", &args.out_file)?;
|
||||||
|
}
|
||||||
|
Action::AudioToAudio | Action::CodeToAudio => {
|
||||||
|
let pcm = model.decode(&codes)?;
|
||||||
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
|
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
||||||
|
let mut output = std::fs::File::create(&args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,580 +0,0 @@
|
|||||||
use crate::nn::conv1d_weight_norm;
|
|
||||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
|
||||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
|
||||||
|
|
||||||
// Encodec Model
|
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
enum NormType {
|
|
||||||
WeightNorm,
|
|
||||||
TimeGroupNorm,
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct Config {
|
|
||||||
target_bandwidths: Vec<f64>,
|
|
||||||
sampling_rate: usize,
|
|
||||||
audio_channels: usize,
|
|
||||||
normalize: bool,
|
|
||||||
chunk_length_s: Option<usize>,
|
|
||||||
overlap: Option<usize>,
|
|
||||||
hidden_size: usize,
|
|
||||||
num_filters: usize,
|
|
||||||
num_residual_layers: usize,
|
|
||||||
upsampling_ratios: Vec<usize>,
|
|
||||||
norm_type: NormType,
|
|
||||||
kernel_size: usize,
|
|
||||||
last_kernel_size: usize,
|
|
||||||
residual_kernel_size: usize,
|
|
||||||
dilation_growth_rate: usize,
|
|
||||||
use_causal_conv: bool,
|
|
||||||
pad_mode: &'static str,
|
|
||||||
compress: usize,
|
|
||||||
num_lstm_layers: usize,
|
|
||||||
trim_right_ratio: f64,
|
|
||||||
codebook_size: usize,
|
|
||||||
codebook_dim: Option<usize>,
|
|
||||||
use_conv_shortcut: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
|
||||||
sampling_rate: 24_000,
|
|
||||||
audio_channels: 1,
|
|
||||||
normalize: false,
|
|
||||||
chunk_length_s: None,
|
|
||||||
overlap: None,
|
|
||||||
hidden_size: 128,
|
|
||||||
num_filters: 32,
|
|
||||||
num_residual_layers: 1,
|
|
||||||
upsampling_ratios: vec![8, 5, 4, 2],
|
|
||||||
norm_type: NormType::WeightNorm,
|
|
||||||
kernel_size: 7,
|
|
||||||
last_kernel_size: 7,
|
|
||||||
residual_kernel_size: 3,
|
|
||||||
dilation_growth_rate: 2,
|
|
||||||
use_causal_conv: true,
|
|
||||||
pad_mode: "reflect",
|
|
||||||
compress: 2,
|
|
||||||
num_lstm_layers: 2,
|
|
||||||
trim_right_ratio: 1.0,
|
|
||||||
codebook_size: 1024,
|
|
||||||
codebook_dim: None,
|
|
||||||
use_conv_shortcut: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
|
||||||
pub fn musicgen_small() -> Self {
|
|
||||||
Self {
|
|
||||||
audio_channels: 1,
|
|
||||||
chunk_length_s: None,
|
|
||||||
codebook_dim: Some(128),
|
|
||||||
codebook_size: 2048,
|
|
||||||
compress: 2,
|
|
||||||
dilation_growth_rate: 2,
|
|
||||||
hidden_size: 128,
|
|
||||||
kernel_size: 7,
|
|
||||||
last_kernel_size: 7,
|
|
||||||
norm_type: NormType::WeightNorm,
|
|
||||||
normalize: false,
|
|
||||||
num_filters: 64,
|
|
||||||
num_lstm_layers: 2,
|
|
||||||
num_residual_layers: 1,
|
|
||||||
overlap: None,
|
|
||||||
pad_mode: "reflect",
|
|
||||||
residual_kernel_size: 3,
|
|
||||||
sampling_rate: 32_000,
|
|
||||||
target_bandwidths: vec![2.2],
|
|
||||||
trim_right_ratio: 1.0,
|
|
||||||
upsampling_ratios: vec![8, 5, 4, 4],
|
|
||||||
use_causal_conv: false,
|
|
||||||
use_conv_shortcut: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn codebook_dim(&self) -> usize {
|
|
||||||
self.codebook_dim.unwrap_or(self.codebook_size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn frame_rate(&self) -> usize {
|
|
||||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
|
||||||
(self.sampling_rate + hop_length - 1) / hop_length
|
|
||||||
}
|
|
||||||
|
|
||||||
fn num_quantizers(&self) -> usize {
|
|
||||||
let num = 1000f64
|
|
||||||
* self
|
|
||||||
.target_bandwidths
|
|
||||||
.last()
|
|
||||||
.expect("empty target_bandwidths");
|
|
||||||
(num as usize) / (self.frame_rate() * 10)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecEuclideanCodebook {
|
|
||||||
inited: Tensor,
|
|
||||||
cluster_size: Tensor,
|
|
||||||
embed: Tensor,
|
|
||||||
embed_avg: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecEuclideanCodebook {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let inited = vb.get(1, "inited")?;
|
|
||||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
|
||||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
|
||||||
let embed = vb.get(e_shape, "embed")?;
|
|
||||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
|
||||||
Ok(Self {
|
|
||||||
inited,
|
|
||||||
cluster_size,
|
|
||||||
embed,
|
|
||||||
embed_avg,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
|
||||||
let quantize = self.embed.embedding(embed_ind)?;
|
|
||||||
Ok(quantize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecVectorQuantization {
|
|
||||||
codebook: EncodecEuclideanCodebook,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecVectorQuantization {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
|
||||||
Ok(Self { codebook })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
|
||||||
let quantize = self.codebook.decode(embed_ind)?;
|
|
||||||
let quantize = quantize.transpose(1, 2)?;
|
|
||||||
Ok(quantize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecResidualVectorQuantizer {
|
|
||||||
layers: Vec<EncodecVectorQuantization>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecResidualVectorQuantizer {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let vb = &vb.pp("layers");
|
|
||||||
let layers = (0..cfg.num_quantizers())
|
|
||||||
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(Self { layers })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
|
||||||
if codes.dim(0)? != self.layers.len() {
|
|
||||||
candle::bail!(
|
|
||||||
"codes shape {:?} does not match the number of quantization layers {}",
|
|
||||||
codes.shape(),
|
|
||||||
self.layers.len()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for (i, layer) in self.layers.iter().enumerate() {
|
|
||||||
let quantized = layer.decode(&codes.i(i)?)?;
|
|
||||||
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
|
||||||
}
|
|
||||||
Ok(quantized_out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecLSTM {
|
|
||||||
layers: Vec<candle_nn::LSTM>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecLSTM {
|
|
||||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let vb = &vb.pp("lstm");
|
|
||||||
let mut layers = vec![];
|
|
||||||
for layer_idx in 0..cfg.num_lstm_layers {
|
|
||||||
let config = candle_nn::LSTMConfig {
|
|
||||||
layer_idx,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
|
||||||
layers.push(lstm)
|
|
||||||
}
|
|
||||||
Ok(Self { layers })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for EncodecLSTM {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
use candle_nn::RNN;
|
|
||||||
let mut xs = xs.clone();
|
|
||||||
for layer in self.layers.iter() {
|
|
||||||
let states = layer.seq(&xs)?;
|
|
||||||
xs = layer.states_to_tensor(&states)?;
|
|
||||||
}
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecConvTranspose1d {
|
|
||||||
weight_g: Tensor,
|
|
||||||
weight_v: Tensor,
|
|
||||||
bias: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecConvTranspose1d {
|
|
||||||
fn load(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
k: usize,
|
|
||||||
_stride: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
_cfg: &Config,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let vb = &vb.pp("conv");
|
|
||||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
|
||||||
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
|
||||||
let bias = vb.get(out_c, "bias")?;
|
|
||||||
Ok(Self {
|
|
||||||
weight_g,
|
|
||||||
weight_v,
|
|
||||||
bias,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for EncodecConvTranspose1d {
|
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecConv1d {
|
|
||||||
causal: bool,
|
|
||||||
conv: Conv1d,
|
|
||||||
norm: Option<candle_nn::GroupNorm>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecConv1d {
|
|
||||||
fn load(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
stride: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
cfg: &Config,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let conv = match cfg.norm_type {
|
|
||||||
NormType::WeightNorm => conv1d_weight_norm(
|
|
||||||
in_c,
|
|
||||||
out_c,
|
|
||||||
kernel_size,
|
|
||||||
Conv1dConfig {
|
|
||||||
padding: 0,
|
|
||||||
stride,
|
|
||||||
groups: 1,
|
|
||||||
dilation: 1,
|
|
||||||
},
|
|
||||||
vb.pp("conv"),
|
|
||||||
)?,
|
|
||||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
|
||||||
in_c,
|
|
||||||
out_c,
|
|
||||||
kernel_size,
|
|
||||||
Conv1dConfig {
|
|
||||||
padding: 0,
|
|
||||||
stride,
|
|
||||||
groups: 1,
|
|
||||||
dilation: 1,
|
|
||||||
},
|
|
||||||
vb.pp("conv"),
|
|
||||||
)?,
|
|
||||||
};
|
|
||||||
let norm = match cfg.norm_type {
|
|
||||||
NormType::None | NormType::WeightNorm => None,
|
|
||||||
NormType::TimeGroupNorm => {
|
|
||||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
|
||||||
Some(gn)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
causal: cfg.use_causal_conv,
|
|
||||||
conv,
|
|
||||||
norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for EncodecConv1d {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
// TODO: padding, depending on causal.
|
|
||||||
let xs = self.conv.forward(xs)?;
|
|
||||||
match &self.norm {
|
|
||||||
None => Ok(xs),
|
|
||||||
Some(norm) => xs.apply(norm),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecResnetBlock {
|
|
||||||
block_conv1: EncodecConv1d,
|
|
||||||
block_conv2: EncodecConv1d,
|
|
||||||
shortcut: Option<EncodecConv1d>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecResnetBlock {
|
|
||||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let h = dim / cfg.compress;
|
|
||||||
let mut layer = Layer::new(vb.pp("block"));
|
|
||||||
if dilations.len() != 2 {
|
|
||||||
candle::bail!("expected dilations of size 2")
|
|
||||||
}
|
|
||||||
// TODO: Apply dilations!
|
|
||||||
layer.inc();
|
|
||||||
let block_conv1 =
|
|
||||||
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
|
||||||
layer.inc();
|
|
||||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
|
||||||
let shortcut = if cfg.use_conv_shortcut {
|
|
||||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
|
||||||
Some(conv)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
block_conv1,
|
|
||||||
block_conv2,
|
|
||||||
shortcut,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for EncodecResnetBlock {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let residual = xs.clone();
|
|
||||||
let xs = xs.elu(1.)?;
|
|
||||||
let xs = self.block_conv1.forward(&xs)?;
|
|
||||||
let xs = xs.elu(1.)?;
|
|
||||||
let xs = self.block_conv2.forward(&xs)?;
|
|
||||||
let xs = match &self.shortcut {
|
|
||||||
None => (xs + residual)?,
|
|
||||||
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
|
||||||
};
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Layer<'a> {
|
|
||||||
vb: VarBuilder<'a>,
|
|
||||||
cnt: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Layer<'a> {
|
|
||||||
fn new(vb: VarBuilder<'a>) -> Self {
|
|
||||||
Self { vb, cnt: 0 }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn inc(&mut self) {
|
|
||||||
self.cnt += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn next(&mut self) -> VarBuilder {
|
|
||||||
let vb = self.vb.pp(&self.cnt.to_string());
|
|
||||||
self.cnt += 1;
|
|
||||||
vb
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecEncoder {
|
|
||||||
init_conv: EncodecConv1d,
|
|
||||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
|
||||||
final_lstm: EncodecLSTM,
|
|
||||||
final_conv: EncodecConv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecEncoder {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let mut layer = Layer::new(vb.pp("layers"));
|
|
||||||
let init_conv = EncodecConv1d::load(
|
|
||||||
cfg.audio_channels,
|
|
||||||
cfg.num_filters,
|
|
||||||
cfg.kernel_size,
|
|
||||||
1,
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
let mut sampling_layers = vec![];
|
|
||||||
let mut scaling = 1;
|
|
||||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
|
||||||
let current_scale = scaling * cfg.num_filters;
|
|
||||||
let mut resnets = vec![];
|
|
||||||
for j in 0..(cfg.num_residual_layers as u32) {
|
|
||||||
let resnet = EncodecResnetBlock::load(
|
|
||||||
current_scale,
|
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
resnets.push(resnet)
|
|
||||||
}
|
|
||||||
layer.inc(); // ELU
|
|
||||||
let conv1d = EncodecConv1d::load(
|
|
||||||
current_scale,
|
|
||||||
current_scale * 2,
|
|
||||||
ratio * 2,
|
|
||||||
ratio,
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
sampling_layers.push((resnets, conv1d));
|
|
||||||
scaling *= 2;
|
|
||||||
}
|
|
||||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
|
||||||
layer.inc(); // ELU
|
|
||||||
let final_conv = EncodecConv1d::load(
|
|
||||||
cfg.num_filters * scaling,
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.last_kernel_size,
|
|
||||||
1,
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
init_conv,
|
|
||||||
sampling_layers,
|
|
||||||
final_conv,
|
|
||||||
final_lstm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut xs = xs.apply(&self.init_conv)?;
|
|
||||||
for (resnets, conv) in self.sampling_layers.iter() {
|
|
||||||
for resnet in resnets.iter() {
|
|
||||||
xs = xs.apply(resnet)?;
|
|
||||||
}
|
|
||||||
xs = xs.elu(1.0)?.apply(conv)?;
|
|
||||||
}
|
|
||||||
xs.apply(&self.final_lstm)?
|
|
||||||
.elu(1.0)?
|
|
||||||
.apply(&self.final_conv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EncodecDecoder {
|
|
||||||
init_conv: EncodecConv1d,
|
|
||||||
init_lstm: EncodecLSTM,
|
|
||||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
|
||||||
final_conv: EncodecConv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecDecoder {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let mut layer = Layer::new(vb.pp("layers"));
|
|
||||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
|
||||||
let init_conv = EncodecConv1d::load(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.num_filters * scaling,
|
|
||||||
cfg.last_kernel_size,
|
|
||||||
1,
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
|
||||||
let mut sampling_layers = vec![];
|
|
||||||
for &ratio in cfg.upsampling_ratios.iter() {
|
|
||||||
let current_scale = scaling * cfg.num_filters;
|
|
||||||
layer.inc(); // ELU
|
|
||||||
let conv1d = EncodecConvTranspose1d::load(
|
|
||||||
current_scale,
|
|
||||||
current_scale / 2,
|
|
||||||
ratio * 2,
|
|
||||||
ratio,
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
let mut resnets = vec![];
|
|
||||||
for j in 0..(cfg.num_residual_layers as u32) {
|
|
||||||
let resnet = EncodecResnetBlock::load(
|
|
||||||
current_scale / 2,
|
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
resnets.push(resnet)
|
|
||||||
}
|
|
||||||
sampling_layers.push((conv1d, resnets));
|
|
||||||
scaling /= 2;
|
|
||||||
}
|
|
||||||
layer.inc(); // ELU
|
|
||||||
let final_conv = EncodecConv1d::load(
|
|
||||||
cfg.num_filters,
|
|
||||||
cfg.audio_channels,
|
|
||||||
cfg.last_kernel_size,
|
|
||||||
1,
|
|
||||||
layer.next(),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
init_conv,
|
|
||||||
init_lstm,
|
|
||||||
sampling_layers,
|
|
||||||
final_conv,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
|
||||||
for (conv, resnets) in self.sampling_layers.iter() {
|
|
||||||
xs = xs.elu(1.)?.apply(conv)?;
|
|
||||||
for resnet in resnets.iter() {
|
|
||||||
xs = xs.apply(resnet)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
xs.elu(1.)?.apply(&self.final_conv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct EncodecModel {
|
|
||||||
encoder: EncodecEncoder,
|
|
||||||
decoder: EncodecDecoder,
|
|
||||||
quantizer: EncodecResidualVectorQuantizer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncodecModel {
|
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
|
||||||
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
|
||||||
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
|
||||||
Ok(Self {
|
|
||||||
encoder,
|
|
||||||
decoder,
|
|
||||||
quantizer,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
|
@ -10,9 +10,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
mod encodec_model;
|
|
||||||
mod musicgen_model;
|
mod musicgen_model;
|
||||||
mod nn;
|
|
||||||
|
|
||||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
use crate::encodec_model;
|
|
||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||||
VarBuilder,
|
VarBuilder,
|
||||||
};
|
};
|
||||||
use candle_transformers::models::t5;
|
use candle_transformers::models::{encodec, t5};
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -372,7 +371,7 @@ impl MusicgenForCausalLM {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MusicgenForConditionalGeneration {
|
pub struct MusicgenForConditionalGeneration {
|
||||||
pub text_encoder: t5::T5EncoderModel,
|
pub text_encoder: t5::T5EncoderModel,
|
||||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
pub audio_encoder: encodec::Model,
|
||||||
pub decoder: MusicgenForCausalLM,
|
pub decoder: MusicgenForCausalLM,
|
||||||
cfg: GenConfig,
|
cfg: GenConfig,
|
||||||
}
|
}
|
||||||
@ -381,15 +380,42 @@ pub struct MusicgenForConditionalGeneration {
|
|||||||
pub struct GenConfig {
|
pub struct GenConfig {
|
||||||
musicgen: Config,
|
musicgen: Config,
|
||||||
t5: t5::Config,
|
t5: t5::Config,
|
||||||
encodec: crate::encodec_model::Config,
|
encodec: encodec::Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenConfig {
|
impl GenConfig {
|
||||||
pub fn small() -> Self {
|
pub fn small() -> Self {
|
||||||
|
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||||
|
let encodec = encodec::Config {
|
||||||
|
audio_channels: 1,
|
||||||
|
chunk_length_s: None,
|
||||||
|
codebook_dim: Some(128),
|
||||||
|
codebook_size: 2048,
|
||||||
|
compress: 2,
|
||||||
|
dilation_growth_rate: 2,
|
||||||
|
hidden_size: 128,
|
||||||
|
kernel_size: 7,
|
||||||
|
last_kernel_size: 7,
|
||||||
|
norm_type: encodec::NormType::WeightNorm,
|
||||||
|
normalize: false,
|
||||||
|
num_filters: 64,
|
||||||
|
num_lstm_layers: 2,
|
||||||
|
num_residual_layers: 1,
|
||||||
|
overlap: None,
|
||||||
|
// This should be Reflect and not Replicate but Reflect does not work yet.
|
||||||
|
pad_mode: encodec::PadMode::Replicate,
|
||||||
|
residual_kernel_size: 3,
|
||||||
|
sampling_rate: 32_000,
|
||||||
|
target_bandwidths: vec![2.2],
|
||||||
|
trim_right_ratio: 1.0,
|
||||||
|
upsampling_ratios: vec![8, 5, 4, 4],
|
||||||
|
use_causal_conv: false,
|
||||||
|
use_conv_shortcut: false,
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
musicgen: Config::musicgen_small(),
|
musicgen: Config::musicgen_small(),
|
||||||
t5: t5::Config::musicgen_small(),
|
t5: t5::Config::musicgen_small(),
|
||||||
encodec: encodec_model::Config::musicgen_small(),
|
encodec,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -401,8 +427,7 @@ impl MusicgenForConditionalGeneration {
|
|||||||
|
|
||||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||||
let audio_encoder =
|
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
|
||||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
|
||||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
text_encoder,
|
text_encoder,
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
use candle::Result;
|
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
|
||||||
|
|
||||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
|
||||||
// does not apply to training.
|
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
|
||||||
pub fn conv1d_weight_norm(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
|
||||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
|
||||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
|
||||||
let bias = vb.get(out_c, "bias")?;
|
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
|
||||||
}
|
|
Reference in New Issue
Block a user