mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add more of the conv1d op.
This commit is contained in:
24
candle-core/src/conv.rs
Normal file
24
candle-core/src/conv.rs
Normal file
@ -0,0 +1,24 @@
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct ParamsConv1D {
|
||||
pub(crate) b_size: Option<usize>,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
pub(crate) fn l_out(&self, l_in: usize) -> usize {
|
||||
let dilation = 1;
|
||||
(l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> {
|
||||
let l_out = self.l_out(l_in);
|
||||
match self.b_size {
|
||||
None => vec![self.c_out, l_out],
|
||||
Some(n) => vec![n, self.c_out, l_out],
|
||||
}
|
||||
}
|
||||
}
|
@ -632,8 +632,7 @@ impl CpuStorage {
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_padding: usize,
|
||||
_stride: usize,
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -806,8 +806,7 @@ impl CudaStorage {
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_padding: usize,
|
||||
_stride: usize,
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -105,8 +105,7 @@ impl CudaStorage {
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_padding: usize,
|
||||
_stride: usize,
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
mod backprop;
|
||||
mod conv;
|
||||
mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_backend;
|
||||
|
@ -149,18 +149,17 @@ impl Storage {
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv1d")?;
|
||||
self.same_dtype(kernel, "conv1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
|
@ -433,13 +433,26 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let storage = self.storage.conv1d(
|
||||
self.layout(),
|
||||
&kernel.storage,
|
||||
kernel.layout(),
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
let (b_size, c_in, l_in) = match *self.dims() {
|
||||
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||
[c_in, l_in] => (None, c_in, l_in),
|
||||
_ => todo!("proper error message"),
|
||||
};
|
||||
if c_in != c_in_k {
|
||||
todo!("proper error message")
|
||||
}
|
||||
let params = crate::conv::ParamsConv1D {
|
||||
b_size,
|
||||
c_out,
|
||||
c_in,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
)?;
|
||||
};
|
||||
let storage =
|
||||
self.storage
|
||||
.conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
@ -450,8 +463,8 @@ impl Tensor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let dims = self.dims();
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
let out_dims = params.out_dims(l_in);
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
|
@ -236,8 +236,7 @@ impl Conv1D {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
// TODO: Add the conv1d operation
|
||||
let x = x.matmul(&w)?;
|
||||
let x = x.conv1d(&w, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
|
Reference in New Issue
Block a user