mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add a conv1d benchmark based on the whisper sizes. (#377)
* Add a conv1d benchmark based on the whisper sizes. * Enforce the batch-dim in conv1d.
This commit is contained in:
@ -773,18 +773,7 @@ impl Tensor {
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = match *self.dims() {
|
||||
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||
[c_in, l_in] => (None, c_in, l_in),
|
||||
_ => Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "input rank is not 2 or 3",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
|
Reference in New Issue
Block a user