Use BufferOffset in metal backend ops. (#2029)

* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
This commit is contained in:
Laurent Mazare
2024-04-08 09:37:25 +02:00
committed by GitHub
parent c5fe4a7f89
commit 718671a0d5
3 changed files with 117 additions and 178 deletions

View File

@ -314,6 +314,7 @@ impl BackendStorage for MetalStorage {
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
@ -322,8 +323,7 @@ impl BackendStorage for MetalStorage {
&dims,
&stride,
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -617,21 +617,21 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::U8) => "where_u8_u8",
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
let t = buffer_o(&t.buffer, t_l, t.dtype);
let f = buffer_o(&f.buffer, f_l, f.dtype);
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
dims,
&self.buffer,
(
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
),
&t.buffer,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
src,
layout.stride(),
t,
t_l.stride(),
f,
f_l.stride(),
&buffer,
)
.map_err(MetalError::from)?;
@ -664,6 +664,7 @@ impl BackendStorage for MetalStorage {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
@ -672,8 +673,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&dst,
)
.map_err(MetalError::from)?;
@ -791,6 +791,7 @@ impl BackendStorage for MetalStorage {
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
@ -799,8 +800,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&dst,
)
.map_err(MetalError::from)?;
@ -1013,6 +1013,7 @@ impl BackendStorage for MetalStorage {
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, inp_l, self.dtype);
candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device,
&command_buffer,
@ -1022,8 +1023,7 @@ impl BackendStorage for MetalStorage {
strides,
out_w,
out_h,
&self.buffer,
inp_l.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -1031,9 +1031,8 @@ impl BackendStorage for MetalStorage {
}
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let (ids_o1, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
if !ids_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
};
let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count();
@ -1046,6 +1045,8 @@ impl BackendStorage for MetalStorage {
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_gather(
&device.device,
&command_buffer,
@ -1054,10 +1055,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
ids_el,
dim,
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_o1 * ids.dtype.size_in_bytes(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
@ -1075,13 +1074,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32",
@ -1100,6 +1094,8 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
&self.device.device,
&command_buffer,
@ -1108,10 +1104,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
@ -1147,6 +1141,8 @@ impl BackendStorage for MetalStorage {
}
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@ -1158,10 +1154,8 @@ impl BackendStorage for MetalStorage {
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
@ -1179,13 +1173,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::I64, DType::BF16) => "ia_i64_bf16",
@ -1216,6 +1205,8 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_add(
&self.device.device,
&command_buffer,
@ -1225,10 +1216,8 @@ impl BackendStorage for MetalStorage {
l.dims(),
ids_l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;