Add Arc to metalstorage buffer for quick cloning

This commit is contained in:
Ivar Flakstad
2023-11-04 09:03:23 +01:00
parent d4d6850c78
commit c921cc3784

View File

@ -10,6 +10,7 @@ use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType}; use metal::mps::{Float32, MPSDataType};
use metal::{Buffer, MTLResourceOptions}; use metal::{Buffer, MTLResourceOptions};
use std::sync::Arc;
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -56,7 +57,7 @@ impl MetalDevice {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MetalStorage { pub struct MetalStorage {
buffer: metal::Buffer, buffer: Arc<metal::Buffer>,
device: MetalDevice, device: MetalDevice,
dtype: DType, dtype: DType,
} }
@ -120,10 +121,10 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims(); let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling"); //todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype); // device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self { Ok(Self {
buffer, buffer: Arc::new(buffer),
device, device,
dtype, dtype,
}) })
@ -292,7 +293,7 @@ impl MetalStorage {
println!("TODO implement batched matmul for B={b}"); println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet"); // bail!("Didn't implemented strided matmul yet");
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: Arc::new(out_buffer),
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}); });
@ -304,7 +305,7 @@ impl MetalStorage {
rhs_l.is_contiguous() rhs_l.is_contiguous()
); );
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: Arc::new(out_buffer),
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}); });
@ -358,7 +359,7 @@ impl MetalStorage {
&result_matrix, &result_matrix,
); );
Ok(Self { Ok(Self {
buffer: out_buffer, buffer: Arc::new(out_buffer),
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}) })
@ -446,7 +447,7 @@ impl BackendDevice for MetalDevice {
), ),
}; };
Ok(Self::Storage { Ok(Self::Storage {
buffer, buffer: Arc::new(buffer),
device: self.clone(), device: self.clone(),
dtype: storage.dtype(), dtype: storage.dtype(),
}) })