//! Tensor ops. //! use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. /// /// ```rust /// use candle::{Tensor, Device, test_utils::to_vec2_round}; /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; /// let a = candle_nn::ops::softmax(&a, 1)?; /// assert_eq!( /// to_vec2_round(&a, 4)?, /// &[ /// [0.1345, 0.3655, 0.1345, 0.3655], /// [0.0049, 0.2671, 0.7262, 0.0018] /// ]); /// # Ok::<(), candle::Error>(()) /// ``` pub fn softmax(xs: &Tensor, dim: D) -> Result { let dim = dim.to_index(xs.shape(), "softmax")?; let max = xs.max_keepdim(dim)?; let diff = xs.broadcast_sub(&max)?; let num = diff.exp()?; let den = num.sum_keepdim(dim)?; num.broadcast_div(&den) } pub fn log_softmax(xs: &Tensor, d: D) -> Result { let d = d.to_index(xs.shape(), "log-softmax")?; let max = xs.max_keepdim(d)?; let diff = xs.broadcast_sub(&max)?; let sum_exp = diff.exp()?.sum_keepdim(d)?; let log_sm = diff.broadcast_sub(&sum_exp.log()?)?; Ok(log_sm) } pub fn silu(xs: &Tensor) -> Result { xs.silu() } pub fn swiglu(xs: &Tensor) -> Result { let xs = xs.chunk(2, D::Minus1)?; &xs[0].silu()? * &xs[1] } struct Sigmoid; impl candle::CustomOp1 for Sigmoid { fn name(&self) -> &'static str { "sigmoid" } fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { use candle::backend::BackendStorage; fn fwd(v: T) -> T { (v.neg().exp() + T::one()).recip() } // FIXME: using `candle::map_dtype` causes compilation errors. let storage = match storage { CpuStorage::BF16(slice) => { CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd)) } CpuStorage::F16(slice) => { CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd)) } CpuStorage::F32(slice) => { CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd)) } CpuStorage::F64(slice) => { CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd)) } _ => Err(candle::Error::UnsupportedDTypeForOp( storage.dtype(), self.name(), ))?, }; Ok((storage, layout.shape().clone())) } #[cfg(feature = "cuda")] fn cuda_fwd( &self, storage: &candle::CudaStorage, layout: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::{ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; use candle::cuda_backend::SlicePtrOrNull; use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; use candle::{CudaDevice, WithDType}; struct S; impl Map1 for S { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("usigmoid"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }.w()?; let params = (el_count, dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } let dev = storage.device(); let slice = S.map(&storage.slice, dev, layout)?; let dst = candle::CudaStorage { slice, device: dev.clone(), }; Ok((dst, layout.shape().clone())) } #[cfg(feature = "metal")] fn metal_fwd( &self, storage: &candle::MetalStorage, layout: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { use candle::backend::BackendStorage; use candle::MetalError; let device = storage.device(); let dtype = storage.dtype(); let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "sigmoid")?; let command_buffer = device.command_buffer()?; command_buffer.set_label("sigmoid"); let src = candle_metal_kernels::BufferOffset { buffer: storage.buffer(), offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(), }; match (el_count % 2, dtype, layout.is_contiguous()) { (0, DType::BF16 | DType::F16, true) => { use candle_metal_kernels::unary::contiguous_tiled; let kernel_name = match dtype { DType::F16 => contiguous_tiled::sigmoid::HALF, DType::F32 => contiguous_tiled::sigmoid::FLOAT, DType::BF16 => contiguous_tiled::sigmoid::BFLOAT, dtype => { candle::bail!( "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented" ) } }; candle_metal_kernels::call_unary_contiguous_tiled( device.metal_device(), &command_buffer, device.kernels(), kernel_name, el_count, src, &buffer, ) .map_err(MetalError::from)?; } (_, _, true) => { use candle_metal_kernels::unary::contiguous; let kernel_name = match dtype { DType::F16 => contiguous::sigmoid::HALF, DType::F32 => contiguous::sigmoid::FLOAT, DType::BF16 => contiguous::sigmoid::BFLOAT, dtype => { candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented") } }; candle_metal_kernels::call_unary_contiguous( device.metal_device(), &command_buffer, device.kernels(), kernel_name, el_count, src, &buffer, ) .map_err(MetalError::from)?; } (_, _, false) => { use candle_metal_kernels::unary::strided; let kernel_name = match dtype { DType::F16 => strided::sigmoid::HALF, DType::F32 => strided::sigmoid::FLOAT, DType::BF16 => strided::sigmoid::BFLOAT, dtype => { candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented") } }; let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer); candle_metal_kernels::call_unary_strided( device.metal_device(), &command_buffer, device.kernels(), kernel_name, layout.dims(), src, layout.stride(), dst, ) .map_err(MetalError::from)?; } } let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype); Ok((new_storage, layout.shape().clone())) } fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result> { // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x) let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?; Ok(Some(grad_res.mul(&d_dx_sigmoid)?)) } } pub fn sigmoid(xs: &Tensor) -> Result { xs.apply_op1(Sigmoid) } pub fn hard_sigmoid(xs: &Tensor) -> Result { // TODO: Should we have a specialized op for this? ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32) } pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result { let zeros = xs.zeros_like()?; xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope } pub fn dropout(xs: &Tensor, drop_p: f32) -> Result { // This implementation is inefficient as it stores the full mask for the backward pass. // Instead we could just store the seed and have a specialized kernel that would both // generate the random mask and apply it. // Another easier optimization would be to be able to generate boolean mask using just a bit of // entropy per element rather than generating a full float per element. if !(0. ..1.).contains(&drop_p) { candle::bail!("dropout probability has to be in [0, 1), got {drop_p}") } let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?; let scale = 1.0 / (1.0 - drop_p as f64); let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?; let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?; xs * mask } #[derive(Clone, Debug)] pub struct Dropout { drop_p: f32, } impl Dropout { pub fn new(drop_p: f32) -> Dropout { Self { drop_p } } pub fn forward(&self, xs: &Tensor, train: bool) -> Result { if train { dropout(xs, self.drop_p) } else { Ok(xs.clone()) } } } impl candle::ModuleT for Dropout { fn forward_t(&self, xs: &Tensor, train: bool) -> Result { self.forward(xs, train) } } struct SoftmaxLastDim; impl candle::CustomOp1 for SoftmaxLastDim { fn name(&self) -> &'static str { "softmax-last-dim" } fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { fn softmax( src: &[T], layout: &Layout, ) -> Result<(CpuStorage, Shape)> { let src = match layout.contiguous_offsets() { None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => &src[o1..o2], }; let el_count = layout.shape().elem_count(); let dims = layout.shape().dims(); let dim_m1 = dims[dims.len() - 1]; let mut dst = vec![T::zero(); el_count]; src.par_chunks(dim_m1) .zip(dst.par_chunks_mut(dim_m1)) .for_each(|(src, dst)| { let mut max = T::neg_infinity(); unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) }; for (s, d) in src.iter().zip(dst.iter_mut()) { *d = (*s - max).exp(); } let mut sum_exp = T::zero(); unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) }; for d in dst.iter_mut() { *d /= sum_exp } }); let storage = candle::WithDType::to_cpu_storage_owned(dst); Ok((storage, Shape::from_dims(dims))) } match storage { CpuStorage::BF16(slice) => softmax::(slice, layout), CpuStorage::F16(slice) => softmax::(slice, layout), CpuStorage::F32(slice) => softmax::(slice, layout), CpuStorage::F64(slice) => softmax::(slice, layout), _ => candle::bail!("unsupported dtype for softmax {:?}", storage), } } #[cfg(feature = "cuda")] fn cuda_fwd( &self, storage: &candle::CudaStorage, layout: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, }; use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; use candle::{CudaDevice, WithDType}; struct S; impl Map1 for S { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let src = match layout.contiguous_offsets() { None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => src.slice(o1..o2), }; let el = layout.shape().elem_count(); let dims = layout.shape().dims(); let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), block_dim: (1, 32, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("softmax"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; let params = (&src, &dst, n_cols as i32); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) } } use candle::backend::BackendStorage; let dev = storage.device(); let slice = S.map(&storage.slice, dev, layout)?; let dst = candle::cuda_backend::CudaStorage { slice, device: dev.clone(), }; Ok((dst, layout.shape().clone())) } #[cfg(feature = "metal")] fn metal_fwd( &self, storage: &candle::MetalStorage, layout: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { use candle::backend::BackendStorage; let device = storage.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); let name = match storage.dtype() { DType::F32 => "softmax_f32", DType::F16 => "softmax_f16", DType::BF16 => "softmax_bf16", dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), }; let n = layout.stride().len(); if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, kernels, name, elem_count, last_dim, storage.buffer(), layout.start_offset() * storage.dtype().size_in_bytes(), &output, ) .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype()); Ok((newstorage, layout.shape().clone())) } } pub fn softmax_last_dim(xs: &Tensor) -> Result { xs.apply_op1_no_bwd(&SoftmaxLastDim) } #[derive(Debug, Clone)] struct RmsNorm { eps: f32, } impl candle::CustomOp2 for RmsNorm { fn name(&self) -> &'static str { "rms-norm" } fn cpu_fwd( &self, s1: &CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout, ) -> Result<(CpuStorage, Shape)> { use candle::backend::BackendStorage; let eps = self.eps; fn inner< T: candle::WithDType + num_traits::Float + num_traits::AsPrimitive + num_traits::FromPrimitive, >( src: &[T], layout: &Layout, alpha: &[T], alpha_layout: &Layout, eps: f32, ) -> Result<(CpuStorage, Shape)> { let src = match layout.contiguous_offsets() { None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => &src[o1..o2], }; let alpha = match alpha_layout.contiguous_offsets() { None => candle::bail!("alpha has to be contiguous"), Some((o1, o2)) => &alpha[o1..o2], }; let el_count = layout.shape().elem_count(); let dims = layout.shape().dims(); let dim_m1 = dims[dims.len() - 1]; let mut dst = vec![T::zero(); el_count]; src.par_chunks(dim_m1) .zip(dst.par_chunks_mut(dim_m1)) .for_each(|(src, dst)| { let sum2 = src .iter() .map(|&v| { let v = v.as_(); v * v }) .sum::(); let m = (sum2 / dim_m1 as f32 + eps).sqrt(); let m = T::from_f32(m).unwrap_or_else(T::nan); for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) { *d = *s / m * *alpha } }); let storage = candle::WithDType::to_cpu_storage_owned(dst); Ok((storage, Shape::from_dims(dims))) } use CpuStorage as C; match (s1, s2) { (C::BF16(s1), C::BF16(s2)) => inner::(s1, l1, s2, l2, eps), (C::F16(s1), C::F16(s2)) => inner::(s1, l1, s2, l2, eps), (C::F32(s1), C::F32(s2)) => inner::(s1, l1, s2, l2, eps), _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), } } #[cfg(feature = "cuda")] fn cuda_fwd( &self, s1: &candle::CudaStorage, l1: &Layout, s2: &candle::CudaStorage, l2: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, }; use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; use candle::{CudaDevice, WithDType}; struct S { eps: f32, } impl Map2 for S { fn f( &self, src: &CudaSlice, layout: &Layout, alpha: &CudaSlice, alpha_layout: &Layout, dev: &CudaDevice, ) -> Result> { let src = match layout.contiguous_offsets() { None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => src.slice(o1..o2), }; let alpha = match alpha_layout.contiguous_offsets() { None => candle::bail!("alpha has to be contiguous"), Some((o1, o2)) => alpha.slice(o1..o2), }; let el = layout.shape().elem_count(); let dims = layout.shape().dims(); let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); let block_size = if n_cols < 1024 { 32 } else { 1024 }; let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; let params = ( &src, &dst, &alpha, n_cols as i32, block_size as i32, self.eps, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) } } use candle::backend::BackendStorage; let dev = s1.device(); let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?; let dst = candle::cuda_backend::CudaStorage { slice, device: dev.clone(), }; Ok((dst, l1.shape().clone())) } #[cfg(feature = "metal")] fn metal_fwd( &self, s1: &candle::MetalStorage, l1: &Layout, s2: &candle::MetalStorage, l2: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { use candle::backend::BackendStorage; let device = s1.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); let name = match (s1.dtype(), s2.dtype()) { (DType::F32, DType::F32) => "rmsnorm_f32", (DType::F16, DType::F16) => "rmsnorm_f16", (DType::BF16, DType::BF16) => "rmsnorm_bf16", (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"), }; if !(l1.is_contiguous() && l2.is_contiguous()) { candle::bail!("Non contiguous rmsnorm is not implemented"); } let last_dim = l1.dims()[l1.shape().rank() - 1]; let elem_count = l1.shape().elem_count(); let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?; candle_metal_kernels::call_rms_norm( device.metal_device(), &command_buffer, kernels, name, elem_count, last_dim, self.eps, s1.buffer(), l1.start_offset() * s1.dtype().size_in_bytes(), s2.buffer(), l2.start_offset() * s2.dtype().size_in_bytes(), &output, ) .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); Ok((newstorage, l1.shape().clone())) } } pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result { let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, d => d, }; let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) } pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result { let hidden_size_xs = xs.dim(D::Minus1)?; let hidden_size_alpha = alpha.dims1()?; if hidden_size_xs != hidden_size_alpha { candle::bail!( "shape mismatch in rms-norm {:?} {:?}", xs.shape(), alpha.shape() ) } xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) } #[derive(Debug, Clone)] struct LayerNorm { eps: f32, } impl candle::CustomOp3 for LayerNorm { fn name(&self) -> &'static str { "layer-norm" } fn cpu_fwd( &self, s1: &CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout, s3: &CpuStorage, l3: &Layout, ) -> Result<(CpuStorage, Shape)> { use candle::backend::BackendStorage; let eps = self.eps; fn inner< T: candle::WithDType + num_traits::Float + num_traits::AsPrimitive + num_traits::FromPrimitive, >( src: &[T], layout: &Layout, alpha: &[T], alpha_layout: &Layout, beta: &[T], beta_layout: &Layout, eps: f32, ) -> Result<(CpuStorage, Shape)> { let src = match layout.contiguous_offsets() { None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => &src[o1..o2], }; let alpha = match alpha_layout.contiguous_offsets() { None => candle::bail!("alpha has to be contiguous"), Some((o1, o2)) => &alpha[o1..o2], }; let beta = match beta_layout.contiguous_offsets() { None => candle::bail!("beta has to be contiguous"), Some((o1, o2)) => &beta[o1..o2], }; let el_count = layout.shape().elem_count(); let dims = layout.shape().dims(); let dim_m1 = dims[dims.len() - 1]; let mut dst = vec![T::zero(); el_count]; src.par_chunks(dim_m1) .zip(dst.par_chunks_mut(dim_m1)) .for_each(|(src, dst)| { let mut sum = 0f32; let mut sum2 = 0f32; for v in src { let v = v.as_(); sum += v; sum2 += v * v; } let mean = sum / dim_m1 as f32; let var = sum2 / dim_m1 as f32 - mean * mean; let inv_std = (var + eps).sqrt().recip(); for ((d, s), (alpha, beta)) in dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta)) { let alpha = alpha.as_(); let beta = beta.as_(); let d_ = (s.as_() - mean) * inv_std * alpha + beta; *d = T::from_f32(d_).unwrap_or_else(T::nan); } }); let storage = candle::WithDType::to_cpu_storage_owned(dst); Ok((storage, Shape::from_dims(dims))) } use CpuStorage as C; match (s1, s2, s3) { (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => { inner::(s1, l1, s2, l2, s3, l3, eps) } (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::(s1, l1, s2, l2, s3, l3, eps), (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::(s1, l1, s2, l2, s3, l3, eps), _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), } } #[cfg(feature = "cuda")] fn cuda_fwd( &self, s1: &candle::CudaStorage, l1: &Layout, s2: &candle::CudaStorage, l2: &Layout, s3: &candle::CudaStorage, l3: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, }; use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr}; use candle::{CudaDevice, WithDType}; struct S { eps: f32, } impl Map3 for S { fn f( &self, src: &CudaSlice, layout: &Layout, alpha: &CudaSlice, alpha_layout: &Layout, beta: &CudaSlice, beta_layout: &Layout, dev: &CudaDevice, ) -> Result> { let src = match layout.contiguous_offsets() { None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => src.slice(o1..o2), }; let alpha = match alpha_layout.contiguous_offsets() { None => candle::bail!("alpha has to be contiguous"), Some((o1, o2)) => alpha.slice(o1..o2), }; let beta = match beta_layout.contiguous_offsets() { None => candle::bail!("beta has to be contiguous"), Some((o1, o2)) => beta.slice(o1..o2), }; let el = layout.shape().elem_count(); let dims = layout.shape().dims(); let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); let block_size = if n_cols < 1024 { 32 } else { 1024 }; let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("layernorm"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; let params = ( &src, &dst, &alpha, &beta, n_cols as i32, block_size as i32, self.eps, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) } } use candle::backend::BackendStorage; let dev = s1.device(); let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?; let dst = candle::cuda_backend::CudaStorage { slice, device: dev.clone(), }; Ok((dst, l1.shape().clone())) } #[cfg(feature = "metal")] fn metal_fwd( &self, s1: &candle::MetalStorage, l1: &Layout, s2: &candle::MetalStorage, l2: &Layout, s3: &candle::MetalStorage, l3: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { use candle::backend::BackendStorage; let device = s1.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); let name = match (s1.dtype(), s2.dtype(), s3.dtype()) { (DType::F32, DType::F32, DType::F32) => "layernorm_f32", (DType::F16, DType::F16, DType::F16) => "layernorm_f16", (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16", (dt1, dt2, dt3) => { candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}") } }; if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) { candle::bail!("Non contiguous layernorm is not implemented"); } let last_dim = l1.dims()[l1.shape().rank() - 1]; let elem_count = l1.shape().elem_count(); let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?; candle_metal_kernels::call_layer_norm( device.metal_device(), &command_buffer, kernels, name, elem_count, last_dim, self.eps, s1.buffer(), l1.start_offset() * s1.dtype().size_in_bytes(), s2.buffer(), l2.start_offset() * s2.dtype().size_in_bytes(), s3.buffer(), l3.start_offset() * s3.dtype().size_in_bytes(), &output, ) .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); Ok((newstorage, l1.shape().clone())) } } pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result { let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, d => d, }; let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let x = { let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; x.broadcast_sub(&mean_x)? }; let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; x_normed .to_dtype(x_dtype)? .broadcast_mul(alpha)? .broadcast_add(beta) } pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result { let hidden_size_xs = xs.dim(D::Minus1)?; let hidden_size_alpha = alpha.dims1()?; let hidden_size_beta = beta.dims1()?; if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta { candle::bail!( "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}", xs.shape(), alpha.shape(), beta.shape() ) } xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps }) } // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result { let (b_size, c, h, w) = xs.dims4()?; let out_c = c / upscale_factor / upscale_factor; xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))? .permute((0, 1, 4, 2, 5, 3))? .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor)) } pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result { let (b_size, c, h, w) = xs.dims4()?; let out_c = c * downscale_factor * downscale_factor; xs.reshape(( b_size, c, h / downscale_factor, downscale_factor, w / downscale_factor, downscale_factor, ))? .permute((0, 1, 3, 5, 2, 4))? .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) } // https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { match pad { 0 => Ok(xs.clone()), 1 => { let (_b_size, _c, h, w) = xs.dims4()?; let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?); let xs = Tensor::cat(&[&first, xs, &last], 3)?; let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?); Tensor::cat(&[&first, &xs, &last], 2) } n => candle::bail!("replication-pad with a size of {n} is not supported"), } } #[derive(Clone, Debug)] pub struct Identity; impl Identity { pub fn new() -> Identity { Self } } impl Default for Identity { fn default() -> Self { Self } } impl Module for Identity { fn forward(&self, xs: &Tensor) -> Result { Ok(xs.clone()) } } #[allow(dead_code)] struct Sdpa { scale: f32, softcapping: f32, } impl candle::CustomOp3 for Sdpa { fn name(&self) -> &'static str { "metal-sdpa" } fn cpu_fwd( &self, _s1: &CpuStorage, _l1: &Layout, _s2: &CpuStorage, _l2: &Layout, _s3: &CpuStorage, _l3: &Layout, ) -> Result<(CpuStorage, Shape)> { candle::bail!("SDPA has no cpu impl") } #[cfg(feature = "metal")] fn metal_fwd( &self, q: &candle::MetalStorage, q_l: &Layout, k: &candle::MetalStorage, k_l: &Layout, v: &candle::MetalStorage, v_l: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { use candle::backend::BackendStorage; use candle_metal_kernels::SdpaDType; let device = q.device(); let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; let elem_count: usize = out_dims.iter().product(); let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; // q,k must have matching emb dim if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? { candle::bail!("`q` and `k` last dims must match"); } // k,v must have matching n kv heads if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? { candle::bail!("`k` and `v` head dims must match"); } // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1. if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 { candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); } let k_head = k_l.dim(D::Minus1)?; let q_head = q_l.dim(D::Minus1)?; let q_seq = q_l.dim(2)?; let mut implementation_supports_use_case = q_head == k_head; let supported_head_dim = q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; const SDPA_FULL_THRESHOLD: usize = 2; let supports_sdpa_full = q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; let supports_sdpa_vector = q_seq == 1 && supported_head_dim; implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; if !supported_head_dim { candle::bail!( "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.", q_l.dims(), k_l.dims(), v_l.dims() ); } if !implementation_supports_use_case { candle::bail!( "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.", q_l.dims(), k_l.dims(), v_l.dims() ); } for t in [k.dtype(), v.dtype()] { if q.dtype() != t { candle::bail!("all q, k, v dtypes must match."); } } let itype = match q.dtype() { DType::BF16 => SdpaDType::BF16, DType::F16 => SdpaDType::F16, DType::F32 => SdpaDType::F32, other => candle::bail!("unsupported sdpa type {other:?}"), }; let command_buffer = q.device().command_buffer()?; if supports_sdpa_vector { command_buffer.set_label("vector_attention"); candle_metal_kernels::call_sdpa_vector( q.device().device(), &command_buffer, q.device().kernels(), q_l.start_offset(), q_l.dims(), q.buffer(), k_l.start_offset(), k_l.dims(), k_l.stride(), k.buffer(), v_l.start_offset(), v_l.stride(), v.buffer(), &output, self.scale, self.softcapping, itype, ) .map_err(candle::Error::wrap)?; } else if supports_sdpa_full { if q_l.dim(2)? != k_l.dim(2)? { candle::bail!( "query and key sequence length must be equal if using full metal sdpa" ) } command_buffer.set_label("full_attention"); candle_metal_kernels::call_sdpa_full( q.device().device(), &command_buffer, q.device().kernels(), q_l.start_offset(), q_l.dims(), q.buffer(), k_l.start_offset(), k.buffer(), v_l.start_offset(), v.buffer(), &output, self.scale, self.softcapping, itype, ) .map_err(candle::Error::wrap)?; } else { candle::bail!("must be vector or full sdpa kernel"); } let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); Ok((newstorage, Shape::from_dims(&out_dims))) } } /// Scaled dot product attention with a fused kernel. /// /// Computes softmax(qk^T*scale)v. /// /// **Inputs shapes:** /// - `q`: (bs, qhead, seq, hidden) /// - `k`: (bs, kv_head, kv_seq, hidden) /// - `k`: (bs, kv_head, kv_seq, v_hidden) /// - `scale` is applied before softmax. /// - If `softcapping` != 1.0: /// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v /// /// **Output shape:** (bs, qhead, seq, v_hidden) /// /// **Supported head dims:** 32, 64, 96, 128, 256. /// /// ## On Metal: /// - If `seq` == 1: /// - Use a vectorized kernel /// - Supports `seq` != `kv_seq` (cross attn. support) /// - Supports GQA when `qhead` is a multiple of `kv_head` /// - Otherwise: /// - Use an alternate kernel /// - Requires `seq` == `kv_seq` /// - GQA is not supported (requires `qhead` == `kv_head`) pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result { q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) }