mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the yolo-v8 weights from the hub. (#544)
* Use the weights from the hub. * Add to the readme.
This commit is contained in:
@ -12,7 +12,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||
};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
|
||||
const CONFIDENCE_THRESHOLD: f32 = 0.5;
|
||||
@ -719,6 +719,15 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
Ok(DynamicImage::ImageRgb8(img))
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, ValueEnum, Debug)]
|
||||
enum Which {
|
||||
N,
|
||||
S,
|
||||
M,
|
||||
L,
|
||||
X,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -726,6 +735,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Which model variant to use.
|
||||
#[arg(long, value_enum, default_value_t = Which::S)]
|
||||
which: Which,
|
||||
|
||||
images: Vec<String>,
|
||||
}
|
||||
|
||||
@ -735,8 +748,15 @@ impl Args {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-yolo-v3".to_string());
|
||||
api.get("yolo-v3.safetensors")?
|
||||
let api = api.model("lmz/candle-yolo-v8".to_string());
|
||||
let filename = match self.which {
|
||||
Which::N => "yolov8n.safetensors",
|
||||
Which::S => "yolov8s.safetensors",
|
||||
Which::M => "yolov8m.safetensors",
|
||||
Which::L => "yolov8l.safetensors",
|
||||
Which::X => "yolov8x.safetensors",
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(path)
|
||||
@ -747,11 +767,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
let multiples = match args.which {
|
||||
Which::N => Multiples::n(),
|
||||
Which::S => Multiples::s(),
|
||||
Which::M => Multiples::m(),
|
||||
Which::L => Multiples::l(),
|
||||
Which::X => Multiples::x(),
|
||||
};
|
||||
let model = args.model()?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let multiples = Multiples::s();
|
||||
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
|
||||
println!("model loaded");
|
||||
for image_name in args.images.iter() {
|
||||
|
Reference in New Issue
Block a user