mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Going overbounds will break other kernels running from other threads.
This commit is contained in:
@ -153,16 +153,15 @@ mod tests {
|
||||
|
||||
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
||||
let device = device();
|
||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
option,
|
||||
options,
|
||||
);
|
||||
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, option);
|
||||
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
let library = device
|
||||
.new_library_with_source(UNARY, &CompileOptions::new())
|
||||
.expect("Failed to load unary library");
|
||||
@ -184,8 +183,6 @@ mod tests {
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
||||
// encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
||||
// encoder.set_bytes(2, 4, void_ptr(&info));
|
||||
|
||||
encoder.set_buffer(1, Some(&input), 0);
|
||||
encoder.set_buffer(2, Some(&output), 0);
|
||||
@ -239,8 +236,7 @@ mod tests {
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
// let options = MTLResourceOptions::StorageModeShared;
|
||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
@ -284,6 +280,7 @@ mod tests {
|
||||
|
||||
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
||||
println!("Result {:?}", result.as_ptr());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
@ -306,7 +303,7 @@ mod tests {
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
@ -353,6 +350,7 @@ mod tests {
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
println!("Result {:?}", result.as_ptr());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,10 @@ kernel void FN_NAME( \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
output[i] = TYPENAME(FN(input[i])); \
|
||||
if (i > dim){ \
|
||||
return; \
|
||||
} \
|
||||
output[i] = FN(input[i]); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -61,8 +64,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
||||
for (size_t i = start; i < stop; i++) { \
|
||||
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
|
||||
output[i] = 1; \
|
||||
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user