Expose the larger resnets (50/101/152) in the example. (#1131)

This commit is contained in:
Laurent Mazare
2023-10-19 13:48:28 +01:00
committed by GitHub
parent cd53c472df
commit 93c25e8844
3 changed files with 27 additions and 1 deletions

View File

@ -0,0 +1,12 @@
# This script exports pre-trained model weights in the safetensors format.
import numpy as np
import torch
import torchvision
from safetensors import torch as stt
m = torchvision.models.resnet50(pretrained=True)
stt.save_file(m.state_dict(), 'resnet50.safetensors')
m = torchvision.models.resnet101(pretrained=True)
stt.save_file(m.state_dict(), 'resnet101.safetensors')
m = torchvision.models.resnet152(pretrained=True)
stt.save_file(m.state_dict(), 'resnet152.safetensors')