From c921cc37847e5bdbd149264d919de73a14dffb6b Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sat, 4 Nov 2023 09:03:23 +0100 Subject: [PATCH] Add Arc to metalstorage buffer for quick cloning --- candle-core/src/metal_backend.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 34c74622..1850cc8f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -10,6 +10,7 @@ use metal; use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::mps::{Float32, MPSDataType}; use metal::{Buffer, MTLResourceOptions}; +use std::sync::Arc; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -56,7 +57,7 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + buffer: Arc, device: MetalDevice, dtype: DType, } @@ -120,10 +121,10 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - todo!("Implement the kernel calling"); + //todo!("Implement the kernel calling"); // device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype); Ok(Self { - buffer, + buffer: Arc::new(buffer), device, dtype, }) @@ -292,7 +293,7 @@ impl MetalStorage { println!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); return Ok(Self { - buffer: out_buffer, + buffer: Arc::new(out_buffer), device: self.device.clone(), dtype: self.dtype(), }); @@ -304,7 +305,7 @@ impl MetalStorage { rhs_l.is_contiguous() ); return Ok(Self { - buffer: out_buffer, + buffer: Arc::new(out_buffer), device: self.device.clone(), dtype: self.dtype(), }); @@ -358,7 +359,7 @@ impl MetalStorage { &result_matrix, ); Ok(Self { - buffer: out_buffer, + buffer: Arc::new(out_buffer), device: self.device.clone(), dtype: self.dtype(), }) @@ -446,7 +447,7 @@ impl BackendDevice for MetalDevice { ), }; Ok(Self::Storage { - buffer, + buffer: Arc::new(buffer), device: self.clone(), dtype: storage.dtype(), })