mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Reusing a single buffer (for now) to speed things up.
This commit is contained in:
@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels;
|
|||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::Arc;
|
use std::ops::Deref;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -36,6 +37,7 @@ impl From<String> for MetalError {
|
|||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
|
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,10 +56,6 @@ impl std::ops::Deref for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalDevice {
|
impl MetalDevice {
|
||||||
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
|
||||||
// self.device.as_ref()
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn id(&self) -> NSUInteger {
|
pub fn id(&self) -> NSUInteger {
|
||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
@ -66,6 +64,19 @@ impl MetalDevice {
|
|||||||
&self.command_queue
|
&self.command_queue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
|
||||||
|
self.command_buffer.read().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&self) {
|
||||||
|
let mut old = self.command_buffer.write().unwrap();
|
||||||
|
old.commit();
|
||||||
|
old.wait_until_completed();
|
||||||
|
let command_buffer = self.command_queue.new_owned_command_buffer();
|
||||||
|
*old = command_buffer;
|
||||||
|
// self.command_buffer.replace_with(|_| command_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
pub fn kernels(&self) -> &Kernels {
|
||||||
&self.kernels
|
&self.kernels
|
||||||
}
|
}
|
||||||
@ -105,6 +116,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
self.device.wait_until_completed();
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(
|
DType::U8 => Ok(CpuStorage::U8(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
||||||
@ -138,7 +151,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
let mut buffer = device.new_buffer(el, self.dtype);
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
assert_eq!(dtype, DType::F32);
|
assert_eq!(dtype, DType::F32);
|
||||||
candle_metal_kernels::call_affine(
|
candle_metal_kernels::call_affine(
|
||||||
@ -168,8 +181,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -225,7 +238,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -237,8 +250,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
@ -256,7 +269,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_queue.new_command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
@ -280,8 +293,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -295,7 +308,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_queue.new_command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
@ -329,8 +342,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
} else {
|
} else {
|
||||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
@ -350,7 +363,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let shape = lhs_l.shape();
|
let shape = lhs_l.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_queue.new_command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||||
{
|
{
|
||||||
@ -404,8 +417,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
@ -428,7 +441,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = t.dtype;
|
let dtype = t.dtype;
|
||||||
let mut buffer = self.device.new_buffer(el, dtype);
|
let mut buffer = self.device.new_buffer(el, dtype);
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
candle_metal_kernels::call_where_cond_strided(
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -447,8 +460,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
@ -544,7 +557,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -558,8 +571,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -641,7 +654,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||||
|
|
||||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
for bi in 0..b {
|
for bi in 0..b {
|
||||||
// Create matrix objects
|
// Create matrix objects
|
||||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
@ -689,14 +702,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
// Encode kernel to command buffer
|
// Encode kernel to command buffer
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
command_buffer,
|
command_buffer.deref(),
|
||||||
&left_matrix,
|
&left_matrix,
|
||||||
&right_matrix,
|
&right_matrix,
|
||||||
&result_matrix,
|
&result_matrix,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
@ -712,7 +725,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
if el_count == 0 {
|
if el_count == 0 {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
let kernel_name = match self.dtype {
|
let kernel_name = match self.dtype {
|
||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
@ -733,8 +746,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
dst_offset,
|
dst_offset,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
// command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
// command_buffer.wait_until_scheduled();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -760,10 +773,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
|
command_buffer,
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user