Reuse buffers on our own reference counts.

This commit is contained in:
Nicolas Patry
2023-11-18 23:28:59 +01:00
parent 251c65f9f1
commit eed1631ee2
2 changed files with 77 additions and 46 deletions

View File

@ -298,7 +298,7 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
@ -320,7 +320,7 @@ pub fn call_unary_strided(
input: &Buffer,
strides: &[usize],
offset: usize,
output: &mut Buffer,
output: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -358,7 +358,7 @@ pub fn call_binary_contiguous(
length: usize,
left: &Buffer,
right: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -386,7 +386,7 @@ pub fn call_binary_strided(
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -425,7 +425,7 @@ pub fn call_cast_contiguous(
kernel_name: &'static str,
length: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -450,7 +450,7 @@ pub fn call_cast_strided(
input: &Buffer,
input_strides: &[usize],
input_offset: usize,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
@ -482,7 +482,7 @@ pub fn call_reduce_contiguous(
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
@ -523,7 +523,7 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
@ -564,7 +564,7 @@ pub fn call_affine(
name: &'static str,
size: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
@ -590,7 +590,7 @@ pub fn call_affine_strided(
input: &Buffer,
input_stride: &[usize],
input_offset: usize,
output: &mut Buffer,
output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
@ -632,7 +632,7 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -675,7 +675,7 @@ pub fn call_index_select(
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -750,7 +750,7 @@ mod tests {
name,
v.len(),
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -775,7 +775,7 @@ mod tests {
x.len(),
&left,
&right,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -805,7 +805,7 @@ mod tests {
&input,
strides,
offset,
&mut output,
&output,
0,
)
.unwrap();
@ -943,7 +943,7 @@ mod tests {
name,
v.len(),
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -984,7 +984,7 @@ mod tests {
"affine_float",
size,
&input,
&mut output,
&output,
mul as f32,
add as f32,
)
@ -1021,7 +1021,7 @@ mod tests {
&input,
strides,
0,
&mut output,
&output,
mul as f32,
add as f32,
)
@ -1119,7 +1119,7 @@ mod tests {
dim,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer,
&dst_buffer,
)
.unwrap();
@ -1227,7 +1227,7 @@ mod tests {
out_length,
&input,
0,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -1255,7 +1255,7 @@ mod tests {
v.len(),
last_dim,
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -1355,7 +1355,7 @@ mod tests {
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
&mut output,
&output,
)
.unwrap();
command_buffer.commit();