diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 349edc49..00301352 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -58,6 +58,10 @@ impl MetalDevice { self.registry_id() } + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } @@ -215,8 +219,9 @@ impl BackendStorage for MetalStorage { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { assert!(sum_dims.len() == 1); assert!(sum_dims[0] == layout.shape().rank() - 1); - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); + assert!(layout.stride()[sum_dims[0]] == 1); + // assert!(layout.is_contiguous()); + // assert!(layout.start_offset() == 0); let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -251,6 +256,9 @@ impl BackendStorage for MetalStorage { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; + if dtype == DType::U32{ + todo!("Implement this"); + } let mut buffer = device.new_buffer(dst_el, dtype); let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( @@ -261,6 +269,7 @@ impl BackendStorage for MetalStorage { src_el, dst_el, &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), &mut buffer, ) .map_err(MetalError::from)?; @@ -727,26 +736,26 @@ impl BackendStorage for MetalStorage { mnk: (m, n, k), })? }; - let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] { - [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, - [stride] => stride, - [] => m * k, - _ => Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })?, - } as u64; - let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] { - [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, - [stride] => stride, - [] => n * k, - _ => Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })?, - } as u64; + // let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] { + // [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + // [stride] => stride, + // [] => m * k, + // _ => Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })?, + // } as u64; + // let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] { + // [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + // [stride] => stride, + // [] => n * k, + // _ => Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })?, + // } as u64; let b = b as NSUInteger; let m = m as NSUInteger; diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 9c839b95..88e86eac 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] +metal = ["candle/metal", "candle-nn/metal"] [[example]] name = "llama_multiprocess" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a9d108f4..e4220286 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -300,9 +300,6 @@ pub fn call_unary_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); - let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -484,6 +481,7 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, + input_offset: usize, output: &mut Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -492,7 +490,7 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!(encoder, (length, elements_to_sum, (input,input_offset), output)); let thread_group_count = MTLSize { width: out_length as u64, @@ -1228,6 +1226,7 @@ mod tests { v.len(), out_length, &input, + 0, &mut output, ) .unwrap(); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index c6984474..96c28687 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 256; +constant int THREADGROUP_SIZE = 1024; # define REDUCE(FN, NAME, TYPENAME) \ kernel void NAME( \ diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 4b1f7917..546f392b 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -29,3 +30,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] +metal = ["candle/metal", "dep:candle-metal-kernels"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..1c8251b1 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,37 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::{BackendStorage}; + let device = storage.device(); + let command_buffer = device.command_buffer(); + let kernels = device.kernels(); + let name = "softmax_float"; + assert!(layout.is_contiguous()); + assert!(layout.start_offset() == 0); + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let mut output = device.new_buffer(elem_count, storage.dtype()); + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + &kernels, + name, + elem_count, + last_dim, + storage.buffer(), + &mut output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result {