Add more of the conv1d op.

This commit is contained in:
laurent
2023-07-04 11:15:45 +01:00
parent 3aac1047fe
commit a424d95473
8 changed files with 52 additions and 19 deletions

24
candle-core/src/conv.rs Normal file
View File

@ -0,0 +1,24 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self, l_in: usize) -> usize {
let dilation = 1;
(l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> {
let l_out = self.l_out(l_in);
match self.b_size {
None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out],
}
}
}

View File

@ -632,8 +632,7 @@ impl CpuStorage {
_l: &Layout, _l: &Layout,
_kernel: &Self, _kernel: &Self,
_kernel_l: &Layout, _kernel_l: &Layout,
_padding: usize, _params: &crate::conv::ParamsConv1D,
_stride: usize,
) -> Result<Self> { ) -> Result<Self> {
todo!() todo!()
} }

View File

@ -806,8 +806,7 @@ impl CudaStorage {
_l: &Layout, _l: &Layout,
_kernel: &Self, _kernel: &Self,
_kernel_l: &Layout, _kernel_l: &Layout,
_padding: usize, _params: &crate::conv::ParamsConv1D,
_stride: usize,
) -> Result<Self> { ) -> Result<Self> {
todo!() todo!()
} }

View File

@ -105,8 +105,7 @@ impl CudaStorage {
_l: &Layout, _l: &Layout,
_kernel: &Self, _kernel: &Self,
_kernel_l: &Layout, _kernel_l: &Layout,
_padding: usize, _params: &crate::conv::ParamsConv1D,
_stride: usize,
) -> Result<Self> { ) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -1,4 +1,5 @@
mod backprop; mod backprop;
mod conv;
mod cpu_backend; mod cpu_backend;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
mod cuda_backend; mod cuda_backend;

View File

@ -149,18 +149,17 @@ impl Storage {
l: &Layout, l: &Layout,
kernel: &Self, kernel: &Self,
kernel_l: &Layout, kernel_l: &Layout,
padding: usize, params: &crate::conv::ParamsConv1D,
stride: usize,
) -> Result<Self> { ) -> Result<Self> {
self.same_device(kernel, "conv1d")?; self.same_device(kernel, "conv1d")?;
self.same_dtype(kernel, "conv1d")?; self.same_dtype(kernel, "conv1d")?;
match (self, &kernel) { match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => { (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?; let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s)) Ok(Self::Cpu(s))
} }
(Storage::Cuda(inp), Storage::Cuda(kernel)) => { (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?; let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {

View File

@ -433,13 +433,26 @@ impl Tensor {
} }
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let storage = self.storage.conv1d( let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
self.layout(), let (b_size, c_in, l_in) = match *self.dims() {
&kernel.storage, [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
kernel.layout(), [c_in, l_in] => (None, c_in, l_in),
_ => todo!("proper error message"),
};
if c_in != c_in_k {
todo!("proper error message")
}
let params = crate::conv::ParamsConv1D {
b_size,
c_out,
c_in,
k_size,
padding, padding,
stride, stride,
)?; };
let storage =
self.storage
.conv1d(self.layout(), &kernel.storage, kernel.layout(), &params)?;
let op = if self.track_op() || kernel.track_op() { let op = if self.track_op() || kernel.track_op() {
Some(Op::Conv1D { Some(Op::Conv1D {
arg: self.clone(), arg: self.clone(),
@ -450,8 +463,8 @@ impl Tensor {
} else { } else {
None None
}; };
let dims = self.dims(); let out_dims = params.out_dims(l_in);
Ok(from_storage(storage, dims, op, false)) Ok(from_storage(storage, out_dims, op, false))
} }
pub fn matmul(&self, rhs: &Self) -> Result<Self> { pub fn matmul(&self, rhs: &Self) -> Result<Self> {

View File

@ -236,8 +236,7 @@ impl Conv1D {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?; let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?; let w = self.weight.broadcast_left(bsize)?.t()?;
// TODO: Add the conv1d operation let x = x.conv1d(&w, self.config.padding, self.config.stride)?;
let x = x.matmul(&w)?;
match &self.bias { match &self.bias {
None => Ok(x), None => Ok(x),
Some(bias) => x.broadcast_add(bias), Some(bias) => x.broadcast_add(bias),