Add support for strided index-select on Metal (#1909)

* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
This commit is contained in:
Thomas Santerre
2024-03-22 02:30:02 -04:00
committed by GitHub
parent 6708870e63
commit fee33b45c2
4 changed files with 129 additions and 23 deletions

View File

@ -2,9 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::CallConvTranspose2dCfg;
use candle_metal_kernels::Kernels;
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
@ -1348,12 +1347,8 @@ impl BackendStorage for MetalStorage {
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
if !(src_l.is_contiguous()
&& src_l.start_offset() == 0
&& ids_l.is_contiguous()
&& ids_l.start_offset() == 0)
{
crate::bail!("Metal strided index_select not implemented");
if !ids_l.is_contiguous() {
crate::bail!("Metal index_select requires contiguous ids")
}
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
@ -1364,6 +1359,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
@ -1382,8 +1379,13 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
ids_el,
dim,
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;

View File

@ -1,20 +1,38 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
constant size_t &ids_size,
constant bool &contiguous,
constant size_t *src_dims,
constant size_t *src_strides,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
@ -26,7 +44,8 @@ METAL_FUNC void index(
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
output[tid] = input[src_i];
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -36,12 +55,15 @@ kernel void NAME( \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
constant bool &contiguous, \
constant size_t *src_dims, \
constant size_t *src_strides, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \
}
@ -165,10 +187,15 @@ kernel void NAME( \
}
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
#endif
INDEX_OP(is_u8_f32, uint8_t, float)
INDEX_OP(is_u8_f16, uint8_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
#endif

View File

@ -1067,8 +1067,13 @@ pub fn call_index_select(
shape: &[usize],
ids_size: usize,
dim: usize,
contiguous: bool,
src_dims: &[usize],
src_strides: &[usize],
input: &Buffer,
src_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1090,8 +1095,11 @@ pub fn call_index_select(
src_dim_size,
right_size,
ids_size,
input,
ids,
contiguous,
src_dims,
src_strides,
(input, src_offset),
(ids, ids_offset),
output
)
);

View File

@ -600,22 +600,35 @@ fn affine_strided() {
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let stride = [2, 1];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let stride = [1, 2];
let ids = [0u32, 1, 0];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn index_select_strided() {
let embedding = (0..16).map(|x| x as f32).collect::<Vec<_>>();
let shape = [2, 2];
let stride = [2, 4];
let ids = [0u32];
let dim = 0;
let result = run_index_select_strided(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![0.0, 4.0]);
}
#[test]
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
@ -623,9 +636,10 @@ fn index_select_f16() {
.map(|x| f16::from_f32(x))
.collect();
let shape = [5, 2];
let stride = [2, 1];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f16");
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@ -636,9 +650,10 @@ fn index_select_f16() {
fn index_select_is_u32_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let stride = [2, 1];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@ -649,9 +664,10 @@ fn index_select_is_u32_bf16() {
fn index_select_is_u8_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let stride = [2, 1];
let ids = [0u8, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u8_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@ -662,9 +678,10 @@ fn index_select_is_u8_bf16() {
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let stride = [2, 1];
let ids = [0u32, 1, 0];
let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@ -674,6 +691,7 @@ fn index_select_dim1() {
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
embeddings: &[T],
shape: &[usize],
stride: &[usize],
ids: &[I],
dim: usize,
name: &'static str,
@ -699,8 +717,59 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape,
ids.len(),
dim,
true,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
&dst_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&dst_buffer, dst_el)
}
fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
embeddings: &[T],
shape: &[usize],
stride: &[usize],
ids: &[I],
dim: usize,
name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
&kernels,
name,
shape,
ids.len(),
dim,
false,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
&dst_buffer,
)
.unwrap();