mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
13 Commits
0.8.1
...
ivarflakst
Author | SHA1 | Date | |
---|---|---|---|
ceaf7f1e2d | |||
f8abfee854 | |||
b9ce263e4d | |||
5c6d5c3d0e | |||
36ce0988c0 | |||
45936a18f8 | |||
4462198bc1 | |||
e8e24f1284 | |||
6eb44d1bce | |||
7fc26764b6 | |||
0a29d2e9b8 | |||
fd9bf3bcdd | |||
90c74e199c |
@ -3,6 +3,7 @@ mod benchmarks;
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::fill::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
|
44
candle-core/benches/benchmarks/fill.rs
Normal file
44
candle-core/benches/benchmarks/fill.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) {
|
||||
Tensor::ones(shape, dtype, device).unwrap();
|
||||
}
|
||||
|
||||
fn run_fill_benchmark(c: &mut Criterion, device: &Device, name: &str, dtype: DType) {
|
||||
let b = 1;
|
||||
let rows = 1024;
|
||||
let columns = 1024;
|
||||
|
||||
let flops = b * rows * columns * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |bencher| {
|
||||
bencher.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box((b, rows, columns)),
|
||||
black_box(dtype),
|
||||
black_box(&device),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_fill_benchmark(c, &device, "fill_u8", DType::U8);
|
||||
run_fill_benchmark(c, &device, "fill_f32", DType::F32);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,4 +1,5 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod fill;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
@ -4,6 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
@ -1591,9 +1592,41 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("ones");
|
||||
|
||||
macro_rules! fill {
|
||||
($value:expr) => {
|
||||
candle_metal_kernels::call_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
$value,
|
||||
)
|
||||
.map_err(MetalError::from)?
|
||||
};
|
||||
}
|
||||
match dtype {
|
||||
DType::U8 => candle_metal_kernels::call_fill_u8(
|
||||
&command_buffer,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
1u8,
|
||||
)
|
||||
.map_err(MetalError::from)?,
|
||||
DType::U32 => fill!(1u32),
|
||||
DType::I64 => fill!(1i64),
|
||||
DType::BF16 => fill!(bf16::ONE),
|
||||
DType::F16 => fill!(f16::ONE),
|
||||
DType::F32 => fill!(1f32),
|
||||
DType::F64 => {
|
||||
return Err(MetalError::Message("Metal doesn't support double".to_string()).into())
|
||||
}
|
||||
}
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
|
@ -9,17 +9,13 @@ keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
num-traits = "0.2.17"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = [
|
||||
"num-traits",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
|
34
candle-metal-kernels/src/fill.metal
Normal file
34
candle-metal-kernels/src/fill.metal
Normal file
@ -0,0 +1,34 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
template<typename T>
|
||||
void fill(
|
||||
device T *buffer [[buffer(0)]],
|
||||
constant T &value,
|
||||
constant size_t &numel,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (gid >= numel) return;
|
||||
buffer[gid] = value;
|
||||
}
|
||||
|
||||
#define FILL_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device T *buffer [[buffer(0)]], \
|
||||
constant T &value, \
|
||||
constant size_t &numel, \
|
||||
uint gid [[thread_position_in_grid]] \
|
||||
) { fill<T>(buffer, value, numel, gid); } \
|
||||
|
||||
FILL_OP(uint8_t, fill_u8)
|
||||
FILL_OP(uint32_t, fill_u32)
|
||||
FILL_OP(half, fill_f16)
|
||||
FILL_OP(float, fill_f32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
FILL_OP(int64_t, fill_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
FILL_OP(bfloat, fill_bf16)
|
||||
#endif
|
@ -1,3 +1,4 @@
|
||||
use half::{bf16, f16};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const FILL: &str = include_str!("fill.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
@ -47,29 +49,26 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
trait EncoderParam {
|
||||
pub trait EncoderParam: private::Sealed {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl EncoderParam for $type {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<$type>() as u64,
|
||||
&data as *const $type as *const c_void,
|
||||
);
|
||||
|
||||
macro_rules! primitives {
|
||||
($($type:ty),+) => {
|
||||
$(
|
||||
impl EncoderParam for $type {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<$type>() as u64,
|
||||
&data as *const $type as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
primitives!(bool, usize, u8, u32, u64, i32, i64, f16, bf16, f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
@ -112,6 +111,22 @@ macro_rules! set_params {
|
||||
);
|
||||
}
|
||||
|
||||
// Seal for EncoderParam so that only the types we want can implement it
|
||||
mod private {
|
||||
use super::*;
|
||||
|
||||
pub trait Sealed {}
|
||||
|
||||
macro_rules! sealed {
|
||||
($($type:ty),+) => {
|
||||
$(impl Sealed for $type {})+
|
||||
};
|
||||
}
|
||||
sealed!(usize, u8, u32, u64, i32, i64, f16, bf16, f32, bool);
|
||||
sealed!(&Buffer, (&Buffer, usize), &mut Buffer, (&mut Buffer, usize));
|
||||
impl<T> Sealed for &[T] {}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
@ -123,6 +138,7 @@ pub enum Source {
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Fill,
|
||||
Random,
|
||||
Quantized,
|
||||
}
|
||||
@ -192,6 +208,8 @@ pub mod binary {
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalKernelError {
|
||||
#[error("Invalid usage of kernel: {0}")]
|
||||
InvalidUsage(String),
|
||||
#[error("Could not lock kernel map: {0}")]
|
||||
LockError(String),
|
||||
#[error("Error while loading library: {0}")]
|
||||
@ -244,6 +262,7 @@ impl Kernels {
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Fill => FILL,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
@ -1769,9 +1788,68 @@ pub fn call_quantized_matmul_t(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
||||
pub fn call_fill<T: FillOp>(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
elem_count: usize,
|
||||
buffer: &Buffer,
|
||||
value: T,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);
|
||||
|
||||
set_params!(encoder, (buffer, value, elem_count));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_fill_u8(
|
||||
command_buffer: &CommandBufferRef,
|
||||
elem_count: usize,
|
||||
buffer: &Buffer,
|
||||
value: u8,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: elem_count as NSUInteger,
|
||||
},
|
||||
value,
|
||||
);
|
||||
blit.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub trait FillOp: EncoderParam {
|
||||
const FILL_KERNEL: &'static str;
|
||||
}
|
||||
|
||||
macro_rules ! impl_call_fill {
|
||||
($($t:ty),*) => {
|
||||
$(
|
||||
impl FillOp for $t {
|
||||
const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t));
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
impl_call_fill!(u8, u32, i64, f16, bf16, f32);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -927,6 +927,42 @@ fn gemm() {
|
||||
);
|
||||
}
|
||||
|
||||
fn run_fill<T: FillOp + Clone>(elem_count: usize, value: T) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
|
||||
call_fill(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
elem_count,
|
||||
&buffer,
|
||||
value,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&buffer, elem_count)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fill() {
|
||||
fn assert_fill<T: FillOp + Copy + std::fmt::Debug + PartialEq>(value: T) {
|
||||
for i in 0..4 {
|
||||
assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]);
|
||||
}
|
||||
}
|
||||
assert_fill(123u8);
|
||||
assert_fill(456u32);
|
||||
assert_fill(789i64);
|
||||
assert_fill(f16::from_f32(1.23));
|
||||
assert_fill(bf16::from_f32(4.56));
|
||||
assert_fill(7.89f32);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_math>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
|
Reference in New Issue
Block a user