mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Fixes + cache compute_pipeline_state.
This commit is contained in:
@ -86,7 +86,6 @@ impl MetalDevice {
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
// debug!("Allocate 1 - buffer size {size}");
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
@ -115,7 +114,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
let start = std::time::Instant::now();
|
||||
self.device.wait_until_completed();
|
||||
println!("Wait took {:?}", start.elapsed());
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
@ -414,6 +415,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -899,9 +901,12 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||
Ok(MetalStorage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
|
@ -1,7 +1,7 @@
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, Library, MTLSize,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -174,6 +174,10 @@ pub enum MetalKernelError {
|
||||
LoadLibraryError(String),
|
||||
#[error("Error while loading function: {0:?}")]
|
||||
LoadFunctionError(String),
|
||||
#[error("Failed to create compute function")]
|
||||
FailedToCreateComputeFunction,
|
||||
#[error("Failed to create pipeline")]
|
||||
FailedToCreatePipeline(String),
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
@ -184,19 +188,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
|
||||
type KernelMap<T> = HashMap<&'static str, T>;
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Functions = KernelMap<Function>;
|
||||
type Pipelines = KernelMap<ComputePipelineState>;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug)]
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
funcs: RwLock<Functions>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let funcs = RwLock::new(Functions::new());
|
||||
Self { libraries, funcs }
|
||||
let pipelines = RwLock::new(Pipelines::new());
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||
@ -243,22 +250,43 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_function(
|
||||
fn load_function(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<Function, MetalKernelError> {
|
||||
let mut funcs = self.funcs.write()?;
|
||||
if let Some(func) = funcs.get(name) {
|
||||
Ok(func.clone())
|
||||
} else {
|
||||
let func = self
|
||||
.load_library(device, source)?
|
||||
.get_function(name, None)
|
||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||
funcs.insert(name, func.clone());
|
||||
Ok(func)
|
||||
// let mut funcs = self.funcs.write()?;
|
||||
// if let Some(func) = funcs.get(name) {
|
||||
// Ok(func.clone())
|
||||
// } else {
|
||||
// funcs.insert(name, func.clone());
|
||||
// Ok(func)
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn load_pipeline(
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||
let mut pipelines = self.pipelines.write()?;
|
||||
if let Some(pipeline) = pipelines.get(name) {
|
||||
Ok(pipeline.clone())
|
||||
} else {
|
||||
let func = self.load_function(device, source, name)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
||||
pipelines.insert(name, pipeline.clone());
|
||||
|
||||
Ok(pipeline)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -274,16 +302,8 @@ pub fn call_unary_contiguous(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -306,15 +326,7 @@ pub fn call_unary_strided(
|
||||
output: &mut Buffer,
|
||||
output_offset: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -351,17 +363,7 @@ pub fn call_binary_contiguous(
|
||||
right: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -389,15 +391,7 @@ pub fn call_binary_strided(
|
||||
right_offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Binary, name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -436,17 +430,7 @@ pub fn call_cast_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -473,15 +457,7 @@ pub fn call_cast_strided(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -510,16 +486,7 @@ pub fn call_reduce_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -560,16 +527,7 @@ pub fn call_last_softmax(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -612,15 +570,7 @@ pub fn call_affine(
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -646,18 +596,9 @@ pub fn call_affine_strided(
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -695,15 +636,7 @@ pub fn call_where_cond_strided(
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Ternary, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -751,10 +684,7 @@ pub fn call_index_select(
|
||||
let src_dim_size = shape[dim];
|
||||
let dst_el = ids_size * left_size * right_size;
|
||||
|
||||
let func = kernels.load_function(device, Source::Indexing, name)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.unwrap();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
|
@ -40,7 +40,7 @@ template <typename T> METAL_FUNC T erf(T in){
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return (T) sign*y;
|
||||
return T(sign*y);
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
|
||||
@ -49,7 +49,7 @@ template <typename T> METAL_FUNC T gelu(T x){
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanh(beta));
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user