Heap buffers for metal ?

This commit is contained in:
Nicolas Patry
2023-11-13 18:56:46 +01:00
parent 4289984d32
commit 51f05e997d
2 changed files with 15 additions and 7 deletions

View File

@ -6,7 +6,7 @@ use candle_metal_kernels;
use candle_metal_kernels::Kernels; use candle_metal_kernels::Kernels;
use half::f16; use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
/// Metal related errors /// Metal related errors
@ -35,6 +35,7 @@ impl From<String> for MetalError {
pub struct MetalDevice { pub struct MetalDevice {
device: metal::Device, device: metal::Device,
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
heap: metal::Heap,
command_buffer: Arc<RwLock<metal::CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
} }
@ -85,12 +86,13 @@ impl MetalDevice {
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger; let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.device self.heap
.new_buffer(size, MTLResourceOptions::StorageModeManaged) .new_buffer(size, MTLResourceOptions::StorageModeShared)
.expect(" New buffer")
} }
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer { pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let option = metal::MTLResourceOptions::StorageModeManaged; let option = metal::MTLResourceOptions::StorageModeShared;
self.device.new_buffer_with_data( self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void, data.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(data) as NSUInteger, core::mem::size_of_val(data) as NSUInteger,
@ -881,10 +883,19 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let descriptor = HeapDescriptor::new();
let mut size =
device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared);
size.size += (size.size & (size.align - 1)) + size.align;
descriptor.set_size(size.size);
descriptor.set_storage_mode(metal::MTLStorageMode::Shared);
let heap = device.new_heap(&descriptor);
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer())); let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
Ok(Self { Ok(Self {
device, device,
heap,
command_queue, command_queue,
command_buffer, command_buffer,
kernels, kernels,

View File

@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);