Retrieve the yolo-v3 weights from the hub. (#537)

This commit is contained in:
Laurent Mazare
2023-08-21 10:55:09 +01:00
committed by GitHub
parent 4300864ce9
commit e3b71851e6

View File

@ -130,22 +130,50 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
struct Args { struct Args {
/// Model weights, in safetensors format. /// Model weights, in safetensors format.
#[arg(long)] #[arg(long)]
model: String, model: Option<String>,
#[arg(long)] #[arg(long)]
config: String, config: Option<String>,
images: Vec<String>, images: Vec<String>,
} }
impl Args {
fn config(&self) -> anyhow::Result<std::path::PathBuf> {
let path = match &self.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-yolo-v3".to_string());
api.get("yolo-v3.cfg")?
}
};
Ok(path)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let path = match &self.model {
Some(model) => std::path::PathBuf::from(model),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-yolo-v3".to_string());
api.get("yolo-v3.safetensors")?
}
};
Ok(path)
}
}
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
// Create the model and load the weights from the file. // Create the model and load the weights from the file.
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? }; let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let darknet = darknet::parse_config(&args.config)?; let config = args.config()?;
let darknet = darknet::parse_config(config)?;
let model = darknet.build_model(vb)?; let model = darknet.build_model(vb)?;
for image_name in args.images.iter() { for image_name in args.images.iter() {