mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Proper conv1d dispatch.
This commit is contained in:
@ -1,6 +1,9 @@
|
|||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub(crate) struct ParamsConv1D {
|
pub(crate) struct ParamsConv1D {
|
||||||
pub(crate) b_size: Option<usize>,
|
pub(crate) b_size: Option<usize>,
|
||||||
|
// Maybe we should have a version without l_in as this bit depends on the input and not only on
|
||||||
|
// the weights.
|
||||||
|
pub(crate) l_in: usize,
|
||||||
pub(crate) c_out: usize,
|
pub(crate) c_out: usize,
|
||||||
pub(crate) c_in: usize,
|
pub(crate) c_in: usize,
|
||||||
pub(crate) k_size: usize,
|
pub(crate) k_size: usize,
|
||||||
@ -9,13 +12,13 @@ pub(crate) struct ParamsConv1D {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ParamsConv1D {
|
impl ParamsConv1D {
|
||||||
pub(crate) fn l_out(&self, l_in: usize) -> usize {
|
pub(crate) fn l_out(&self) -> usize {
|
||||||
let dilation = 1;
|
let dilation = 1;
|
||||||
(l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> {
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||||
let l_out = self.l_out(l_in);
|
let l_out = self.l_out();
|
||||||
match self.b_size {
|
match self.b_size {
|
||||||
None => vec![self.c_out, l_out],
|
None => vec![self.c_out, l_out],
|
||||||
Some(n) => vec![n, self.c_out, l_out],
|
Some(n) => vec![n, self.c_out, l_out],
|
||||||
|
@ -202,6 +202,26 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||||
|
|
||||||
|
impl<'a> Map2 for Conv1D<'a> {
|
||||||
|
const OP: &'static str = "conv1d";
|
||||||
|
fn f<T: 'static + num_traits::Num + Copy>(
|
||||||
|
&self,
|
||||||
|
_inp: &[T],
|
||||||
|
_inp_l: &Layout,
|
||||||
|
_k: &[T],
|
||||||
|
_k_l: &Layout,
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
let p = self.0;
|
||||||
|
let l_out = p.l_out();
|
||||||
|
let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||||
|
let dst = vec![T::zero(); out_elems];
|
||||||
|
// TODO: actually implement the ops.
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct MatMul((usize, usize, usize, usize));
|
struct MatMul((usize, usize, usize, usize));
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
@ -629,12 +649,12 @@ impl CpuStorage {
|
|||||||
|
|
||||||
pub(crate) fn conv1d(
|
pub(crate) fn conv1d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
l: &Layout,
|
||||||
_kernel: &Self,
|
kernel: &Self,
|
||||||
_kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
_params: &crate::conv::ParamsConv1D,
|
params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
|
@ -444,6 +444,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
let params = crate::conv::ParamsConv1D {
|
let params = crate::conv::ParamsConv1D {
|
||||||
b_size,
|
b_size,
|
||||||
|
l_in,
|
||||||
c_out,
|
c_out,
|
||||||
c_in,
|
c_in,
|
||||||
k_size,
|
k_size,
|
||||||
@ -463,7 +464,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let out_dims = params.out_dims(l_in);
|
let out_dims = params.out_dims();
|
||||||
Ok(from_storage(storage, out_dims, op, false))
|
Ok(from_storage(storage, out_dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user