mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
369 lines
12 KiB
Rust
369 lines
12 KiB
Rust
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
|
use once_cell::sync::Lazy;
|
|
use std::collections::HashMap;
|
|
use std::ffi::c_void;
|
|
use std::sync::RwLock;
|
|
|
|
pub const AFFINE: &str = include_str!("affine.metal");
|
|
pub const INDEXING: &str = include_str!("indexing.metal");
|
|
pub const UNARY: &str = include_str!("unary.metal");
|
|
|
|
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
|
let mut l = HashMap::new();
|
|
l.insert("affine", AFFINE);
|
|
l.insert("indexing", INDEXING);
|
|
l.insert("unary", UNARY);
|
|
l
|
|
});
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
pub enum MetalKernelError {
|
|
#[error("Could not lock kernel map: {0}")]
|
|
LockError(String),
|
|
#[error("Error while loading library: {0}")]
|
|
LoadLibraryError(String),
|
|
#[error("Error while loading function: {0}")]
|
|
LoadFunctionError(String),
|
|
}
|
|
|
|
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|
fn from(e: std::sync::PoisonError<T>) -> Self {
|
|
Self::LockError(e.to_string())
|
|
}
|
|
}
|
|
|
|
type KernelMap<T> = HashMap<&'static str, T>;
|
|
type Libraries = KernelMap<Library>;
|
|
type Functions = KernelMap<Function>;
|
|
|
|
#[derive(Debug)]
|
|
pub struct Kernels {
|
|
libraries: RwLock<Libraries>,
|
|
funcs: RwLock<Functions>,
|
|
}
|
|
|
|
impl Kernels {
|
|
pub fn new() -> Self {
|
|
let libraries = RwLock::new(Libraries::new());
|
|
let funcs = RwLock::new(Functions::new());
|
|
Self { libraries, funcs }
|
|
}
|
|
|
|
pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
|
let kernels = Self::new();
|
|
kernels.load_libraries(device)?;
|
|
Ok(kernels)
|
|
}
|
|
|
|
fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
|
for name in LIBRARY_SOURCES.keys() {
|
|
self.load_library(device, name)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn get_library_source(&self, name: &'static str) -> Option<&'static str> {
|
|
LIBRARY_SOURCES.get(name).cloned()
|
|
}
|
|
|
|
pub fn load_library(
|
|
&self,
|
|
device: &Device,
|
|
name: &'static str,
|
|
) -> Result<Library, MetalKernelError> {
|
|
let mut libraries = self.libraries.write()?;
|
|
if let Some(lib) = libraries.get(name) {
|
|
Ok(lib.clone())
|
|
} else {
|
|
let source = self.get_library_source(name).ok_or_else(|| {
|
|
MetalKernelError::LoadLibraryError(format!("No source found for {}", name))
|
|
})?;
|
|
let lib = device
|
|
.new_library_with_source(source, &CompileOptions::new())
|
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
|
libraries.insert(name, lib.clone());
|
|
Ok(lib)
|
|
}
|
|
}
|
|
|
|
pub fn load_function(
|
|
&self,
|
|
device: &Device,
|
|
library_name: &'static str,
|
|
name: &'static str,
|
|
) -> Result<Function, MetalKernelError> {
|
|
let mut funcs = self.funcs.write()?;
|
|
if let Some(func) = funcs.get(name) {
|
|
Ok(func.clone())
|
|
} else {
|
|
let func = self
|
|
.load_library(device, library_name)?
|
|
.get_function(name, None)
|
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
|
funcs.insert(name, func.clone());
|
|
Ok(func)
|
|
}
|
|
}
|
|
|
|
pub fn call_unary(
|
|
&self,
|
|
device: &Device,
|
|
library_name: &'static str,
|
|
name: &'static str,
|
|
input: &Buffer,
|
|
output: &mut Buffer,
|
|
length: usize,
|
|
) -> Result<(), MetalKernelError> {
|
|
let func = self.load_function(device, library_name, name)?;
|
|
call_unary(&func, input, output, length);
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
|
|
todo!("Call unary");
|
|
}
|
|
|
|
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
|
(v as *const T).cast()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use half::f16;
|
|
use metal::{
|
|
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger,
|
|
};
|
|
use std::mem;
|
|
|
|
fn device() -> Device {
|
|
Device::system_default().unwrap()
|
|
}
|
|
|
|
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
|
let b = 10f32.powi(digits);
|
|
v.iter().map(|t| f32::round(t * b) / b).collect()
|
|
}
|
|
|
|
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
|
let b = 10f32.powi(digits);
|
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
|
}
|
|
|
|
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
|
let device = device();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let input = device.new_buffer_with_data(
|
|
v.as_ptr() as *const core::ffi::c_void,
|
|
(v.len() * core::mem::size_of::<T>()) as u64,
|
|
options,
|
|
);
|
|
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
|
let library = device
|
|
.new_library_with_source(UNARY, &CompileOptions::new())
|
|
.expect("Failed to load unary library");
|
|
let func = library.get_function(&format!("cos_{name}"), None).unwrap();
|
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
|
|
let pipeline = device
|
|
.new_compute_pipeline_state_with_function(
|
|
pipeline_state_descriptor.compute_function().unwrap(),
|
|
)
|
|
.unwrap();
|
|
|
|
let dim: u32 = v.len() as u32;
|
|
// let num_dims: u32 = 1;
|
|
// let info = [v.len() as u32, 1];
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
encoder.set_bytes(0, 4, void_ptr(&dim));
|
|
|
|
encoder.set_buffer(1, Some(&input), 0);
|
|
encoder.set_buffer(2, Some(&output), 0);
|
|
|
|
let width = v.len() as NSUInteger;
|
|
|
|
let thread_group_count = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let thread_group_size = MTLSize {
|
|
width: pipeline.max_total_threads_per_threadgroup(),
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.end_encoding();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
output.read_to_vec::<T>(v.len())
|
|
}
|
|
|
|
#[test]
|
|
fn cos_f32() {
|
|
let v = vec![1.0f32, 2.0, 3.0];
|
|
let results = run_cos(&v, "float");
|
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
|
}
|
|
|
|
#[test]
|
|
fn affine() {
|
|
let device = device();
|
|
|
|
let options = CompileOptions::new();
|
|
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
|
|
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
|
let dim: u32 = 8;
|
|
let num_dims: u32 = 4;
|
|
let info = [1u32, 2, 3];
|
|
let mul: f32 = 1.5;
|
|
let add: f32 = 1.1;
|
|
|
|
let function = library.get_function("affine", None).unwrap();
|
|
let pipeline = device
|
|
.new_compute_pipeline_state_with_function(&function)
|
|
.unwrap();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
let input_size = (input.len() * mem::size_of::<f32>()) as NSUInteger;
|
|
let output_size = (output.len() * mem::size_of::<f32>()) as NSUInteger;
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
|
|
|
let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options);
|
|
let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options);
|
|
|
|
encoder.set_bytes(0, 4, void_ptr(&dim));
|
|
encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
|
encoder.set_bytes(2, 4, void_ptr(&info));
|
|
|
|
encoder.set_buffer(3, Some(&inputs_buffer), 0);
|
|
encoder.set_buffer(4, Some(&outputs_buffer), 0);
|
|
|
|
encoder.set_bytes(5, 4, void_ptr(&mul));
|
|
encoder.set_bytes(6, 4, void_ptr(&add));
|
|
|
|
let grid_size = MTLSize {
|
|
width: output.len() as NSUInteger,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let thread_group_size = MTLSize {
|
|
width: pipeline.max_total_threads_per_threadgroup(),
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
|
encoder.end_encoding();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
|
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
|
println!("Result {:?}", result.as_ptr());
|
|
assert_eq!(result, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn index_add() {
|
|
let device = Device::system_default().expect("no device found");
|
|
|
|
let options = CompileOptions::new();
|
|
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
|
|
|
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
|
let right = [1.0f32; 15];
|
|
let index = [0u32, 4, 2];
|
|
let ids_dim_size = index.len() as u32;
|
|
let dst_dim_size: u32 = 15;
|
|
let left_size: u32 = 3;
|
|
let right_size: u32 = 3;
|
|
|
|
let function = library.get_function("ia_u32_f32", None).unwrap();
|
|
let pipeline = device
|
|
.new_compute_pipeline_state_with_function(&function)
|
|
.unwrap();
|
|
let options = MTLResourceOptions::StorageModeManaged;
|
|
|
|
let command_queue = device.new_command_queue();
|
|
let command_buffer = command_queue.new_command_buffer();
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
|
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
|
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
|
|
|
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
|
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
|
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
|
|
|
encoder.set_buffer(0, Some(&index_buffer), 0);
|
|
encoder.set_buffer(1, Some(&inputs_buffer), 0);
|
|
encoder.set_buffer(2, Some(&outputs_buffer), 0);
|
|
|
|
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
|
|
encoder.set_bytes(4, 4, void_ptr(&left_size));
|
|
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
|
|
encoder.set_bytes(6, 4, void_ptr(&right_size));
|
|
|
|
let grid_size = MTLSize {
|
|
width: right.len() as NSUInteger,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let thread_group_size = MTLSize {
|
|
width: pipeline.max_total_threads_per_threadgroup(),
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
|
encoder.end_encoding();
|
|
command_buffer.commit();
|
|
command_buffer.wait_until_completed();
|
|
|
|
let expected = vec![
|
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
|
];
|
|
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
|
println!("Result {:?}", result.as_ptr());
|
|
assert_eq!(result, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn cos_f16() {
|
|
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
|
.iter()
|
|
.map(|v| f16::from_f32(*v))
|
|
.collect();
|
|
let results = run_cos(&v, "half");
|
|
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
|
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
|
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
|
}
|
|
}
|