mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend. * More BufferOffset usage. * Use in where-cond.
This commit is contained in:
@ -503,8 +503,7 @@ pub fn call_reduce_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -513,10 +512,7 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -536,7 +532,7 @@ pub fn call_reduce_contiguous(
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -552,8 +548,7 @@ pub fn call_reduce_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length: usize = shape.iter().product();
|
||||
@ -565,14 +560,7 @@ pub fn call_reduce_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
shape.len(),
|
||||
shape,
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(shape.len(), shape, strides, elements_to_sum, &input, output)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
@ -593,7 +581,7 @@ pub fn call_reduce_strided(
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
cond: &Buffer,
|
||||
(cond_stride, cond_offset): (&[usize], usize),
|
||||
left: &Buffer,
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
cond: BufferOffset,
|
||||
cond_stride: &[usize],
|
||||
left: BufferOffset,
|
||||
left_stride: &[usize],
|
||||
right: BufferOffset,
|
||||
right_stride: &[usize],
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided(
|
||||
cond_stride,
|
||||
left_stride,
|
||||
right_stride,
|
||||
(cond, cond_offset),
|
||||
(left, left_offset),
|
||||
(right, right_offset),
|
||||
&cond,
|
||||
&left,
|
||||
&right,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
|
||||
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1079,10 +1067,8 @@ pub fn call_index_select(
|
||||
contiguous: bool,
|
||||
src_dims: &[usize],
|
||||
src_strides: &[usize],
|
||||
input: &Buffer,
|
||||
src_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1107,16 +1093,16 @@ pub fn call_index_select(
|
||||
contiguous,
|
||||
src_dims,
|
||||
src_strides,
|
||||
(input, src_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1132,10 +1118,8 @@ pub fn call_gather(
|
||||
shape: &[usize],
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1157,16 +1141,16 @@ pub fn call_gather(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1182,10 +1166,8 @@ pub fn call_scatter_add(
|
||||
src_shape: &[usize],
|
||||
dst_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
@ -1208,16 +1190,16 @@ pub fn call_scatter_add(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1234,10 +1216,8 @@ pub fn call_index_add(
|
||||
dst_shape: &[usize],
|
||||
ids_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
@ -1261,16 +1241,16 @@ pub fn call_index_add(
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
ids_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
l_out,
|
||||
k_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1579,8 +1547,7 @@ pub fn call_im2col_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1598,21 +1565,11 @@ pub fn call_im2col_strided(
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
|
||||
strides: &[usize],
|
||||
out_w: usize,
|
||||
out_h: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
out_w,
|
||||
out_h,
|
||||
scale_w,
|
||||
scale_h,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
Reference in New Issue
Block a user