mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Heap buffers for metal ?
This commit is contained in:
@ -6,7 +6,7 @@ use candle_metal_kernels;
|
|||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
@ -35,6 +35,7 @@ impl From<String> for MetalError {
|
|||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
|
heap: metal::Heap,
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
}
|
}
|
||||||
@ -85,12 +86,13 @@ impl MetalDevice {
|
|||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
self.device
|
self.heap
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeShared)
|
||||||
|
.expect(" New buffer")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
let option = metal::MTLResourceOptions::StorageModeShared;
|
||||||
self.device.new_buffer_with_data(
|
self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const core::ffi::c_void,
|
data.as_ptr() as *const core::ffi::c_void,
|
||||||
core::mem::size_of_val(data) as NSUInteger,
|
core::mem::size_of_val(data) as NSUInteger,
|
||||||
@ -881,10 +883,19 @@ impl BackendDevice for MetalDevice {
|
|||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
|
|
||||||
|
let descriptor = HeapDescriptor::new();
|
||||||
|
let mut size =
|
||||||
|
device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared);
|
||||||
|
size.size += (size.size & (size.align - 1)) + size.align;
|
||||||
|
descriptor.set_size(size.size);
|
||||||
|
descriptor.set_storage_mode(metal::MTLStorageMode::Shared);
|
||||||
|
let heap = device.new_heap(&descriptor);
|
||||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
|
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
|
heap,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
kernels,
|
kernels,
|
||||||
|
@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
Reference in New Issue
Block a user