Add more of the conv1d op.

This commit is contained in:
laurent
2023-07-04 11:15:45 +01:00
parent 3aac1047fe
commit a424d95473
8 changed files with 52 additions and 19 deletions

24
candle-core/src/conv.rs Normal file
View File

@ -0,0 +1,24 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self, l_in: usize) -> usize {
let dilation = 1;
(l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> {
let l_out = self.l_out(l_in);
match self.b_size {
None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out],
}
}
}