mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the conv-transpose1d op. (#1251)
* Skeleton structure for conv-transpose1d. * CPU implementation for conv-transpose1d.
This commit is contained in:
@ -279,6 +279,33 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv-transpose1d")?;
|
||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv-transpose1d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
Reference in New Issue
Block a user