diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d62bb159..6159ffc0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -6,7 +6,7 @@ use candle_metal_kernels; use candle_metal_kernels::Kernels; use half::f16; use metal; -use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger}; use std::sync::{Arc, RwLock}; /// Metal related errors @@ -35,6 +35,7 @@ impl From for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, + heap: metal::Heap, command_buffer: Arc>, kernels: Arc, } @@ -85,12 +86,13 @@ impl MetalDevice { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.device - .new_buffer(size, MTLResourceOptions::StorageModeManaged) + self.heap + .new_buffer(size, MTLResourceOptions::StorageModeShared) + .expect(" New buffer") } pub fn new_buffer_with_data(&self, data: &[T]) -> Buffer { - let option = metal::MTLResourceOptions::StorageModeManaged; + let option = metal::MTLResourceOptions::StorageModeShared; self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, core::mem::size_of_val(data) as NSUInteger, @@ -881,10 +883,19 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); + + let descriptor = HeapDescriptor::new(); + let mut size = + device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared); + size.size += (size.size & (size.align - 1)) + size.align; + descriptor.set_size(size.size); + descriptor.set_storage_mode(metal::MTLStorageMode::Shared); + let heap = device.new_heap(&descriptor); let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer())); let kernels = Arc::new(Kernels::new()); Ok(Self { device, + heap, command_queue, command_buffer, kernels, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a9d108f4..a0227119 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -300,9 +300,6 @@ pub fn call_unary_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); - let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline);