Merge pull request #1461 from huggingface/metal-conv

Adding the convolutions (1d + 2d) to candle on metal.
This commit is contained in:
Nicolas Patry
2023-12-25 12:48:09 +01:00
committed by GitHub
5 changed files with 399 additions and 86 deletions

View File

@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage {
fn conv1d( fn conv1d(
&self, &self,
_l: &Layout, layout: &Layout,
_kernel: &Self, kernel: &Self,
_kernel_l: &Layout, kernel_l: &Layout,
_params: &ParamsConv1D, params: &ParamsConv1D,
) -> Result<Self> { ) -> Result<Self> {
crate::bail!("conv1d metal") let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let strides = layout.stride();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let k_size = params.k_size;
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = dims[0] * l_out * dims[1] * k_size;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
dtype: self.dtype,
};
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
fn conv_transpose1d( fn conv_transpose1d(
@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage {
fn conv2d( fn conv2d(
&self, &self,
_l: &Layout, layout: &Layout,
_kernel: &Self, kernel: &Self,
_kernel_l: &Layout, kernel_l: &Layout,
_params: &ParamsConv2D, params: &ParamsConv2D,
) -> Result<Self> { ) -> Result<Self> {
crate::bail!("conv2d metal") let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let h_k = params.k_h;
let w_k = params.k_w;
let h = dims[2];
let w = dims[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
dtype: self.dtype,
};
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
fn conv_transpose2d( fn conv_transpose2d(

View File

@ -0,0 +1,153 @@
template <typename T>
METAL_FUNC void im2col(
constant size_t &dst_numel,
constant size_t &h_out,
constant size_t &w_out,
constant size_t &h_k,
constant size_t &w_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
// src: (b_size, c_in, h_in, w_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
const size_t dst_s0 = h_out * dst_s1;
size_t tmp_tid = tid;
const size_t b_idx = tmp_tid / dst_s0;
tmp_tid -= b_idx * dst_s0;
const size_t h_idx = tmp_tid / dst_s1;
tmp_tid -= h_idx * dst_s1;
const size_t w_idx = tmp_tid / dst_s2;
tmp_tid -= w_idx * dst_s2;
const size_t c_idx = tmp_tid / dst_s3;
tmp_tid -= c_idx * dst_s3;
const size_t h_k_idx = tmp_tid / dst_s4;
tmp_tid -= h_k_idx * dst_s4;
const size_t w_k_idx = tmp_tid;
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
dst[tid] = static_cast<T>(0);
}
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_h_idx -= padding;
src_w_idx -= padding;
const size_t src_i =
b_idx * src_strides[0]
+ c_idx * src_strides[1]
+ src_h_idx * src_strides[2]
+ src_w_idx * src_strides[3];
dst[tid] = src[src_i];
}
}
template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
constant size_t &l_out,
constant size_t &l_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = tid;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
dst[tid] = src[src_i];
}
}
#define IM2COL_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &h_out, \
constant size_t &w_out, \
constant size_t &h_k, \
constant size_t &w_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define IM2COL1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &l_out, \
constant size_t &l_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)

View File

@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal"); const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors /// Most kernels apply similarly across the tensors
@ -115,6 +116,7 @@ pub enum Source {
Cast, Cast,
Reduce, Reduce,
Mfa, Mfa,
Conv,
} }
macro_rules! ops{ macro_rules! ops{
@ -225,6 +227,7 @@ impl Kernels {
Source::Indexing => INDEXING, Source::Indexing => INDEXING,
Source::Cast => CAST, Source::Cast => CAST,
Source::Reduce => REDUCE, Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Mfa => panic!("Invalid lib"), Source::Mfa => panic!("Invalid lib"),
} }
} }
@ -1298,7 +1301,7 @@ pub fn call_gemm(
let fused_activation = false; let fused_activation = false;
let fused_bias = false; let fused_bias = false;
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
let m_simd = 16; let m_simd = 8;
let n_simd = 8; let n_simd = 8;
let k_simd = 64; let k_simd = 64;
let m_splits = 1; let m_splits = 1;
@ -1307,7 +1310,7 @@ pub fn call_gemm(
} else { } else {
let m_simd = 40; let m_simd = 40;
let n_simd = 40; let n_simd = 40;
let k_simd = 8; let k_simd = 32;
let m_splits = 1; let m_splits = 1;
let n_splits = 1; let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits) (m_simd, n_simd, k_simd, m_splits, n_splits)
@ -1418,6 +1421,103 @@ pub fn call_gemm(
Ok(()) Ok(())
} }
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = shape[0] * l_out * shape[1] * k_size;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
l_out,
k_size,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let h = shape[2];
let w = shape[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
h_out,
w_out,
h_k,
w_k,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
fn divide(m: usize, b: usize) -> NSUInteger { fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger ((m + b - 1) / b) as NSUInteger
} }

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use half::{bf16, f16}; use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; use metal::{Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T; let ptr = buffer.contents() as *const T;
@ -485,73 +485,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
read_to_vec(&dst_buffer, dst_el) read_to_vec(&dst_buffer, dst_el)
} }
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let index_buffer = new_buffer(&device, &index);
let inputs_buffer = new_buffer(&device, &left);
let outputs_buffer = new_buffer(&device, &right);
set_params!(
encoder,
(
&index_buffer,
&inputs_buffer,
&outputs_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size
)
);
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
assert_eq!(result, expected);
}
#[test] #[test]
fn cos_f16() { fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0] let v: Vec<f16> = [1.0f32, 2.0, 3.0]

View File

@ -64,12 +64,12 @@ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (thread_position_in_grid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ output[tid] = TYPENAME(FN(float(input[tid]))); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -78,12 +78,12 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \ constant size_t *strides, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (thread_position_in_grid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
} }
#define UNARY_OP(NAME) \ #define UNARY_OP(NAME) \