mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add a conv1d benchmark based on the whisper sizes. (#377)
* Add a conv1d benchmark based on the whisper sizes. * Enforce the batch-dim in conv1d.
This commit is contained in:
@ -1037,10 +1037,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
let mut dst = vec![T::zero(); dst_elems];
|
||||
// The output shape is [b_size, c_out, l_out]
|
||||
for b_idx in 0..p.b_size.unwrap_or(1) {
|
||||
for b_idx in 0..p.b_size {
|
||||
let inp_idx = b_idx * inp_s0;
|
||||
let dst_idx = b_idx * p.c_out * l_out;
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
|
Reference in New Issue
Block a user