mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a weight extraction script.
This commit is contained in:
13
candle-examples/examples/whisper/extract_weights.py
Normal file
13
candle-examples/examples/whisper/extract_weights.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Get the checkpoint from
|
||||
# https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
data = torch.load("tiny.en.pt")
|
||||
weights = {}
|
||||
for k, v in data["model_state_dict"].items():
|
||||
weights[k] = v.contiguous()
|
||||
print(k, v.shape)
|
||||
save_file(weights, "tiny.en.safetensors")
|
||||
print(data["dims"])
|
@ -96,6 +96,23 @@ struct Config {
|
||||
n_text_layer: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn tiny() -> Self {
|
||||
Self {
|
||||
n_mels: 80,
|
||||
n_vocab: 51864,
|
||||
n_audio_ctx: 1500,
|
||||
n_audio_state: 384,
|
||||
n_audio_head: 6,
|
||||
n_audio_layer: 4,
|
||||
n_text_ctx: 448,
|
||||
n_text_state: 384,
|
||||
n_text_head: 6,
|
||||
n_text_layer: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
|
Reference in New Issue
Block a user