From 9dbaf958dc47198cd365dc46b431f8123fe527ef Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 18 Apr 2025 22:13:38 +0200 Subject: [PATCH] Add an enum for scalar values. (#2909) * Add a scalar enum type. * Add a bit more to the scalar type. * Small tweak. * More scalar usage. --- candle-core/src/cuda_backend/device.rs | 27 ++++------ candle-core/src/dtype.rs | 5 ++ candle-core/src/metal_backend/device.rs | 40 ++++++++++++++ candle-core/src/metal_backend/mod.rs | 40 ++++---------- candle-core/src/scalar.rs | 70 ++++++++++++++++++++++++- candle-metal-kernels/Cargo.toml | 1 + candle-metal-kernels/src/fill.metal | 6 +-- candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/tests.rs | 10 ++-- candle-metal-kernels/src/utils.rs | 4 ++ 10 files changed, 150 insertions(+), 55 deletions(-) diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 7dd18b7a..1d270116 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,4 +1,5 @@ use crate::backend::BackendDevice; +use crate::scalar::Scalar; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; @@ -188,83 +189,77 @@ impl CudaDevice { self.id } - fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { + fn const_impl(&self, v: Scalar, shape: &Shape) -> Result { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let slice = match dtype { - DType::U8 => { + let slice = match v { + Scalar::U8(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_u8", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); - let v = v as u8; builder.arg(&data); builder.arg(&v); builder.arg(&elem_count); unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::U8(data) } - DType::U32 => { + Scalar::U32(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_u32", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); - let v = v as u32; builder.arg(&data); builder.arg(&v); builder.arg(&elem_count); unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::U32(data) } - DType::I64 => { + Scalar::I64(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_i64", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); - let v = v as i64; builder.arg(&data); builder.arg(&v); builder.arg(&elem_count); unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::I64(data) } - DType::BF16 => { + Scalar::BF16(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); - let v = bf16::from_f64(v); builder.arg(&data); builder.arg(&v); builder.arg(&elem_count); unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::BF16(data) } - DType::F16 => { + Scalar::F16(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_f16", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); - let v = f16::from_f64(v); builder.arg(&data); builder.arg(&v); builder.arg(&elem_count); unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F16(data) } - DType::F32 => { + Scalar::F32(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_f32", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); - let v = v as f32; builder.arg(&data); builder.arg(&v); builder.arg(&elem_count); unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F32(data) } - DType::F64 => { + Scalar::F64(v) => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }?; let func = self.get_or_load_func("fill_f64", &kernels::FILL)?; @@ -505,7 +500,7 @@ impl BackendDevice for CudaDevice { } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - self.const_impl(1., shape, dtype) + self.const_impl(Scalar::one(dtype), shape) } unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index de6cddc3..1908e600 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -107,6 +107,7 @@ pub trait WithDType: fn from_f64(v: f64) -> Self; fn to_f64(self) -> f64; + fn to_scalar(self) -> crate::scalar::Scalar; fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>; fn to_cpu_storage_owned(data: Vec) -> CpuStorage; @@ -131,6 +132,10 @@ macro_rules! with_dtype { $to_f64(self) } + fn to_scalar(self) -> crate::scalar::Scalar { + crate::scalar::Scalar::$dtype(self) + } + fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> { CpuStorageRef::$dtype(data) } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 43869a0c..38e5b528 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -313,6 +313,46 @@ impl MetalDevice { .map_err(MetalError::from)?; Ok(()) } + + pub(crate) fn const_impl( + &self, + v: T, + shape: &crate::Shape, + ) -> Result { + use crate::backend::BackendDevice; + let dtype = T::DTYPE; + let name = match dtype { + DType::U8 => "fill_u8", + DType::U32 => "fill_u32", + DType::I64 => "fill_i64", + DType::F16 => "fill_f16", + DType::BF16 => "fill_bf16", + DType::F32 => "fill_f32", + DType::F64 => { + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + return self.storage_from_cpu_storage(&cpu_storage); + } + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_const_fill( + &self.device, + &command_buffer, + &self.kernels, + name, + shape.elem_count(), + &buffer, + v, + ) + .map_err(MetalError::from)?; + + Ok(super::MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } } fn buf_size(size: NSUInteger) -> NSUInteger { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 433188cf..92d267ce 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1966,37 +1966,15 @@ impl BackendDevice for MetalDevice { } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - let name = match dtype { - DType::U8 => "fill_u8", - DType::U32 => "fill_u32", - DType::I64 => "fill_i64", - DType::F16 => "fill_f16", - DType::BF16 => "fill_bf16", - DType::F32 => "fill_f32", - DType::F64 => { - let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; - return self.storage_from_cpu_storage(&cpu_storage); - } - }; - let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?; - let command_buffer = self.command_buffer()?; - candle_metal_kernels::call_const_fill( - &self.device, - &command_buffer, - &self.kernels, - name, - shape.elem_count(), - &buffer, - 1., - ) - .map_err(MetalError::from)?; - - Ok(MetalStorage::new( - buffer, - self.clone(), - shape.elem_count(), - dtype, - )) + match dtype { + DType::U8 => self.const_impl(1u8, shape), + DType::U32 => self.const_impl(1u32, shape), + DType::I64 => self.const_impl(1i64, shape), + DType::F16 => self.const_impl(half::f16::ONE, shape), + DType::BF16 => self.const_impl(half::bf16::ONE, shape), + DType::F32 => self.const_impl(1f32, shape), + DType::F64 => self.const_impl(1f64, shape), + } } fn storage_from_slice(&self, s: &[T]) -> Result { diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs index 30308d11..b86d885f 100644 --- a/candle-core/src/scalar.rs +++ b/candle-core/src/scalar.rs @@ -1,6 +1,74 @@ //! TensorScalar Enum and Trait //! -use crate::{Result, Tensor, WithDType}; +use crate::{DType, Result, Tensor, WithDType}; +use half::{bf16, f16}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Scalar { + U8(u8), + U32(u32), + I64(i64), + BF16(bf16), + F16(f16), + F32(f32), + F64(f64), +} + +impl From for Scalar { + fn from(value: T) -> Self { + value.to_scalar() + } +} + +impl Scalar { + pub fn zero(dtype: DType) -> Self { + match dtype { + DType::U8 => Scalar::U8(0), + DType::U32 => Scalar::U32(0), + DType::I64 => Scalar::I64(0), + DType::BF16 => Scalar::BF16(bf16::ZERO), + DType::F16 => Scalar::F16(f16::ZERO), + DType::F32 => Scalar::F32(0.0), + DType::F64 => Scalar::F64(0.0), + } + } + + pub fn one(dtype: DType) -> Self { + match dtype { + DType::U8 => Scalar::U8(1), + DType::U32 => Scalar::U32(1), + DType::I64 => Scalar::I64(1), + DType::BF16 => Scalar::BF16(bf16::ONE), + DType::F16 => Scalar::F16(f16::ONE), + DType::F32 => Scalar::F32(1.0), + DType::F64 => Scalar::F64(1.0), + } + } + + pub fn dtype(&self) -> DType { + match self { + Scalar::U8(_) => DType::U8, + Scalar::U32(_) => DType::U32, + Scalar::I64(_) => DType::I64, + Scalar::BF16(_) => DType::BF16, + Scalar::F16(_) => DType::F16, + Scalar::F32(_) => DType::F32, + Scalar::F64(_) => DType::F64, + } + } + + pub fn to_f64(&self) -> f64 { + match self { + Scalar::U8(v) => *v as f64, + Scalar::U32(v) => *v as f64, + Scalar::I64(v) => *v as f64, + Scalar::BF16(v) => v.to_f64(), + Scalar::F16(v) => v.to_f64(), + Scalar::F32(v) => *v as f64, + Scalar::F64(v) => *v, + } + } +} pub enum TensorScalar { Tensor(Tensor), diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index d84f6824..b00e7ca0 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0" [dependencies] metal = { version = "0.27.0", features = ["mps"] } +half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal index 35c3fe7a..dfb24a26 100644 --- a/candle-metal-kernels/src/fill.metal +++ b/candle-metal-kernels/src/fill.metal @@ -4,20 +4,20 @@ using namespace metal; template METAL_FUNC void fill_with( device T *out, - constant float &value, + constant T &value, constant size_t &numel, uint tid [[thread_position_in_grid]] ) { if (tid >= numel) { return; } - out[tid] = static_cast(value); + out[tid] = value; } #define FILL_OP(NAME, T) \ kernel void fill_##NAME( \ device T *out, \ - constant float &value, \ + constant T &value, \ constant size_t &numel, \ uint tid [[thread_position_in_grid]] \ ) { \ diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6de44f9c..2a898b54 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2570,7 +2570,7 @@ pub fn call_const_fill( name: &'static str, length: usize, output: &Buffer, - v: f32, + v: impl EncoderParam, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let encoder = ep.encoder(); diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 21ade21c..9121f671 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() { #[test] fn const_fill() { - fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec { + fn constant_fill(name: &'static str, len: usize, value: T) -> Vec { let dev = device(); let kernels = Kernels::new(); let command_queue = dev.new_command_queue(); @@ -2357,11 +2357,15 @@ fn const_fill() { command_buffer.wait_until_completed(); read_to_vec::(&buffer, len) } - fn test T>(name: &'static str, f: F) { + fn test T>( + name: &'static str, + f: F, + ) { let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); let value = rand::thread_rng().gen_range(1. ..19.); + let value = f(value); let v = constant_fill::(name, len, value); - assert_eq!(v, vec![f(value); len]) + assert_eq!(v, vec![value; len]) } test::("fill_u8", |v| v as u8); test::("fill_u32", |v| v as u32); diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 025808d7..c8f1a2d9 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -88,9 +88,13 @@ primitive!(bool); primitive!(usize); primitive!(i32); primitive!(i64); +primitive!(u8); primitive!(u32); primitive!(u64); primitive!(f32); +primitive!(f64); +primitive!(half::bf16); +primitive!(half::f16); pub struct BufferOffset<'a> { pub buffer: &'a Buffer,