mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend. * More BufferOffset usage. * Use in where-cond.
This commit is contained in:
@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
true,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
false,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
&dims,
|
||||
&strides,
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
);
|
||||
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let cond = BufferOffset {
|
||||
buffer: &cond,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
let left = BufferOffset {
|
||||
buffer: &left,
|
||||
offset_in_bytes: left_offset,
|
||||
};
|
||||
let right = BufferOffset {
|
||||
buffer: &right,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
cond,
|
||||
&cond_stride,
|
||||
left,
|
||||
&left_stride,
|
||||
right,
|
||||
&cond_stride,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&indices_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&indices_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
|
Reference in New Issue
Block a user