diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e1bdda63..f877b5e3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -153,16 +153,15 @@ mod tests { fn run_cos(v: &[T], name: &str) -> Vec { let device = device(); - let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; - let option = metal::MTLResourceOptions::StorageModeManaged; + let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = device.new_buffer_with_data( v.as_ptr() as *const core::ffi::c_void, (v.len() * core::mem::size_of::()) as u64, - option, + options, ); - let output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, option); + let output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); let library = device .new_library_with_source(UNARY, &CompileOptions::new()) .expect("Failed to load unary library"); @@ -184,8 +183,6 @@ mod tests { encoder.set_compute_pipeline_state(&pipeline); encoder.set_bytes(0, 4, void_ptr(&dim)); - // encoder.set_bytes(1, 4, void_ptr(&num_dims)); - // encoder.set_bytes(2, 4, void_ptr(&info)); encoder.set_buffer(1, Some(&input), 0); encoder.set_buffer(2, Some(&output), 0); @@ -239,8 +236,7 @@ mod tests { let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); - // let options = MTLResourceOptions::StorageModeShared; - let options = metal::MTLResourceOptions::StorageModeManaged; + let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -284,6 +280,7 @@ mod tests { let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]; let result = outputs_buffer.read_to_vec::(output.len()); + println!("Result {:?}", result.as_ptr()); assert_eq!(result, expected); } @@ -306,7 +303,7 @@ mod tests { let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); - let options = metal::MTLResourceOptions::StorageModeManaged; + let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -353,6 +350,7 @@ mod tests { 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; let result = outputs_buffer.read_to_vec::(right.len()); + println!("Result {:?}", result.as_ptr()); assert_eq!(result, expected); } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 715dcced..fd9011ba 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -44,7 +44,10 @@ kernel void FN_NAME( \ uint thread_index [[thread_index_in_threadgroup]] \ ) { \ const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ - output[i] = TYPENAME(FN(input[i])); \ + if (i > dim){ \ + return; \ + } \ + output[i] = FN(input[i]); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -61,8 +64,7 @@ kernel void FN_NAME_STRIDED( \ const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ for (size_t i = start; i < stop; i++) { \ - output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ - output[i] = 1; \ + output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \ } \ }