diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 69bf47cf..0afa210f 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,6 +12,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } +tracing = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true } candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } metal = { workspace = true, optional = true} diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 75efb0cc..c400b59c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{void_ptr, Kernels}; +use candle_metal_kernels::{void_ptr, Kernels, Source}; use core::mem; use half::{bf16, f16}; use metal; @@ -11,6 +11,7 @@ use metal::mps::matrix::encode_gemm; use metal::mps::Float32; use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; use std::sync::Arc; +use tracing::debug; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -55,9 +56,9 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - pub fn metal_device(&self) -> &metal::DeviceRef { - self.device.as_ref() - } + // pub fn metal_device(&self) -> &metal::DeviceRef { + // self.device.as_ref() + // } pub fn id(&self) -> u64 { self.registry_id() @@ -65,6 +66,7 @@ impl MetalDevice { fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as u64; + // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -103,73 +105,95 @@ impl BackendStorage for MetalStorage { fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { let device = self.device().clone(); - let command_buffer = self.device.command_queue.new_owned_command_buffer(); let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); + let dtype = self.dtype; + debug!("{shape:?} {el:?} {:?}", layout.stride()); + let output_buffer = device.new_buffer(el, self.dtype); + // return Ok(Self { + // buffer: output_buffer, + // device: device.clone(), + // dtype, + // }); let function = self .device .kernels - .load_function(&device.device, "affine", "affine") + .load_function(&device.device, Source::Affine, "affine") .map_err(MetalError::from)?; let pipeline = device .new_compute_pipeline_state_with_function(&function) .map_err(MetalError::msg)?; + let command_buffer = self.device.command_queue.new_command_buffer(); - let output_size = el * self.dtype.size_in_bytes(); - let output_buffer = device.new_buffer(output_size, self.dtype); - - let src_length = self.buffer.length() as usize - layout.start_offset(); - let src = self.device.new_buffer(src_length, self.dtype); - let blit_encoder = command_buffer.new_blit_command_encoder(); - blit_encoder.copy_from_buffer( - self.buffer.as_ref(), - layout.start_offset() as NSUInteger, - output_buffer.as_ref(), - 0, - (src_length * self.dtype.size_in_bytes()) as NSUInteger, - ); - blit_encoder.end_encoding(); + assert_eq!(output_buffer.length(), self.buffer.length()); + let length = el; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); + // encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); encoder.set_bytes(0, 4, void_ptr(&el)); encoder.set_bytes(1, 4, void_ptr(&dims)); - let info = [dims, layout.stride()].concat(); - let info_len = (info.len() * mem::size_of::()) as NSUInteger; - encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast()); + encoder.set_bytes( + 2, + (mem::size_of::() * dims.len()) as u64, + dims.as_ptr() as *const core::ffi::c_void, + ); + encoder.set_bytes( + 3, + (mem::size_of::() * layout.stride().len()) as u64, + layout.stride().as_ptr() as *const core::ffi::c_void, + ); + encoder.set_buffer(4, Some(&self.buffer), 0); + encoder.set_buffer(5, Some(&output_buffer), 0); - encoder.set_buffer(3, Some(&src), 0); - encoder.set_buffer(4, Some(&output_buffer), 0); - - encoder.set_bytes(5, 4, void_ptr(&(mul as f32))); - encoder.set_bytes(6, 4, void_ptr(&(add as f32))); + encoder.set_bytes(6, mem::size_of::() as u64, void_ptr(&(mul as f32))); + encoder.set_bytes(7, mem::size_of::() as u64, void_ptr(&(add as f32))); let grid_size = MTLSize { - width: output_size as NSUInteger, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { width: 1, height: 1, depth: 1, }; - encoder.dispatch_threads(grid_size, thread_group_size); + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), el as u64), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(grid_size, thread_group_size); encoder.end_encoding(); + let start = std::time::Instant::now(); command_buffer.commit(); + // debug!( + // "Affine {:?}({:?}, {:?}) - {:?}", + // command_buffer.status(), + // self.buffer.length(), + // output_buffer.length(), + // start.elapsed() + // ); // command_buffer.wait_until_completed(); - println!("Affine"); + debug!( + "Affine {:?} - {:?}", + command_buffer.status(), + start.elapsed() + ); - Ok(self.clone()) + let capture = metal::CaptureManager::shared(); + capture.stop_capture(); + panic!("Done"); + + Ok(Self { + buffer: output_buffer, + device: device.clone(), + dtype, + }) } fn powf(&self, _: &Layout, _: f64) -> Result { @@ -180,10 +204,78 @@ impl BackendStorage for MetalStorage { todo!() } - fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { - println!("TODO reduce_op"); - Ok(self.clone()) + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + debug!("TODO reduce_op"); + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + let src_el: usize = src_dims.iter().product(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !sum_dims.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in sum_dims.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + // let el_to_sum_per_block = src_el / dst_el; + // // The reduction loop requires the shared array to be properly initialized and for + // // this we want the number of threads to be a power of two. + // let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); + // let cfg = LaunchConfig { + // // TODO: Maybe use grid_y if the output is too large? + // // TODO: Specialized implementation when reducing on no or all dimensions or when + // // reducing only aggregate a small number of elements together. + // grid_dim: (dst_el as u32, 1, 1), + // block_dim: (block_dim as u32, 1, 1), + // shared_mem_bytes: 0, + // }; + // let ds = dev + // .htod_copy([dims.as_slice(), stride.as_slice()].concat()) + // .w()?; + // let src = &src.slice(layout.start_offset()..); + // let (name, check_empty, return_index) = match self.1 { + // ReduceOp::Sum => ("fast_sum", false, false), + // ReduceOp::Min => ("fast_min", true, false), + // ReduceOp::Max => ("fast_max", true, false), + // ReduceOp::ArgMin => ("fast_argmin", true, true), + // ReduceOp::ArgMax => ("fast_argmax", true, true), + // }; + // if check_empty && layout.shape().elem_count() == 0 { + // Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + // } + // let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; + // if return_index { + // // SAFETY: filled in by the follow up kernel. + // let out = unsafe { dev.alloc::(dst_el) }.w()?; + // let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // // SAFETY: ffi. + // unsafe { func.launch(cfg, params) }.w()?; + // Ok(S::U32(out)) + // } else { + // // SAFETY: filled in by the follow up kernel. + // let out = unsafe { dev.alloc::(dst_el) }.w()?; + // let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // // SAFETY: ffi. + // unsafe { func.launch(cfg, params) }.w()?; + // Ok(wrap(out)) + // } + // Ok(self.clone()) // todo!() + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { @@ -201,6 +293,12 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); + // TODO remove + return Ok(Self { + buffer, + device: device.clone(), + dtype, + }); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; @@ -227,9 +325,17 @@ impl BackendStorage for MetalStorage { } else { todo!("TODO Implement the kernel calling {}", B::KERNEL); } + + let start = std::time::Instant::now(); command_buffer.commit(); - // command_buffer.wait_until_completed(); - println!("Unary {:?}", B::KERNEL); + command_buffer.wait_until_completed(); + debug!( + "Unary {:?} - {:?} - {:?} - {:?}", + B::KERNEL, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); Ok(Self { buffer, @@ -239,13 +345,13 @@ impl BackendStorage for MetalStorage { } fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { - println!("TODO Binary {:?}", B::NAME); + debug!("TODO Binary {:?}", B::NAME); Ok(self.clone()) // todo!() } fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { - println!("TODO where_cond"); + debug!("TODO where_cond"); Ok(rhs.clone()) // todo!() } @@ -312,9 +418,29 @@ impl BackendStorage for MetalStorage { todo!() } - fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { - println!("TODO Index select"); - Ok(self.clone()) + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + // todo!("TODO Index select {:?} {ids:?} {l:?} {ids_l:?} {dim:?}", self.buffer.length()); + let src = self; + let ids_shape = ids_l.shape(); + let ids_dims = ids_shape.dims(); + // let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; + // let src = match src_l.contiguous_offsets() { + // Some((o1, o2)) => src.slice(o1..o2), + // None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, + // }; + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_size = src_l.dims()[dim]; + let ids_dim_size = ids_shape.elem_count(); + let dst_el = ids_shape.elem_count() * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) // todo!() } @@ -354,7 +480,7 @@ impl BackendStorage for MetalStorage { } fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { - println!("TODO Copy strided"); + debug!("TODO Copy strided"); Ok(()) } } @@ -398,7 +524,7 @@ impl MetalStorage { (DType::F32, DType::F32) => { let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); if b != 1 { - println!("TODO implement batched matmul for B={b}"); + debug!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); return Ok(Self { buffer: out_buffer, @@ -407,7 +533,7 @@ impl MetalStorage { }); } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - println!( + debug!( "Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous() @@ -419,7 +545,7 @@ impl MetalStorage { }); } - println!("GEMM"); + debug!("GEMM"); let command_buffer = self.device.command_queue.new_command_buffer(); encode_gemm::( &self.device, @@ -438,6 +564,7 @@ impl MetalStorage { .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_scheduled(); // println!("lhs {:?} {m} {k}", self.buffer.length()); // println!("rhs {:?} {k} {n}", rhs.buffer.length()); @@ -460,6 +587,19 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); + + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + println!("{:?}", std::env::current_dir()?); + descriptor.set_capture_device(&device); + let mut dir = std::env::current_dir()?; + dir.push("out.gputrace"); + descriptor.set_output_url(dir); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; let command_queue = device.new_command_queue(); // let command_buffer = _command_queue.new_owned_command_buffer(); let kernels = Arc::new(Kernels::new()); @@ -496,7 +636,7 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; + let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { CpuStorage::U8(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, @@ -534,6 +674,7 @@ impl BackendDevice for MetalDevice { option, ), }; + // debug!("Allocate 2 - buffer size {}", buffer.length()); Ok(Self::Storage { buffer, device: self.clone(), diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index cdd461c0..9e3306ed 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,4 +1,5 @@ use crate::{Device, Result, Shape, Tensor}; +use tracing::debug; #[cfg(target_feature = "avx")] pub mod avx; @@ -321,7 +322,7 @@ impl crate::CustomOp1 for QTensor { storage: &crate::MetalStorage, layout: &crate::Layout, ) -> Result<(crate::MetalStorage, Shape)> { - println!("TODO qmatmul"); + debug!("TODO qmatmul"); if !layout.is_contiguous() { crate::bail!("input tensor is not contiguous {layout:?}") } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 23f0053f..f5812536 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -238,13 +238,16 @@ fn main() -> anyhow::Result<()> { } else { Some(args.temperature) }; - let _guard = if args.tracing { - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; + tracing_subscriber::fmt::init(); + // let _guard = if args.tracing { + // // let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + // // tracing_subscriber::registry().with(chrome_layer).init(); + // tracing_subscriber::fmt::init(); + // None + // // Some(guard) + // } else { + // None + // }; println!( "avx: {}, neon: {}, simd128: {}, f16c: {}", diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 7bd98adc..b8bebcb9 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -35,25 +35,27 @@ METAL_FUNC uint get_strided_index( kernel void affine( constant size_t &dim, constant size_t &num_dims, - constant size_t *info, + constant size_t *dims, + constant size_t *strides, - device float *inp [[buffer(3)]], - device float *out [[buffer(4)]], + device float *inp [[buffer(4)]], + device float *out [[buffer(5)]], constant float &mul, - constant float &add + constant float &add, + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] ) { - - constant size_t *dims = info; - constant size_t *strides = info + num_dims; - + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; + const size_t start = thread_index * length; + const size_t stop = min(start + length, dim); if (is_contiguous(num_dims, dims, strides)) { - for (size_t i = 0; i < dim; i++) { + for (size_t i = start; i < stop; i++) { float x = inp ? inp[i] : out[i]; out[i] = x * mul + add; } } else { - for (size_t i = 0; i < dim; i++) { + for (size_t i = start; i < stop; i++) { uint strided_i = get_strided_index(i, num_dims, dims, strides); float x = inp ? inp[strided_i] : out[strided_i]; out[strided_i] = x * mul + add; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index bf3a7bcd..ce98334a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -424,15 +424,13 @@ mod tests { #[test] fn affine() { let device = device(); - let options = CompileOptions::new(); let library = device.new_library_with_source(AFFINE, &options).unwrap(); let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; - let dim: u32 = 8; - let num_dims: u32 = 4; - let info = [1u32, 2, 3]; + let shape = vec![4usize, 2]; + let strides = vec![2usize, 1]; let mul: f32 = 1.5; let add: f32 = 1.1; @@ -455,29 +453,42 @@ mod tests { let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options); let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options); - encoder.set_bytes(0, 4, void_ptr(&dim)); - encoder.set_bytes(1, 4, void_ptr(&num_dims)); - encoder.set_bytes(2, 4, void_ptr(&info)); + let dim: usize = shape.iter().product(); + let num_dims = shape.len(); + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&dim)); + encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + (core::mem::size_of::() * shape.len()) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + (core::mem::size_of::() * strides.len()) as u64, + strides.as_ptr() as *const c_void, + ); - encoder.set_buffer(3, Some(&inputs_buffer), 0); - encoder.set_buffer(4, Some(&outputs_buffer), 0); + encoder.set_buffer(4, Some(&inputs_buffer), 0); + encoder.set_buffer(5, Some(&outputs_buffer), 0); - encoder.set_bytes(5, 4, void_ptr(&mul)); - encoder.set_bytes(6, 4, void_ptr(&add)); + encoder.set_bytes(6, core::mem::size_of::() as u64, void_ptr(&mul)); + encoder.set_bytes(7, core::mem::size_of::() as u64, void_ptr(&add)); - let grid_size = MTLSize { - width: output.len() as NSUInteger, + let thread_group_count = MTLSize { + width: 1, height: 1, depth: 1, }; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dim as u64); + println!("WIDTH {width}"); let thread_group_size = MTLSize { - width: pipeline.max_total_threads_per_threadgroup(), + width, height: 1, depth: 1, }; - encoder.dispatch_threads(grid_size, thread_group_size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed(); @@ -545,7 +556,7 @@ mod tests { depth: 1, }; - encoder.dispatch_threads(grid_size, thread_group_size); + encoder.dispatch_thread_groups(grid_size, thread_group_size); encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed(); diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index d4324e65..6cc45b26 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -14,6 +14,7 @@ accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } half = { workspace = true } thiserror = { workspace = true } +tracing = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rayon = { workspace = true } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index fb8bc21f..ba28358a 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,5 +1,6 @@ use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use rayon::prelude::*; +use tracing::debug; /// Applies the softmax function to the input tensor, rescaling the element so that elements on /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. @@ -198,7 +199,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { storage: &candle::MetalStorage, layout: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { - println!("TODO softmax-last-dim"); + debug!("TODO softmax-last-dim"); Ok((storage.clone(), layout.shape().clone())) } }