Reuse buffers on our own reference counts.

This commit is contained in:
Nicolas Patry
2023-11-18 23:28:59 +01:00
parent 251c65f9f1
commit eed1631ee2
2 changed files with 77 additions and 46 deletions

View File

@ -8,6 +8,7 @@ use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::collections::HashMap;
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -37,6 +38,7 @@ pub struct MetalDevice {
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
buffers: Arc<RwLock<HashMap<usize, Vec<Arc<Buffer>>>>>,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -87,8 +89,26 @@ impl MetalDevice {
&self.device &self.device
} }
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{
let size = (element_count * dtype.size_in_bytes()) as NSUInteger; let size = element_count * dtype.size_in_bytes();
let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry(size).or_insert(vec![]);
for sub in &mut *subbuffers{
// if sub.retain_count() == 1{
// println!("{size} {:?}", );
if Arc::strong_count(sub) == 1{
return sub.clone();
}
}
let new_buffer = self.device
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
new_buffer
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
self.device self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged) .new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
@ -105,7 +125,7 @@ impl MetalDevice {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MetalStorage { pub struct MetalStorage {
buffer: metal::Buffer, buffer: Arc<metal::Buffer>,
device: MetalDevice, device: MetalDevice,
dtype: DType, dtype: DType,
} }
@ -126,29 +146,38 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed(); self.device.wait_until_completed();
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8( DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize), buffer.read_to_vec(buffer.length() as usize),
)), )),
DType::U32 => Ok(CpuStorage::U32( DType::U32 => Ok(CpuStorage::U32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4), buffer.read_to_vec(buffer.length() as usize / 4),
)), )),
DType::I64 => Ok(CpuStorage::I64( DType::I64 => Ok(CpuStorage::I64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8), buffer.read_to_vec(buffer.length() as usize / 8),
)), )),
DType::F16 => Ok(CpuStorage::F16( DType::F16 => Ok(CpuStorage::F16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2), buffer.read_to_vec(buffer.length() as usize / 2),
)), )),
DType::BF16 => Ok(CpuStorage::BF16( DType::BF16 => Ok(CpuStorage::BF16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2), buffer.read_to_vec(buffer.length() as usize / 2),
)), )),
DType::F32 => Ok(CpuStorage::F32( DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4), buffer.read_to_vec(buffer.length() as usize / 4),
)), )),
DType::F64 => Ok(CpuStorage::F64( DType::F64 => Ok(CpuStorage::F64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8), buffer.read_to_vec(buffer.length() as usize / 8),
)), )),
} }
} }
@ -175,7 +204,7 @@ impl BackendStorage for MetalStorage {
name, name,
el, el,
&self.buffer, &self.buffer,
&mut buffer, &buffer,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -195,7 +224,7 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
layout.stride(), layout.stride(),
layout.start_offset() * dtype.size_in_bytes(), layout.start_offset() * dtype.size_in_bytes(),
&mut buffer, &buffer,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -270,7 +299,7 @@ impl BackendStorage for MetalStorage {
dst_el, dst_el,
&self.buffer, &self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -305,7 +334,7 @@ impl BackendStorage for MetalStorage {
kernel_name, kernel_name,
el_count, el_count,
&self.buffer, &self.buffer,
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
@ -324,7 +353,7 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
layout.stride(), layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
@ -382,7 +411,7 @@ impl BackendStorage for MetalStorage {
kernel_name, kernel_name,
el_count, el_count,
&self.buffer, &self.buffer,
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
@ -425,7 +454,7 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
layout.stride(), layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer, &buffer,
0, 0,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -481,7 +510,7 @@ impl BackendStorage for MetalStorage {
el_count, el_count,
&self.buffer, &self.buffer,
&rhs.buffer, &rhs.buffer,
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
@ -510,7 +539,7 @@ impl BackendStorage for MetalStorage {
&rhs.buffer, &rhs.buffer,
rhs_l.stride(), rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
@ -551,7 +580,7 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer, &f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self { Ok(Self {
@ -661,7 +690,7 @@ impl BackendStorage for MetalStorage {
dim, dim,
&self.buffer, &self.buffer,
&ids.buffer, &ids.buffer,
&mut buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self { Ok(Self {
@ -860,7 +889,7 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
src_l.stride(), src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(), src_l.start_offset() * self.dtype.size_in_bytes(),
&mut dst.buffer, &dst.buffer,
dst_offset * dst.dtype.size_in_bytes(), dst_offset * dst.dtype.size_in_bytes(),
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -869,7 +898,7 @@ impl BackendStorage for MetalStorage {
} }
impl MetalStorage { impl MetalStorage {
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
Self { Self {
buffer, buffer,
device, device,
@ -904,10 +933,12 @@ impl BackendDevice for MetalDevice {
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
Ok(Self { Ok(Self {
device, device,
command_queue, command_queue,
command_buffer, command_buffer,
buffers,
kernels, kernels,
}) })
} }
@ -952,7 +983,7 @@ impl BackendDevice for MetalDevice {
CpuStorage::F64(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
}; };
Ok(Self::Storage { Ok(Self::Storage {
buffer, buffer: buffer.into(),
device: self.clone(), device: self.clone(),
dtype: storage.dtype(), dtype: storage.dtype(),
}) })

View File

@ -298,7 +298,7 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel, kernel_name: unary::contiguous::Kernel,
length: usize, length: usize,
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
@ -320,7 +320,7 @@ pub fn call_unary_strided(
input: &Buffer, input: &Buffer,
strides: &[usize], strides: &[usize],
offset: usize, offset: usize,
output: &mut Buffer, output: &Buffer,
output_offset: usize, output_offset: usize,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -358,7 +358,7 @@ pub fn call_binary_contiguous(
length: usize, length: usize,
left: &Buffer, left: &Buffer,
right: &Buffer, right: &Buffer,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -386,7 +386,7 @@ pub fn call_binary_strided(
right_input: &Buffer, right_input: &Buffer,
right_strides: &[usize], right_strides: &[usize],
right_offset: usize, right_offset: usize,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -425,7 +425,7 @@ pub fn call_cast_contiguous(
kernel_name: &'static str, kernel_name: &'static str,
length: usize, length: usize,
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -450,7 +450,7 @@ pub fn call_cast_strided(
input: &Buffer, input: &Buffer,
input_strides: &[usize], input_strides: &[usize],
input_offset: usize, input_offset: usize,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0); // println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length()); // assert_eq!(input.length(), output.length());
@ -482,7 +482,7 @@ pub fn call_reduce_contiguous(
out_length: usize, out_length: usize,
input: &Buffer, input: &Buffer,
input_offset: usize, input_offset: usize,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
@ -523,7 +523,7 @@ pub fn call_last_softmax(
length: usize, length: usize,
elements_to_sum: usize, elements_to_sum: usize,
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
@ -564,7 +564,7 @@ pub fn call_affine(
name: &'static str, name: &'static str,
size: usize, size: usize,
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &Buffer,
mul: f32, mul: f32,
add: f32, add: f32,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
@ -590,7 +590,7 @@ pub fn call_affine_strided(
input: &Buffer, input: &Buffer,
input_stride: &[usize], input_stride: &[usize],
input_offset: usize, input_offset: usize,
output: &mut Buffer, output: &Buffer,
mul: f32, mul: f32,
add: f32, add: f32,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
@ -632,7 +632,7 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize), (left_stride, left_offset): (&[usize], usize),
right: &Buffer, right: &Buffer,
(right_stride, right_offset): (&[usize], usize), (right_stride, right_offset): (&[usize], usize),
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -675,7 +675,7 @@ pub fn call_index_select(
dim: usize, dim: usize,
input: &Buffer, input: &Buffer,
ids: &Buffer, ids: &Buffer,
output: &mut Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product(); let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product(); let right_size: usize = shape[dim + 1..].iter().product();
@ -750,7 +750,7 @@ mod tests {
name, name,
v.len(), v.len(),
&input, &input,
&mut output, &output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -775,7 +775,7 @@ mod tests {
x.len(), x.len(),
&left, &left,
&right, &right,
&mut output, &output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -805,7 +805,7 @@ mod tests {
&input, &input,
strides, strides,
offset, offset,
&mut output, &output,
0, 0,
) )
.unwrap(); .unwrap();
@ -943,7 +943,7 @@ mod tests {
name, name,
v.len(), v.len(),
&input, &input,
&mut output, &output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -984,7 +984,7 @@ mod tests {
"affine_float", "affine_float",
size, size,
&input, &input,
&mut output, &output,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -1021,7 +1021,7 @@ mod tests {
&input, &input,
strides, strides,
0, 0,
&mut output, &output,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -1119,7 +1119,7 @@ mod tests {
dim, dim,
&embeddings_buffer, &embeddings_buffer,
&ids_buffer, &ids_buffer,
&mut dst_buffer, &dst_buffer,
) )
.unwrap(); .unwrap();
@ -1227,7 +1227,7 @@ mod tests {
out_length, out_length,
&input, &input,
0, 0,
&mut output, &output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -1255,7 +1255,7 @@ mod tests {
v.len(), v.len(),
last_dim, last_dim,
&input, &input,
&mut output, &output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -1355,7 +1355,7 @@ mod tests {
(&left_stride, left_offset), (&left_stride, left_offset),
&right, &right,
(&cond_stride, cond_offset), (&cond_stride, cond_offset),
&mut output, &output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();