mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add Arc to metalstorage buffer for quick cloning
This commit is contained in:
@ -10,6 +10,7 @@ use metal;
|
|||||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||||
use metal::mps::{Float32, MPSDataType};
|
use metal::mps::{Float32, MPSDataType};
|
||||||
use metal::{Buffer, MTLResourceOptions};
|
use metal::{Buffer, MTLResourceOptions};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -56,7 +57,7 @@ impl MetalDevice {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
buffer: metal::Buffer,
|
buffer: Arc<metal::Buffer>,
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -120,10 +121,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
todo!("Implement the kernel calling");
|
//todo!("Implement the kernel calling");
|
||||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer: Arc::new(buffer),
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
@ -292,7 +293,7 @@ impl MetalStorage {
|
|||||||
println!("TODO implement batched matmul for B={b}");
|
println!("TODO implement batched matmul for B={b}");
|
||||||
// bail!("Didn't implemented strided matmul yet");
|
// bail!("Didn't implemented strided matmul yet");
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: Arc::new(out_buffer),
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
});
|
});
|
||||||
@ -304,7 +305,7 @@ impl MetalStorage {
|
|||||||
rhs_l.is_contiguous()
|
rhs_l.is_contiguous()
|
||||||
);
|
);
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: Arc::new(out_buffer),
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
});
|
});
|
||||||
@ -358,7 +359,7 @@ impl MetalStorage {
|
|||||||
&result_matrix,
|
&result_matrix,
|
||||||
);
|
);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: Arc::new(out_buffer),
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
})
|
})
|
||||||
@ -446,7 +447,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
Ok(Self::Storage {
|
Ok(Self::Storage {
|
||||||
buffer,
|
buffer: Arc::new(buffer),
|
||||||
device: self.clone(),
|
device: self.clone(),
|
||||||
dtype: storage.dtype(),
|
dtype: storage.dtype(),
|
||||||
})
|
})
|
||||||
|
Reference in New Issue
Block a user