mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support more snac variants. (#2871)
This commit is contained in:
@ -245,13 +245,14 @@ pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>
|
|||||||
Ok((pcm_data, sample_rate))
|
Ok((pcm_data, sample_rate))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
|
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||||
use rubato::Resampler;
|
use rubato::Resampler;
|
||||||
|
|
||||||
let mut pcm_out =
|
let mut pcm_out =
|
||||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||||
|
|
||||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
|
let mut resampler =
|
||||||
|
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
|
||||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||||
let mut pos_in = 0;
|
let mut pos_in = 0;
|
||||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||||
|
@ -20,6 +20,42 @@ enum Action {
|
|||||||
CodeToAudio,
|
CodeToAudio,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "24khz")]
|
||||||
|
S24khz,
|
||||||
|
#[value(name = "32khz")]
|
||||||
|
S32khz,
|
||||||
|
#[value(name = "44khz")]
|
||||||
|
S44khz,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn sample_rate(&self) -> u32 {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => 24000,
|
||||||
|
Which::S32khz => 32000,
|
||||||
|
Which::S44khz => 44000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_repo(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => "hubertsiuzdak/snac_24khz",
|
||||||
|
Which::S32khz => "hubertsiuzdak/snac_32khz",
|
||||||
|
Which::S44khz => "hubertsiuzdak/snac_44khz",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_file(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => "snac_24khz.safetensors",
|
||||||
|
Which::S32khz => "snac_32khz.safetensors",
|
||||||
|
Which::S44khz => "snac_44khz.safetensors",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -32,6 +68,10 @@ struct Args {
|
|||||||
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
||||||
out_file: String,
|
out_file: String,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "24khz")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
@ -48,18 +88,19 @@ struct Args {
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let model_sample_rate = args.which.sample_rate();
|
||||||
let config = match args.config {
|
let config = match args.config {
|
||||||
Some(c) => std::path::PathBuf::from(c),
|
Some(c) => std::path::PathBuf::from(c),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.model("hubertsiuzdak/snac_24khz".to_string())
|
.model(args.which.config_repo().to_string())
|
||||||
.get("config.json")?,
|
.get("config.json")?,
|
||||||
};
|
};
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.model("lmz/candle_snac_24khz".to_string())
|
.model("lmz/candle-snac".to_string())
|
||||||
.get("model.safetensors")?,
|
.get(args.which.model_file())?,
|
||||||
};
|
};
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
@ -98,9 +139,9 @@ fn main() -> Result<()> {
|
|||||||
pcms.concat()
|
pcms.concat()
|
||||||
} else {
|
} else {
|
||||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||||
if sample_rate != 24_000 {
|
if sample_rate != model_sample_rate {
|
||||||
println!("WARNING: snac uses a 24khz sample rate, input uses {sample_rate}, resampling...");
|
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
|
||||||
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
|
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
|
||||||
} else {
|
} else {
|
||||||
pcm
|
pcm
|
||||||
}
|
}
|
||||||
@ -128,7 +169,7 @@ fn main() -> Result<()> {
|
|||||||
let pcm = model.decode(&codes)?;
|
let pcm = model.decode(&codes)?;
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
let pcm = pcm.i(0)?.i(0)?;
|
let pcm = pcm.i(0)?.i(0)?;
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
if args.out_file == "-" {
|
if args.out_file == "-" {
|
||||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||||
@ -148,7 +189,7 @@ fn main() -> Result<()> {
|
|||||||
drop(stream)
|
drop(stream)
|
||||||
} else {
|
} else {
|
||||||
let mut output = std::fs::File::create(&args.out_file)?;
|
let mut output = std::fs::File::create(&args.out_file)?;
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user