mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Sketch the conv1d op.
This commit is contained in:
@ -33,7 +33,12 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
Op::Add(lhs, rhs)
|
Op::Conv1D {
|
||||||
|
arg: lhs,
|
||||||
|
kernel: rhs,
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| Op::Add(lhs, rhs)
|
||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
| Op::Div(lhs, rhs)
|
| Op::Div(lhs, rhs)
|
||||||
@ -147,6 +152,7 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||||
}
|
}
|
||||||
|
@ -627,6 +627,17 @@ impl CpuStorage {
|
|||||||
WCond(pred, layout).map(t, t_l, f, f_l)
|
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_padding: usize,
|
||||||
|
_stride: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||||
|
@ -801,6 +801,17 @@ impl CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_padding: usize,
|
||||||
|
_stride: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||||
|
@ -100,6 +100,17 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_padding: usize,
|
||||||
|
_stride: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,14 @@ pub(crate) enum Op {
|
|||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
WhereCond(Tensor, Tensor, Tensor),
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Conv1D {
|
||||||
|
arg: Tensor,
|
||||||
|
kernel: Tensor,
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
},
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
#[allow(dead_code)] // add is currently unused.
|
#[allow(dead_code)] // add is currently unused.
|
||||||
|
@ -144,6 +144,33 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(kernel, "conv1d")?;
|
||||||
|
self.same_dtype(kernel, "conv1d")?;
|
||||||
|
match (self, &kernel) {
|
||||||
|
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
|
||||||
|
Ok(Self::Cpu(s))
|
||||||
|
}
|
||||||
|
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
|
||||||
|
Ok(Self::Cuda(s))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "conv1d",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
|
@ -432,6 +432,28 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, dims, op, false))
|
Ok(from_storage(storage, dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||||
|
let storage = self.storage.conv1d(
|
||||||
|
self.layout(),
|
||||||
|
&kernel.storage,
|
||||||
|
kernel.layout(),
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
)?;
|
||||||
|
let op = if self.track_op() || kernel.track_op() {
|
||||||
|
Some(Op::Conv1D {
|
||||||
|
arg: self.clone(),
|
||||||
|
kernel: kernel.clone(),
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let dims = self.dims();
|
||||||
|
Ok(from_storage(storage, dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||||
let a_dims = self.shape().dims();
|
let a_dims = self.shape().dims();
|
||||||
let b_dims = rhs.shape().dims();
|
let b_dims = rhs.shape().dims();
|
||||||
|
Reference in New Issue
Block a user