mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Going overbounds will break other kernels running from other threads.
This commit is contained in:
@ -153,16 +153,15 @@ mod tests {
|
|||||||
|
|
||||||
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = device.new_buffer_with_data(
|
let input = device.new_buffer_with_data(
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||||
option,
|
options,
|
||||||
);
|
);
|
||||||
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, option);
|
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||||
let library = device
|
let library = device
|
||||||
.new_library_with_source(UNARY, &CompileOptions::new())
|
.new_library_with_source(UNARY, &CompileOptions::new())
|
||||||
.expect("Failed to load unary library");
|
.expect("Failed to load unary library");
|
||||||
@ -184,8 +183,6 @@ mod tests {
|
|||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
encoder.set_bytes(0, 4, void_ptr(&dim));
|
||||||
// encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
|
||||||
// encoder.set_bytes(2, 4, void_ptr(&info));
|
|
||||||
|
|
||||||
encoder.set_buffer(1, Some(&input), 0);
|
encoder.set_buffer(1, Some(&input), 0);
|
||||||
encoder.set_buffer(2, Some(&output), 0);
|
encoder.set_buffer(2, Some(&output), 0);
|
||||||
@ -239,8 +236,7 @@ mod tests {
|
|||||||
let pipeline = device
|
let pipeline = device
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// let options = MTLResourceOptions::StorageModeShared;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
@ -284,6 +280,7 @@ mod tests {
|
|||||||
|
|
||||||
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
||||||
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
||||||
|
println!("Result {:?}", result.as_ptr());
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,7 +303,7 @@ mod tests {
|
|||||||
let pipeline = device
|
let pipeline = device
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
@ -353,6 +350,7 @@ mod tests {
|
|||||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
];
|
];
|
||||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||||
|
println!("Result {:?}", result.as_ptr());
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +44,10 @@ kernel void FN_NAME( \
|
|||||||
uint thread_index [[thread_index_in_threadgroup]] \
|
uint thread_index [[thread_index_in_threadgroup]] \
|
||||||
) { \
|
) { \
|
||||||
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||||
output[i] = TYPENAME(FN(input[i])); \
|
if (i > dim){ \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[i] = FN(input[i]); \
|
||||||
}\
|
}\
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -61,8 +64,7 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||||
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
||||||
for (size_t i = start; i < stop; i++) { \
|
for (size_t i = start; i < stop; i++) { \
|
||||||
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
|
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||||
output[i] = 1; \
|
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user