diff --git a/Cargo.toml b/Cargo.toml index c827507d..d3130105 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,8 +55,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } -metal = { path = "../metal-rs", features = ["mps"] } +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 7056e500..ecfce77b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -111,89 +111,28 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - debug!("{shape:?} {el:?} {:?}", layout.stride()); - let output_buffer = device.new_buffer(el, self.dtype); + assert!(layout.is_contiguous()); + assert_eq!(dtype, DType::F32); + + let mut buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + el, + &self.buffer, + &mut buffer, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); return Ok(Self { - buffer: output_buffer, + buffer, device: device.clone(), dtype, }); - let function = self - .device - .kernels - .load_function(&device.device, Source::Affine, "affine") - .map_err(MetalError::from)?; - - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .map_err(MetalError::msg)?; - let command_buffer = self.device.command_queue.new_command_buffer(); - - assert_eq!(output_buffer.length(), self.buffer.length()); - - let length = el; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.set_compute_pipeline_state(&pipeline); - // encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); - - encoder.set_bytes(0, 4, void_ptr(&el)); - encoder.set_bytes(1, 4, void_ptr(&dims)); - encoder.set_bytes( - 2, - (mem::size_of::() * dims.len()) as u64, - dims.as_ptr() as *const core::ffi::c_void, - ); - encoder.set_bytes( - 3, - (mem::size_of::() * layout.stride().len()) as u64, - layout.stride().as_ptr() as *const core::ffi::c_void, - ); - encoder.set_buffer(4, Some(&self.buffer), 0); - encoder.set_buffer(5, Some(&output_buffer), 0); - - encoder.set_bytes(6, mem::size_of::() as u64, void_ptr(&(mul as f32))); - encoder.set_bytes(7, mem::size_of::() as u64, void_ptr(&(add as f32))); - - let grid_size = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), el as u64), - height: 1, - depth: 1, - }; - - encoder.dispatch_thread_groups(grid_size, thread_group_size); - encoder.end_encoding(); - - let start = std::time::Instant::now(); - command_buffer.commit(); - // debug!( - // "Affine {:?}({:?}, {:?}) - {:?}", - // command_buffer.status(), - // self.buffer.length(), - // output_buffer.length(), - // start.elapsed() - // ); - // command_buffer.wait_until_completed(); - debug!( - "Affine {:?} - {:?}", - command_buffer.status(), - start.elapsed() - ); - - // let capture = metal::CaptureManager::shared(); - // capture.stop_capture(); - // panic!("Done"); - - Ok(Self { - buffer: output_buffer, - device: device.clone(), - dtype, - }) } fn powf(&self, _: &Layout, _: f64) -> Result { @@ -288,12 +227,6 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - // TODO remove - // return Ok(Self { - // buffer, - // device: device.clone(), - // dtype, - // }); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; @@ -547,7 +480,11 @@ impl BackendStorage for MetalStorage { } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - // todo!("TODO Index select {:?} {ids:?} {l:?} {ids_l:?} {dim:?}", self.buffer.length()); + debug!( + "TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}", + self.buffer.length(), + ids.buffer.length(), + ); let src = self; let ids_shape = ids_l.shape(); let ids_dims = ids_shape.dims(); @@ -607,8 +544,46 @@ impl BackendStorage for MetalStorage { ) } - fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { - debug!("TODO Copy strided"); + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let src_shape = src_l.shape(); + let dims = src_shape.dims(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + if src_l.is_contiguous() { + let command_buffer = self.device.command_queue.new_command_buffer(); + let blip = command_buffer.new_blit_command_encoder(); + blip.copy_from_buffer( + &self.buffer, + src_l.start_offset() as u64, + &dst.buffer, + dst_offset as u64, + self.buffer.length(), + ); + } else { + let command_buffer = self.device.command_queue.new_command_buffer(); + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + dtype => todo!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + &src_l.stride(), + src_l.start_offset(), + &mut dst.buffer, + dst_offset, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + } Ok(()) } } @@ -662,7 +637,7 @@ impl MetalStorage { } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { debug!( - "Didn't implemented non contiguous matmul yet {:?} {:?}", + "TODO non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous() ); @@ -674,31 +649,27 @@ impl MetalStorage { } debug!("GEMM"); - // let command_buffer = self.device.command_queue.new_command_buffer(); - // encode_gemm::( - // &self.device, - // &command_buffer, - // transpose_left, - // transpose_right, - // &self.buffer, - // &rhs.buffer, - // &mut out_buffer, - // m as NSUInteger, - // n as NSUInteger, - // k as NSUInteger, - // alpha, - // beta, - // ) - // .map_err(MetalError::from)?; + let command_buffer = self.device.command_queue.new_command_buffer(); + encode_gemm::( + &self.device, + &command_buffer, + transpose_left, + transpose_right, + &self.buffer, + &rhs.buffer, + &mut out_buffer, + m as NSUInteger, + n as NSUInteger, + k as NSUInteger, + alpha as f32, + beta as f32, + Some(b as NSUInteger), + ) + .map_err(MetalError::from)?; - // command_buffer.commit(); + command_buffer.commit(); // command_buffer.wait_until_scheduled(); - // println!("lhs {:?} {m} {k}", self.buffer.length()); - // println!("rhs {:?} {k} {n}", rhs.buffer.length()); - // println!("out {:?} {m} {n}", out_buffer.length()); - // println!("lhs {:?}", lhs_l.shape()); - Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -719,7 +690,6 @@ impl BackendDevice for MetalDevice { // let capture = metal::CaptureManager::shared(); // let descriptor = metal::CaptureDescriptor::new(); // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - // println!("{:?}", std::env::current_dir()?); // descriptor.set_capture_device(&device); // let mut dir = std::env::current_dir()?; // dir.push("out.gputrace"); diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index b8bebcb9..f801a03f 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -1,21 +1,4 @@ #include -using namespace metal; - -METAL_FUNC bool is_contiguous( - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides -) { - size_t acc = 1; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - if (acc != strides[dim_idx]) { - return false; - } - acc *= dims[dim_idx]; - } - return true; -} METAL_FUNC uint get_strided_index( uint idx, @@ -32,33 +15,30 @@ METAL_FUNC uint get_strided_index( return strided_i; } -kernel void affine( - constant size_t &dim, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides, +using namespace metal; - device float *inp [[buffer(4)]], - device float *out [[buffer(5)]], - - constant float &mul, - constant float &add, +#define AFFINE(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] -) { - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; - const size_t start = thread_index * length; - const size_t stop = min(start + length, dim); - if (is_contiguous(num_dims, dims, strides)) { - for (size_t i = start; i < stop; i++) { - float x = inp ? inp[i] : out[i]; - out[i] = x * mul + add; - } - } else { - for (size_t i = start; i < stop; i++) { - uint strided_i = get_strided_index(i, num_dims, dims, strides); - float x = inp ? inp[strided_i] : out[strided_i]; - out[strided_i] = x * mul + add; - } - } -} + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = input[i] * mul + add; \ + } \ +} \ + +AFFINE(affine_float, float) +AFFINE(affine_half, half) + + +#if __METAL_VERSION__ >= 310 +AFFINE(affine_bfloat, bfloat); +#endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index 7e0aa5d6..cfd34416 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -62,7 +62,7 @@ BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ BINARY(FN, half, half, NAME##_half, NAME##_half_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(NAME, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); BINARY_OP(x + y, add) @@ -71,8 +71,8 @@ BINARY_OP(x * y, mul) BINARY_OP(x / y, div) #if __METAL_VERSION__ >= 310 -BFLOAT_BINARY_OP(x + y, badd) -BFLOAT_BINARY_OP(x - y, bsub) -BFLOAT_BINARY_OP(x * y, bmul) -BFLOAT_BINARY_OP(x / y, bdiv) +BFLOAT_BINARY_OP(x + y, add) +BFLOAT_BINARY_OP(x - y, sub) +BFLOAT_BINARY_OP(x * y, mul) +BFLOAT_BINARY_OP(x / y, div) #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index ba818819..189fd508 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -51,7 +51,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg); + ops!(cos, sin, exp, sqr, sqrt, neg, copy); } pub mod binary { ops!(add, sub, mul, div); @@ -210,11 +210,12 @@ pub fn call_unary_strided( command_buffer: &CommandBufferRef, kernels: &Kernels, name: unary::strided::Kernel, - input: &Buffer, shape: &[usize], + input: &Buffer, strides: &[usize], offset: usize, output: &mut Buffer, + output_offset: usize, ) -> Result<(), MetalKernelError> { let func = kernels.load_function(device, Source::Unary, name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); @@ -245,7 +246,7 @@ pub fn call_unary_strided( ); encoder.set_buffer(4, Some(&input), offset as u64); - encoder.set_buffer(5, Some(&output), 0); + encoder.set_buffer(5, Some(&output), output_offset as u64); let width = output.length(); @@ -434,6 +435,53 @@ pub fn void_ptr(v: &T) -> *const c_void { (v as *const T).cast() } +pub fn call_affine( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + size: usize, + input: &Buffer, + output: &mut Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Affine, "affine_float")?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); + encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&mul)); + encoder.set_bytes(2, core::mem::size_of::() as u64, void_ptr(&add)); + encoder.set_buffer(3, Some(&input), 0); + encoder.set_buffer(4, Some(&output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -538,11 +586,12 @@ mod tests { &command_buffer, &kernels, kernel, - &input, shape, + &input, strides, offset, &mut output, + 0, ) .unwrap(); command_buffer.commit(); @@ -682,82 +731,52 @@ mod tests { assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); } - #[test] - fn affine() { + fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); - let options = CompileOptions::new(); - let library = device.new_library_with_source(AFFINE, &options).unwrap(); - - let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; - let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; - let shape = vec![4usize, 2]; - let strides = vec![2usize, 1]; - let mul: f32 = 1.5; - let add: f32 = 1.1; - - let function = library.get_function("affine", None).unwrap(); - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .unwrap(); - let options = MTLResourceOptions::StorageModeManaged; - + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); + let options = MTLResourceOptions::StorageModeManaged; - let input_size = (input.len() * mem::size_of::()) as NSUInteger; - let output_size = (output.len() * mem::size_of::()) as NSUInteger; - - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); - - let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options); - let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options); - - let dim: usize = shape.iter().product(); - let num_dims = shape.len(); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&dim)); - encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&num_dims)); - encoder.set_bytes( - 2, - (core::mem::size_of::() * shape.len()) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - (core::mem::size_of::() * strides.len()) as u64, - strides.as_ptr() as *const c_void, + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + (v.len() * core::mem::size_of::()) as u64, + options, ); + let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); - encoder.set_buffer(4, Some(&inputs_buffer), 0); - encoder.set_buffer(5, Some(&outputs_buffer), 0); + let size = v.len(); - encoder.set_bytes(6, core::mem::size_of::() as u64, void_ptr(&mul)); - encoder.set_bytes(7, core::mem::size_of::() as u64, void_ptr(&add)); - - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dim as u64); - println!("WIDTH {width}"); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + call_affine( + &device, + &command_buffer, + &kernels, + size, + &input, + &mut output, + mul as f32, + add as f32, + ) + .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]; - let result = outputs_buffer.read_to_vec::(output.len()); - println!("Result {:?}", result.as_ptr()); - assert_eq!(result, expected); + output.read_to_vec::(v.len()) + } + + #[test] + fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); } #[test] @@ -826,7 +845,6 @@ mod tests { 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; let result = outputs_buffer.read_to_vec::(right.len()); - println!("Result {:?}", result.as_ptr()); assert_eq!(result, expected); } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 03f88779..77de214e 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -17,6 +17,7 @@ METAL_FUNC uint get_strided_index( template METAL_FUNC T sqr(T in){ return in * in; } template METAL_FUNC T neg(T in){ return -in; } +template METAL_FUNC T id(T in){ return in; } using namespace metal; @@ -68,6 +69,8 @@ UNARY_OP(sqr) UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) +UNARY(id, float, copy_float, copy_float_strided) +UNARY(id, half, copy_half, copy_half_strided) #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos)