mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Properly handle the stride in conv1d.
This commit is contained in:
@ -435,12 +435,7 @@ impl AudioEncoder {
|
||||
};
|
||||
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
||||
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||
let positional_embedding = if true {
|
||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
|
||||
} else {
|
||||
/* The positional embeddings could be regenerated via the following. */
|
||||
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
|
||||
};
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
|
||||
let blocks = (0..cfg.n_audio_layer)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
@ -567,7 +562,11 @@ fn main() -> Result<()> {
|
||||
|
||||
let model = Whisper::load(&vb, &cfg)?;
|
||||
let logits = model.forward(&mel, &tokens)?;
|
||||
println!("{logits}");
|
||||
println!("tokens\n{tokens}");
|
||||
println!("logits:\n{logits}");
|
||||
println!("python logits: {}", input.tensor("dec", &device)?);
|
||||
let enc = model.encoder.forward(&mel)?;
|
||||
println!("encoder:\n{enc}");
|
||||
println!("python enc: {}", input.tensor("enc", &device)?);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user