mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend. * More BufferOffset usage. * Use in where-cond.
This commit is contained in:
@ -314,6 +314,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_reduce_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -322,8 +323,7 @@ impl BackendStorage for MetalStorage {
|
||||
&dims,
|
||||
&stride,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -617,21 +617,21 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U8, DType::U8) => "where_u8_u8",
|
||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
let t = buffer_o(&t.buffer, t_l, t.dtype);
|
||||
let f = buffer_o(&f.buffer, f_l, f.dtype);
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
dims,
|
||||
&self.buffer,
|
||||
(
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
t,
|
||||
t_l.stride(),
|
||||
f,
|
||||
f_l.stride(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -664,6 +664,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => "im2col1d_f32",
|
||||
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col1d_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -672,8 +673,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
strides,
|
||||
(k_size, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -791,6 +791,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::U32 => "im2col_u32",
|
||||
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -799,8 +800,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(h_k, w_k, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1013,6 +1013,7 @@ impl BackendStorage for MetalStorage {
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, inp_l, self.dtype);
|
||||
candle_metal_kernels::call_upsample_nearest_2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1022,8 +1023,7 @@ impl BackendStorage for MetalStorage {
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
&self.buffer,
|
||||
inp_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1031,9 +1031,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
if !ids_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
|
||||
};
|
||||
let ids_el = ids_l.dims()[dim];
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
@ -1046,6 +1045,8 @@ impl BackendStorage for MetalStorage {
|
||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_gather(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1054,10 +1055,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_o1 * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1075,13 +1074,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||
@ -1100,6 +1094,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1108,10 +1104,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1147,6 +1141,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1158,10 +1154,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.is_contiguous(),
|
||||
src_l.dims(),
|
||||
src_l.stride(),
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_l.start_offset() * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1179,13 +1173,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::I64, DType::BF16) => "ia_i64_bf16",
|
||||
@ -1216,6 +1205,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1225,10 +1216,8 @@ impl BackendStorage for MetalStorage {
|
||||
l.dims(),
|
||||
ids_l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
@ -503,8 +503,7 @@ pub fn call_reduce_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -513,10 +512,7 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -536,7 +532,7 @@ pub fn call_reduce_contiguous(
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -552,8 +548,7 @@ pub fn call_reduce_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length: usize = shape.iter().product();
|
||||
@ -565,14 +560,7 @@ pub fn call_reduce_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
shape.len(),
|
||||
shape,
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(shape.len(), shape, strides, elements_to_sum, &input, output)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
@ -593,7 +581,7 @@ pub fn call_reduce_strided(
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
cond: &Buffer,
|
||||
(cond_stride, cond_offset): (&[usize], usize),
|
||||
left: &Buffer,
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
cond: BufferOffset,
|
||||
cond_stride: &[usize],
|
||||
left: BufferOffset,
|
||||
left_stride: &[usize],
|
||||
right: BufferOffset,
|
||||
right_stride: &[usize],
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided(
|
||||
cond_stride,
|
||||
left_stride,
|
||||
right_stride,
|
||||
(cond, cond_offset),
|
||||
(left, left_offset),
|
||||
(right, right_offset),
|
||||
&cond,
|
||||
&left,
|
||||
&right,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
|
||||
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1079,10 +1067,8 @@ pub fn call_index_select(
|
||||
contiguous: bool,
|
||||
src_dims: &[usize],
|
||||
src_strides: &[usize],
|
||||
input: &Buffer,
|
||||
src_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1107,16 +1093,16 @@ pub fn call_index_select(
|
||||
contiguous,
|
||||
src_dims,
|
||||
src_strides,
|
||||
(input, src_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1132,10 +1118,8 @@ pub fn call_gather(
|
||||
shape: &[usize],
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1157,16 +1141,16 @@ pub fn call_gather(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1182,10 +1166,8 @@ pub fn call_scatter_add(
|
||||
src_shape: &[usize],
|
||||
dst_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
@ -1208,16 +1190,16 @@ pub fn call_scatter_add(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1234,10 +1216,8 @@ pub fn call_index_add(
|
||||
dst_shape: &[usize],
|
||||
ids_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
@ -1261,16 +1241,16 @@ pub fn call_index_add(
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
ids_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
l_out,
|
||||
k_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1579,8 +1547,7 @@ pub fn call_im2col_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1598,21 +1565,11 @@ pub fn call_im2col_strided(
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
|
||||
strides: &[usize],
|
||||
out_w: usize,
|
||||
out_h: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
out_w,
|
||||
out_h,
|
||||
scale_w,
|
||||
scale_h,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
true,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
false,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
&dims,
|
||||
&strides,
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
);
|
||||
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let cond = BufferOffset {
|
||||
buffer: &cond,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
let left = BufferOffset {
|
||||
buffer: &left,
|
||||
offset_in_bytes: left_offset,
|
||||
};
|
||||
let right = BufferOffset {
|
||||
buffer: &right,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
cond,
|
||||
&cond_stride,
|
||||
left,
|
||||
&left_stride,
|
||||
right,
|
||||
&cond_stride,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&indices_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&indices_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
|
Reference in New Issue
Block a user