Sketch the conv1d op.

This commit is contained in:
laurent
2023-07-04 10:52:34 +01:00
parent e6b01d0c18
commit 3aac1047fe
7 changed files with 97 additions and 1 deletions

View File

@ -33,7 +33,12 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Add(lhs, rhs)
Op::Conv1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
@ -147,6 +152,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}

View File

@ -627,6 +627,17 @@ impl CpuStorage {
WCond(pred, layout).map(t, t_l, f, f_l)
}
pub(crate) fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_padding: usize,
_stride: usize,
) -> Result<Self> {
todo!()
}
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = self.as_slice::<u32>()?;
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;

View File

@ -801,6 +801,17 @@ impl CudaStorage {
Ok(Self { slice, device })
}
pub(crate) fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_padding: usize,
_stride: usize,
) -> Result<Self> {
todo!()
}
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;

View File

@ -100,6 +100,17 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_padding: usize,
_stride: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -12,6 +12,14 @@ pub(crate) enum Op {
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),
#[allow(dead_code)]
Conv1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
stride: usize,
},
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.

View File

@ -144,6 +144,33 @@ impl Storage {
}
}
pub(crate) fn conv1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
padding: usize,
stride: usize,
) -> Result<Self> {
self.same_device(kernel, "conv1d")?;
self.same_dtype(kernel, "conv1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv1d",
}),
}
}
pub(crate) fn where_cond(
&self,
layout: &Layout,

View File

@ -432,6 +432,28 @@ impl Tensor {
Ok(from_storage(storage, dims, op, false))
}
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let storage = self.storage.conv1d(
self.layout(),
&kernel.storage,
kernel.layout(),
padding,
stride,
)?;
let op = if self.track_op() || kernel.track_op() {
Some(Op::Conv1D {
arg: self.clone(),
kernel: kernel.clone(),
padding,
stride,
})
} else {
None
};
let dims = self.dims();
Ok(from_storage(storage, dims, op, false))
}
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();