Use BufferOffset in metal backend ops. (#2029)

* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
This commit is contained in:
Laurent Mazare
2024-04-08 09:37:25 +02:00
committed by GitHub
parent c5fe4a7f89
commit 718671a0d5
3 changed files with 117 additions and 178 deletions

View File

@ -314,6 +314,7 @@ impl BackendStorage for MetalStorage {
let dtype = if return_index { DType::U32 } else { self.dtype }; let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_strided( candle_metal_kernels::call_reduce_strided(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -322,8 +323,7 @@ impl BackendStorage for MetalStorage {
&dims, &dims,
&stride, &stride,
dst_el, dst_el,
&self.buffer, src,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -617,21 +617,21 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::U8) => "where_u8_u8", (DType::U8, DType::U8) => "where_u8_u8",
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
}; };
let src = buffer_o(&self.buffer, layout, self.dtype);
let t = buffer_o(&t.buffer, t_l, t.dtype);
let f = buffer_o(&f.buffer, f_l, f.dtype);
candle_metal_kernels::call_where_cond_strided( candle_metal_kernels::call_where_cond_strided(
&device.device, &device.device,
&command_buffer, &command_buffer,
&device.kernels, &device.kernels,
name, name,
dims, dims,
&self.buffer, src,
(
layout.stride(), layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(), t,
), t_l.stride(),
&t.buffer, f,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), f_l.stride(),
&f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -664,6 +664,7 @@ impl BackendStorage for MetalStorage {
DType::F32 => "im2col1d_f32", DType::F32 => "im2col1d_f32",
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"), dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
}; };
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col1d_strided( candle_metal_kernels::call_im2col1d_strided(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,
@ -672,8 +673,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(), layout.shape().dims(),
strides, strides,
(k_size, stride, padding, dilation), (k_size, stride, padding, dilation),
&self.buffer, src,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst, &dst,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -791,6 +791,7 @@ impl BackendStorage for MetalStorage {
DType::U32 => "im2col_u32", DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"), dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
}; };
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col_strided( candle_metal_kernels::call_im2col_strided(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,
@ -799,8 +800,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(), layout.shape().dims(),
layout.stride(), layout.stride(),
(h_k, w_k, stride, padding, dilation), (h_k, w_k, stride, padding, dilation),
&self.buffer, src,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst, &dst,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1013,6 +1013,7 @@ impl BackendStorage for MetalStorage {
.device .device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?; .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, inp_l, self.dtype);
candle_metal_kernels::call_upsample_nearest_2d( candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,
@ -1022,8 +1023,7 @@ impl BackendStorage for MetalStorage {
strides, strides,
out_w, out_w,
out_h, out_h,
&self.buffer, src,
inp_l.start_offset() * self.dtype.size_in_bytes(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1031,9 +1031,8 @@ impl BackendStorage for MetalStorage {
} }
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let (ids_o1, _) = match ids_l.contiguous_offsets() { if !ids_l.is_contiguous() {
Some(o12) => o12, return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
}; };
let ids_el = ids_l.dims()[dim]; let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count(); let dst_el = ids_l.shape().elem_count();
@ -1046,6 +1045,8 @@ impl BackendStorage for MetalStorage {
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
}; };
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_gather( candle_metal_kernels::call_gather(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -1054,10 +1055,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(), src_l.dims(),
ids_el, ids_el,
dim, dim,
&self.buffer, src,
src_l.start_offset() * dtype.size_in_bytes(), ids,
&ids.buffer,
ids_o1 * ids.dtype.size_in_bytes(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1075,13 +1074,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> { ) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?; self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() { if !ids_l.is_contiguous() || !src_l.is_contiguous() {
Some(o12) => o12, return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
}; };
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32", (DType::U8, DType::F32) => "sa_u8_f32",
@ -1100,6 +1094,8 @@ impl BackendStorage for MetalStorage {
})?, })?,
}; };
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add( candle_metal_kernels::call_scatter_add(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,
@ -1108,10 +1104,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(), src_l.dims(),
l.dims(), l.dims(),
dim, dim,
&src.buffer, src,
src_offset * src.dtype.size_in_bytes(), ids,
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
&acc.buffer, &acc.buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1147,6 +1141,8 @@ impl BackendStorage for MetalStorage {
} }
}; };
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_select( candle_metal_kernels::call_index_select(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -1158,10 +1154,8 @@ impl BackendStorage for MetalStorage {
src_l.is_contiguous(), src_l.is_contiguous(),
src_l.dims(), src_l.dims(),
src_l.stride(), src_l.stride(),
&self.buffer, src,
src_l.start_offset() * dtype.size_in_bytes(), ids,
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1179,13 +1173,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> { ) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?; self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() { if !ids_l.is_contiguous() || !src_l.is_contiguous() {
Some(o12) => o12, return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
}; };
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::I64, DType::BF16) => "ia_i64_bf16", (DType::I64, DType::BF16) => "ia_i64_bf16",
@ -1216,6 +1205,8 @@ impl BackendStorage for MetalStorage {
})?, })?,
}; };
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_add( candle_metal_kernels::call_index_add(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,
@ -1225,10 +1216,8 @@ impl BackendStorage for MetalStorage {
l.dims(), l.dims(),
ids_l.dims(), ids_l.dims(),
dim, dim,
&src.buffer, src,
src_offset * src.dtype.size_in_bytes(), ids,
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
&acc.buffer, &acc.buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;

View File

@ -503,8 +503,7 @@ pub fn call_reduce_contiguous(
kernel_name: &'static str, kernel_name: &'static str,
length: usize, length: usize,
out_length: usize, out_length: usize,
input: &Buffer, input: BufferOffset,
input_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -513,10 +512,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(encoder, (length, elements_to_sum, &input, output));
encoder,
(length, elements_to_sum, (input, input_offset), output)
);
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: out_length as u64, width: out_length as u64,
@ -536,7 +532,7 @@ pub fn call_reduce_contiguous(
depth: 1, depth: 1,
}; };
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -552,8 +548,7 @@ pub fn call_reduce_strided(
shape: &[usize], shape: &[usize],
strides: &[usize], strides: &[usize],
out_length: usize, out_length: usize,
input: &Buffer, input: BufferOffset,
input_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -565,14 +560,7 @@ pub fn call_reduce_strided(
set_params!( set_params!(
encoder, encoder,
( (shape.len(), shape, strides, elements_to_sum, &input, output)
shape.len(),
shape,
strides,
elements_to_sum,
(input, input_offset),
output
)
); );
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
@ -593,7 +581,7 @@ pub fn call_reduce_strided(
depth: 1, depth: 1,
}; };
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided(
kernels: &Kernels, kernels: &Kernels,
name: &'static str, name: &'static str,
shape: &[usize], shape: &[usize],
cond: &Buffer, cond: BufferOffset,
(cond_stride, cond_offset): (&[usize], usize), cond_stride: &[usize],
left: &Buffer, left: BufferOffset,
(left_stride, left_offset): (&[usize], usize), left_stride: &[usize],
right: &Buffer, right: BufferOffset,
(right_stride, right_offset): (&[usize], usize), right_stride: &[usize],
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided(
cond_stride, cond_stride,
left_stride, left_stride,
right_stride, right_stride,
(cond, cond_offset), &cond,
(left, left_offset), &left,
(right, right_offset), &right,
output output
) )
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(cond, metal::MTLResourceUsage::Read); encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(left, metal::MTLResourceUsage::Read); encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1079,10 +1067,8 @@ pub fn call_index_select(
contiguous: bool, contiguous: bool,
src_dims: &[usize], src_dims: &[usize],
src_strides: &[usize], src_strides: &[usize],
input: &Buffer, input: BufferOffset,
src_offset: usize, ids: BufferOffset,
ids: &Buffer,
ids_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product(); let left_size: usize = shape[..dim].iter().product();
@ -1107,16 +1093,16 @@ pub fn call_index_select(
contiguous, contiguous,
src_dims, src_dims,
src_strides, src_strides,
(input, src_offset), &input,
(ids, ids_offset), &ids,
output output
) )
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1132,10 +1118,8 @@ pub fn call_gather(
shape: &[usize], shape: &[usize],
ids_size: usize, ids_size: usize,
dim: usize, dim: usize,
input: &Buffer, input: BufferOffset,
input_offset: usize, ids: BufferOffset,
ids: &Buffer,
ids_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product(); let left_size: usize = shape[..dim].iter().product();
@ -1157,16 +1141,16 @@ pub fn call_gather(
src_dim_size, src_dim_size,
right_size, right_size,
ids_size, ids_size,
(input, input_offset), &input,
(ids, ids_offset), &ids,
output output
) )
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1182,10 +1166,8 @@ pub fn call_scatter_add(
src_shape: &[usize], src_shape: &[usize],
dst_shape: &[usize], dst_shape: &[usize],
dim: usize, dim: usize,
input: &Buffer, input: BufferOffset,
input_offset: usize, ids: BufferOffset,
ids: &Buffer,
ids_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product(); let left_size: usize = src_shape[..dim].iter().product();
@ -1208,16 +1190,16 @@ pub fn call_scatter_add(
src_dim_size, src_dim_size,
right_size, right_size,
dst_dim_size, dst_dim_size,
(input, input_offset), &input,
(ids, ids_offset), &ids,
output output
) )
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1234,10 +1216,8 @@ pub fn call_index_add(
dst_shape: &[usize], dst_shape: &[usize],
ids_shape: &[usize], ids_shape: &[usize],
dim: usize, dim: usize,
input: &Buffer, input: BufferOffset,
input_offset: usize, ids: BufferOffset,
ids: &Buffer,
ids_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product(); let left_size: usize = src_shape[..dim].iter().product();
@ -1261,16 +1241,16 @@ pub fn call_index_add(
right_size, right_size,
dst_dim_size, dst_dim_size,
ids_dim_size, ids_dim_size,
(input, input_offset), &input,
(ids, ids_offset), &ids,
output output
) )
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided(
shape: &[usize], shape: &[usize],
strides: &[usize], strides: &[usize],
(k_size, stride, padding, dilation): (usize, usize, usize, usize), (k_size, stride, padding, dilation): (usize, usize, usize, usize),
input: &Buffer, input: BufferOffset,
input_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided(
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
( (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
dst_el,
l_out,
k_size,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
); );
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1579,8 +1547,7 @@ pub fn call_im2col_strided(
shape: &[usize], shape: &[usize],
strides: &[usize], strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
input: &Buffer, input: BufferOffset,
input_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1598,21 +1565,11 @@ pub fn call_im2col_strided(
set_params!( set_params!(
encoder, encoder,
( (
dst_el, dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
h_out,
w_out,
h_k,
w_k,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output output
) )
); );
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
strides: &[usize], strides: &[usize],
out_w: usize, out_w: usize,
out_h: usize, out_h: usize,
input: &Buffer, input: BufferOffset,
input_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
( (out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
out_w,
out_h,
scale_w,
scale_h,
shape,
strides,
(input, input_offset),
output
)
); );
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();

View File

@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
true, true,
shape, shape,
stride, stride,
&embeddings_buffer, BufferOffset::zero_offset(&embeddings_buffer),
0, BufferOffset::zero_offset(&ids_buffer),
&ids_buffer,
0,
&dst_buffer, &dst_buffer,
) )
.unwrap(); .unwrap();
@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
false, false,
shape, shape,
stride, stride,
&embeddings_buffer, BufferOffset::zero_offset(&embeddings_buffer),
0, BufferOffset::zero_offset(&ids_buffer),
&ids_buffer,
0,
&dst_buffer, &dst_buffer,
) )
.unwrap(); .unwrap();
@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&dims, &dims,
&strides, &strides,
out_length, out_length,
&input, BufferOffset::zero_offset(&input),
0,
&output, &output,
) )
.unwrap(); .unwrap();
@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
); );
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let cond = BufferOffset {
buffer: &cond,
offset_in_bytes: cond_offset,
};
let left = BufferOffset {
buffer: &left,
offset_in_bytes: left_offset,
};
let right = BufferOffset {
buffer: &right,
offset_in_bytes: cond_offset,
};
call_where_cond_strided( call_where_cond_strided(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
name, name,
shape, shape,
&cond, cond,
(&cond_stride, cond_offset), &cond_stride,
&left, left,
(&left_stride, left_offset), &left_stride,
&right, right,
(&cond_stride, cond_offset), &cond_stride,
&output, &output,
) )
.unwrap(); .unwrap();
@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
shape, shape,
shape, shape,
dim, dim,
&input_buffer, BufferOffset::zero_offset(&input_buffer),
0, BufferOffset::zero_offset(&ids_buffer),
&ids_buffer,
0,
&output, &output,
) )
.unwrap(); .unwrap();
@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
shape, shape,
shape, shape,
dim, dim,
&input_buffer, BufferOffset::zero_offset(&input_buffer),
0, BufferOffset::zero_offset(&indices_buffer),
&indices_buffer,
0,
&output, &output,
) )
.unwrap(); .unwrap();