mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Working but failing tests because of threadgroup.
This commit is contained in:
@ -59,4 +59,4 @@ kernel void affine(
|
||||
out[strided_i] = x * mul + add;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -137,6 +137,10 @@ mod tests {
|
||||
};
|
||||
use std::mem;
|
||||
|
||||
fn device() -> Device {
|
||||
Device::system_default().unwrap()
|
||||
}
|
||||
|
||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||
@ -148,8 +152,9 @@ mod tests {
|
||||
}
|
||||
|
||||
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
||||
let device = device();
|
||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||
let device = Device::system_default().unwrap();
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = device.new_buffer_with_data(
|
||||
@ -165,18 +170,27 @@ mod tests {
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline_state = device
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline_state);
|
||||
encoder.set_buffer(0, Some(&input), 0);
|
||||
encoder.set_buffer(1, Some(&output), 0);
|
||||
let dim: u32 = v.len() as u32;
|
||||
// let num_dims: u32 = 1;
|
||||
// let info = [v.len() as u32, 1];
|
||||
|
||||
let width = 16;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
||||
// encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
||||
// encoder.set_bytes(2, 4, void_ptr(&info));
|
||||
|
||||
encoder.set_buffer(1, Some(&input), 0);
|
||||
encoder.set_buffer(2, Some(&output), 0);
|
||||
|
||||
let width = v.len() as NSUInteger;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width,
|
||||
@ -185,7 +199,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: (v.len() as u64 + width) / width,
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
@ -208,7 +222,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn affine() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
let device = device();
|
||||
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
||||
@ -225,7 +239,8 @@ mod tests {
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
let options = MTLResourceOptions::StorageModeShared;
|
||||
// let options = MTLResourceOptions::StorageModeShared;
|
||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
@ -291,7 +306,7 @@ mod tests {
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
let options = MTLResourceOptions::StorageModeShared;
|
||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
@ -1,22 +1,74 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
METAL_FUNC bool is_contiguous(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
size_t acc = 1;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
kernel void unary_cos(device const T *input, device T *output, uint index [[thread_position_in_grid]])
|
||||
{
|
||||
output[index] = cos(input[index]);
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
output[i] = FN(input[i]); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *info, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
constant size_t *dims = info; \
|
||||
constant size_t *strides = info + num_dims; \
|
||||
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
||||
for (size_t i = start; i < stop; i++) { \
|
||||
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
output[i] = 1; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME(device const TYPENAME *input, device TYPENAME *output, uint index [[thread_position_in_grid]]) \
|
||||
{ \
|
||||
output[index] = FN(input[index]);\
|
||||
}
|
||||
|
||||
UNARY(cos, float, cos_float);
|
||||
UNARY(cos, half, cos_half);
|
||||
UNARY(cos, float, cos_float, cos_float_strided);
|
||||
UNARY(cos, half, cos_half, cos_half_strided);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
UNARY(cos, half, cos_half);
|
||||
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user