mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
TP sharding v2
This commit is contained in:
@ -15,6 +15,7 @@ candle = { path = "../../candle-core" }
|
||||
candle-nn = { path = "../../candle-nn" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
safetensors = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
|
@ -236,11 +236,11 @@ impl Decoder {
|
||||
let device = Device::Cpu;
|
||||
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
|
||||
|
||||
let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?;
|
||||
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
|
||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let config = Config::tiny_en();
|
||||
let whisper = Whisper::load(&vb, config)?;
|
||||
|
Reference in New Issue
Block a user