mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Sketch the conv1d op.
This commit is contained in:
@ -33,7 +33,12 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Add(lhs, rhs)
|
||||
Op::Conv1D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
@ -147,6 +152,7 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
|
@ -627,6 +627,17 @@ impl CpuStorage {
|
||||
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||
}
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_padding: usize,
|
||||
_stride: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||
|
@ -801,6 +801,17 @@ impl CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_padding: usize,
|
||||
_stride: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||
|
@ -100,6 +100,17 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_padding: usize,
|
||||
_stride: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -12,6 +12,14 @@ pub(crate) enum Op {
|
||||
Embedding(Tensor, Tensor),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
||||
#[allow(dead_code)]
|
||||
Conv1D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
},
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
#[allow(dead_code)] // add is currently unused.
|
||||
|
@ -144,6 +144,33 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv1d")?;
|
||||
self.same_dtype(kernel, "conv1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv1d",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
|
@ -432,6 +432,28 @@ impl Tensor {
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let storage = self.storage.conv1d(
|
||||
self.layout(),
|
||||
&kernel.storage,
|
||||
kernel.layout(),
|
||||
padding,
|
||||
stride,
|
||||
)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
kernel: kernel.clone(),
|
||||
padding,
|
||||
stride,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let dims = self.dims();
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let a_dims = self.shape().dims();
|
||||
let b_dims = rhs.shape().dims();
|
||||
|
Reference in New Issue
Block a user