From 634a4e716881dd82bc4cb66cff02d8e4e0f6e3b4 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 6 Nov 2023 08:23:36 +0100 Subject: [PATCH] BlitEncoder added to affine for copying buffer contents quickly. --- candle-core/src/metal_backend.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d911fe32..d6bed286 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -90,33 +90,45 @@ impl BackendStorage for MetalStorage { fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { let device = self.device().clone(); - /* let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); // TODO: Don't load library every time - let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap(); + let library = device + .new_library_with_source(AFFINE, &CompileOptions::new()) + .unwrap(); let function = library.get_function("affine", None).unwrap(); let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); + let output_size = el * self.dtype.size_in_bytes(); + let output_buffer = device.new_buffer(output_size, self.dtype); + + let src_length = self.buffer.length() as usize - layout.start_offset(); + let src = self.device.new_buffer(src_length, self.dtype); + let blit_encoder = self.device.command_buffer.new_blit_command_encoder(); + blit_encoder.copy_from_buffer( + self.buffer.as_ref(), + layout.start_offset() as NSUInteger, + output_buffer.as_ref(), + 0, + (src_length * self.dtype.size_in_bytes()) as NSUInteger, + ); + blit_encoder.end_encoding(); + let encoder = device.command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - - let output_size = el * self.dtype.size_in_bytes(); encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); - let output_buffer = device.new_buffer(output_size, self.dtype); - encoder.set_bytes(0, 4, void_ptr(&el)); encoder.set_bytes(1, 4, void_ptr(&dims)); let info = [dims, layout.stride()].concat(); let info_len = (info.len() * mem::size_of::()) as NSUInteger; encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast()); - encoder.set_buffer(3, Some(&self.buffer), 0); + encoder.set_buffer(3, Some(&src), 0); encoder.set_buffer(4, Some(&output_buffer), 0); encoder.set_bytes(5, 4, void_ptr(&(mul as f32))); @@ -136,7 +148,6 @@ impl BackendStorage for MetalStorage { encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); - */ Ok(self.clone()) }