mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Adding the convolutions (1d + 2d) to candle on metal.
This commit is contained in:
@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
@ -115,6 +116,7 @@ pub enum Source {
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
@ -225,6 +227,7 @@ impl Kernels {
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -1418,6 +1421,103 @@ pub fn call_gemm(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col1d_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
l_out,
|
||||
k_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
|
||||
let h = shape[2];
|
||||
let w = shape[3];
|
||||
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
||||
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
||||
|
||||
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
Reference in New Issue
Block a user