mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add an enum for scalar values. (#2909)
* Add a scalar enum type. * Add a bit more to the scalar type. * Small tweak. * More scalar usage.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::scalar::Scalar;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
@ -188,83 +189,77 @@ impl CudaDevice {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
fn const_impl(&self, v: Scalar, shape: &Shape) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let slice = match v {
|
||||
Scalar::U8(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
Scalar::U32(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
Scalar::I64(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
Scalar::BF16(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
Scalar::F16(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
Scalar::F32(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
Scalar::F64(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
@ -505,7 +500,7 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
self.const_impl(Scalar::one(dtype), shape)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
|
@ -107,6 +107,7 @@ pub trait WithDType:
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn to_scalar(self) -> crate::scalar::Scalar;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
@ -131,6 +132,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn to_scalar(self) -> crate::scalar::Scalar {
|
||||
crate::scalar::Scalar::$dtype(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
|
@ -313,6 +313,46 @@ impl MetalDevice {
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn const_impl<T: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||
&self,
|
||||
v: T,
|
||||
shape: &crate::Shape,
|
||||
) -> Result<super::MetalStorage> {
|
||||
use crate::backend::BackendDevice;
|
||||
let dtype = T::DTYPE;
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
v,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(super::MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
|
@ -1966,37 +1966,15 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
1.,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
match dtype {
|
||||
DType::U8 => self.const_impl(1u8, shape),
|
||||
DType::U32 => self.const_impl(1u32, shape),
|
||||
DType::I64 => self.const_impl(1i64, shape),
|
||||
DType::F16 => self.const_impl(half::f16::ONE, shape),
|
||||
DType::BF16 => self.const_impl(half::bf16::ONE, shape),
|
||||
DType::F32 => self.const_impl(1f32, shape),
|
||||
DType::F64 => self.const_impl(1f64, shape),
|
||||
}
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
|
@ -1,6 +1,74 @@
|
||||
//! TensorScalar Enum and Trait
|
||||
//!
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
use crate::{DType, Result, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Scalar {
|
||||
U8(u8),
|
||||
U32(u32),
|
||||
I64(i64),
|
||||
BF16(bf16),
|
||||
F16(f16),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
}
|
||||
|
||||
impl<T: WithDType> From<T> for Scalar {
|
||||
fn from(value: T) -> Self {
|
||||
value.to_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scalar {
|
||||
pub fn zero(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(0),
|
||||
DType::U32 => Scalar::U32(0),
|
||||
DType::I64 => Scalar::I64(0),
|
||||
DType::BF16 => Scalar::BF16(bf16::ZERO),
|
||||
DType::F16 => Scalar::F16(f16::ZERO),
|
||||
DType::F32 => Scalar::F32(0.0),
|
||||
DType::F64 => Scalar::F64(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(1),
|
||||
DType::U32 => Scalar::U32(1),
|
||||
DType::I64 => Scalar::I64(1),
|
||||
DType::BF16 => Scalar::BF16(bf16::ONE),
|
||||
DType::F16 => Scalar::F16(f16::ONE),
|
||||
DType::F32 => Scalar::F32(1.0),
|
||||
DType::F64 => Scalar::F64(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Scalar::U8(_) => DType::U8,
|
||||
Scalar::U32(_) => DType::U32,
|
||||
Scalar::I64(_) => DType::I64,
|
||||
Scalar::BF16(_) => DType::BF16,
|
||||
Scalar::F16(_) => DType::F16,
|
||||
Scalar::F32(_) => DType::F32,
|
||||
Scalar::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> f64 {
|
||||
match self {
|
||||
Scalar::U8(v) => *v as f64,
|
||||
Scalar::U32(v) => *v as f64,
|
||||
Scalar::I64(v) => *v as f64,
|
||||
Scalar::BF16(v) => v.to_f64(),
|
||||
Scalar::F16(v) => v.to_f64(),
|
||||
Scalar::F32(v) => *v as f64,
|
||||
Scalar::F64(v) => *v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
|
Reference in New Issue
Block a user