mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add Arc to metalstorage buffer for quick cloning
This commit is contained in:
@ -10,6 +10,7 @@ use metal;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::mps::{Float32, MPSDataType};
|
||||
use metal::{Buffer, MTLResourceOptions};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -56,7 +57,7 @@ impl MetalDevice {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: metal::Buffer,
|
||||
buffer: Arc<metal::Buffer>,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
@ -120,10 +121,10 @@ impl BackendStorage for MetalStorage {
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
todo!("Implement the kernel calling");
|
||||
//todo!("Implement the kernel calling");
|
||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
buffer: Arc::new(buffer),
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
@ -292,7 +293,7 @@ impl MetalStorage {
|
||||
println!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
buffer: Arc::new(out_buffer),
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
@ -304,7 +305,7 @@ impl MetalStorage {
|
||||
rhs_l.is_contiguous()
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
buffer: Arc::new(out_buffer),
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
@ -358,7 +359,7 @@ impl MetalStorage {
|
||||
&result_matrix,
|
||||
);
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
buffer: Arc::new(out_buffer),
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
@ -446,7 +447,7 @@ impl BackendDevice for MetalDevice {
|
||||
),
|
||||
};
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
buffer: Arc::new(buffer),
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
|
Reference in New Issue
Block a user