TP sharding v2

This commit is contained in:
Nicolas Patry
2023-07-21 15:10:51 +00:00
parent 209f06d7c3
commit 1735e4831e
9 changed files with 833 additions and 18 deletions

View File

@ -236,11 +236,11 @@ impl Decoder {
let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?;
let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config = Config::tiny_en();
let whisper = Whisper::load(&vb, config)?;