mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Fixes + cache compute_pipeline_state.
This commit is contained in:
@ -86,7 +86,6 @@ impl MetalDevice {
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
// debug!("Allocate 1 - buffer size {size}");
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
@ -115,7 +114,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
let start = std::time::Instant::now();
|
||||
self.device.wait_until_completed();
|
||||
println!("Wait took {:?}", start.elapsed());
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
@ -414,6 +415,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -899,9 +901,12 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||
Ok(MetalStorage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
|
Reference in New Issue
Block a user