Add a weight extraction script.

This commit is contained in:
laurent
2023-07-04 09:29:19 +01:00
parent c09aa4b0f4
commit d71b31144d
2 changed files with 30 additions and 0 deletions

View File

@ -0,0 +1,13 @@
# Get the checkpoint from
# https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt
import torch
from safetensors.torch import save_file
data = torch.load("tiny.en.pt")
weights = {}
for k, v in data["model_state_dict"].items():
weights[k] = v.contiguous()
print(k, v.shape)
save_file(weights, "tiny.en.safetensors")
print(data["dims"])

View File

@ -96,6 +96,23 @@ struct Config {
n_text_layer: usize,
}
impl Config {
fn tiny() -> Self {
Self {
n_mels: 80,
n_vocab: 51864,
n_audio_ctx: 1500,
n_audio_state: 384,
n_audio_head: 6,
n_audio_layer: 4,
n_text_ctx: 448,
n_text_state: 384,
n_text_head: 6,
n_text_layer: 4,
}
}
}
struct Embedding {
embeddings: Tensor,
hidden_size: usize,