mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a flag to select the dtype used in metavoice. (#1805)
This commit is contained in:
@ -19,6 +19,13 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ArgDType {
|
||||
F32,
|
||||
F16,
|
||||
Bf16,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -67,6 +74,9 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
spk_emb: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: ArgDType,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -120,15 +130,18 @@ fn main() -> Result<()> {
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let first_stage_vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[first_stage_weights], DType::F32, &device)?
|
||||
let dtype = match args.dtype {
|
||||
ArgDType::F32 => DType::F32,
|
||||
ArgDType::F16 => DType::F16,
|
||||
ArgDType::Bf16 => DType::BF16,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
|
||||
let second_stage_vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[second_stage_weights], DType::F32, &device)?
|
||||
};
|
||||
let second_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
||||
|
||||
@ -137,9 +150,8 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
&device
|
||||
};
|
||||
let encodec_vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, encodec_device)?
|
||||
};
|
||||
let encodec_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
|
||||
let encodec_config = encodec::Config::default();
|
||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
||||
|
||||
@ -154,7 +166,7 @@ fn main() -> Result<()> {
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||
let spk_emb = match spk_emb.get("spk_emb") {
|
||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
|
||||
};
|
||||
let spk_emb = spk_emb.to_device(&device)?;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||
@ -228,7 +240,7 @@ fn main() -> Result<()> {
|
||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
||||
let pcm = encodec_model.decode(&audio_ids)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
|
Reference in New Issue
Block a user