mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend. * More BufferOffset usage. * Use in where-cond.
This commit is contained in:
@ -314,6 +314,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_reduce_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -322,8 +323,7 @@ impl BackendStorage for MetalStorage {
|
||||
&dims,
|
||||
&stride,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -617,21 +617,21 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U8, DType::U8) => "where_u8_u8",
|
||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
let t = buffer_o(&t.buffer, t_l, t.dtype);
|
||||
let f = buffer_o(&f.buffer, f_l, f.dtype);
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
dims,
|
||||
&self.buffer,
|
||||
(
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
src,
|
||||
layout.stride(),
|
||||
t,
|
||||
t_l.stride(),
|
||||
f,
|
||||
f_l.stride(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -664,6 +664,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => "im2col1d_f32",
|
||||
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col1d_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -672,8 +673,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
strides,
|
||||
(k_size, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -791,6 +791,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::U32 => "im2col_u32",
|
||||
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -799,8 +800,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(h_k, w_k, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1013,6 +1013,7 @@ impl BackendStorage for MetalStorage {
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, inp_l, self.dtype);
|
||||
candle_metal_kernels::call_upsample_nearest_2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1022,8 +1023,7 @@ impl BackendStorage for MetalStorage {
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
&self.buffer,
|
||||
inp_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1031,9 +1031,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
if !ids_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
|
||||
};
|
||||
let ids_el = ids_l.dims()[dim];
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
@ -1046,6 +1045,8 @@ impl BackendStorage for MetalStorage {
|
||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_gather(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1054,10 +1055,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_o1 * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1075,13 +1074,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||
@ -1100,6 +1094,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1108,10 +1104,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1147,6 +1141,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1158,10 +1154,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.is_contiguous(),
|
||||
src_l.dims(),
|
||||
src_l.stride(),
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_l.start_offset() * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1179,13 +1173,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::I64, DType::BF16) => "ia_i64_bf16",
|
||||
@ -1216,6 +1205,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1225,10 +1216,8 @@ impl BackendStorage for MetalStorage {
|
||||
l.dims(),
|
||||
ids_l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
Reference in New Issue
Block a user