Add an enum for scalar values. (#2909)

* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
This commit is contained in:
Laurent Mazare
2025-04-18 22:13:38 +02:00
committed by GitHub
parent ce5f8dd129
commit 9dbaf958dc
10 changed files with 150 additions and 55 deletions

View File

@ -1,4 +1,5 @@
use crate::backend::BackendDevice;
use crate::scalar::Scalar;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
@ -188,83 +189,77 @@ impl CudaDevice {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
fn const_impl(&self, v: Scalar, shape: &Shape) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
let slice = match v {
Scalar::U8(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
Scalar::U32(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
Scalar::I64(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
Scalar::BF16(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
Scalar::F16(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
Scalar::F32(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
Scalar::F64(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
@ -505,7 +500,7 @@ impl BackendDevice for CudaDevice {
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
self.const_impl(Scalar::one(dtype), shape)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {

View File

@ -107,6 +107,7 @@ pub trait WithDType:
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn to_scalar(self) -> crate::scalar::Scalar;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
@ -131,6 +132,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn to_scalar(self) -> crate::scalar::Scalar {
crate::scalar::Scalar::$dtype(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}

View File

@ -313,6 +313,46 @@ impl MetalDevice {
.map_err(MetalError::from)?;
Ok(())
}
pub(crate) fn const_impl<T: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
&self,
v: T,
shape: &crate::Shape,
) -> Result<super::MetalStorage> {
use crate::backend::BackendDevice;
let dtype = T::DTYPE;
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
v,
)
.map_err(MetalError::from)?;
Ok(super::MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {

View File

@ -1966,37 +1966,15 @@ impl BackendDevice for MetalDevice {
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
match dtype {
DType::U8 => self.const_impl(1u8, shape),
DType::U32 => self.const_impl(1u32, shape),
DType::I64 => self.const_impl(1i64, shape),
DType::F16 => self.const_impl(half::f16::ONE, shape),
DType::BF16 => self.const_impl(half::bf16::ONE, shape),
DType::F32 => self.const_impl(1f32, shape),
DType::F64 => self.const_impl(1f64, shape),
}
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {

View File

@ -1,6 +1,74 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Scalar {
U8(u8),
U32(u32),
I64(i64),
BF16(bf16),
F16(f16),
F32(f32),
F64(f64),
}
impl<T: WithDType> From<T> for Scalar {
fn from(value: T) -> Self {
value.to_scalar()
}
}
impl Scalar {
pub fn zero(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(0),
DType::U32 => Scalar::U32(0),
DType::I64 => Scalar::I64(0),
DType::BF16 => Scalar::BF16(bf16::ZERO),
DType::F16 => Scalar::F16(f16::ZERO),
DType::F32 => Scalar::F32(0.0),
DType::F64 => Scalar::F64(0.0),
}
}
pub fn one(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(1),
DType::U32 => Scalar::U32(1),
DType::I64 => Scalar::I64(1),
DType::BF16 => Scalar::BF16(bf16::ONE),
DType::F16 => Scalar::F16(f16::ONE),
DType::F32 => Scalar::F32(1.0),
DType::F64 => Scalar::F64(1.0),
}
}
pub fn dtype(&self) -> DType {
match self {
Scalar::U8(_) => DType::U8,
Scalar::U32(_) => DType::U32,
Scalar::I64(_) => DType::I64,
Scalar::BF16(_) => DType::BF16,
Scalar::F16(_) => DType::F16,
Scalar::F32(_) => DType::F32,
Scalar::F64(_) => DType::F64,
}
}
pub fn to_f64(&self) -> f64 {
match self {
Scalar::U8(v) => *v as f64,
Scalar::U32(v) => *v as f64,
Scalar::I64(v) => *v as f64,
Scalar::BF16(v) => v.to_f64(),
Scalar::F16(v) => v.to_f64(),
Scalar::F32(v) => *v as f64,
Scalar::F64(v) => *v,
}
}
}
pub enum TensorScalar {
Tensor(Tensor),

View File

@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -4,20 +4,20 @@ using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant T &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
out[tid] = value;
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant T &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \

View File

@ -2570,7 +2570,7 @@ pub fn call_const_fill(
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
v: impl EncoderParam,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();

View File

@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
@ -2357,11 +2357,15 @@ fn const_fill() {
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
name: &'static str,
f: F,
) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let value = f(value);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
assert_eq!(v, vec![value; len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);

View File

@ -88,9 +88,13 @@ primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u8);
primitive!(u32);
primitive!(u64);
primitive!(f32);
primitive!(f64);
primitive!(half::bf16);
primitive!(half::f16);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,