Use BufferOffset in metal backend ops. (#2029)

* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
This commit is contained in:
Laurent Mazare
2024-04-08 09:37:25 +02:00
committed by GitHub
parent c5fe4a7f89
commit 718671a0d5
3 changed files with 117 additions and 178 deletions

View File

@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
true,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&embeddings_buffer),
BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
false,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&embeddings_buffer),
BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&dims,
&strides,
out_length,
&input,
0,
BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let cond = BufferOffset {
buffer: &cond,
offset_in_bytes: cond_offset,
};
let left = BufferOffset {
buffer: &left,
offset_in_bytes: left_offset,
};
let right = BufferOffset {
buffer: &right,
offset_in_bytes: cond_offset,
};
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
&cond,
(&cond_stride, cond_offset),
&left,
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
cond,
&cond_stride,
left,
&left_stride,
right,
&cond_stride,
&output,
)
.unwrap();
@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
&input_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&ids_buffer),
&output,
)
.unwrap();
@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
&input_buffer,
0,
&indices_buffer,
0,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&indices_buffer),
&output,
)
.unwrap();