Fixes + cache compute_pipeline_state.

This commit is contained in:
Nicolas Patry
2023-11-13 14:33:16 +01:00
parent 79845bd93b
commit dd4a40f1c0
3 changed files with 68 additions and 133 deletions

View File

@ -86,7 +86,6 @@ impl MetalDevice {
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
// debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
@ -115,7 +114,9 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
let start = std::time::Instant::now();
self.device.wait_until_completed();
println!("Wait took {:?}", start.elapsed());
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
@ -414,6 +415,7 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
self.device.wait_until_completed();
Ok(Self {
buffer,
device: device.clone(),
@ -899,9 +901,12 @@ impl BackendDevice for MetalDevice {
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
let buffer = self.new_buffer(shape.elem_count(), dtype);
Ok(MetalStorage {
buffer,
device: self.clone(),
dtype,
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {