From a4c56a958e6151c6cb8cf4790d6b2595ff4e7809 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 19 Apr 2025 10:07:02 +0200 Subject: [PATCH] Add the const-set op. (#2910) * Add the const-set op. * Cuda implementation. * Bugfix. * Metal cleanup. * Add the metal kernels. * Add some testing. * Finish the metal implementation. * Bump the version. --- Cargo.toml | 18 ++-- candle-core/src/backend.rs | 4 +- candle-core/src/cpu_backend/mod.rs | 56 +++++++++---- candle-core/src/cuda_backend/device.rs | 95 +-------------------- candle-core/src/cuda_backend/mod.rs | 45 ++++++++++ candle-core/src/device.rs | 17 ---- candle-core/src/dummy_cuda_backend.rs | 8 +- candle-core/src/dummy_metal_backend.rs | 8 +- candle-core/src/metal_backend/device.rs | 40 --------- candle-core/src/metal_backend/mod.rs | 106 +++++++++++++++++++++--- candle-core/src/storage.rs | 9 ++ candle-core/src/tensor.rs | 16 +++- candle-core/tests/tensor_tests.rs | 31 ++++++- candle-flash-attn/Cargo.toml | 4 +- candle-kernels/Cargo.toml | 2 +- candle-kernels/src/fill.cu | 33 ++++++++ candle-metal-kernels/Cargo.toml | 2 +- candle-metal-kernels/src/lib.rs | 78 ++++++++++++++++- candle-metal-kernels/src/unary.metal | 45 ++++++++++ candle-onnx/Cargo.toml | 6 +- 20 files changed, 414 insertions(+), 209 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 316d9e75..ea643d3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.9.0-alpha.4" +version = "0.9.0-alpha.5" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -33,14 +33,14 @@ ab_glyph = "0.2.23" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" -candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" } -candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" } -candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" } -candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" } -candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" } -candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" } -candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" } -candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" } +candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.5" } +candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.5" } +candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.5" } +candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.5" } +candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.5" } +candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.5" } +candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.5" } +candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.5" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index f98cb4f4..8ab59f4a 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -113,6 +113,8 @@ pub trait BackendStorage: Sized { _src_offset: usize, _dst_offset: usize, ) -> Result<()>; + + fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>; } pub trait BackendDevice: Sized + std::fmt::Debug + Clone { @@ -127,8 +129,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result; - fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result; - /// # Safety /// This function is unsafe as it doesn't initialize the underlying data store. /// The caller should ensure that the data is properly initialized as early as possible diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 7e4675f7..a405320c 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2454,6 +2454,48 @@ impl BackendStorage for CpuStorage { fn to_cpu_storage(&self) -> Result { Ok(self.clone()) } + + fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> { + use crate::scalar::Scalar; + fn set(src: &mut [T], l: &Layout, s: T) { + match l.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + src[start_offset..start_offset + len].fill(s) + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len: 1, + } => { + for src_index in block_start_index { + src[src_index] = s + } + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + for src_index in block_start_index { + src[src_index..src_index + block_len].fill(s) + } + } + } + } + match (self, s) { + (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v), + (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v), + (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v), + (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v), + (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v), + (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v), + (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v), + (st, s) => crate::bail!( + "const_set dtype mismatch, expected {:?} but got {:?}", + st.dtype(), + s + ), + } + Ok(()) + } } impl BackendDevice for CpuDevice { @@ -2628,20 +2670,6 @@ impl BackendDevice for CpuDevice { Ok(storage) } - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let storage = match dtype { - DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), - DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), - DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), - DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), - DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), - DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), - DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), - }; - Ok(storage) - } - fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let storage = match dtype { diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 1d270116..ba3267e0 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,9 +1,8 @@ use crate::backend::BackendDevice; -use crate::scalar::Scalar; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; -use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg}; +use cudarc::driver::CudaFunction; use half::{bf16, f16}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -189,94 +188,6 @@ impl CudaDevice { self.id } - fn const_impl(&self, v: Scalar, shape: &Shape) -> Result { - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let slice = match v { - Scalar::U8(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count)? }; - let func = self.get_or_load_func("fill_u8", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::U8(data) - } - Scalar::U32(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count)? }; - let func = self.get_or_load_func("fill_u32", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::U32(data) - } - Scalar::I64(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count)? }; - let func = self.get_or_load_func("fill_i64", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::I64(data) - } - Scalar::BF16(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count)? }; - let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::BF16(data) - } - Scalar::F16(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count)? }; - let func = self.get_or_load_func("fill_f16", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::F16(data) - } - Scalar::F32(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count)? }; - let func = self.get_or_load_func("fill_f32", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::F32(data) - } - Scalar::F64(v) => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; - let func = self.get_or_load_func("fill_f64", &kernels::FILL)?; - let mut builder = self.stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&v); - builder.arg(&elem_count); - unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - pub fn get_or_load_custom_func( &self, fn_name: &str, @@ -499,10 +410,6 @@ impl BackendDevice for CudaDevice { }) } - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - self.const_impl(Scalar::one(dtype), shape) - } - unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let slice = match dtype { diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index bbbe5faf..00765af9 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -34,6 +34,21 @@ impl SlicePtrOrNull { } } +impl crate::scalar::Scalar { + pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) { + use crate::scalar::Scalar; + match self { + Scalar::U8(v) => builder.arg(v), + Scalar::U32(v) => builder.arg(v), + Scalar::I64(v) => builder.arg(v), + Scalar::F32(v) => builder.arg(v), + Scalar::F64(v) => builder.arg(v), + Scalar::F16(v) => builder.arg(v), + Scalar::BF16(v) => builder.arg(v), + }; + } +} + impl SlicePtrOrNull { pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result { let ds = if l.is_contiguous() { @@ -1235,6 +1250,36 @@ impl BackendStorage for CudaStorage { &self.device } + fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> { + let dev = &self.device; + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src_o = layout.start_offset(); + let ((src, _guard_src), kernel_name) = match &mut self.slice { + S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"), + S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"), + S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"), + S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"), + S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"), + S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"), + S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"), + }; + + let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + s.builder_arg(&mut builder); + barg!(builder, src); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + Ok(()) + } + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { let shape = layout.shape(); let dims = shape.dims(); diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 9b1fb9ee..130be7e0 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -292,23 +292,6 @@ impl Device { self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE) } - pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { - match self { - Device::Cpu => { - let storage = CpuDevice.ones_impl(shape, dtype)?; - Ok(Storage::Cpu(storage)) - } - Device::Cuda(device) => { - let storage = device.ones_impl(shape, dtype)?; - Ok(Storage::Cuda(storage)) - } - Device::Metal(device) => { - let storage = device.ones_impl(shape, dtype)?; - Ok(Storage::Metal(storage)) - } - } - } - pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 9d30d821..358081a0 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage { fail!() } + fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + fn to_cpu_storage(&self) -> Result { Err(Error::NotCompiledWithCudaSupport) } @@ -214,10 +218,6 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } - fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index a1c2394d..434e8d7b 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage { fail!() } + fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + fn to_cpu_storage(&self) -> Result { Err(Error::NotCompiledWithMetalSupport) } @@ -218,10 +222,6 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } - fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - Err(Error::NotCompiledWithMetalSupport) - } - unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 38e5b528..43869a0c 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -313,46 +313,6 @@ impl MetalDevice { .map_err(MetalError::from)?; Ok(()) } - - pub(crate) fn const_impl( - &self, - v: T, - shape: &crate::Shape, - ) -> Result { - use crate::backend::BackendDevice; - let dtype = T::DTYPE; - let name = match dtype { - DType::U8 => "fill_u8", - DType::U32 => "fill_u32", - DType::I64 => "fill_i64", - DType::F16 => "fill_f16", - DType::BF16 => "fill_bf16", - DType::F32 => "fill_f32", - DType::F64 => { - let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; - return self.storage_from_cpu_storage(&cpu_storage); - } - }; - let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?; - let command_buffer = self.command_buffer()?; - candle_metal_kernels::call_const_fill( - &self.device, - &command_buffer, - &self.kernels, - name, - shape.elem_count(), - &buffer, - v, - ) - .map_err(MetalError::from)?; - - Ok(super::MetalStorage::new( - buffer, - self.clone(), - shape.elem_count(), - dtype, - )) - } } fn buf_size(size: NSUInteger) -> NSUInteger { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 92d267ce..e529c3f5 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage { self.binary(name, rhs, lhs_l, rhs_l) } + fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> { + use crate::scalar::Scalar; + fn set( + self_: &mut MetalStorage, + s: S, + l: &Layout, + ) -> Result<()> { + let device = self_.device(); + let dtype = self_.dtype; + let shape = l.shape(); + let el_count = shape.elem_count(); + let command_buffer = device.command_buffer()?; + command_buffer.set_label("const-set"); + let dst = buffer_o(&self_.buffer, l, self_.dtype); + + match (el_count % 2, dtype, l.is_contiguous()) { + (0, DType::BF16 | DType::F16, true) => { + use candle_metal_kernels::unary::contiguous_tiled; + let kernel_name = match dtype { + DType::F16 => contiguous_tiled::const_set::HALF, + DType::BF16 => contiguous_tiled::const_set::BFLOAT, + _ => crate::bail!("internal bug in const_set"), + }; + candle_metal_kernels::call_const_set_contiguous_tiled( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + s, + dst, + ) + .map_err(MetalError::from)?; + } + (_, _, true) => { + use candle_metal_kernels::unary::contiguous; + let kernel_name = match dtype { + DType::F16 => contiguous::const_set::HALF, + DType::BF16 => contiguous::const_set::BFLOAT, + DType::F32 => contiguous::const_set::FLOAT, + DType::I64 => contiguous::const_set::I64, + DType::U32 => contiguous::const_set::U32, + DType::U8 => contiguous::const_set::U8, + DType::F64 => crate::bail!("unsupported const-set f64"), + }; + candle_metal_kernels::call_const_set_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + s, + dst, + ) + .map_err(MetalError::from)?; + } + (_, _, false) => { + use candle_metal_kernels::unary::strided; + let kernel_name = match dtype { + DType::F16 => strided::const_set::HALF, + DType::BF16 => strided::const_set::BFLOAT, + DType::F32 => strided::const_set::FLOAT, + DType::I64 => strided::const_set::I64, + DType::U32 => strided::const_set::U32, + DType::U8 => strided::const_set::U8, + DType::F64 => crate::bail!("unsupported const-set f64"), + }; + candle_metal_kernels::call_const_set_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + l.dims(), + s, + l.stride(), + dst, + ) + .map_err(MetalError::from)?; + } + } + Ok(()) + } + match (self.dtype, s) { + (DType::U8, Scalar::U8(s)) => set(self, s, l), + (DType::U32, Scalar::U32(s)) => set(self, s, l), + (DType::I64, Scalar::I64(s)) => set(self, s, l), + (DType::F16, Scalar::F16(s)) => set(self, s, l), + (DType::BF16, Scalar::BF16(s)) => set(self, s, l), + (DType::F32, Scalar::F32(s)) => set(self, s, l), + (DType::F64, Scalar::F64(s)) => set(self, s, l), + _ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s), + } + } + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { let device = self.device(); let shape = layout.shape(); @@ -1965,18 +2059,6 @@ impl BackendDevice for MetalDevice { )) } - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - match dtype { - DType::U8 => self.const_impl(1u8, shape), - DType::U32 => self.const_impl(1u32, shape), - DType::I64 => self.const_impl(1i64, shape), - DType::F16 => self.const_impl(half::f16::ONE, shape), - DType::BF16 => self.const_impl(half::bf16::ONE, shape), - DType::F32 => self.const_impl(1f32, shape), - DType::F64 => self.const_impl(1f64, shape), - } - } - fn storage_from_slice(&self, s: &[T]) -> Result { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 8a0637e3..3148a00a 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,5 +1,6 @@ use crate::backend::BackendStorage; use crate::op::{self, CmpOp, ReduceOp}; +use crate::scalar::Scalar; use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape}; use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; @@ -73,6 +74,14 @@ impl Storage { } } + pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> { + match self { + Storage::Cpu(storage) => storage.const_set(v, l), + Storage::Cuda(storage) => storage.const_set(v, l), + Storage::Metal(storage) => storage.const_set(v, l), + } + } + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 3fdcbcc6..cd51ccbc 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -185,7 +185,9 @@ impl Tensor { ) -> Result { let none = BackpropOp::none(); let shape = shape.into(); - let storage = device.ones(&shape, dtype)?; + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; + let layout = Layout::contiguous(shape.clone()); + storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?; Ok(from_storage(storage, shape, none, is_variable)) } @@ -202,6 +204,18 @@ impl Tensor { Self::ones_impl(shape, dtype, device, false) } + pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> { + self.storage_mut().const_set(value, self.layout()) + } + + pub fn zero_set(&self) -> Result<()> { + self.const_set(crate::scalar::Scalar::zero(self.dtype())) + } + + pub fn one_set(&self) -> Result<()> { + self.const_set(crate::scalar::Scalar::one(self.dtype())) + } + /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor. /// /// ```rust diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 168012c5..7d33f9d7 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::F32, device)?.to_vec2::()?, [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ); - assert_eq!( - Tensor::ones((2, 3), DType::F64, device)?.to_vec2::()?, - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ); + if !device.is_metal() { + assert_eq!( + Tensor::ones((2, 3), DType::F64, device)?.to_vec2::()?, + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ); + } assert_eq!( Tensor::ones((2, 3), DType::F16, device)?.to_vec2::()?, [ @@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> { } fn full(device: &Device) -> Result<()> { + let tensor = Tensor::zeros((3, 4), DType::U32, device)?; + tensor.const_set(42u32.into())?; + assert_eq!( + tensor.to_vec2::()?, + [[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]] + ); + tensor.i((.., 2))?.const_set(1337u32.into())?; + assert_eq!( + tensor.to_vec2::()?, + [[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]] + ); + tensor.i((2, ..))?.const_set(1u32.into())?; + assert_eq!( + tensor.to_vec2::()?, + [[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]] + ); + Ok(()) +} + +fn const_set(device: &Device) -> Result<()> { assert_eq!( Tensor::full(42u32, (2, 3), device)?.to_vec2::()?, [[42, 42, 42], [42, 42, 42]], @@ -1509,6 +1531,7 @@ fn zero_dim(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(ones, ones_cpu, ones_gpu, ones_metal); test_device!(full, full_cpu, full_gpu, full_metal); +test_device!(const_set, cs_cpu, cs_gpu, cs_metal); test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 40063ba9..ca46186f 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.9.0-alpha.4" +version = "0.9.0-alpha.5" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.5" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index f786aaa4..c0860d0f 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.9.0-alpha.4" +version = "0.9.0-alpha.5" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d98..f9ab68fe 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -1,5 +1,6 @@ #include #include "cuda_fp16.h" +#include "cuda_utils.cuh" template __device__ void fill_with(T *buf, T value, const size_t numel) { @@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8) COPY2D_OP(uint32_t, copy2d_u32) COPY2D_OP(int64_t, copy2d_i64) +#define CONST_SET_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME inp, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + out[i] = inp; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + out[strided_i] = inp; \ + } \ + } \ +} \ + +CONST_SET_OP(float, const_set_f32) +CONST_SET_OP(double, const_set_f64) +CONST_SET_OP(uint8_t, const_set_u8) +CONST_SET_OP(uint32_t, const_set_u32) +CONST_SET_OP(int64_t, const_set_i64) + + #if __CUDA_ARCH__ >= 530 extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__half, copy2d_f16) +CONST_SET_OP(__half, const_set_f16) #endif #if __CUDA_ARCH__ >= 800 #include extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) +CONST_SET_OP(__nv_bfloat16, const_set_bf16) #endif diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index b00e7ca0..0e796968 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.9.0-alpha.4" +version = "0.9.0-alpha.5" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2a898b54..be31f824 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -161,7 +161,7 @@ macro_rules! ops{ pub mod unary { ops!( cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, - tanh, recip, silu, sign, sigmoid + tanh, recip, silu, sign, sigmoid, const_set ); } pub mod binary { @@ -419,6 +419,82 @@ pub fn call_copy2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_const_set_contiguous_tiled( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: unary::contiguous_tiled::Kernel, + length: usize, + input: impl EncoderParam, + output: BufferOffset, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let tile_size = 2; + let tiles = length.div_ceil(tile_size); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, input, &output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_const_set_contiguous( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: impl EncoderParam, + output: BufferOffset, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, input, &output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_const_set_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: unary::strided::Kernel, + shape: &[usize], + input: impl EncoderParam, + strides: &[usize], + output: BufferOffset, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + + let length: usize = shape.iter().product(); + let num_dims: usize = shape.len(); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.set_compute_pipeline_state(&pipeline); + set_params!(encoder, (length, num_dims, shape, strides, input, &output)); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous_tiled( device: &Device, diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index e3a18cfe..ae286f36 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -73,6 +73,44 @@ template METAL_FUNC T sigmoid(T in) { #define TILE_SIZE 2 +#define CONST_SET(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant TYPENAME &input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = input; \ +} \ +kernel void FN_NAME##_##strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant TYPENAME &input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[get_strided_index(tid, num_dims, dims, strides)] = input; \ +} \ +kernel void FN_NAME##_##tiled( \ + constant size_t &dim, \ + constant TYPENAME &input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + for (uint i = 0; i < TILE_SIZE; i++) { \ + const uint idx = tid * TILE_SIZE + i; \ + output[idx] = input; \ + } \ +} + #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half) COPY2D(copy2d_u8, uint8_t) COPY2D(copy2d_u32, uint32_t) +CONST_SET(float, const_set_f32) +CONST_SET(half, const_set_f16) +CONST_SET(uint8_t, const_set_u8) +CONST_SET(uint32_t, const_set_u32) + UNARY_OP(cos) UNARY_OP(sin) UNARY_OP(sqr) @@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided); #if __METAL_VERSION__ >= 220 UNARY(id, int64_t, copy_i64, copy_i64_strided) COPY2D(copy2d_i64, int64_t) +CONST_SET(int64_t, const_set_i64) #endif #if defined(__HAVE_BFLOAT__) @@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided) UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided); COPY2D(copy2d_bf16, bfloat) +CONST_SET(bfloat, const_set_bf16) #endif diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index 6954257d..ea2c39d1 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.9.0-alpha.4" +version = "0.9.0-alpha.5" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" } -candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" } +candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.5" } +candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.5" } prost = "0.12.1" [build-dependencies]