diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 377e1406..d911fe32 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,12 +4,13 @@ use crate::error::Error; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; +use candle_metal_kernels::{void_ptr, AFFINE}; use core::mem; use half::{bf16, f16}; use metal; use metal::mps::matrix::encode_gemm; use metal::mps::Float32; -use metal::{Buffer, MTLResourceOptions, NSUInteger}; +use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; use std::sync::Arc; /// Metal related errors @@ -86,10 +87,58 @@ impl BackendStorage for MetalStorage { } } - fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { - println!("TODO Affine"); + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let device = self.device().clone(); + + /* + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + + // TODO: Don't load library every time + let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap(); + let function = library.get_function("affine", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + + let encoder = device.command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let output_size = el * self.dtype.size_in_bytes(); + encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); + + let output_buffer = device.new_buffer(output_size, self.dtype); + + encoder.set_bytes(0, 4, void_ptr(&el)); + encoder.set_bytes(1, 4, void_ptr(&dims)); + let info = [dims, layout.stride()].concat(); + let info_len = (info.len() * mem::size_of::()) as NSUInteger; + encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast()); + + encoder.set_buffer(3, Some(&self.buffer), 0); + encoder.set_buffer(4, Some(&output_buffer), 0); + + encoder.set_bytes(5, 4, void_ptr(&(mul as f32))); + encoder.set_bytes(6, 4, void_ptr(&(add as f32))); + + let grid_size = MTLSize { + width: output_size as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_threads(grid_size, thread_group_size); + encoder.end_encoding(); + */ + Ok(self.clone()) - // todo!() } fn powf(&self, _: &Layout, _: f64) -> Result { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal new file mode 100644 index 00000000..4111e799 --- /dev/null +++ b/candle-metal-kernels/src/affine.metal @@ -0,0 +1,62 @@ +#include +using namespace metal; + +METAL_FUNC bool is_contiguous( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + size_t acc = 1; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + if (acc != strides[dim_idx]) { + return false; + } + acc *= dims[dim_idx]; + } + return true; +} + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +kernel void affine( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *info, + + device float *inp [[buffer(3)]], + device float *out [[buffer(4)]], + + constant float &mul, + constant float &add +) { + + constant size_t *dims = info; + constant size_t *strides = info + num_dims; + + if (is_contiguous(num_dims, dims, strides)) { + for (size_t i = 0; i < dim; i++) { + float x = inp ? inp[i] : out[i]; + out[i] = x * mul + add; + } + } else { + for (size_t i = 0; i < dim; i++) { + uint strided_i = get_strided_index(i, num_dims, dims, strides); + float x = inp ? inp[strided_i] : out[strided_i]; + out[strided_i] = x * mul + add; + } + } +} \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 8625de3b..766bf261 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,9 @@ use metal::{Buffer, CompileOptions, Device, Function, Library}; use std::collections::HashMap; +use std::ffi::c_void; use std::sync::RwLock; +pub const AFFINE: &str = include_str!("affine.metal"); pub const INDEXING: &str = include_str!("indexing.metal"); pub const UNARY: &str = include_str!("unary.metal"); @@ -60,6 +62,10 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz todo!("Call unary"); } +pub fn void_ptr(v: &T) -> *const c_void { + (v as *const T).cast() +} + #[cfg(test)] mod tests { use super::*; @@ -70,9 +76,6 @@ mod tests { use std::ffi::c_void; use std::mem; - pub fn void_ptr(v: &T) -> *const c_void { - (v as *const T).cast() - } fn approx(v: Vec, digits: i32) -> Vec { let b = 10f32.powi(digits); v.iter().map(|t| f32::round(t * b) / b).collect() @@ -144,6 +147,72 @@ mod tests { assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); } + #[test] + fn affine() { + let device = Device::system_default().expect("no device found"); + + let options = CompileOptions::new(); + let library = device.new_library_with_source(AFFINE, &options).unwrap(); + + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let dim: u32 = 8; + let num_dims: u32 = 4; + let info = [1u32, 2, 3]; + let mul: f32 = 1.5; + let add: f32 = 1.1; + + let function = library.get_function("affine", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + let options = MTLResourceOptions::StorageModeShared; + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + let input_size = (input.len() * mem::size_of::()) as NSUInteger; + let output_size = (output.len() * mem::size_of::()) as NSUInteger; + + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); + + let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options); + let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options); + + encoder.set_bytes(0, 4, void_ptr(&dim)); + encoder.set_bytes(1, 4, void_ptr(&num_dims)); + encoder.set_bytes(2, 4, void_ptr(&info)); + + encoder.set_buffer(3, Some(&inputs_buffer), 0); + encoder.set_buffer(4, Some(&outputs_buffer), 0); + + encoder.set_bytes(5, 4, void_ptr(&mul)); + encoder.set_bytes(6, 4, void_ptr(&add)); + + let grid_size = MTLSize { + width: output.len() as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: pipeline.max_total_threads_per_threadgroup(), + height: 1, + depth: 1, + }; + + encoder.dispatch_threads(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]; + let result = outputs_buffer.read_to_vec::(output.len()); + assert_eq!(result, expected); + } + #[test] fn index_add() { let device = Device::system_default().expect("no device found"); diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 3a6fdd39..fb8bc21f 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -190,7 +190,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { device: dev.clone(), }; Ok((dst, layout.shape().clone())) - } + } #[cfg(feature = "metal")] fn metal_fwd(