Properly handle the stride in conv1d.

This commit is contained in:
laurent
2023-07-04 15:05:04 +01:00
parent 29a0330d6d
commit 459e2e1ae3
2 changed files with 9 additions and 9 deletions

View File

@ -238,9 +238,10 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + dst_l;
let mut d = T::zero();
for offset in 0..p.k_size {
let src_l_plus = p.stride * dst_l + offset;
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in {
let src_l = dst_l + offset - k_over_2;
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
let src_l = src_l_plus - k_over_2;
for src_c_idx in 0..p.c_in {
let inp_idx =
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];

View File

@ -435,12 +435,7 @@ impl AudioEncoder {
};
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let positional_embedding = if true {
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
} else {
/* The positional embeddings could be regenerated via the following. */
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
};
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
let blocks = (0..cfg.n_audio_layer)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
@ -567,7 +562,11 @@ fn main() -> Result<()> {
let model = Whisper::load(&vb, &cfg)?;
let logits = model.forward(&mel, &tokens)?;
println!("{logits}");
println!("tokens\n{tokens}");
println!("logits:\n{logits}");
println!("python logits: {}", input.tensor("dec", &device)?);
let enc = model.encoder.forward(&mel)?;
println!("encoder:\n{enc}");
println!("python enc: {}", input.tensor("enc", &device)?);
Ok(())
}