Merge branch 'main' into ivarflakstad/metal-prng

This commit is contained in:
Ivar Flakstad
2024-01-17 18:02:01 +01:00
37 changed files with 6934 additions and 552 deletions

View File

@ -16,6 +16,7 @@ const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -64,6 +65,8 @@ macro_rules! primitive {
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
@ -121,6 +124,7 @@ pub enum Source {
Mfa,
Conv,
Random,
Quantized,
}
macro_rules! ops{
@ -219,17 +223,15 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels {
libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>,
fence: metal::Fence,
}
impl Kernels {
pub fn new(fence: metal::Fence) -> Self {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new());
Self {
libraries,
pipelines,
fence,
}
}
@ -244,6 +246,7 @@ impl Kernels {
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -350,7 +353,6 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
@ -359,7 +361,6 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -381,7 +382,6 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -403,7 +403,6 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -422,7 +421,6 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
@ -433,7 +431,6 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -458,7 +455,6 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -483,7 +479,6 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -502,7 +497,6 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
@ -511,7 +505,6 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -531,7 +524,6 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -553,7 +545,6 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -573,7 +564,6 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -602,7 +592,6 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -624,7 +613,6 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -660,7 +648,6 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -679,7 +666,6 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -710,7 +696,6 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -730,7 +715,6 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
@ -739,7 +723,6 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -762,7 +745,6 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -783,7 +765,6 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -802,7 +783,6 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
@ -811,7 +791,6 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -833,7 +812,6 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -853,7 +831,6 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -872,7 +849,6 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
@ -881,7 +857,6 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -903,7 +878,6 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -923,7 +897,6 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -945,7 +918,6 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@ -974,7 +946,6 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1001,7 +972,6 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1024,7 +994,6 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1053,7 +1022,6 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1076,7 +1044,6 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1105,7 +1072,6 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1128,7 +1094,6 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1158,7 +1123,6 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1182,7 +1146,6 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1386,7 +1349,6 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1430,7 +1392,6 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1455,7 +1416,6 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1475,7 +1435,6 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1505,7 +1464,6 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1527,7 +1485,6 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1553,7 +1510,6 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1571,16 +1527,11 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
@ -1604,7 +1555,6 @@ pub fn call_random_uniform(
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, min, max, seed, buffer));
@ -1613,7 +1563,6 @@ pub fn call_random_uniform(
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1637,7 +1586,6 @@ pub fn call_random_normal(
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, mean, stddev, seed, buffer));
@ -1646,11 +1594,184 @@ pub fn call_random_normal(
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub enum GgmlDType {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
F16,
F32,
}
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
dtype: GgmlDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let nb10 = 0i64;
let nb11 = 0i64;
let nb12 = 0i64;
let ne0 = n as i64;
let ne1 = m as i64;
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;
let (nth0, nth1, align) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
};
let thread_groups_count = MTLSize {
width: divide(ne01 as usize, align),
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
rhs,
(lhs, lhs_offset),
output,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();
Ok(())
}
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
#[cfg(test)]
mod tests;