diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 91b99ebc..0a3826d8 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -19,7 +19,7 @@ struct fault_counter { }; constant uint IDS_DIM_SIZE [[function_constant(0)]]; -constant uint SRC_DIM_SIZE [[function_constant(1)]]; +constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed constant uint DST_DIM_SIZE [[function_constant(2)]]; constant uint LEFT_SIZE [[function_constant(3)]]; constant uint RIGHT_SIZE [[function_constant(4)]]; @@ -29,21 +29,16 @@ kernel void index_add( device uint *ids [[buffer(0)]], device float *inp [[buffer(1)]], device float *out [[buffer(2)]], - - uint grid_size [[threadgroups_per_grid]], // gridDim - uint gid [[thread_position_in_grid]], // blockIdx - uint num_threads [[threads_per_grid]], // blockDim - uint thread_index [[thread_index_in_threadgroup]] // threadIdx + uint thread_index [[thread_index_in_threadgroup]] ) { - for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) { - const uint pre = i / RIGHT_SIZE; - const uint post = i % RIGHT_SIZE; + const uint i = thread_index; + const uint pre = i / RIGHT_SIZE; + const uint post = i % RIGHT_SIZE; - for (uint j = 0; j < IDS_DIM_SIZE; j++) { - const uint idx = ids[j]; - const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post; - const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post; - out[dst_i] += inp[src_i]; - } + for (uint j = 0; j < IDS_DIM_SIZE; ++j) { + const uint idx = ids[j]; + const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post; + const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post; + out[dst_i] += inp[src_i]; } } \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6d40e5cd..d4c9cfc3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,4 +1,4 @@ -use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger}; +use metal::{Buffer, CompileOptions, Device, Function, Library}; use std::collections::HashMap; use std::sync::RwLock; @@ -56,7 +56,7 @@ impl Kernels { } } -fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) { +fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) { todo!("Call unary"); } @@ -154,12 +154,14 @@ mod tests { let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; let right = [1.0f32; 15]; let index = [0u32, 4, 2]; + let ids_dim_size = index.len() as u32; - let src_dim_size = 2u32; - let dst_dim_size = 2u32; - let left_size = left.len() as u32; - let right_size = right.len() as u32; - let numel = left_size * right_size; + + // Are these reversed? + let src_dim_size: u32 = 9; + let dst_dim_size: u32 = 15; + let left_size: u32 = 3; + let right_size: u32 = 3; let fcv = FunctionConstantValues::new(); fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0); @@ -200,18 +202,16 @@ mod tests { let width = 16; let thread_group_count = MTLSize { - width, + width: 1, height: 1, depth: 1, }; let thread_group_size = MTLSize { - width: (numel as NSUInteger + width) / width, + width, height: 1, depth: 1, }; - println!("{:?}", thread_group_count); - println!("{:?}", thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -222,8 +222,6 @@ mod tests { 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; let result = outputs.read_to_vec::(right.len()); - println!("{:?}", result); - assert_eq!(result, expected); } }