Adding the convolutions (1d + 2d) to candle on metal.

This commit is contained in:
Nicolas Patry
2023-12-21 10:39:24 +01:00
parent 9fc210fae8
commit 10d94659c3
5 changed files with 396 additions and 84 deletions

View File

@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors
@ -115,6 +116,7 @@ pub enum Source {
Cast,
Reduce,
Mfa,
Conv,
}
macro_rules! ops{
@ -225,6 +227,7 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -1418,6 +1421,103 @@ pub fn call_gemm(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = shape[0] * l_out * shape[1] * k_size;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
l_out,
k_size,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let h = shape[2];
let w = shape[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
h_out,
w_out,
h_k,
w_k,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}