diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4eb57bc7..b2345756 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -238,9 +238,10 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + dst_l; let mut d = T::zero(); for offset in 0..p.k_size { + let src_l_plus = p.stride * dst_l + offset; // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] - if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in { - let src_l = dst_l + offset - k_over_2; + if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { + let src_l = src_l_plus - k_over_2; for src_c_idx in 0..p.c_in { let inp_idx = inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 839dfc13..d119b6a7 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -435,12 +435,7 @@ impl AudioEncoder { }; let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; - let positional_embedding = if true { - vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))? - } else { - /* The positional embeddings could be regenerated via the following. */ - sinusoids(n_ctx, n_state)?.to_device(&vb.device)? - }; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?; let blocks = (0..cfg.n_audio_layer) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) @@ -567,7 +562,11 @@ fn main() -> Result<()> { let model = Whisper::load(&vb, &cfg)?; let logits = model.forward(&mel, &tokens)?; - println!("{logits}"); + println!("tokens\n{tokens}"); + println!("logits:\n{logits}"); println!("python logits: {}", input.tensor("dec", &device)?); + let enc = model.encoder.forward(&mel)?; + println!("encoder:\n{enc}"); + println!("python enc: {}", input.tensor("enc", &device)?); Ok(()) }