mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Metal: f16 and bf16 where_cond + benchmark (#1545)
* Use cfg to seperate benchmark results based on features * Add metal where_cond for f16 and bf16. Add benchmark * Remove allow pragma * Avoid some unnecessary returns. * Improve benchmarks layout * Updated feature separated benchmarks --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(benchmarks::matmul::benches);
|
||||
criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches);
|
||||
|
@ -1,4 +1,5 @@
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
|
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||
a.where_cond(b, c).unwrap();
|
||||
}
|
||||
|
||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||
let mut arr = [0u8; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
arr[i] = (i % 2) as u8;
|
||||
i += 1;
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
const SIZE: usize = B * M * K;
|
||||
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(&tensor),
|
||||
black_box(&on_true),
|
||||
black_box(&on_false),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -806,6 +806,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
(DType::U8, DType::U32) => "where_u8_u32",
|
||||
|
@ -1 +0,0 @@
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
||||
|
@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template<typename T, typename ID>
|
||||
METAL_FUNC void where_cond(
|
||||
constant size_t &numel,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t *strides_t,
|
||||
constant size_t *strides_f,
|
||||
device const ID *ids,
|
||||
device const T *t,
|
||||
device const T *f,
|
||||
device T *out,
|
||||
uint i [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (i >= numel){
|
||||
return;
|
||||
}
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||
}
|
||||
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
#define WHERE_OP(T, ID, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID *ids, \
|
||||
device const T *t, \
|
||||
device const T *f, \
|
||||
device T *out, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
@ -54,10 +70,14 @@ kernel void FN_NAME( \
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||
#endif
|
Reference in New Issue
Block a user