Expose the larger resnets (50/101/152) in the example. (#1131)

This commit is contained in:
Laurent Mazare
2023-10-19 13:48:28 +01:00
committed by GitHub
parent cd53c472df
commit 93c25e8844
3 changed files with 27 additions and 1 deletions

View File

@ -166,7 +166,7 @@ If you have an addition to this list, please submit a pull request.
- DINOv2.
- ConvMixer.
- EfficientNet.
- ResNet-18/34.
- ResNet-18/34/50/101/152.
- yolo-v3.
- yolo-v8.
- Segment-Anything Model (SAM).

View File

@ -0,0 +1,12 @@
# This script exports pre-trained model weights in the safetensors format.
import numpy as np
import torch
import torchvision
from safetensors import torch as stt
m = torchvision.models.resnet50(pretrained=True)
stt.save_file(m.state_dict(), 'resnet50.safetensors')
m = torchvision.models.resnet101(pretrained=True)
stt.save_file(m.state_dict(), 'resnet101.safetensors')
m = torchvision.models.resnet152(pretrained=True)
stt.save_file(m.state_dict(), 'resnet152.safetensors')

View File

@ -11,8 +11,16 @@ use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
#[value(name = "18")]
Resnet18,
#[value(name = "34")]
Resnet34,
#[value(name = "50")]
Resnet50,
#[value(name = "101")]
Resnet101,
#[value(name = "152")]
Resnet152,
}
#[derive(Parser)]
@ -47,6 +55,9 @@ pub fn main() -> anyhow::Result<()> {
let filename = match args.which {
Which::Resnet18 => "resnet18.safetensors",
Which::Resnet34 => "resnet34.safetensors",
Which::Resnet50 => "resnet50.safetensors",
Which::Resnet101 => "resnet101.safetensors",
Which::Resnet152 => "resnet152.safetensors",
};
api.get(filename)?
}
@ -57,6 +68,9 @@ pub fn main() -> anyhow::Result<()> {
let model = match args.which {
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
};
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;