mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Heap buffers for metal ?
This commit is contained in:
@ -6,7 +6,7 @@ use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use half::f16;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Metal related errors
|
||||
@ -35,6 +35,7 @@ impl From<String> for MetalError {
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
heap: metal::Heap,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
@ -85,12 +86,13 @@ impl MetalDevice {
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
self.heap
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeShared)
|
||||
.expect(" New buffer")
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let option = metal::MTLResourceOptions::StorageModeShared;
|
||||
self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(data) as NSUInteger,
|
||||
@ -881,10 +883,19 @@ impl BackendDevice for MetalDevice {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let descriptor = HeapDescriptor::new();
|
||||
let mut size =
|
||||
device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared);
|
||||
size.size += (size.size & (size.align - 1)) + size.align;
|
||||
descriptor.set_size(size.size);
|
||||
descriptor.set_storage_mode(metal::MTLStorageMode::Shared);
|
||||
let heap = device.new_heap(&descriptor);
|
||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
heap,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
kernels,
|
||||
|
@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
Reference in New Issue
Block a user