Tweaks to run metavoice on metal (#1792)

* Enable tanh + tweak conv-transpose.

* Run the encodec decoding on cpu.

* Clippy fixes.
This commit is contained in:
Laurent Mazare
2024-03-03 07:46:44 +01:00
committed by GitHub
parent de11623752
commit 09e0148cce
3 changed files with 17 additions and 4 deletions

View File

@ -131,8 +131,14 @@ fn main() -> Result<()> {
let second_stage_config = gpt::Config::cfg1b_v0_1();
let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?;
let encodec_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, &device)? };
let encodec_device = if device.is_metal() {
&candle::Device::Cpu
} else {
&device
};
let encodec_vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, encodec_device)?
};
let encodec_config = encodec::Config::default();
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
@ -144,11 +150,12 @@ fn main() -> Result<()> {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("spk_emb.safetensors")?,
};
let spk_emb = candle::safetensors::load(&spk_emb_file, &device)?;
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
let spk_emb = match spk_emb.get("spk_emb") {
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
};
let spk_emb = spk_emb.to_device(&device)?;
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
// First stage generation.
@ -210,7 +217,7 @@ fn main() -> Result<()> {
let codes = codes.i(0)?.to_vec2::<u32>()?;
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
println!("text_ids len: {:?}", text_ids.len());
let audio_ids = Tensor::new(audio_ids, &device)?.unsqueeze(0)?;
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
println!("audio_ids shape: {:?}", audio_ids.shape());
let pcm = encodec_model.decode(&audio_ids)?;
println!("output pcm shape: {:?}", pcm.shape());