mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
13 Commits
0.7.0
...
ivarflakst
Author | SHA1 | Date | |
---|---|---|---|
ceaf7f1e2d | |||
f8abfee854 | |||
b9ce263e4d | |||
5c6d5c3d0e | |||
36ce0988c0 | |||
45936a18f8 | |||
4462198bc1 | |||
e8e24f1284 | |||
6eb44d1bce | |||
7fc26764b6 | |||
0a29d2e9b8 | |||
fd9bf3bcdd | |||
90c74e199c |
@ -3,6 +3,7 @@ mod benchmarks;
|
|||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
benchmarks::affine::benches,
|
||||||
|
benchmarks::fill::benches,
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
benchmarks::random::benches,
|
||||||
benchmarks::where_cond::benches
|
benchmarks::where_cond::benches
|
||||||
|
44
candle-core/benches/benchmarks/fill.rs
Normal file
44
candle-core/benches/benchmarks/fill.rs
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(shape: (usize, usize, usize), dtype: DType, device: &Device) {
|
||||||
|
Tensor::ones(shape, dtype, device).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_fill_benchmark(c: &mut Criterion, device: &Device, name: &str, dtype: DType) {
|
||||||
|
let b = 1;
|
||||||
|
let rows = 1024;
|
||||||
|
let columns = 1024;
|
||||||
|
|
||||||
|
let flops = b * rows * columns * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |bencher| {
|
||||||
|
bencher.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(
|
||||||
|
black_box((b, rows, columns)),
|
||||||
|
black_box(dtype),
|
||||||
|
black_box(&device),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
run_fill_benchmark(c, &device, "fill_u8", DType::U8);
|
||||||
|
run_fill_benchmark(c, &device, "fill_f32", DType::F32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -1,4 +1,5 @@
|
|||||||
pub(crate) mod affine;
|
pub(crate) mod affine;
|
||||||
|
pub(crate) mod fill;
|
||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
pub(crate) mod random;
|
pub(crate) mod random;
|
||||||
pub(crate) mod where_cond;
|
pub(crate) mod where_cond;
|
||||||
|
@ -4,6 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -1591,9 +1592,41 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
// TODO Is there a faster way ?
|
let buffer = self.new_buffer(shape.elem_count(), dtype, "ones")?;
|
||||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
let command_buffer = self.command_buffer()?;
|
||||||
self.storage_from_cpu_storage(&cpu_storage)
|
command_buffer.set_label("ones");
|
||||||
|
|
||||||
|
macro_rules! fill {
|
||||||
|
($value:expr) => {
|
||||||
|
candle_metal_kernels::call_fill(
|
||||||
|
&self.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.kernels,
|
||||||
|
shape.elem_count(),
|
||||||
|
&buffer,
|
||||||
|
$value,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?
|
||||||
|
};
|
||||||
|
}
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => candle_metal_kernels::call_fill_u8(
|
||||||
|
&command_buffer,
|
||||||
|
shape.elem_count(),
|
||||||
|
&buffer,
|
||||||
|
1u8,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?,
|
||||||
|
DType::U32 => fill!(1u32),
|
||||||
|
DType::I64 => fill!(1i64),
|
||||||
|
DType::BF16 => fill!(bf16::ONE),
|
||||||
|
DType::F16 => fill!(f16::ONE),
|
||||||
|
DType::F32 => fill!(1f32),
|
||||||
|
DType::F64 => {
|
||||||
|
return Err(MetalError::Message("Metal doesn't support double".to_string()).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
@ -9,17 +9,13 @@ keywords = ["blas", "tensor", "machine-learning"]
|
|||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { version = "0.27.0", features = ["mps"] }
|
metal = { version = "0.27.0", features = ["mps"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
|
num-traits = "0.2.17"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
half = { version = "2.3.1", features = [
|
|
||||||
"num-traits",
|
|
||||||
"use-intrinsics",
|
|
||||||
"rand_distr",
|
|
||||||
] }
|
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
34
candle-metal-kernels/src/fill.metal
Normal file
34
candle-metal-kernels/src/fill.metal
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void fill(
|
||||||
|
device T *buffer [[buffer(0)]],
|
||||||
|
constant T &value,
|
||||||
|
constant size_t &numel,
|
||||||
|
uint gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
if (gid >= numel) return;
|
||||||
|
buffer[gid] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FILL_OP(T, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
device T *buffer [[buffer(0)]], \
|
||||||
|
constant T &value, \
|
||||||
|
constant size_t &numel, \
|
||||||
|
uint gid [[thread_position_in_grid]] \
|
||||||
|
) { fill<T>(buffer, value, numel, gid); } \
|
||||||
|
|
||||||
|
FILL_OP(uint8_t, fill_u8)
|
||||||
|
FILL_OP(uint32_t, fill_u32)
|
||||||
|
FILL_OP(half, fill_f16)
|
||||||
|
FILL_OP(float, fill_f32)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 220
|
||||||
|
FILL_OP(int64_t, fill_i64)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
FILL_OP(bfloat, fill_bf16)
|
||||||
|
#endif
|
@ -1,3 +1,4 @@
|
|||||||
|
use half::{bf16, f16};
|
||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
|
|||||||
const BINARY: &str = include_str!("binary.metal");
|
const BINARY: &str = include_str!("binary.metal");
|
||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("ternary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("cast.metal");
|
||||||
|
const FILL: &str = include_str!("fill.metal");
|
||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &str = include_str!("conv.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
@ -47,29 +49,26 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
|
|||||||
/// Helper functions to create the various objects on the compute command encoder
|
/// Helper functions to create the various objects on the compute command encoder
|
||||||
/// on a single line.
|
/// on a single line.
|
||||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||||
trait EncoderParam {
|
pub trait EncoderParam: private::Sealed {
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||||
}
|
}
|
||||||
macro_rules! primitive {
|
|
||||||
($type:ty) => {
|
macro_rules! primitives {
|
||||||
impl EncoderParam for $type {
|
($($type:ty),+) => {
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
$(
|
||||||
encoder.set_bytes(
|
impl EncoderParam for $type {
|
||||||
position,
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
core::mem::size_of::<$type>() as u64,
|
encoder.set_bytes(
|
||||||
&data as *const $type as *const c_void,
|
position,
|
||||||
);
|
core::mem::size_of::<$type>() as u64,
|
||||||
|
&data as *const $type as *const c_void,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
)+
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
primitive!(bool);
|
primitives!(bool, usize, u8, u32, u64, i32, i64, f16, bf16, f32);
|
||||||
primitive!(usize);
|
|
||||||
primitive!(i32);
|
|
||||||
primitive!(i64);
|
|
||||||
primitive!(u32);
|
|
||||||
primitive!(u64);
|
|
||||||
primitive!(f32);
|
|
||||||
|
|
||||||
impl<T> EncoderParam for &[T] {
|
impl<T> EncoderParam for &[T] {
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
@ -112,6 +111,22 @@ macro_rules! set_params {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Seal for EncoderParam so that only the types we want can implement it
|
||||||
|
mod private {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub trait Sealed {}
|
||||||
|
|
||||||
|
macro_rules! sealed {
|
||||||
|
($($type:ty),+) => {
|
||||||
|
$(impl Sealed for $type {})+
|
||||||
|
};
|
||||||
|
}
|
||||||
|
sealed!(usize, u8, u32, u64, i32, i64, f16, bf16, f32, bool);
|
||||||
|
sealed!(&Buffer, (&Buffer, usize), &mut Buffer, (&mut Buffer, usize));
|
||||||
|
impl<T> Sealed for &[T] {}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum Source {
|
pub enum Source {
|
||||||
Affine,
|
Affine,
|
||||||
@ -123,6 +138,7 @@ pub enum Source {
|
|||||||
Reduce,
|
Reduce,
|
||||||
Mfa,
|
Mfa,
|
||||||
Conv,
|
Conv,
|
||||||
|
Fill,
|
||||||
Random,
|
Random,
|
||||||
Quantized,
|
Quantized,
|
||||||
}
|
}
|
||||||
@ -192,6 +208,8 @@ pub mod binary {
|
|||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum MetalKernelError {
|
pub enum MetalKernelError {
|
||||||
|
#[error("Invalid usage of kernel: {0}")]
|
||||||
|
InvalidUsage(String),
|
||||||
#[error("Could not lock kernel map: {0}")]
|
#[error("Could not lock kernel map: {0}")]
|
||||||
LockError(String),
|
LockError(String),
|
||||||
#[error("Error while loading library: {0}")]
|
#[error("Error while loading library: {0}")]
|
||||||
@ -244,6 +262,7 @@ impl Kernels {
|
|||||||
Source::Indexing => INDEXING,
|
Source::Indexing => INDEXING,
|
||||||
Source::Cast => CAST,
|
Source::Cast => CAST,
|
||||||
Source::Reduce => REDUCE,
|
Source::Reduce => REDUCE,
|
||||||
|
Source::Fill => FILL,
|
||||||
Source::Conv => CONV,
|
Source::Conv => CONV,
|
||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
@ -1769,9 +1788,68 @@ pub fn call_quantized_matmul_t(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||||
((m + b - 1) / b) as NSUInteger
|
((m + b - 1) / b) as NSUInteger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_fill<T: FillOp>(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
elem_count: usize,
|
||||||
|
buffer: &Buffer,
|
||||||
|
value: T,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?;
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);
|
||||||
|
|
||||||
|
set_params!(encoder, (buffer, value, elem_count));
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_fill_u8(
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
elem_count: usize,
|
||||||
|
buffer: &Buffer,
|
||||||
|
value: u8,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.fill_buffer(
|
||||||
|
buffer,
|
||||||
|
metal::NSRange {
|
||||||
|
location: 0,
|
||||||
|
length: elem_count as NSUInteger,
|
||||||
|
},
|
||||||
|
value,
|
||||||
|
);
|
||||||
|
blit.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait FillOp: EncoderParam {
|
||||||
|
const FILL_KERNEL: &'static str;
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules ! impl_call_fill {
|
||||||
|
($($t:ty),*) => {
|
||||||
|
$(
|
||||||
|
impl FillOp for $t {
|
||||||
|
const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t));
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
impl_call_fill!(u8, u32, i64, f16, bf16, f32);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
@ -927,6 +927,42 @@ fn gemm() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_fill<T: FillOp + Clone>(elem_count: usize, value: T) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
|
||||||
|
call_fill(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
elem_count,
|
||||||
|
&buffer,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
read_to_vec(&buffer, elem_count)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fill() {
|
||||||
|
fn assert_fill<T: FillOp + Copy + std::fmt::Debug + PartialEq>(value: T) {
|
||||||
|
for i in 0..4 {
|
||||||
|
assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert_fill(123u8);
|
||||||
|
assert_fill(456u32);
|
||||||
|
assert_fill(789i64);
|
||||||
|
assert_fill(f16::from_f32(1.23));
|
||||||
|
assert_fill(bf16::from_f32(4.56));
|
||||||
|
assert_fill(7.89f32);
|
||||||
|
}
|
||||||
|
|
||||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
#
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
|
Reference in New Issue
Block a user