mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add support for strided index-select on Metal (#1909)
* initial implementation * use correct index, but still not breaking like it should have... * fix test
This commit is contained in:
@ -2,9 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
|
|||||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
|
use candle_metal_kernels::CallConvTranspose2dCfg;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
|
|
||||||
use metal;
|
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -1348,12 +1347,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
if !(src_l.is_contiguous()
|
if !ids_l.is_contiguous() {
|
||||||
&& src_l.start_offset() == 0
|
crate::bail!("Metal index_select requires contiguous ids")
|
||||||
&& ids_l.is_contiguous()
|
|
||||||
&& ids_l.start_offset() == 0)
|
|
||||||
{
|
|
||||||
crate::bail!("Metal strided index_select not implemented");
|
|
||||||
}
|
}
|
||||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
@ -1364,6 +1359,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||||
|
(DType::U8, DType::F32) => "is_u8_f32",
|
||||||
|
(DType::U8, DType::F16) => "is_u8_f16",
|
||||||
|
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
@ -1382,8 +1379,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
src_l.dims(),
|
src_l.dims(),
|
||||||
ids_el,
|
ids_el,
|
||||||
dim,
|
dim,
|
||||||
|
src_l.is_contiguous(),
|
||||||
|
src_l.dims(),
|
||||||
|
src_l.stride(),
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
|
src_l.start_offset() * dtype.size_in_bytes(),
|
||||||
&ids.buffer,
|
&ids.buffer,
|
||||||
|
ids_l.start_offset() * ids.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
@ -1,6 +1,21 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void index(
|
METAL_FUNC void index(
|
||||||
constant size_t &dst_size,
|
constant size_t &dst_size,
|
||||||
@ -8,6 +23,9 @@ METAL_FUNC void index(
|
|||||||
constant size_t &src_dim_size,
|
constant size_t &src_dim_size,
|
||||||
constant size_t &right_size,
|
constant size_t &right_size,
|
||||||
constant size_t &ids_size,
|
constant size_t &ids_size,
|
||||||
|
constant bool &contiguous,
|
||||||
|
constant size_t *src_dims,
|
||||||
|
constant size_t *src_strides,
|
||||||
const device TYPENAME *input,
|
const device TYPENAME *input,
|
||||||
const device INDEX_TYPENAME *input_ids,
|
const device INDEX_TYPENAME *input_ids,
|
||||||
device TYPENAME *output,
|
device TYPENAME *output,
|
||||||
@ -26,7 +44,8 @@ METAL_FUNC void index(
|
|||||||
// No need to check for zero we're only allowing unsized.
|
// No need to check for zero we're only allowing unsized.
|
||||||
*/
|
*/
|
||||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||||
output[tid] = input[src_i];
|
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||||
|
output[tid] = input[strided_src_i];
|
||||||
}
|
}
|
||||||
|
|
||||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
@ -36,12 +55,15 @@ kernel void NAME( \
|
|||||||
constant size_t &src_dim_size, \
|
constant size_t &src_dim_size, \
|
||||||
constant size_t &right_size, \
|
constant size_t &right_size, \
|
||||||
constant size_t &ids_size, \
|
constant size_t &ids_size, \
|
||||||
|
constant bool &contiguous, \
|
||||||
|
constant size_t *src_dims, \
|
||||||
|
constant size_t *src_strides, \
|
||||||
const device TYPENAME *input, \
|
const device TYPENAME *input, \
|
||||||
const device INDEX_TYPENAME *input_ids, \
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -165,10 +187,15 @@ kernel void NAME( \
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint32_t, float)
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
INDEX_OP(is_u32_f16, uint32_t, half)
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
INDEX_OP(is_u8_f32, uint8_t, float)
|
||||||
|
INDEX_OP(is_u8_f16, uint8_t, half)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -1067,8 +1067,13 @@ pub fn call_index_select(
|
|||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
ids_size: usize,
|
ids_size: usize,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
contiguous: bool,
|
||||||
|
src_dims: &[usize],
|
||||||
|
src_strides: &[usize],
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
|
src_offset: usize,
|
||||||
ids: &Buffer,
|
ids: &Buffer,
|
||||||
|
ids_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
@ -1090,8 +1095,11 @@ pub fn call_index_select(
|
|||||||
src_dim_size,
|
src_dim_size,
|
||||||
right_size,
|
right_size,
|
||||||
ids_size,
|
ids_size,
|
||||||
input,
|
contiguous,
|
||||||
ids,
|
src_dims,
|
||||||
|
src_strides,
|
||||||
|
(input, src_offset),
|
||||||
|
(ids, ids_offset),
|
||||||
output
|
output
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
@ -600,22 +600,35 @@ fn affine_strided() {
|
|||||||
fn index_select() {
|
fn index_select() {
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
|
let stride = [2, 1];
|
||||||
let ids = [0u32, 4, 2];
|
let ids = [0u32, 4, 2];
|
||||||
let dim = 0;
|
let dim = 0;
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
||||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||||
|
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
let shape = [2, 5];
|
let shape = [2, 5];
|
||||||
|
let stride = [1, 2];
|
||||||
let ids = [0u32, 1, 0];
|
let ids = [0u32, 1, 0];
|
||||||
let dim = 0;
|
let dim = 0;
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select_strided() {
|
||||||
|
let embedding = (0..16).map(|x| x as f32).collect::<Vec<_>>();
|
||||||
|
let shape = [2, 2];
|
||||||
|
let stride = [2, 4];
|
||||||
|
let ids = [0u32];
|
||||||
|
let dim = 0;
|
||||||
|
let result = run_index_select_strided(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
||||||
|
assert_eq!(result, vec![0.0, 4.0]);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn index_select_f16() {
|
fn index_select_f16() {
|
||||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
@ -623,9 +636,10 @@ fn index_select_f16() {
|
|||||||
.map(|x| f16::from_f32(x))
|
.map(|x| f16::from_f32(x))
|
||||||
.collect();
|
.collect();
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
|
let stride = [2, 1];
|
||||||
let ids = [0u32, 4, 2];
|
let ids = [0u32, 4, 2];
|
||||||
let dim = 0;
|
let dim = 0;
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f16");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx_f16(result, 4),
|
approx_f16(result, 4),
|
||||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||||
@ -636,9 +650,10 @@ fn index_select_f16() {
|
|||||||
fn index_select_is_u32_bf16() {
|
fn index_select_is_u32_bf16() {
|
||||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
|
let stride = [2, 1];
|
||||||
let ids = [0u32, 4, 2];
|
let ids = [0u32, 4, 2];
|
||||||
let dim = 0;
|
let dim = 0;
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_bf16");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx_bf16(result, 4),
|
approx_bf16(result, 4),
|
||||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||||
@ -649,9 +664,10 @@ fn index_select_is_u32_bf16() {
|
|||||||
fn index_select_is_u8_bf16() {
|
fn index_select_is_u8_bf16() {
|
||||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
|
let stride = [2, 1];
|
||||||
let ids = [0u8, 4, 2];
|
let ids = [0u8, 4, 2];
|
||||||
let dim = 0;
|
let dim = 0;
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u8_bf16");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx_bf16(result, 4),
|
approx_bf16(result, 4),
|
||||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||||
@ -662,9 +678,10 @@ fn index_select_is_u8_bf16() {
|
|||||||
fn index_select_dim1() {
|
fn index_select_dim1() {
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
|
let stride = [2, 1];
|
||||||
let ids = [0u32, 1, 0];
|
let ids = [0u32, 1, 0];
|
||||||
let dim = 1;
|
let dim = 1;
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||||
@ -674,6 +691,7 @@ fn index_select_dim1() {
|
|||||||
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||||
embeddings: &[T],
|
embeddings: &[T],
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
|
stride: &[usize],
|
||||||
ids: &[I],
|
ids: &[I],
|
||||||
dim: usize,
|
dim: usize,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
@ -699,8 +717,59 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
shape,
|
shape,
|
||||||
ids.len(),
|
ids.len(),
|
||||||
dim,
|
dim,
|
||||||
|
true,
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
&embeddings_buffer,
|
&embeddings_buffer,
|
||||||
|
0,
|
||||||
&ids_buffer,
|
&ids_buffer,
|
||||||
|
0,
|
||||||
|
&dst_buffer,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
read_to_vec(&dst_buffer, dst_el)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||||
|
embeddings: &[T],
|
||||||
|
shape: &[usize],
|
||||||
|
stride: &[usize],
|
||||||
|
ids: &[I],
|
||||||
|
dim: usize,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||||
|
let ids_buffer = new_buffer(&device, &ids);
|
||||||
|
|
||||||
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
|
let dst_el = ids.len() * left_size * right_size;
|
||||||
|
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||||
|
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
call_index_select(
|
||||||
|
&device,
|
||||||
|
&command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
ids.len(),
|
||||||
|
dim,
|
||||||
|
false,
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
&embeddings_buffer,
|
||||||
|
0,
|
||||||
|
&ids_buffer,
|
||||||
|
0,
|
||||||
&dst_buffer,
|
&dst_buffer,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
Reference in New Issue
Block a user