Add Arc to metalstorage buffer for quick cloning

This commit is contained in:
Ivar Flakstad
2023-11-04 09:03:23 +01:00
parent d4d6850c78
commit c921cc3784

View File

@ -10,6 +10,7 @@ use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType};
use metal::{Buffer, MTLResourceOptions};
use std::sync::Arc;
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -56,7 +57,7 @@ impl MetalDevice {
#[derive(Debug, Clone)]
pub struct MetalStorage {
buffer: metal::Buffer,
buffer: Arc<metal::Buffer>,
device: MetalDevice,
dtype: DType,
}
@ -120,10 +121,10 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling");
//todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self {
buffer,
buffer: Arc::new(buffer),
device,
dtype,
})
@ -292,7 +293,7 @@ impl MetalStorage {
println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
return Ok(Self {
buffer: out_buffer,
buffer: Arc::new(out_buffer),
device: self.device.clone(),
dtype: self.dtype(),
});
@ -304,7 +305,7 @@ impl MetalStorage {
rhs_l.is_contiguous()
);
return Ok(Self {
buffer: out_buffer,
buffer: Arc::new(out_buffer),
device: self.device.clone(),
dtype: self.dtype(),
});
@ -358,7 +359,7 @@ impl MetalStorage {
&result_matrix,
);
Ok(Self {
buffer: out_buffer,
buffer: Arc::new(out_buffer),
device: self.device.clone(),
dtype: self.dtype(),
})
@ -446,7 +447,7 @@ impl BackendDevice for MetalDevice {
),
};
Ok(Self::Storage {
buffer,
buffer: Arc::new(buffer),
device: self.clone(),
dtype: storage.dtype(),
})