mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Impl index_add via template for all types
This commit is contained in:
@ -1,44 +1,71 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include <metal_config>
|
|
||||||
#define METAL_FUNC inline __attribute__((__always_inline__))
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
struct fault_counter {
|
template <typename T, typename I>
|
||||||
uint counter;
|
void index_add(
|
||||||
uint tolerance;
|
device I *ids [[buffer(0)]],
|
||||||
|
device T *inp [[buffer(1)]],
|
||||||
|
device T *out [[buffer(2)]],
|
||||||
|
|
||||||
fault_counter(uint tolerance) {
|
constant uint &ids_dim_size,
|
||||||
this->counter = 0;
|
constant uint &left_size,
|
||||||
this->tolerance = tolerance;
|
constant uint &dst_dim_size,
|
||||||
}
|
constant uint &right_size,
|
||||||
|
|
||||||
bool quit() {
|
uint threadgroup_size [[threads_per_threadgroup]],
|
||||||
counter += 1;
|
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
|
||||||
return (counter > tolerance);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
constant uint IDS_DIM_SIZE [[function_constant(0)]];
|
|
||||||
constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed
|
|
||||||
constant uint DST_DIM_SIZE [[function_constant(2)]];
|
|
||||||
constant uint LEFT_SIZE [[function_constant(3)]];
|
|
||||||
constant uint RIGHT_SIZE [[function_constant(4)]];
|
|
||||||
constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE;
|
|
||||||
|
|
||||||
kernel void index_add(
|
|
||||||
device uint *ids [[buffer(0)]],
|
|
||||||
device float *inp [[buffer(1)]],
|
|
||||||
device float *out [[buffer(2)]],
|
|
||||||
uint thread_index [[thread_index_in_threadgroup]]
|
uint thread_index [[thread_index_in_threadgroup]]
|
||||||
) {
|
) {
|
||||||
const uint i = thread_index;
|
|
||||||
const uint pre = i / RIGHT_SIZE;
|
|
||||||
const uint post = i % RIGHT_SIZE;
|
|
||||||
|
|
||||||
for (uint j = 0; j < IDS_DIM_SIZE; ++j) {
|
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
|
||||||
|
if (gid >= left_size * right_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i = gid;
|
||||||
|
const uint pre = i / right_size;
|
||||||
|
const uint post = i % right_size;
|
||||||
|
|
||||||
|
for (uint j = 0; j < ids_dim_size; j++) {
|
||||||
const uint idx = ids[j];
|
const uint idx = ids[j];
|
||||||
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
out[dst_i] += inp[src_i];
|
out[dst_i] += inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||||
|
device TYPENAME *inp [[buffer(1)]], \
|
||||||
|
device TYPENAME *out [[buffer(2)]], \
|
||||||
|
constant uint &ids_dim_size, \
|
||||||
|
constant uint &left_size, \
|
||||||
|
constant uint &dst_dim_size, \
|
||||||
|
constant uint &right_size, \
|
||||||
|
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||||
|
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||||
|
uint thread_index [[thread_index_in_threadgroup]] \
|
||||||
|
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
|
||||||
|
|
||||||
|
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||||
|
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||||
|
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||||
|
|
||||||
|
IA_OP(half, uint32_t, ia_u32_f16)
|
||||||
|
IA_OP(half, uint8_t, ia_u8_f16)
|
||||||
|
|
||||||
|
IA_OP(float, int64_t, ia_i64_f32)
|
||||||
|
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||||
|
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||||
|
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||||
|
|
||||||
|
IA_OP(float, uint32_t, ia_u32_f32)
|
||||||
|
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||||
|
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||||
|
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||||
|
|
||||||
|
IA_OP(float, uint8_t, ia_u8_f32)
|
||||||
|
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||||
|
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||||
|
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
@ -64,8 +64,8 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use metal::{
|
use metal::{
|
||||||
CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
|
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage,
|
||||||
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
|
MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
@ -99,7 +99,7 @@ mod tests {
|
|||||||
let argument_encoder = func.new_argument_encoder(0);
|
let argument_encoder = func.new_argument_encoder(0);
|
||||||
let arg_buffer = device.new_buffer(
|
let arg_buffer = device.new_buffer(
|
||||||
argument_encoder.encoded_length(),
|
argument_encoder.encoded_length(),
|
||||||
MTLResourceOptions::empty(),
|
MTLResourceOptions::StorageModeShared,
|
||||||
);
|
);
|
||||||
argument_encoder.set_argument_buffer(&arg_buffer, 0);
|
argument_encoder.set_argument_buffer(&arg_buffer, 0);
|
||||||
argument_encoder.set_buffer(0, &input, 0);
|
argument_encoder.set_buffer(0, &input, 0);
|
||||||
@ -154,66 +154,54 @@ mod tests {
|
|||||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
let right = [1.0f32; 15];
|
let right = [1.0f32; 15];
|
||||||
let index = [0u32, 4, 2];
|
let index = [0u32, 4, 2];
|
||||||
|
|
||||||
let ids_dim_size = index.len() as u32;
|
let ids_dim_size = index.len() as u32;
|
||||||
|
|
||||||
// Are these reversed?
|
|
||||||
let src_dim_size: u32 = 9;
|
|
||||||
let dst_dim_size: u32 = 15;
|
let dst_dim_size: u32 = 15;
|
||||||
let left_size: u32 = 3;
|
let left_size: u32 = 3;
|
||||||
let right_size: u32 = 3;
|
let right_size: u32 = 3;
|
||||||
|
|
||||||
let fcv = FunctionConstantValues::new();
|
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||||
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
|
|
||||||
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
|
|
||||||
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
|
|
||||||
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
|
|
||||||
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
|
|
||||||
|
|
||||||
let function = library.get_function("index_add", Some(fcv)).unwrap();
|
|
||||||
let pipeline = device
|
let pipeline = device
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let options = MTLResourceOptions::StorageModeShared;
|
let options = MTLResourceOptions::StorageModeShared;
|
||||||
|
|
||||||
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
|
||||||
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
|
||||||
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
|
||||||
|
|
||||||
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
|
||||||
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
|
||||||
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
||||||
|
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||||
|
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
let thread_group_memory_length = output_size;
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||||
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
|
|
||||||
|
|
||||||
encoder.use_resource(&ids, MTLResourceUsage::Read);
|
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
||||||
encoder.use_resource(&inputs, MTLResourceUsage::Read);
|
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
||||||
encoder.use_resource(&outputs, MTLResourceUsage::Write);
|
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
||||||
|
|
||||||
encoder.set_buffer(0, Some(&ids), 0);
|
encoder.set_buffer(0, Some(&index_buffer), 0);
|
||||||
encoder.set_buffer(1, Some(&inputs), 0);
|
encoder.set_buffer(1, Some(&inputs_buffer), 0);
|
||||||
encoder.set_buffer(2, Some(&outputs), 0);
|
encoder.set_buffer(2, Some(&outputs_buffer), 0);
|
||||||
let width = 16;
|
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
|
||||||
width: 1,
|
encoder.set_bytes(4, 4, void_ptr(&left_size));
|
||||||
|
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
|
||||||
|
encoder.set_bytes(6, 4, void_ptr(&right_size));
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: right.len() as NSUInteger,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
let thread_group_size = MTLSize {
|
||||||
width,
|
width: pipeline.max_total_threads_per_threadgroup(),
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
@ -221,7 +209,7 @@ mod tests {
|
|||||||
let expected = vec![
|
let expected = vec![
|
||||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
];
|
];
|
||||||
let result = outputs.read_to_vec::<f32>(right.len());
|
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
18
candle-metal-kernels/src/utils.metal
Normal file
18
candle-metal-kernels/src/utils.metal
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct fault_counter {
|
||||||
|
uint counter;
|
||||||
|
uint tolerance;
|
||||||
|
|
||||||
|
fault_counter(uint tolerance) {
|
||||||
|
this->counter = 0;
|
||||||
|
this->tolerance = tolerance;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool quit() {
|
||||||
|
counter += 1;
|
||||||
|
return (counter > tolerance);
|
||||||
|
}
|
||||||
|
};
|
Reference in New Issue
Block a user