mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Merge pull request #1461 from huggingface/metal-conv
Adding the convolutions (1d + 2d) to candle on metal.
This commit is contained in:
@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv1D,
|
||||
layout: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("conv1d metal")
|
||||
let device = self.device().clone();
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let strides = layout.stride();
|
||||
|
||||
let stride = params.stride;
|
||||
let dilation = params.dilation;
|
||||
let padding = params.padding;
|
||||
let k_size = params.k_size;
|
||||
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||
let dst_el = dims[0] * l_out * dims[1] * k_size;
|
||||
let dst = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "im2col1d_f32",
|
||||
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_im2col1d_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
layout.shape().dims(),
|
||||
strides,
|
||||
(k_size, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let l_out = params.l_out();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv2D,
|
||||
layout: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("conv2d metal")
|
||||
let device = self.device().clone();
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
|
||||
let stride = params.stride;
|
||||
let dilation = params.dilation;
|
||||
let padding = params.padding;
|
||||
let h_k = params.k_h;
|
||||
let w_k = params.k_w;
|
||||
let h = dims[2];
|
||||
let w = dims[3];
|
||||
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
||||
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
||||
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
|
||||
|
||||
let dst = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "im2col_f32",
|
||||
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(h_k, w_k, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let h_out = params.out_h();
|
||||
let w_out = params.out_w();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
Reference in New Issue
Block a user