Merge pull request #1461 from huggingface/metal-conv

Adding the convolutions (1d + 2d) to candle on metal.
This commit is contained in:
Nicolas Patry
2023-12-25 12:48:09 +01:00
committed by GitHub
5 changed files with 399 additions and 86 deletions

View File

@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage {
fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv1D,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv1D,
) -> Result<Self> {
crate::bail!("conv1d metal")
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let strides = layout.stride();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let k_size = params.k_size;
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = dims[0] * l_out * dims[1] * k_size;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
dtype: self.dtype,
};
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose1d(
@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage {
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv2D,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv2D,
) -> Result<Self> {
crate::bail!("conv2d metal")
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let h_k = params.k_h;
let w_k = params.k_w;
let h = dims[2];
let w = dims[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
dtype: self.dtype,
};
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose2d(