From 599160605c0294c94c33f64aeca0ac9f388d03c7 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 13:13:28 +0100 Subject: [PATCH] Use the stored embeddings. --- candle-examples/examples/whisper/main.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 6341c5ee..1b6f4bfe 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -424,6 +424,7 @@ impl AudioEncoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let n_state = cfg.n_audio_state; let n_head = cfg.n_audio_head; + let n_ctx = cfg.n_audio_ctx; let cfg1 = ConvConfig { padding: 1, stride: 1, @@ -434,7 +435,12 @@ impl AudioEncoder { }; let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; - let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?; + /* The positional embeddings could be regenerated via the following. */ + let positional_embedding = if true { + vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))? + } else { + sinusoids(n_ctx, n_state)?.to_device(&vb.device)? + }; let blocks = (0..cfg.n_audio_layer) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)