Fixes + cache compute_pipeline_state.

This commit is contained in:
Nicolas Patry
2023-11-13 14:33:16 +01:00
parent 79845bd93b
commit dd4a40f1c0
3 changed files with 68 additions and 133 deletions

View File

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)]
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
ComputePipelineState, Device, Function, Library, MTLSize,
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, Library, MTLSize,
};
use std::collections::HashMap;
use std::ffi::c_void;
@ -174,6 +174,10 @@ pub enum MetalKernelError {
LoadLibraryError(String),
#[error("Error while loading function: {0:?}")]
LoadFunctionError(String),
#[error("Failed to create compute function")]
FailedToCreateComputeFunction,
#[error("Failed to create pipeline")]
FailedToCreatePipeline(String),
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@ -184,19 +188,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
type Functions = KernelMap<Function>;
type Pipelines = KernelMap<ComputePipelineState>;
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
funcs: RwLock<Functions>,
pipelines: RwLock<Pipelines>,
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
let funcs = RwLock::new(Functions::new());
Self { libraries, funcs }
let pipelines = RwLock::new(Pipelines::new());
Self {
libraries,
pipelines,
}
}
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
@ -243,22 +250,43 @@ impl Kernels {
}
}
pub fn load_function(
fn load_function(
&self,
device: &Device,
source: Source,
name: &'static str,
) -> Result<Function, MetalKernelError> {
let mut funcs = self.funcs.write()?;
if let Some(func) = funcs.get(name) {
Ok(func.clone())
let func = self
.load_library(device, source)?
.get_function(name, None)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
Ok(func)
// let mut funcs = self.funcs.write()?;
// if let Some(func) = funcs.get(name) {
// Ok(func.clone())
// } else {
// funcs.insert(name, func.clone());
// Ok(func)
// }
}
pub fn load_pipeline(
&self,
device: &Device,
source: Source,
name: &'static str,
) -> Result<ComputePipelineState, MetalKernelError> {
let mut pipelines = self.pipelines.write()?;
if let Some(pipeline) = pipelines.get(name) {
Ok(pipeline.clone())
} else {
let func = self
.load_library(device, source)?
.get_function(name, None)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
funcs.insert(name, func.clone());
Ok(func)
let func = self.load_function(device, source, name)?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
pipelines.insert(name, pipeline.clone());
Ok(pipeline)
}
}
}
@ -274,16 +302,8 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -306,15 +326,7 @@ pub fn call_unary_strided(
output: &mut Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Unary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
@ -351,17 +363,7 @@ pub fn call_binary_contiguous(
right: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -389,15 +391,7 @@ pub fn call_binary_strided(
right_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Binary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
@ -436,17 +430,7 @@ pub fn call_cast_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -473,15 +457,7 @@ pub fn call_cast_strided(
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -510,16 +486,7 @@ pub fn call_reduce_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
@ -560,16 +527,7 @@ pub fn call_last_softmax(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -612,15 +570,7 @@ pub fn call_affine(
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -646,18 +596,9 @@ pub fn call_affine_strided(
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -695,15 +636,7 @@ pub fn call_where_cond_strided(
(right_stride, right_offset): (&[usize], usize),
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Ternary, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -751,10 +684,7 @@ pub fn call_index_select(
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
let func = kernels.load_function(device, Source::Indexing, name)?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.unwrap();
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();