From 7edd755756ca1b28ad73ddbc1ea7aa1b2c21938b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 24 Sep 2023 06:34:44 +0100 Subject: [PATCH] Pass directly the buffer ownership. (#949) --- candle-wasm-examples/bert/src/bin/m.rs | 3 +-- candle-wasm-examples/t5/src/bin/m.rs | 6 ++---- candle-wasm-examples/whisper/src/worker.rs | 3 +-- candle-wasm-examples/yolo/src/bin/m.rs | 4 ++-- candle-wasm-examples/yolo/src/worker.rs | 14 ++++++-------- 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs index f5521abd..67d85e71 100644 --- a/candle-wasm-examples/bert/src/bin/m.rs +++ b/candle-wasm-examples/bert/src/bin/m.rs @@ -18,8 +18,7 @@ impl Model { console_error_panic_hook::set_once(); console_log!("loading model"); let device = &Device::Cpu; - let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device); + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F64, device)?; let config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; diff --git a/candle-wasm-examples/t5/src/bin/m.rs b/candle-wasm-examples/t5/src/bin/m.rs index c82e00cd..acb9e40a 100644 --- a/candle-wasm-examples/t5/src/bin/m.rs +++ b/candle-wasm-examples/t5/src/bin/m.rs @@ -29,8 +29,7 @@ impl ModelConditionalGeneration { console_error_panic_hook::set_once(); console_log!("loading model"); let device = &Device::Cpu; - let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; let mut config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; @@ -128,8 +127,7 @@ impl ModelEncoder { console_error_panic_hook::set_once(); console_log!("loading model"); let device = &Device::Cpu; - let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; let mut config: Config = serde_json::from_slice(&config)?; config.use_cache = false; let tokenizer = diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index 4fb2223a..6ea0954c 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -253,8 +253,7 @@ impl Decoder { let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; console_log!("loaded mel filters {:?}", mel_filters.shape()); let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; - let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?; let config = Config::tiny_en(); let whisper = Whisper::load(&vb, config)?; console_log!("done loading model"); diff --git a/candle-wasm-examples/yolo/src/bin/m.rs b/candle-wasm-examples/yolo/src/bin/m.rs index 800188f6..77e427cc 100644 --- a/candle-wasm-examples/yolo/src/bin/m.rs +++ b/candle-wasm-examples/yolo/src/bin/m.rs @@ -13,7 +13,7 @@ pub struct Model { impl Model { #[wasm_bindgen(constructor)] pub fn new(data: Vec, model_size: &str) -> Result { - let inner = M::load_(&data, model_size)?; + let inner = M::load_(data, model_size)?; Ok(Self { inner }) } @@ -46,7 +46,7 @@ pub struct ModelPose { impl ModelPose { #[wasm_bindgen(constructor)] pub fn new(data: Vec, model_size: &str) -> Result { - let inner = P::load_(&data, model_size)?; + let inner = P::load_(data, model_size)?; Ok(Self { inner }) } diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs index 11d41c53..5733a3fd 100644 --- a/candle-wasm-examples/yolo/src/worker.rs +++ b/candle-wasm-examples/yolo/src/worker.rs @@ -92,7 +92,7 @@ impl Model { Ok(bboxes) } - pub fn load_(weights: &[u8], model_size: &str) -> Result { + pub fn load_(weights: Vec, model_size: &str) -> Result { let multiples = match model_size { "n" => Multiples::n(), "s" => Multiples::s(), @@ -104,14 +104,13 @@ impl Model { ))?, }; let dev = &Device::Cpu; - let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?; let model = YoloV8::load(vb, multiples, 80)?; Ok(Self { model }) } pub fn load(md: ModelData) -> Result { - Self::load_(&md.weights, &md.model_size.to_string()) + Self::load_(md.weights, &md.model_size.to_string()) } } @@ -172,7 +171,7 @@ impl ModelPose { Ok(bboxes) } - pub fn load_(weights: &[u8], model_size: &str) -> Result { + pub fn load_(weights: Vec, model_size: &str) -> Result { let multiples = match model_size { "n" => Multiples::n(), "s" => Multiples::s(), @@ -184,14 +183,13 @@ impl ModelPose { ))?, }; let dev = &Device::Cpu; - let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?; let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?; Ok(Self { model }) } pub fn load(md: ModelData) -> Result { - Self::load_(&md.weights, &md.model_size.to_string()) + Self::load_(md.weights, &md.model_size.to_string()) } }