Use BufferOffset in metal backend ops. (#2029)

* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
This commit is contained in:
Laurent Mazare
2024-04-08 09:37:25 +02:00
committed by GitHub
parent c5fe4a7f89
commit 718671a0d5
3 changed files with 117 additions and 178 deletions

View File

@ -503,8 +503,7 @@ pub fn call_reduce_contiguous(
kernel_name: &'static str,
length: usize,
out_length: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -513,10 +512,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
);
set_params!(encoder, (length, elements_to_sum, &input, output));
let thread_group_count = MTLSize {
width: out_length as u64,
@ -536,7 +532,7 @@ pub fn call_reduce_contiguous(
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -552,8 +548,7 @@ pub fn call_reduce_strided(
shape: &[usize],
strides: &[usize],
out_length: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
@ -565,14 +560,7 @@ pub fn call_reduce_strided(
set_params!(
encoder,
(
shape.len(),
shape,
strides,
elements_to_sum,
(input, input_offset),
output
)
(shape.len(), shape, strides, elements_to_sum, &input, output)
);
let thread_group_count = MTLSize {
@ -593,7 +581,7 @@ pub fn call_reduce_strided(
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
cond: &Buffer,
(cond_stride, cond_offset): (&[usize], usize),
left: &Buffer,
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
cond: BufferOffset,
cond_stride: &[usize],
left: BufferOffset,
left_stride: &[usize],
right: BufferOffset,
right_stride: &[usize],
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided(
cond_stride,
left_stride,
right_stride,
(cond, cond_offset),
(left, left_offset),
(right, right_offset),
&cond,
&left,
&right,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
encoder.use_resource(left, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1079,10 +1067,8 @@ pub fn call_index_select(
contiguous: bool,
src_dims: &[usize],
src_strides: &[usize],
input: &Buffer,
src_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1107,16 +1093,16 @@ pub fn call_index_select(
contiguous,
src_dims,
src_strides,
(input, src_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1132,10 +1118,8 @@ pub fn call_gather(
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1157,16 +1141,16 @@ pub fn call_gather(
src_dim_size,
right_size,
ids_size,
(input, input_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1182,10 +1166,8 @@ pub fn call_scatter_add(
src_shape: &[usize],
dst_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
@ -1208,16 +1190,16 @@ pub fn call_scatter_add(
src_dim_size,
right_size,
dst_dim_size,
(input, input_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1234,10 +1216,8 @@ pub fn call_index_add(
dst_shape: &[usize],
ids_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
@ -1261,16 +1241,16 @@ pub fn call_index_add(
right_size,
dst_dim_size,
ids_dim_size,
(input, input_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided(
shape: &[usize],
strides: &[usize],
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided(
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
l_out,
k_size,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
(dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1579,8 +1547,7 @@ pub fn call_im2col_strided(
shape: &[usize],
strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1598,21 +1565,11 @@ pub fn call_im2col_strided(
set_params!(
encoder,
(
dst_el,
h_out,
w_out,
h_k,
w_k,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
strides: &[usize],
out_w: usize,
out_h: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
out_w,
out_h,
scale_w,
scale_h,
shape,
strides,
(input, input_offset),
output
)
(out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();