From 93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 19 Oct 2023 13:48:28 +0100 Subject: [PATCH] Expose the larger resnets (50/101/152) in the example. (#1131) --- README.md | 2 +- candle-examples/examples/resnet/export_models.py | 12 ++++++++++++ candle-examples/examples/resnet/main.rs | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 candle-examples/examples/resnet/export_models.py diff --git a/README.md b/README.md index ac9f82c7..fd3a9fbf 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ If you have an addition to this list, please submit a pull request. - DINOv2. - ConvMixer. - EfficientNet. - - ResNet-18/34. + - ResNet-18/34/50/101/152. - yolo-v3. - yolo-v8. - Segment-Anything Model (SAM). diff --git a/candle-examples/examples/resnet/export_models.py b/candle-examples/examples/resnet/export_models.py new file mode 100644 index 00000000..74ef6e7d --- /dev/null +++ b/candle-examples/examples/resnet/export_models.py @@ -0,0 +1,12 @@ +# This script exports pre-trained model weights in the safetensors format. +import numpy as np +import torch +import torchvision +from safetensors import torch as stt + +m = torchvision.models.resnet50(pretrained=True) +stt.save_file(m.state_dict(), 'resnet50.safetensors') +m = torchvision.models.resnet101(pretrained=True) +stt.save_file(m.state_dict(), 'resnet101.safetensors') +m = torchvision.models.resnet152(pretrained=True) +stt.save_file(m.state_dict(), 'resnet152.safetensors') diff --git a/candle-examples/examples/resnet/main.rs b/candle-examples/examples/resnet/main.rs index 3badc48e..4a4592ad 100644 --- a/candle-examples/examples/resnet/main.rs +++ b/candle-examples/examples/resnet/main.rs @@ -11,8 +11,16 @@ use clap::{Parser, ValueEnum}; #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { + #[value(name = "18")] Resnet18, + #[value(name = "34")] Resnet34, + #[value(name = "50")] + Resnet50, + #[value(name = "101")] + Resnet101, + #[value(name = "152")] + Resnet152, } #[derive(Parser)] @@ -47,6 +55,9 @@ pub fn main() -> anyhow::Result<()> { let filename = match args.which { Which::Resnet18 => "resnet18.safetensors", Which::Resnet34 => "resnet34.safetensors", + Which::Resnet50 => "resnet50.safetensors", + Which::Resnet101 => "resnet101.safetensors", + Which::Resnet152 => "resnet152.safetensors", }; api.get(filename)? } @@ -57,6 +68,9 @@ pub fn main() -> anyhow::Result<()> { let model = match args.which { Which::Resnet18 => resnet::resnet18(class_count, vb)?, Which::Resnet34 => resnet::resnet34(class_count, vb)?, + Which::Resnet50 => resnet::resnet50(class_count, vb)?, + Which::Resnet101 => resnet::resnet101(class_count, vb)?, + Which::Resnet152 => resnet::resnet152(class_count, vb)?, }; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?;