mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Retrieve the yolo-v3 weights from the hub. (#537)
This commit is contained in:
@ -130,22 +130,50 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
|||||||
struct Args {
|
struct Args {
|
||||||
/// Model weights, in safetensors format.
|
/// Model weights, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: String,
|
model: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
config: String,
|
config: Option<String>,
|
||||||
|
|
||||||
images: Vec<String>,
|
images: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn config(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||||
|
let path = match &self.config {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-yolo-v3".to_string());
|
||||||
|
api.get("yolo-v3.cfg")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||||
|
let path = match &self.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-yolo-v3".to_string());
|
||||||
|
api.get("yolo-v3.safetensors")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
// Create the model and load the weights from the file.
|
// Create the model and load the weights from the file.
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
|
let model = args.model()?;
|
||||||
|
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||||
let darknet = darknet::parse_config(&args.config)?;
|
let config = args.config()?;
|
||||||
|
let darknet = darknet::parse_config(config)?;
|
||||||
let model = darknet.build_model(vb)?;
|
let model = darknet.build_model(vb)?;
|
||||||
|
|
||||||
for image_name in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
|
Reference in New Issue
Block a user