Going overbounds will break other kernels running from other threads.

This commit is contained in:
Nicolas Patry
2023-11-06 17:29:58 +01:00
parent 4d87305c48
commit cd68c96803
2 changed files with 12 additions and 12 deletions

View File

@ -153,16 +153,15 @@ mod tests {
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> { fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
let device = device(); let device = device();
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; let options = MTLResourceOptions::StorageModeManaged;
let option = metal::MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data( let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void, v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64, (v.len() * core::mem::size_of::<T>()) as u64,
option, options,
); );
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, option); let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let library = device let library = device
.new_library_with_source(UNARY, &CompileOptions::new()) .new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library"); .expect("Failed to load unary library");
@ -184,8 +183,6 @@ mod tests {
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&dim)); encoder.set_bytes(0, 4, void_ptr(&dim));
// encoder.set_bytes(1, 4, void_ptr(&num_dims));
// encoder.set_bytes(2, 4, void_ptr(&info));
encoder.set_buffer(1, Some(&input), 0); encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0); encoder.set_buffer(2, Some(&output), 0);
@ -239,8 +236,7 @@ mod tests {
let pipeline = device let pipeline = device
.new_compute_pipeline_state_with_function(&function) .new_compute_pipeline_state_with_function(&function)
.unwrap(); .unwrap();
// let options = MTLResourceOptions::StorageModeShared; let options = MTLResourceOptions::StorageModeManaged;
let options = metal::MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -284,6 +280,7 @@ mod tests {
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]; let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
let result = outputs_buffer.read_to_vec::<f32>(output.len()); let result = outputs_buffer.read_to_vec::<f32>(output.len());
println!("Result {:?}", result.as_ptr());
assert_eq!(result, expected); assert_eq!(result, expected);
} }
@ -306,7 +303,7 @@ mod tests {
let pipeline = device let pipeline = device
.new_compute_pipeline_state_with_function(&function) .new_compute_pipeline_state_with_function(&function)
.unwrap(); .unwrap();
let options = metal::MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -353,6 +350,7 @@ mod tests {
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
]; ];
let result = outputs_buffer.read_to_vec::<f32>(right.len()); let result = outputs_buffer.read_to_vec::<f32>(right.len());
println!("Result {:?}", result.as_ptr());
assert_eq!(result, expected); assert_eq!(result, expected);
} }

View File

@ -44,7 +44,10 @@ kernel void FN_NAME( \
uint thread_index [[thread_index_in_threadgroup]] \ uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
output[i] = TYPENAME(FN(input[i])); \ if (i > dim){ \
return; \
} \
output[i] = FN(input[i]); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -61,8 +64,7 @@ kernel void FN_NAME_STRIDED( \
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
for (size_t i = start; i < stop; i++) { \ for (size_t i = start; i < stop; i++) { \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
output[i] = 1; \
} \ } \
} }