mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Expose the larger resnets (50/101/152) in the example. (#1131)
This commit is contained in:
12
candle-examples/examples/resnet/export_models.py
Normal file
12
candle-examples/examples/resnet/export_models.py
Normal file
@ -0,0 +1,12 @@
|
||||
# This script exports pre-trained model weights in the safetensors format.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors import torch as stt
|
||||
|
||||
m = torchvision.models.resnet50(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet50.safetensors')
|
||||
m = torchvision.models.resnet101(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet101.safetensors')
|
||||
m = torchvision.models.resnet152(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet152.safetensors')
|
Reference in New Issue
Block a user