mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Scatter add.
This commit is contained in:
@ -45,6 +45,12 @@ pub enum MetalError {
|
|||||||
},
|
},
|
||||||
#[error("{0:?}")]
|
#[error("{0:?}")]
|
||||||
LockError(LockError),
|
LockError(LockError),
|
||||||
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
|
UnexpectedDType {
|
||||||
|
msg: &'static str,
|
||||||
|
expected: DType,
|
||||||
|
got: DType,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for MetalError {
|
impl From<String> for MetalError {
|
||||||
@ -827,12 +833,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||||
Some(o12) => o12,
|
Some(o12) => o12,
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
|
||||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
|
||||||
let ids_el = ids_l.dims()[dim];
|
let ids_el = ids_l.dims()[dim];
|
||||||
let dst_el = ids_l.shape().elem_count();
|
let dst_el = ids_l.shape().elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
@ -853,7 +857,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
ids_el,
|
ids_el,
|
||||||
dim,
|
dim,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
|
src_l.start_offset() * dtype.size_in_bytes(),
|
||||||
&ids.buffer,
|
&ids.buffer,
|
||||||
|
ids_o1 * ids.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -862,14 +868,48 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn scatter_add(
|
fn scatter_add(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
l: &Layout,
|
||||||
_: &Self,
|
ids: &Self,
|
||||||
_: &Layout,
|
ids_l: &Layout,
|
||||||
_: &Self,
|
src: &Self,
|
||||||
_: &Layout,
|
src_l: &Layout,
|
||||||
_: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
crate::bail!("scatter_add metal")
|
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||||
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
|
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||||
|
Some(o12) => o12,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let src_offset = match src_l.contiguous_offsets() {
|
||||||
|
Some((o1, _)) => o1,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::U32, DType::F32) => "sa_u32_f32",
|
||||||
|
_ => Err(MetalError::UnexpectedDType {
|
||||||
|
msg: "scatter-add ids should be u8/u32/i64",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: ids.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_scatter_add(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
src_l.dims(),
|
||||||
|
l.dims(),
|
||||||
|
dim,
|
||||||
|
&src.buffer,
|
||||||
|
src_offset * src.dtype.size_in_bytes(),
|
||||||
|
&ids.buffer,
|
||||||
|
ids_offset * ids.dtype.size_in_bytes(),
|
||||||
|
&acc.buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
|
@ -63,11 +63,6 @@ METAL_FUNC void gather(
|
|||||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
/*
|
|
||||||
// Force prevent out of bounds indexing
|
|
||||||
// since there doesn't seem to be a good way to force crash
|
|
||||||
// No need to check for zero we're only allowing unsized.
|
|
||||||
*/
|
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||||
output[tid] = input[src_i];
|
output[tid] = input[src_i];
|
||||||
}
|
}
|
||||||
@ -87,6 +82,45 @@ kernel void NAME( \
|
|||||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
|
METAL_FUNC void scatter_add(
|
||||||
|
constant size_t &dst_size,
|
||||||
|
constant size_t &left_size,
|
||||||
|
constant size_t &src_dim_size,
|
||||||
|
constant size_t &right_size,
|
||||||
|
constant size_t &dst_dim_size,
|
||||||
|
const device TYPENAME *input,
|
||||||
|
const device INDEX_TYPENAME *input_ids,
|
||||||
|
device TYPENAME *output,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
if (tid >= dst_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t right_rank_i = tid % right_size;
|
||||||
|
const size_t left_rank_i = tid / right_size;
|
||||||
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
|
output[dst_i] += input[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &dst_size, \
|
||||||
|
constant size_t &left_size, \
|
||||||
|
constant size_t &src_dim_size, \
|
||||||
|
constant size_t &right_size, \
|
||||||
|
constant size_t &dst_dim_size, \
|
||||||
|
const device TYPENAME *input, \
|
||||||
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
@ -136,6 +170,8 @@ INDEX_OP(is_u32_f32, uint, float)
|
|||||||
INDEX_OP(is_u32_f16, uint, half)
|
INDEX_OP(is_u32_f16, uint, half)
|
||||||
GATHER_OP(gather_u32_f32, uint, float)
|
GATHER_OP(gather_u32_f32, uint, float)
|
||||||
GATHER_OP(gather_u32_f16, uint, half)
|
GATHER_OP(gather_u32_f16, uint, half)
|
||||||
|
SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||||
|
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
@ -1020,7 +1020,9 @@ pub fn call_gather(
|
|||||||
ids_size: usize,
|
ids_size: usize,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
ids: &Buffer,
|
ids: &Buffer,
|
||||||
|
ids_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
@ -1043,8 +1045,60 @@ pub fn call_gather(
|
|||||||
src_dim_size,
|
src_dim_size,
|
||||||
right_size,
|
right_size,
|
||||||
ids_size,
|
ids_size,
|
||||||
input,
|
(input, input_offset),
|
||||||
ids,
|
(ids, ids_offset),
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_scatter_add(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
src_shape: &[usize],
|
||||||
|
dst_shape: &[usize],
|
||||||
|
dim: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
|
ids: &Buffer,
|
||||||
|
ids_offset: usize,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let left_size: usize = src_shape[..dim].iter().product();
|
||||||
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||||
|
let src_dim_size = src_shape[dim];
|
||||||
|
let dst_el = left_size * right_size;
|
||||||
|
let dst_dim_size = dst_shape[dim];
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
dst_el,
|
||||||
|
left_size,
|
||||||
|
src_dim_size,
|
||||||
|
right_size,
|
||||||
|
dst_dim_size,
|
||||||
|
(input, input_offset),
|
||||||
|
(ids, ids_offset),
|
||||||
output
|
output
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
Reference in New Issue
Block a user