mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Tweaks to run metavoice on metal (#1792)
* Enable tanh + tweak conv-transpose. * Run the encodec decoding on cpu. * Clippy fixes.
This commit is contained in:
@ -738,6 +738,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||||
("uround", DType::F32) => strided::round::FLOAT,
|
("uround", DType::F32) => strided::round::FLOAT,
|
||||||
|
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||||
("ucos", DType::F16) => strided::cos::HALF,
|
("ucos", DType::F16) => strided::cos::HALF,
|
||||||
("usin", DType::F16) => strided::sin::HALF,
|
("usin", DType::F16) => strided::sin::HALF,
|
||||||
("usqr", DType::F16) => strided::sqr::HALF,
|
("usqr", DType::F16) => strided::sqr::HALF,
|
||||||
@ -754,6 +755,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
("urelu", DType::F16) => strided::relu::HALF,
|
("urelu", DType::F16) => strided::relu::HALF,
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
|
("utanh", DType::F16) => strided::tanh::HALF,
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -352,6 +352,10 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -131,8 +131,14 @@ fn main() -> Result<()> {
|
|||||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||||
let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?;
|
let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?;
|
||||||
|
|
||||||
let encodec_vb =
|
let encodec_device = if device.is_metal() {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, &device)? };
|
&candle::Device::Cpu
|
||||||
|
} else {
|
||||||
|
&device
|
||||||
|
};
|
||||||
|
let encodec_vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, encodec_device)?
|
||||||
|
};
|
||||||
let encodec_config = encodec::Config::default();
|
let encodec_config = encodec::Config::default();
|
||||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
||||||
|
|
||||||
@ -144,11 +150,12 @@ fn main() -> Result<()> {
|
|||||||
Some(w) => std::path::PathBuf::from(w),
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
None => repo.get("spk_emb.safetensors")?,
|
None => repo.get("spk_emb.safetensors")?,
|
||||||
};
|
};
|
||||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &device)?;
|
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||||
let spk_emb = match spk_emb.get("spk_emb") {
|
let spk_emb = match spk_emb.get("spk_emb") {
|
||||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||||
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
||||||
};
|
};
|
||||||
|
let spk_emb = spk_emb.to_device(&device)?;
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
||||||
|
|
||||||
// First stage generation.
|
// First stage generation.
|
||||||
@ -210,7 +217,7 @@ fn main() -> Result<()> {
|
|||||||
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
||||||
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
||||||
println!("text_ids len: {:?}", text_ids.len());
|
println!("text_ids len: {:?}", text_ids.len());
|
||||||
let audio_ids = Tensor::new(audio_ids, &device)?.unsqueeze(0)?;
|
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
|
||||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
println!("audio_ids shape: {:?}", audio_ids.shape());
|
||||||
let pcm = encodec_model.decode(&audio_ids)?;
|
let pcm = encodec_model.decode(&audio_ids)?;
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
|
Reference in New Issue
Block a user