mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Impl index_add via template for all types
This commit is contained in:
@ -1,44 +1,71 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_config>
|
||||
#define METAL_FUNC inline __attribute__((__always_inline__))
|
||||
using namespace metal;
|
||||
|
||||
struct fault_counter {
|
||||
uint counter;
|
||||
uint tolerance;
|
||||
template <typename T, typename I>
|
||||
void index_add(
|
||||
device I *ids [[buffer(0)]],
|
||||
device T *inp [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
|
||||
fault_counter(uint tolerance) {
|
||||
this->counter = 0;
|
||||
this->tolerance = tolerance;
|
||||
}
|
||||
constant uint &ids_dim_size,
|
||||
constant uint &left_size,
|
||||
constant uint &dst_dim_size,
|
||||
constant uint &right_size,
|
||||
|
||||
bool quit() {
|
||||
counter += 1;
|
||||
return (counter > tolerance);
|
||||
}
|
||||
};
|
||||
|
||||
constant uint IDS_DIM_SIZE [[function_constant(0)]];
|
||||
constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed
|
||||
constant uint DST_DIM_SIZE [[function_constant(2)]];
|
||||
constant uint LEFT_SIZE [[function_constant(3)]];
|
||||
constant uint RIGHT_SIZE [[function_constant(4)]];
|
||||
constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE;
|
||||
|
||||
kernel void index_add(
|
||||
device uint *ids [[buffer(0)]],
|
||||
device float *inp [[buffer(1)]],
|
||||
device float *out [[buffer(2)]],
|
||||
uint threadgroup_size [[threads_per_threadgroup]],
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
|
||||
uint thread_index [[thread_index_in_threadgroup]]
|
||||
) {
|
||||
const uint i = thread_index;
|
||||
const uint pre = i / RIGHT_SIZE;
|
||||
const uint post = i % RIGHT_SIZE;
|
||||
|
||||
for (uint j = 0; j < IDS_DIM_SIZE; ++j) {
|
||||
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
|
||||
if (gid >= left_size * right_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = gid;
|
||||
const uint pre = i / right_size;
|
||||
const uint post = i % right_size;
|
||||
|
||||
for (uint j = 0; j < ids_dim_size; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
||||
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||
device TYPENAME *inp [[buffer(1)]], \
|
||||
device TYPENAME *out [[buffer(2)]], \
|
||||
constant uint &ids_dim_size, \
|
||||
constant uint &left_size, \
|
||||
constant uint &dst_dim_size, \
|
||||
constant uint &right_size, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
|
||||
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
|
||||
IA_OP(float, int64_t, ia_i64_f32)
|
||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||
|
||||
IA_OP(float, uint32_t, ia_u32_f32)
|
||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
@ -64,8 +64,8 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz
|
||||
mod tests {
|
||||
use super::*;
|
||||
use metal::{
|
||||
CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
|
||||
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
|
||||
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage,
|
||||
MTLSize, NSUInteger,
|
||||
};
|
||||
use std::ffi::c_void;
|
||||
use std::mem;
|
||||
@ -99,7 +99,7 @@ mod tests {
|
||||
let argument_encoder = func.new_argument_encoder(0);
|
||||
let arg_buffer = device.new_buffer(
|
||||
argument_encoder.encoded_length(),
|
||||
MTLResourceOptions::empty(),
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
argument_encoder.set_argument_buffer(&arg_buffer, 0);
|
||||
argument_encoder.set_buffer(0, &input, 0);
|
||||
@ -154,66 +154,54 @@ mod tests {
|
||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let right = [1.0f32; 15];
|
||||
let index = [0u32, 4, 2];
|
||||
|
||||
let ids_dim_size = index.len() as u32;
|
||||
|
||||
// Are these reversed?
|
||||
let src_dim_size: u32 = 9;
|
||||
let dst_dim_size: u32 = 15;
|
||||
let left_size: u32 = 3;
|
||||
let right_size: u32 = 3;
|
||||
|
||||
let fcv = FunctionConstantValues::new();
|
||||
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
|
||||
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
|
||||
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
|
||||
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
|
||||
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
|
||||
|
||||
let function = library.get_function("index_add", Some(fcv)).unwrap();
|
||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
let options = MTLResourceOptions::StorageModeShared;
|
||||
|
||||
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
||||
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
|
||||
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
||||
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
||||
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
||||
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
let thread_group_memory_length = output_size;
|
||||
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
|
||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||
|
||||
encoder.use_resource(&ids, MTLResourceUsage::Read);
|
||||
encoder.use_resource(&inputs, MTLResourceUsage::Read);
|
||||
encoder.use_resource(&outputs, MTLResourceUsage::Write);
|
||||
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
||||
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
||||
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
||||
|
||||
encoder.set_buffer(0, Some(&ids), 0);
|
||||
encoder.set_buffer(1, Some(&inputs), 0);
|
||||
encoder.set_buffer(2, Some(&outputs), 0);
|
||||
let width = 16;
|
||||
encoder.set_buffer(0, Some(&index_buffer), 0);
|
||||
encoder.set_buffer(1, Some(&inputs_buffer), 0);
|
||||
encoder.set_buffer(2, Some(&outputs_buffer), 0);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
|
||||
encoder.set_bytes(4, 4, void_ptr(&left_size));
|
||||
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
|
||||
encoder.set_bytes(6, 4, void_ptr(&right_size));
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: right.len() as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
@ -221,7 +209,7 @@ mod tests {
|
||||
let expected = vec![
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs.read_to_vec::<f32>(right.len());
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
|
18
candle-metal-kernels/src/utils.metal
Normal file
18
candle-metal-kernels/src/utils.metal
Normal file
@ -0,0 +1,18 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct fault_counter {
|
||||
uint counter;
|
||||
uint tolerance;
|
||||
|
||||
fault_counter(uint tolerance) {
|
||||
this->counter = 0;
|
||||
this->tolerance = tolerance;
|
||||
}
|
||||
|
||||
bool quit() {
|
||||
counter += 1;
|
||||
return (counter > tolerance);
|
||||
}
|
||||
};
|
Reference in New Issue
Block a user