From a9d06574320591f5cd966c8840237dc4a1e72ab3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 13 Dec 2023 12:09:20 +0100 Subject: [PATCH] Better version ? --- candle-core/src/metal_backend.rs | 68 ++++++++++++++------- candle-transformers/src/models/mixformer.rs | 9 +-- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 4354422c..f745342d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -96,6 +96,7 @@ impl MetalDevice { .map(|i| { // println!("Creating command buffer {i}"); let command_buffer = self.command_queue.new_command_buffer().to_owned(); + command_buffer.set_label(&format!("num {i}")); command_buffer.enqueue(); command_buffer }) @@ -157,7 +158,7 @@ impl MetalDevice { for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { // println!("Reusing tensor {size} {name}"); - // return sub.clone(); + return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); @@ -177,7 +178,7 @@ impl MetalDevice { } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed") } pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { @@ -185,19 +186,22 @@ impl MetalDevice { let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, size, - metal::MTLResourceOptions::StorageModeManaged, + metal::MTLResourceOptions::StorageModeShared, ); let real = self._new_buffer( size, metal::MTLResourceOptions::StorageModePrivate, "with_data", ); - let command = self.command_buffer(); - let blit = command.new_blit_command_encoder(); + let command_buffer = self.command_buffer(); + command_buffer.set_label("with_data"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.end_encoding(); - command.commit(); - real.did_modify_range(metal::NSRange::new(0, real.length())); + command_buffer.commit(); + drop(command_buffer); + // real.did_modify_range(metal::NSRange::new(0, real.length())); // println!("Command {:?}", command.status()); // self.commit(); @@ -220,15 +224,29 @@ impl MetalDevice { dtype: DType, ) -> Result<(Matrix, Arc)> { let elem_count = (b * m * n) as usize; - let out_buffer = self.new_buffer(elem_count, dtype, "matrix"); + let buffer = self.new_buffer(elem_count, dtype, "matrix"); + let command_buffer = self.command_buffer(); + command_buffer.set_label("zeros_matmul"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) + let result_matrix = Matrix::init_with_buffer_descriptor(&buffer, 0, &result_descriptor) .ok_or_else(|| { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; - Ok((result_matrix, out_buffer)) + Ok((result_matrix, buffer)) } pub fn capture>(&self, path: P) -> Result<()> { @@ -298,11 +316,13 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer_managed(self.buffer.length()); { let command_buffer = self.device.command_buffer(); + command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.end_encoding(); + command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); } self.device.wait_until_completed(); @@ -550,8 +570,9 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype"); + device.wait_until_completed(); let command_buffer = device.command_buffer(); - if layout.is_contiguous() { + if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", @@ -593,8 +614,10 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.set_label("to_dtype"); command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + device.wait_until_completed(); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -606,6 +629,7 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); + command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -695,7 +719,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.set_label("unary"); command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) @@ -962,7 +985,6 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { DType::F32 => ( metal::mps::MPS_FLOATBIT_ENCODING | 32, @@ -1028,9 +1050,11 @@ impl BackendStorage for MetalStorage { .new_matrix((b, m, n), size, type_id, self.dtype)?; let command_buffer = self.device.command_buffer(); + command_buffer.set_label("matmul"); let alpha = 1.0f64; - let beta = 0.0f64; + // let beta = f64::MIN; + let beta = 1.0; // Create kernel let matrix_multiplication = MatrixMultiplication::init( &self.device, @@ -1045,6 +1069,8 @@ impl BackendStorage for MetalStorage { .ok_or_else(|| { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; + matrix_multiplication.set_batch_size(b); + matrix_multiplication.set_batch_start(0); // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( @@ -1053,7 +1079,6 @@ impl BackendStorage for MetalStorage { &right_matrix, &result_matrix, ); - command_buffer.set_label("matmul"); command_buffer.commit(); out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); @@ -1062,9 +1087,11 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let command_buffer = self.device.command_buffer(); + // println!("Copy strided"); if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy_contiguous"); let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; @@ -1100,8 +1127,6 @@ impl BackendStorage for MetalStorage { command_buffer.set_label("copy_strided"); } command_buffer.commit(); - dst.buffer - .did_modify_range(metal::NSRange::new(0, dst.buffer.length())); Ok(()) } } @@ -1157,13 +1182,14 @@ impl BackendDevice for MetalDevice { // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); - let n = 50; + let n = 64; let command_queue = device.new_command_queue(); let command_buffers = (0..n) - .map(|_| { + .map(|i| { let command_buffer = command_queue.new_command_buffer().to_owned(); command_buffer.enqueue(); + command_buffer.set_label(&format!("num {i}")); command_buffer }) .collect(); @@ -1198,6 +1224,7 @@ impl BackendDevice for MetalDevice { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); let command_buffer = self.command_buffer(); + command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); blit.fill_buffer( &buffer, @@ -1208,7 +1235,6 @@ impl BackendDevice for MetalDevice { 0, ); blit.end_encoding(); - command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index c8dae511..8e16e6a9 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -144,6 +144,7 @@ impl RotaryEmbedding { let freqs = t.matmul(&inv_freq)?; let sin = freqs.sin()?; let cos = freqs.cos()?; + // todo!("{}", sin); Ok(Self { sin, cos }) } @@ -272,10 +273,10 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { - let view = xs.to_string(); - if view.contains("NaN") { - panic!("NaN"); - } + // let view = xs.to_string(); + // if view.contains("NaN") { + // panic!("NaN"); + // } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self