Fixes + cache compute_pipeline_state.

This commit is contained in:
Nicolas Patry
2023-11-13 14:33:16 +01:00
parent 79845bd93b
commit dd4a40f1c0
3 changed files with 68 additions and 133 deletions

View File

@ -86,7 +86,6 @@ impl MetalDevice {
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger; let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
// debug!("Allocate 1 - buffer size {size}");
self.device self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged) .new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
@ -115,7 +114,9 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let start = std::time::Instant::now();
self.device.wait_until_completed(); self.device.wait_until_completed();
println!("Wait took {:?}", start.elapsed());
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8( DType::U8 => Ok(CpuStorage::U8(
@ -414,6 +415,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
self.device.wait_until_completed();
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -899,9 +901,12 @@ impl BackendDevice for MetalDevice {
} }
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
// TODO Is there a faster way ? let buffer = self.new_buffer(shape.elem_count(), dtype);
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; Ok(MetalStorage {
self.storage_from_cpu_storage(&cpu_storage) buffer,
device: self.clone(),
dtype,
})
} }
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {

View File

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_arguments)]
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
ComputePipelineState, Device, Function, Library, MTLSize, Device, Function, Library, MTLSize,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
@ -174,6 +174,10 @@ pub enum MetalKernelError {
LoadLibraryError(String), LoadLibraryError(String),
#[error("Error while loading function: {0:?}")] #[error("Error while loading function: {0:?}")]
LoadFunctionError(String), LoadFunctionError(String),
#[error("Failed to create compute function")]
FailedToCreateComputeFunction,
#[error("Failed to create pipeline")]
FailedToCreatePipeline(String),
} }
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@ -184,19 +188,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
type KernelMap<T> = HashMap<&'static str, T>; type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>; type Libraries = HashMap<Source, Library>;
type Functions = KernelMap<Function>; type Pipelines = KernelMap<ComputePipelineState>;
#[derive(Debug, Default)] #[derive(Debug)]
pub struct Kernels { pub struct Kernels {
libraries: RwLock<Libraries>, libraries: RwLock<Libraries>,
funcs: RwLock<Functions>, pipelines: RwLock<Pipelines>,
} }
impl Kernels { impl Kernels {
pub fn new() -> Self { pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new()); let libraries = RwLock::new(Libraries::new());
let funcs = RwLock::new(Functions::new()); let pipelines = RwLock::new(Pipelines::new());
Self { libraries, funcs } Self {
libraries,
pipelines,
}
} }
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> { // pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
@ -243,22 +250,43 @@ impl Kernels {
} }
} }
pub fn load_function( fn load_function(
&self, &self,
device: &Device, device: &Device,
source: Source, source: Source,
name: &'static str, name: &'static str,
) -> Result<Function, MetalKernelError> { ) -> Result<Function, MetalKernelError> {
let mut funcs = self.funcs.write()?; let func = self
if let Some(func) = funcs.get(name) { .load_library(device, source)?
Ok(func.clone()) .get_function(name, None)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
Ok(func)
// let mut funcs = self.funcs.write()?;
// if let Some(func) = funcs.get(name) {
// Ok(func.clone())
// } else {
// funcs.insert(name, func.clone());
// Ok(func)
// }
}
pub fn load_pipeline(
&self,
device: &Device,
source: Source,
name: &'static str,
) -> Result<ComputePipelineState, MetalKernelError> {
let mut pipelines = self.pipelines.write()?;
if let Some(pipeline) = pipelines.get(name) {
Ok(pipeline.clone())
} else { } else {
let func = self let func = self.load_function(device, source, name)?;
.load_library(device, source)? let pipeline = device
.get_function(name, None) .new_compute_pipeline_state_with_function(&func)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
funcs.insert(name, func.clone()); pipelines.insert(name, pipeline.clone());
Ok(func)
Ok(pipeline)
} }
} }
} }
@ -274,16 +302,8 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0); // println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length()); // assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -306,15 +326,7 @@ pub fn call_unary_strided(
output: &mut Buffer, output: &mut Buffer,
output_offset: usize, output_offset: usize,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Unary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
@ -351,17 +363,7 @@ pub fn call_binary_contiguous(
right: &Buffer, right: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0); let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -389,15 +391,7 @@ pub fn call_binary_strided(
right_offset: usize, right_offset: usize,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Binary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
@ -436,17 +430,7 @@ pub fn call_cast_contiguous(
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0); let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -473,15 +457,7 @@ pub fn call_cast_strided(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0); // println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length()); // assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -510,16 +486,7 @@ pub fn call_reduce_contiguous(
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
@ -560,16 +527,7 @@ pub fn call_last_softmax(
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -612,15 +570,7 @@ pub fn call_affine(
mul: f32, mul: f32,
add: f32, add: f32,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -646,18 +596,9 @@ pub fn call_affine_strided(
mul: f32, mul: f32,
add: f32, add: f32,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -695,15 +636,7 @@ pub fn call_where_cond_strided(
(right_stride, right_offset): (&[usize], usize), (right_stride, right_offset): (&[usize], usize),
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -751,10 +684,7 @@ pub fn call_index_select(
let src_dim_size = shape[dim]; let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size; let dst_el = ids_size * left_size * right_size;
let func = kernels.load_function(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();

View File

@ -40,7 +40,7 @@ template <typename T> METAL_FUNC T erf(T in){
float t = 1.0/(1.0 + p*x); float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return (T) sign*y; return T(sign*y);
} }
template <typename T> METAL_FUNC T id(T in){ return in; } template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; } template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
@ -49,7 +49,7 @@ template <typename T> METAL_FUNC T gelu(T x){
T x_cube = x_sq * x; T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube; T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanh(beta)); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
} }