From 76d3116f5d5c71c05d7575ff651c9f09e9bab9e1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 7 Nov 2023 14:20:13 +0100 Subject: [PATCH] Broken metal ? --- candle-core/src/metal_backend.rs | 89 ++++++++++++++++++-------- candle-metal-kernels/src/lib.rs | 96 +++++++++++++++------------- candle-metal-kernels/src/unary.metal | 29 +++++++-- 3 files changed, 139 insertions(+), 75 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index a3326e6e..75efb0cc 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{void_ptr, Kernels, AFFINE}; +use candle_metal_kernels::{void_ptr, Kernels}; use core::mem; use half::{bf16, f16}; use metal; @@ -36,8 +36,7 @@ impl MetalError { #[derive(Clone)] pub struct MetalDevice { device: metal::Device, - _command_queue: metal::CommandQueue, - command_buffer: metal::CommandBuffer, + command_queue: metal::CommandQueue, kernels: Arc, } @@ -66,13 +65,14 @@ impl MetalDevice { fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as u64; - self.device.new_buffer(size, MTLResourceOptions::empty()) + self.device + .new_buffer(size, MTLResourceOptions::StorageModeManaged) } } #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: Arc, + buffer: metal::Buffer, device: MetalDevice, dtype: DType, } @@ -103,6 +103,7 @@ impl BackendStorage for MetalStorage { fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { let device = self.device().clone(); + let command_buffer = self.device.command_queue.new_owned_command_buffer(); let shape = layout.shape(); let dims = shape.dims(); @@ -123,7 +124,7 @@ impl BackendStorage for MetalStorage { let src_length = self.buffer.length() as usize - layout.start_offset(); let src = self.device.new_buffer(src_length, self.dtype); - let blit_encoder = self.device.command_buffer.new_blit_command_encoder(); + let blit_encoder = command_buffer.new_blit_command_encoder(); blit_encoder.copy_from_buffer( self.buffer.as_ref(), layout.start_offset() as NSUInteger, @@ -133,7 +134,7 @@ impl BackendStorage for MetalStorage { ); blit_encoder.end_encoding(); - let encoder = device.command_buffer.new_compute_command_encoder(); + let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); @@ -164,6 +165,10 @@ impl BackendStorage for MetalStorage { encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); + command_buffer.commit(); + // command_buffer.wait_until_completed(); + println!("Affine"); + Ok(self.clone()) } @@ -190,17 +195,45 @@ impl BackendStorage for MetalStorage { } fn unary_impl(&self, layout: &Layout) -> Result { - let device = self.device().clone(); + let device = self.device(); let dtype = self.dtype; let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - //todo!("Implement the kernel calling"); - // device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + todo!("TODO Implement the kernel calling {}", B::KERNEL); + } + command_buffer.commit(); + // command_buffer.wait_until_completed(); + println!("Unary {:?}", B::KERNEL); + Ok(Self { - buffer: Arc::new(buffer), - device, + buffer, + device: device.clone(), dtype, }) } @@ -368,7 +401,7 @@ impl MetalStorage { println!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); return Ok(Self { - buffer: Arc::new(out_buffer), + buffer: out_buffer, device: self.device.clone(), dtype: self.dtype(), }); @@ -380,15 +413,17 @@ impl MetalStorage { rhs_l.is_contiguous() ); return Ok(Self { - buffer: Arc::new(out_buffer), + buffer: out_buffer, device: self.device.clone(), dtype: self.dtype(), }); } + println!("GEMM"); + let command_buffer = self.device.command_queue.new_command_buffer(); encode_gemm::( &self.device, - &self.device.command_buffer, + &command_buffer, transpose_left, transpose_right, &self.buffer, @@ -402,13 +437,15 @@ impl MetalStorage { ) .map_err(MetalError::from)?; - println!("lhs {:?} {m} {k}", self.buffer.length()); - println!("rhs {:?} {k} {n}", rhs.buffer.length()); - println!("out {:?} {m} {n}", out_buffer.length()); - println!("lhs {:?}", lhs_l.shape()); + command_buffer.commit(); + + // println!("lhs {:?} {m} {k}", self.buffer.length()); + // println!("rhs {:?} {k} {n}", rhs.buffer.length()); + // println!("out {:?} {m} {n}", out_buffer.length()); + // println!("lhs {:?}", lhs_l.shape()); Ok(Self { - buffer: Arc::new(out_buffer), + buffer: out_buffer, device: self.device.clone(), dtype: self.dtype(), }) @@ -423,13 +460,13 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - let _command_queue = device.new_command_queue(); - let command_buffer = _command_queue.new_owned_command_buffer(); - let kernels = Arc::new(Kernels::init(&device).map_err(MetalError::from)?); + let command_queue = device.new_command_queue(); + // let command_buffer = _command_queue.new_owned_command_buffer(); + let kernels = Arc::new(Kernels::new()); Ok(Self { device, - _command_queue, - command_buffer, + command_queue, + // command_buffer, kernels, }) } @@ -498,7 +535,7 @@ impl BackendDevice for MetalDevice { ), }; Ok(Self::Storage { - buffer: Arc::new(buffer), + buffer, device: self.clone(), dtype: storage.dtype(), }) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 01dd309b..bf3a7bcd 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,8 +1,7 @@ use metal::{ - Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, + Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, MTLSize, }; -use once_cell::sync::Lazy; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; @@ -11,11 +10,18 @@ const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); const UNARY: &str = include_str!("unary.metal"); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Source { + Affine, + Indexing, + Unary, +} + macro_rules! unary{ ($($name:ident),+) => { pub mod contiguous { - pub struct Kernel(pub &'static str); + pub struct Kernel(pub(crate) &'static str); $( pub mod $name { use super::Kernel; @@ -27,7 +33,7 @@ macro_rules! unary{ } pub mod strided { - pub struct Kernel(pub &'static str); + pub struct Kernel(pub(crate) &'static str); $( pub mod $name { use super::Kernel; @@ -41,17 +47,17 @@ macro_rules! unary{ } pub mod unary { - unary!(cos, sin, exp); + unary!(cos, sin, exp, sqr, sqrt, neg); } -static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { - let mut l = HashMap::new(); - l.insert("affine", AFFINE); - l.insert("indexing", INDEXING); - l.insert("unary", UNARY); - l -}); - +// static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { +// let mut l = HashMap::new(); +// l.insert("affine", AFFINE); +// l.insert("indexing", INDEXING); +// l.insert("unary", UNARY); +// l +// }); +// #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { #[error("Could not lock kernel map: {0}")] @@ -69,7 +75,7 @@ impl From> for MetalKernelError { } type KernelMap = HashMap<&'static str, T>; -type Libraries = KernelMap; +type Libraries = HashMap; type Functions = KernelMap; #[derive(Debug)] @@ -85,39 +91,42 @@ impl Kernels { Self { libraries, funcs } } - pub fn init(device: &Device) -> Result { - let kernels = Self::new(); - kernels.load_libraries(device)?; - Ok(kernels) - } + // pub fn init(device: &Device) -> Result { + // let kernels = Self::new(); + // kernels.load_libraries(device)?; + // Ok(kernels) + // } - fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { - for name in LIBRARY_SOURCES.keys() { - self.load_library(device, name)?; + // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { + // for name in LIBRARY_SOURCES.keys() { + // self.load_library(device, name)?; + // } + // Ok(()) + // } + + fn get_library_source(&self, source: Source) -> &'static str { + // LIBRARY_SOURCES.get(name).cloned() + match source { + Source::Affine => AFFINE, + Source::Unary => UNARY, + Source::Indexing => INDEXING, } - Ok(()) - } - - fn get_library_source(&self, name: &'static str) -> Option<&'static str> { - LIBRARY_SOURCES.get(name).cloned() } pub fn load_library( &self, device: &Device, - name: &'static str, + source: Source, ) -> Result { let mut libraries = self.libraries.write()?; - if let Some(lib) = libraries.get(name) { + if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source = self.get_library_source(name).ok_or_else(|| { - MetalKernelError::LoadLibraryError(format!("No source found for {}", name)) - })?; + let source_content = self.get_library_source(source); let lib = device - .new_library_with_source(source, &CompileOptions::new()) + .new_library_with_source(source_content, &CompileOptions::new()) .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; - libraries.insert(name, lib.clone()); + libraries.insert(source, lib.clone()); Ok(lib) } } @@ -125,7 +134,7 @@ impl Kernels { pub fn load_function( &self, device: &Device, - library_name: &'static str, + source: Source, name: &'static str, ) -> Result { let mut funcs = self.funcs.write()?; @@ -133,7 +142,7 @@ impl Kernels { Ok(func.clone()) } else { let func = self - .load_library(device, library_name)? + .load_library(device, source)? .get_function(name, None) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; funcs.insert(name, func.clone()); @@ -144,15 +153,16 @@ impl Kernels { pub fn call_unary_contiguous( device: &Device, - command_buffer: &CommandBuffer, + command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - assert_eq!(input.length(), output.length()); - let func = kernels.load_function(device, "unary", kernel_name.0)?; + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -188,7 +198,7 @@ pub fn call_unary_contiguous( } pub fn call_unary_strided( device: &Device, - command_buffer: &CommandBuffer, + command_buffer: &CommandBufferRef, kernels: &Kernels, name: unary::strided::Kernel, input: &Buffer, @@ -197,7 +207,7 @@ pub fn call_unary_strided( offset: usize, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, "unary", name.0)?; + let func = kernels.load_function(device, Source::Unary, name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -277,7 +287,7 @@ mod tests { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_owned_command_buffer(); + let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; let input = device.new_buffer_with_data( v.as_ptr() as *const core::ffi::c_void, @@ -310,7 +320,7 @@ mod tests { let device = device(); let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_owned_command_buffer(); + let command_buffer = command_queue.new_command_buffer(); let input = device.new_buffer_with_data( v.as_ptr() as *const core::ffi::c_void, (v.len() * core::mem::size_of::()) as u64, diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 81171fb2..f30fb929 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -22,6 +22,9 @@ METAL_FUNC uint get_strided_index( return strided_i; } +template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T neg(T in){ return -in; } + using namespace metal; @@ -59,12 +62,26 @@ kernel void FN_NAME_STRIDED( \ } \ } -UNARY(cos, float, cos_float, cos_float_strided); -UNARY(cos, half, cos_half, cos_half_strided); -UNARY(sin, float, sin_float, sin_float_strided); -UNARY(sin, half, sin_half, sin_half_strided); +#define UNARY_OP(NAME) \ +UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ +UNARY(NAME, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_UNARY_OP(NAME) \ +UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +UNARY_OP(cos) +UNARY_OP(sin) +UNARY_OP(sqr) +UNARY_OP(sqrt) +UNARY_OP(neg) +UNARY_OP(exp) #if __METAL_VERSION__ >= 310 -UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided); -UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided); +BFLOAT_UNARY_OP(cos) +BFLOAT_UNARY_OP(sin) +BFLOAT_UNARY_OP(sqr) +BFLOAT_UNARY_OP(sqrt) +BFLOAT_UNARY_OP(neg) +BFLOAT_UNARY_OP(exp) #endif