mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Tweaks to run metavoice on metal (#1792)
* Enable tanh + tweak conv-transpose. * Run the encodec decoding on cpu. * Clippy fixes.
This commit is contained in:
@ -131,8 +131,14 @@ fn main() -> Result<()> {
|
||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||
let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?;
|
||||
|
||||
let encodec_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, &device)? };
|
||||
let encodec_device = if device.is_metal() {
|
||||
&candle::Device::Cpu
|
||||
} else {
|
||||
&device
|
||||
};
|
||||
let encodec_vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, encodec_device)?
|
||||
};
|
||||
let encodec_config = encodec::Config::default();
|
||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
||||
|
||||
@ -144,11 +150,12 @@ fn main() -> Result<()> {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("spk_emb.safetensors")?,
|
||||
};
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &device)?;
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||
let spk_emb = match spk_emb.get("spk_emb") {
|
||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
||||
};
|
||||
let spk_emb = spk_emb.to_device(&device)?;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
||||
|
||||
// First stage generation.
|
||||
@ -210,7 +217,7 @@ fn main() -> Result<()> {
|
||||
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
||||
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
||||
println!("text_ids len: {:?}", text_ids.len());
|
||||
let audio_ids = Tensor::new(audio_ids, &device)?.unsqueeze(0)?;
|
||||
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
|
||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
||||
let pcm = encodec_model.decode(&audio_ids)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
|
Reference in New Issue
Block a user