Pass directly the buffer ownership. (#949)

This commit is contained in:
Laurent Mazare
2023-09-24 06:34:44 +01:00
committed by GitHub
parent e32c89d90c
commit 7edd755756
5 changed files with 12 additions and 18 deletions

View File

@ -92,7 +92,7 @@ impl Model {
Ok(bboxes)
}
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
let multiples = match model_size {
"n" => Multiples::n(),
"s" => Multiples::s(),
@ -104,14 +104,13 @@ impl Model {
))?,
};
let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
let model = YoloV8::load(vb, multiples, 80)?;
Ok(Self { model })
}
pub fn load(md: ModelData) -> Result<Self> {
Self::load_(&md.weights, &md.model_size.to_string())
Self::load_(md.weights, &md.model_size.to_string())
}
}
@ -172,7 +171,7 @@ impl ModelPose {
Ok(bboxes)
}
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
let multiples = match model_size {
"n" => Multiples::n(),
"s" => Multiples::s(),
@ -184,14 +183,13 @@ impl ModelPose {
))?,
};
let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
Ok(Self { model })
}
pub fn load(md: ModelData) -> Result<Self> {
Self::load_(&md.weights, &md.model_size.to_string())
Self::load_(md.weights, &md.model_size.to_string())
}
}