From e3b71851e6d2c9eb51bb9978e7b025386d336f61 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 21 Aug 2023 10:55:09 +0100 Subject: [PATCH] Retrieve the yolo-v3 weights from the hub. (#537) --- candle-examples/examples/yolo-v3/main.rs | 36 +++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index a52f731c..0c7bdd7b 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -130,22 +130,50 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result, #[arg(long)] - config: String, + config: Option, images: Vec, } +impl Args { + fn config(&self) -> anyhow::Result { + let path = match &self.config { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-yolo-v3".to_string()); + api.get("yolo-v3.cfg")? + } + }; + Ok(path) + } + + fn model(&self) -> anyhow::Result { + let path = match &self.model { + Some(model) => std::path::PathBuf::from(model), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-yolo-v3".to_string()); + api.get("yolo-v3.safetensors")? + } + }; + Ok(path) + } +} + pub fn main() -> Result<()> { let args = Args::parse(); // Create the model and load the weights from the file. - let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? }; + let model = args.model()?; + let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); - let darknet = darknet::parse_config(&args.config)?; + let config = args.config()?; + let darknet = darknet::parse_config(config)?; let model = darknet.build_model(vb)?; for image_name in args.images.iter() {