Use the yolo-v8 weights from the hub. (#544)

* Use the weights from the hub.

* Add to the readme.
This commit is contained in:
Laurent Mazare
2023-08-21 22:07:36 +01:00
committed by GitHub
parent 3507e14c0c
commit f16bb97401
2 changed files with 34 additions and 5 deletions

View File

@ -12,7 +12,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
use clap::Parser;
use clap::{Parser, ValueEnum};
use image::{DynamicImage, ImageBuffer};
const CONFIDENCE_THRESHOLD: f32 = 0.5;
@ -719,6 +719,15 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
Ok(DynamicImage::ImageRgb8(img))
}
#[derive(Clone, Copy, ValueEnum, Debug)]
enum Which {
N,
S,
M,
L,
X,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -726,6 +735,10 @@ struct Args {
#[arg(long)]
model: Option<String>,
/// Which model variant to use.
#[arg(long, value_enum, default_value_t = Which::S)]
which: Which,
images: Vec<String>,
}
@ -735,8 +748,15 @@ impl Args {
Some(model) => std::path::PathBuf::from(model),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-yolo-v3".to_string());
api.get("yolo-v3.safetensors")?
let api = api.model("lmz/candle-yolo-v8".to_string());
let filename = match self.which {
Which::N => "yolov8n.safetensors",
Which::S => "yolov8s.safetensors",
Which::M => "yolov8m.safetensors",
Which::L => "yolov8l.safetensors",
Which::X => "yolov8x.safetensors",
};
api.get(filename)?
}
};
Ok(path)
@ -747,11 +767,17 @@ pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Create the model and load the weights from the file.
let multiples = match args.which {
Which::N => Multiples::n(),
Which::S => Multiples::s(),
Which::M => Multiples::m(),
Which::L => Multiples::l(),
Which::X => Multiples::x(),
};
let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let multiples = Multiples::s();
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
println!("model loaded");
for image_name in args.images.iter() {