From 71fcb31873ea5cf5a692296976a951d111370121 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 1 Nov 2023 18:03:53 +0100 Subject: [PATCH] Owned command buffer now. --- Cargo.toml | 3 ++- candle-core/src/metal_backend.rs | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d3130105..c827507d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,7 +55,8 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../metal-rs", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ae49239a..fae1b341 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -21,7 +21,8 @@ pub enum MetalError { #[derive(Clone)] pub struct MetalDevice { device: metal::Device, - command_queue: metal::CommandQueue, + _command_queue: metal::CommandQueue, + command_buffer: metal::CommandBuffer, } impl std::fmt::Debug for MetalDevice { @@ -250,15 +251,13 @@ impl BackendStorage for MetalStorage { ) .expect("Failed to create matrix multiplication kernel"); - let buffer = self.device.command_queue.new_command_buffer(); // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( - &buffer, + &self.device.command_buffer, &left_matrix, &right_matrix, &result_matrix, ); - buffer.commit(); Ok(Self{ buffer: out_buffer, device: self.device.clone(), @@ -280,8 +279,9 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - let command_queue = device.new_command_queue(); - Ok(Self { device, command_queue }) + let _command_queue = device.new_command_queue(); + let command_buffer = _command_queue.new_owned_command_buffer(); + Ok(Self { device, _command_queue, command_buffer }) } fn set_seed(&self, _seed: u64) -> Result<()> {