mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Expose the larger resnets (50/101/152) in the example. (#1131)
This commit is contained in:
@ -166,7 +166,7 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- DINOv2.
|
- DINOv2.
|
||||||
- ConvMixer.
|
- ConvMixer.
|
||||||
- EfficientNet.
|
- EfficientNet.
|
||||||
- ResNet-18/34.
|
- ResNet-18/34/50/101/152.
|
||||||
- yolo-v3.
|
- yolo-v3.
|
||||||
- yolo-v8.
|
- yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
|
12
candle-examples/examples/resnet/export_models.py
Normal file
12
candle-examples/examples/resnet/export_models.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# This script exports pre-trained model weights in the safetensors format.
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from safetensors import torch as stt
|
||||||
|
|
||||||
|
m = torchvision.models.resnet50(pretrained=True)
|
||||||
|
stt.save_file(m.state_dict(), 'resnet50.safetensors')
|
||||||
|
m = torchvision.models.resnet101(pretrained=True)
|
||||||
|
stt.save_file(m.state_dict(), 'resnet101.safetensors')
|
||||||
|
m = torchvision.models.resnet152(pretrained=True)
|
||||||
|
stt.save_file(m.state_dict(), 'resnet152.safetensors')
|
@ -11,8 +11,16 @@ use clap::{Parser, ValueEnum};
|
|||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
|
#[value(name = "18")]
|
||||||
Resnet18,
|
Resnet18,
|
||||||
|
#[value(name = "34")]
|
||||||
Resnet34,
|
Resnet34,
|
||||||
|
#[value(name = "50")]
|
||||||
|
Resnet50,
|
||||||
|
#[value(name = "101")]
|
||||||
|
Resnet101,
|
||||||
|
#[value(name = "152")]
|
||||||
|
Resnet152,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
@ -47,6 +55,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let filename = match args.which {
|
let filename = match args.which {
|
||||||
Which::Resnet18 => "resnet18.safetensors",
|
Which::Resnet18 => "resnet18.safetensors",
|
||||||
Which::Resnet34 => "resnet34.safetensors",
|
Which::Resnet34 => "resnet34.safetensors",
|
||||||
|
Which::Resnet50 => "resnet50.safetensors",
|
||||||
|
Which::Resnet101 => "resnet101.safetensors",
|
||||||
|
Which::Resnet152 => "resnet152.safetensors",
|
||||||
};
|
};
|
||||||
api.get(filename)?
|
api.get(filename)?
|
||||||
}
|
}
|
||||||
@ -57,6 +68,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let model = match args.which {
|
let model = match args.which {
|
||||||
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
|
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
|
||||||
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
|
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
|
||||||
|
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
|
||||||
|
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
|
||||||
|
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
|
||||||
};
|
};
|
||||||
println!("model built");
|
println!("model built");
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
Reference in New Issue
Block a user