Encodec model. (#1771)

* Encodec model.

* Fixes.

* Add the padding functions.

* Get the LSTM bit to work.

* Get the encodec model to generate some tokens (decoder only for now).

* Minor tweak.

* Minor tweak.
This commit is contained in:
Laurent Mazare
2024-02-27 22:59:40 +01:00
committed by GitHub
parent 205767f9de
commit 0c49e95dfb
7 changed files with 834 additions and 1 deletions

Binary file not shown.

View File

@ -0,0 +1,57 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp};
use candle_nn::VarBuilder;
use candle_transformers::models::encodec::{Config, Model};
use clap::Parser;
use hf_hub::api::sync::Api;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// Input file as a safetensors containing the encodec tokens.
#[arg(long)]
code_file: String,
/// Output file that will be generated in wav format.
#[arg(long)]
out: String,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::default();
let model = Model::new(&config, vb)?;
let codes = candle::safetensors::load(args.code_file, &device)?;
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
println!("codes shape: {:?}", codes.shape());
let pcm = model.decode(&codes)?;
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}