mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Pass directly the buffer ownership. (#949)
This commit is contained in:
@ -13,7 +13,7 @@ pub struct Model {
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> {
|
||||
let inner = M::load_(&data, model_size)?;
|
||||
let inner = M::load_(data, model_size)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ pub struct ModelPose {
|
||||
impl ModelPose {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
|
||||
let inner = P::load_(&data, model_size)?;
|
||||
let inner = P::load_(data, model_size)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
|
@ -92,7 +92,7 @@ impl Model {
|
||||
Ok(bboxes)
|
||||
}
|
||||
|
||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
||||
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
|
||||
let multiples = match model_size {
|
||||
"n" => Multiples::n(),
|
||||
"s" => Multiples::s(),
|
||||
@ -104,14 +104,13 @@ impl Model {
|
||||
))?,
|
||||
};
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
|
||||
let model = YoloV8::load(vb, multiples, 80)?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
pub fn load(md: ModelData) -> Result<Self> {
|
||||
Self::load_(&md.weights, &md.model_size.to_string())
|
||||
Self::load_(md.weights, &md.model_size.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@ -172,7 +171,7 @@ impl ModelPose {
|
||||
Ok(bboxes)
|
||||
}
|
||||
|
||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
||||
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
|
||||
let multiples = match model_size {
|
||||
"n" => Multiples::n(),
|
||||
"s" => Multiples::s(),
|
||||
@ -184,14 +183,13 @@ impl ModelPose {
|
||||
))?,
|
||||
};
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
|
||||
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
pub fn load(md: ModelData) -> Result<Self> {
|
||||
Self::load_(&md.weights, &md.model_size.to_string())
|
||||
Self::load_(md.weights, &md.model_size.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user