mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Improve metal buffer usage (#1807)
* Improve metal buffer usage * Clone cpu storage when loading to reduce wait_until_complete calls * Use powers of two for buffer sizes so reuse is more likely. * Select best available buffer by size. * Add count to MetalStorage -> can use buffer with different size Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co> * Simplify new buffer creation without blit copy. Revert &[] -> Vec * Add documentation on newBufferWithBytes safety / synchronization * Drop unused buffers after command buffer is done syncing. --------- Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>
This commit is contained in:
@ -9,7 +9,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
||||||
|
|
||||||
/// Simple way to catch lock error without
|
/// Simple way to catch lock error without
|
||||||
/// depending on T
|
/// depending on T
|
||||||
@ -60,7 +60,8 @@ impl From<String> for MetalError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
|
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||||
|
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
@ -68,7 +69,7 @@ pub struct MetalDevice {
|
|||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
|
|
||||||
/// Single command queue for the entire device.
|
/// Single command queue for the entire device.
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: CommandQueue,
|
||||||
/// One command buffer at a time.
|
/// One command buffer at a time.
|
||||||
/// The scheduler works by allowing multiple
|
/// The scheduler works by allowing multiple
|
||||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||||
@ -78,7 +79,7 @@ pub struct MetalDevice {
|
|||||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||||
/// command buffer2 starts (or there are metal bugs there)
|
/// command buffer2 starts (or there are metal bugs there)
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||||
/// Keeps track of the current amount of compute command encoders on the current
|
/// Keeps track of the current amount of compute command encoders on the current
|
||||||
/// command buffer
|
/// command buffer
|
||||||
/// Arc, RwLock because of the interior mutability.
|
/// Arc, RwLock because of the interior mutability.
|
||||||
@ -87,7 +88,7 @@ pub struct MetalDevice {
|
|||||||
compute_per_buffer: usize,
|
compute_per_buffer: usize,
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
/// Heavily used by [`candle_metal_kernels`]
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<Kernels>,
|
||||||
/// Simple allocator struct.
|
/// Simple allocator struct.
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||||
@ -99,7 +100,7 @@ pub struct MetalDevice {
|
|||||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||||
///
|
///
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||||
/// (strong_count = 1).
|
/// (strong_count = 1).
|
||||||
buffers: AllocatedBuffers,
|
buffers: AllocatedBuffers,
|
||||||
/// Seed for random number generation.
|
/// Seed for random number generation.
|
||||||
@ -145,6 +146,8 @@ impl MetalDevice {
|
|||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
*command_buffer_lock = command_buffer.clone();
|
*command_buffer_lock = command_buffer.clone();
|
||||||
*index = 0;
|
*index = 0;
|
||||||
|
|
||||||
|
self.drop_unused_buffers()?;
|
||||||
}
|
}
|
||||||
*index += 1;
|
*index += 1;
|
||||||
Ok(command_buffer)
|
Ok(command_buffer)
|
||||||
@ -163,6 +166,7 @@ impl MetalDevice {
|
|||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,39 +203,25 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new buffer from data.
|
/// Creates a new buffer from data.
|
||||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
///
|
///
|
||||||
/// This method will block the computation because of the
|
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||||
/// lack of lifetime management through the GPU.
|
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||||
/// Internal comment for technical details.
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||||
let tmp = self.device.new_buffer_with_data(
|
let new_buffer = self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const core::ffi::c_void,
|
data.as_ptr() as *const c_void,
|
||||||
size,
|
size,
|
||||||
metal::MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
let real = self.allocate_buffer(
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
size,
|
let subbuffers = buffers
|
||||||
metal::MTLResourceOptions::StorageModePrivate,
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
"with_data",
|
.or_insert(vec![]);
|
||||||
)?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
command_buffer.set_label("with_data");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.set_label("with_data_blit");
|
|
||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
|
||||||
blit.end_encoding();
|
|
||||||
|
|
||||||
// This is necessary, for mmaped safetensors
|
let new_buffer = Arc::new(new_buffer);
|
||||||
// Because of the unsafe slice cast we're doing.
|
subbuffers.push(new_buffer.clone());
|
||||||
// The slice might not live long enough for metal
|
Ok(new_buffer)
|
||||||
// To actually fill the GPU buffer.
|
|
||||||
// Putting this wait forces the GPU buffer to be filled
|
|
||||||
// with the actual data allowing the CPU storage to do
|
|
||||||
// deallocate properly.
|
|
||||||
self.wait_until_completed()?;
|
|
||||||
Ok(real)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||||
@ -255,6 +245,40 @@ impl MetalDevice {
|
|||||||
Ok(buffer)
|
Ok(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn find_available_buffer(
|
||||||
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
buffers: &RwLockWriteGuard<BufferMap>,
|
||||||
|
) -> Option<Arc<Buffer>> {
|
||||||
|
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||||
|
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||||
|
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||||
|
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||||
|
for sub in subbuffers {
|
||||||
|
if Arc::strong_count(sub) == 1 {
|
||||||
|
best_buffer = Some(sub);
|
||||||
|
best_buffer_size = *buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best_buffer.map(|b| b.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
|
for subbuffers in buffers.values_mut() {
|
||||||
|
let newbuffers = subbuffers
|
||||||
|
.iter()
|
||||||
|
.filter(|s| Arc::strong_count(*s) > 1)
|
||||||
|
.map(Arc::clone)
|
||||||
|
.collect();
|
||||||
|
*subbuffers = newbuffers;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
/// The critical allocator algorithm
|
||||||
fn allocate_buffer(
|
fn allocate_buffer(
|
||||||
&self,
|
&self,
|
||||||
@ -263,24 +287,18 @@ impl MetalDevice {
|
|||||||
_name: &str,
|
_name: &str,
|
||||||
) -> Result<Arc<Buffer>> {
|
) -> Result<Arc<Buffer>> {
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
|
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||||
|
// Cloning also ensures we increment the strong count
|
||||||
|
return Ok(b.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = buf_size(size);
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||||
|
|
||||||
for sub in &mut *subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
return Ok(sub.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||||
let new_buffer = Arc::new(new_buffer);
|
let new_buffer = Arc::new(new_buffer);
|
||||||
subbuffers.push(new_buffer.clone());
|
subbuffers.push(new_buffer.clone());
|
||||||
for subbuffers in buffers.values_mut() {
|
|
||||||
let newbuffers = subbuffers
|
|
||||||
.iter()
|
|
||||||
.filter(|s| Arc::strong_count(s) > 1)
|
|
||||||
.map(Arc::clone)
|
|
||||||
.collect();
|
|
||||||
*subbuffers = newbuffers;
|
|
||||||
}
|
|
||||||
Ok(new_buffer)
|
Ok(new_buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,6 +323,8 @@ pub struct MetalStorage {
|
|||||||
buffer: Arc<metal::Buffer>,
|
buffer: Arc<metal::Buffer>,
|
||||||
/// a reference to the device owning this buffer
|
/// a reference to the device owning this buffer
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
|
/// The count of allocated elements in the buffer
|
||||||
|
count: usize,
|
||||||
/// The dtype is kept since buffers are untyped.
|
/// The dtype is kept since buffers are untyped.
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -386,7 +406,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
|
fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
|
||||||
@ -435,7 +455,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||||
@ -484,7 +504,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
@ -562,7 +582,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
Ok(Self::new(buffer, device, dtype))
|
Ok(Self::new(buffer, device, dst_el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||||
@ -654,7 +674,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.set_label("to_dtype");
|
command_buffer.set_label("to_dtype");
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
@ -774,7 +794,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOpT>(
|
fn binary_impl<B: BinaryOpT>(
|
||||||
@ -835,7 +855,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, device, dtype))
|
Ok(Self::new(buffer, device, el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
@ -880,6 +900,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let col = Self {
|
let col = Self {
|
||||||
buffer: dst,
|
buffer: dst,
|
||||||
device,
|
device,
|
||||||
|
count: dst_el,
|
||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
};
|
};
|
||||||
let l_out = params.l_out();
|
let l_out = params.l_out();
|
||||||
@ -964,6 +985,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let col = Self {
|
let col = Self {
|
||||||
buffer: dst,
|
buffer: dst,
|
||||||
device,
|
device,
|
||||||
|
count: dst_el,
|
||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
};
|
};
|
||||||
let h_out = params.out_h();
|
let h_out = params.out_h();
|
||||||
@ -1049,7 +1071,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, self.device.clone(), self.dtype))
|
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
@ -1083,7 +1105,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_add(
|
||||||
@ -1172,7 +1194,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_add(
|
fn index_add(
|
||||||
@ -1254,7 +1276,12 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
Ok(Self::new(
|
||||||
|
buffer,
|
||||||
|
self.device.clone(),
|
||||||
|
b * m * n,
|
||||||
|
self.dtype(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
@ -1303,10 +1330,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalStorage {
|
impl MetalStorage {
|
||||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, count: usize, dtype: DType) -> Self {
|
||||||
Self {
|
Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
|
count,
|
||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1521,29 +1549,23 @@ impl MetalStorage {
|
|||||||
(buffer, dtype)
|
(buffer, dtype)
|
||||||
};
|
};
|
||||||
command_buffer.set_label("binary");
|
command_buffer.set_label("binary");
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
|
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
|
||||||
let length = self.buffer.length() as usize;
|
let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
|
||||||
let size = self.dtype.size_in_bytes();
|
|
||||||
if length % size != 0 {
|
let buffer = self.device.new_buffer_managed(size)?;
|
||||||
crate::bail!(
|
|
||||||
"The Metal buffer length is not aligned with dtype {:?}",
|
|
||||||
self.dtype
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
|
||||||
{
|
{
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
Ok(read_to_vec(&buffer, length / size))
|
Ok(read_to_vec(&buffer, self.count))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1561,7 +1583,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||||
Ok(val) => val.parse()?,
|
Ok(val) => val.parse()?,
|
||||||
_ => 10,
|
_ => 50,
|
||||||
};
|
};
|
||||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||||
[299792458].as_ptr() as *const c_void,
|
[299792458].as_ptr() as *const c_void,
|
||||||
@ -1593,7 +1615,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
let size = shape.elem_count() * dtype.size_in_bytes();
|
let size = shape.elem_count() * dtype.size_in_bytes();
|
||||||
let buffer = self.allocate_zeros(size)?;
|
let buffer = self.allocate_zeros(size)?;
|
||||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
Ok(MetalStorage::new(
|
||||||
|
buffer,
|
||||||
|
self.clone(),
|
||||||
|
shape.elem_count(),
|
||||||
|
dtype,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
@ -1603,16 +1630,21 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
let buffer = match storage {
|
let (count, buffer) = match storage {
|
||||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
}?;
|
};
|
||||||
Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
|
Ok(Self::Storage::new(
|
||||||
|
buffer?,
|
||||||
|
self.clone(),
|
||||||
|
count,
|
||||||
|
storage.dtype(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(
|
fn rand_uniform(
|
||||||
@ -1643,7 +1675,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
Ok(Self::Storage::new(
|
||||||
|
buffer,
|
||||||
|
self.clone(),
|
||||||
|
shape.elem_count(),
|
||||||
|
dtype,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_normal(
|
fn rand_normal(
|
||||||
@ -1674,7 +1711,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
Ok(Self::Storage::new(
|
||||||
|
buffer,
|
||||||
|
self.clone(),
|
||||||
|
shape.elem_count(),
|
||||||
|
dtype,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
@ -1693,6 +1735,10 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||||
|
(size - 1).next_power_of_two() as NSUInteger
|
||||||
|
}
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||||
let ptr = buffer.contents() as *const T;
|
let ptr = buffer.contents() as *const T;
|
||||||
assert!(!ptr.is_null());
|
assert!(!ptr.is_null());
|
||||||
|
@ -106,7 +106,12 @@ impl QMetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
Ok(MetalStorage::new(
|
||||||
|
buffer,
|
||||||
|
self.device.clone(),
|
||||||
|
elem_count,
|
||||||
|
DType::F32,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||||
@ -170,7 +175,7 @@ impl QMetalStorage {
|
|||||||
&dst,
|
&dst,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||||
Ok((dst_storage, dst_shape))
|
Ok((dst_storage, dst_shape))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -238,7 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
let newstorage =
|
||||||
|
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
|
||||||
Ok((newstorage, layout.shape().clone()))
|
Ok((newstorage, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user