Pass directly the buffer ownership. (#949)

This commit is contained in:
Laurent Mazare
2023-09-24 06:34:44 +01:00
committed by GitHub
parent e32c89d90c
commit 7edd755756
5 changed files with 12 additions and 18 deletions

View File

@ -29,8 +29,7 @@ impl ModelConditionalGeneration {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@ -128,8 +127,7 @@ impl ModelEncoder {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false;
let tokenizer =