mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Writing unary as macro instead, protecting bfloat type with proper metal version.
This commit is contained in:
@ -12,3 +12,6 @@ license.workspace = true
|
||||
metal = { workspace = true }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
half = { workspace = true }
|
||||
|
@ -48,9 +48,13 @@ kernel void FN_NAME( \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
|
||||
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
#endif
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
@ -68,4 +72,4 @@ IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
||||
|
@ -131,11 +131,10 @@ pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use half::f16;
|
||||
use metal::{
|
||||
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage,
|
||||
MTLSize, NSUInteger,
|
||||
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::ffi::c_void;
|
||||
use std::mem;
|
||||
|
||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||
@ -143,32 +142,26 @@ mod tests {
|
||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos() {
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||
let device = Device::system_default().unwrap();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const c_void,
|
||||
(v.len() * mem::size_of::<f32>()) as u64,
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
option,
|
||||
);
|
||||
let output = device.new_buffer((v.len() * mem::size_of::<f32>()) as u64, option);
|
||||
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, option);
|
||||
let library = device
|
||||
.new_library_with_source(UNARY, &CompileOptions::new())
|
||||
.expect("Failed to load unary library");
|
||||
let func = library.get_function("cos", None).unwrap();
|
||||
let argument_encoder = func.new_argument_encoder(0);
|
||||
let arg_buffer = device.new_buffer(
|
||||
argument_encoder.encoded_length(),
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
argument_encoder.set_argument_buffer(&arg_buffer, 0);
|
||||
argument_encoder.set_buffer(0, &input, 0);
|
||||
argument_encoder.set_buffer(1, &output, 0);
|
||||
let func = library.get_function(&format!("cos_{name}"), None).unwrap();
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
@ -178,11 +171,10 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline_state);
|
||||
encoder.set_buffer(0, Some(&arg_buffer), 0);
|
||||
|
||||
encoder.use_resource(&input, MTLResourceUsage::Read);
|
||||
encoder.use_resource(&output, MTLResourceUsage::Write);
|
||||
encoder.set_buffer(0, Some(&input), 0);
|
||||
encoder.set_buffer(1, Some(&output), 0);
|
||||
|
||||
let width = 16;
|
||||
|
||||
@ -202,9 +194,14 @@ mod tests {
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32() {
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let results = run_cos(&v, "float");
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
let results = output.read_to_vec::<f32>(v.len());
|
||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
}
|
||||
@ -343,4 +340,16 @@ mod tests {
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f16() {
|
||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let results = run_cos(&v, "half");
|
||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
||||
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||
}
|
||||
}
|
||||
|
@ -2,13 +2,21 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct Input {
|
||||
device float *input;
|
||||
device float *output;
|
||||
};
|
||||
|
||||
kernel void cos(device Input& args [[ buffer(0) ]], uint index [[thread_position_in_grid]])
|
||||
template <typename T>
|
||||
kernel void unary_cos(device const T *input, device T *output, uint index [[thread_position_in_grid]])
|
||||
{
|
||||
args.output[index] = cos(args.input[index]);
|
||||
output[index] = cos(input[index]);
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME(device const TYPENAME *input, device TYPENAME *output, uint index [[thread_position_in_grid]]) \
|
||||
{ \
|
||||
output[index] = FN(input[index]);\
|
||||
}
|
||||
|
||||
UNARY(cos, float, cos_float);
|
||||
UNARY(cos, half, cos_half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
UNARY(cos, half, cos_half);
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user