mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00

* add module docs for candle-core * doc each of the candle-nn modules and add the links to the doc page
1160 lines
40 KiB
Rust
1160 lines
40 KiB
Rust
//! Tensor ops.
|
|
//!
|
|
|
|
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
|
|
use rayon::prelude::*;
|
|
|
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
|
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
|
///
|
|
/// ```rust
|
|
/// use candle::{Tensor, Device, test_utils::to_vec2_round};
|
|
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
|
|
/// let a = candle_nn::ops::softmax(&a, 1)?;
|
|
/// assert_eq!(
|
|
/// to_vec2_round(&a, 4)?,
|
|
/// &[
|
|
/// [0.1345, 0.3655, 0.1345, 0.3655],
|
|
/// [0.0049, 0.2671, 0.7262, 0.0018]
|
|
/// ]);
|
|
/// # Ok::<(), candle::Error>(())
|
|
/// ```
|
|
pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
|
let dim = dim.to_index(xs.shape(), "softmax")?;
|
|
let max = xs.max_keepdim(dim)?;
|
|
let diff = xs.broadcast_sub(&max)?;
|
|
let num = diff.exp()?;
|
|
let den = num.sum_keepdim(dim)?;
|
|
num.broadcast_div(&den)
|
|
}
|
|
|
|
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
|
let d = d.to_index(xs.shape(), "log-softmax")?;
|
|
let max = xs.max_keepdim(d)?;
|
|
let diff = xs.broadcast_sub(&max)?;
|
|
let sum_exp = diff.exp()?.sum_keepdim(d)?;
|
|
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
|
|
Ok(log_sm)
|
|
}
|
|
|
|
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
|
xs.silu()
|
|
}
|
|
|
|
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
|
let xs = xs.chunk(2, D::Minus1)?;
|
|
&xs[0].silu()? * &xs[1]
|
|
}
|
|
|
|
struct Sigmoid;
|
|
|
|
impl candle::CustomOp1 for Sigmoid {
|
|
fn name(&self) -> &'static str {
|
|
"sigmoid"
|
|
}
|
|
|
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
|
|
fn fwd<T: num_traits::Float>(v: T) -> T {
|
|
(v.neg().exp() + T::one()).recip()
|
|
}
|
|
|
|
// FIXME: using `candle::map_dtype` causes compilation errors.
|
|
let storage = match storage {
|
|
CpuStorage::BF16(slice) => {
|
|
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
|
}
|
|
CpuStorage::F16(slice) => {
|
|
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
|
}
|
|
CpuStorage::F32(slice) => {
|
|
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
|
|
}
|
|
CpuStorage::F64(slice) => {
|
|
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
|
|
}
|
|
_ => Err(candle::Error::UnsupportedDTypeForOp(
|
|
storage.dtype(),
|
|
self.name(),
|
|
))?,
|
|
};
|
|
Ok((storage, layout.shape().clone()))
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
fn cuda_fwd(
|
|
&self,
|
|
storage: &candle::CudaStorage,
|
|
layout: &Layout,
|
|
) -> Result<(candle::CudaStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
use candle::cuda_backend::cudarc::driver::{
|
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
|
};
|
|
use candle::cuda_backend::SlicePtrOrNull;
|
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
|
use candle::{CudaDevice, WithDType};
|
|
|
|
struct S;
|
|
impl Map1 for S {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let el_count = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
|
|
|
let params = (el_count, dims.len(), &ds, src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
let dev = storage.device();
|
|
let slice = S.map(&storage.slice, dev, layout)?;
|
|
let dst = candle::CudaStorage {
|
|
slice,
|
|
device: dev.clone(),
|
|
};
|
|
Ok((dst, layout.shape().clone()))
|
|
}
|
|
|
|
#[cfg(feature = "metal")]
|
|
fn metal_fwd(
|
|
&self,
|
|
storage: &candle::MetalStorage,
|
|
layout: &Layout,
|
|
) -> Result<(candle::MetalStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
use candle::MetalError;
|
|
let device = storage.device();
|
|
let dtype = storage.dtype();
|
|
let shape = layout.shape();
|
|
let el_count = shape.elem_count();
|
|
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
|
|
let command_buffer = device.command_buffer()?;
|
|
command_buffer.set_label("sigmoid");
|
|
let src = candle_metal_kernels::BufferOffset {
|
|
buffer: storage.buffer(),
|
|
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
};
|
|
|
|
match (el_count % 2, dtype, layout.is_contiguous()) {
|
|
(0, DType::BF16 | DType::F16, true) => {
|
|
use candle_metal_kernels::unary::contiguous_tiled;
|
|
let kernel_name = match dtype {
|
|
DType::F16 => contiguous_tiled::sigmoid::HALF,
|
|
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
|
|
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
|
|
dtype => {
|
|
candle::bail!(
|
|
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
|
|
)
|
|
}
|
|
};
|
|
candle_metal_kernels::call_unary_contiguous_tiled(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
device.kernels(),
|
|
kernel_name,
|
|
el_count,
|
|
src,
|
|
&buffer,
|
|
)
|
|
.map_err(MetalError::from)?;
|
|
}
|
|
(_, _, true) => {
|
|
use candle_metal_kernels::unary::contiguous;
|
|
let kernel_name = match dtype {
|
|
DType::F16 => contiguous::sigmoid::HALF,
|
|
DType::F32 => contiguous::sigmoid::FLOAT,
|
|
DType::BF16 => contiguous::sigmoid::BFLOAT,
|
|
dtype => {
|
|
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
|
|
}
|
|
};
|
|
candle_metal_kernels::call_unary_contiguous(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
device.kernels(),
|
|
kernel_name,
|
|
el_count,
|
|
src,
|
|
&buffer,
|
|
)
|
|
.map_err(MetalError::from)?;
|
|
}
|
|
(_, _, false) => {
|
|
use candle_metal_kernels::unary::strided;
|
|
let kernel_name = match dtype {
|
|
DType::F16 => strided::sigmoid::HALF,
|
|
DType::F32 => strided::sigmoid::FLOAT,
|
|
DType::BF16 => strided::sigmoid::BFLOAT,
|
|
dtype => {
|
|
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
|
|
}
|
|
};
|
|
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
|
|
candle_metal_kernels::call_unary_strided(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
device.kernels(),
|
|
kernel_name,
|
|
layout.dims(),
|
|
src,
|
|
layout.stride(),
|
|
dst,
|
|
)
|
|
.map_err(MetalError::from)?;
|
|
}
|
|
}
|
|
|
|
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
|
|
Ok((new_storage, layout.shape().clone()))
|
|
}
|
|
|
|
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
|
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
|
|
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
|
|
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
|
|
}
|
|
}
|
|
|
|
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
|
xs.apply_op1(Sigmoid)
|
|
}
|
|
|
|
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
|
// TODO: Should we have a specialized op for this?
|
|
((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
|
|
}
|
|
|
|
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
|
|
let zeros = xs.zeros_like()?;
|
|
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
|
|
}
|
|
|
|
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
|
// This implementation is inefficient as it stores the full mask for the backward pass.
|
|
// Instead we could just store the seed and have a specialized kernel that would both
|
|
// generate the random mask and apply it.
|
|
// Another easier optimization would be to be able to generate boolean mask using just a bit of
|
|
// entropy per element rather than generating a full float per element.
|
|
if !(0. ..1.).contains(&drop_p) {
|
|
candle::bail!("dropout probability has to be in [0, 1), got {drop_p}")
|
|
}
|
|
let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
|
|
let scale = 1.0 / (1.0 - drop_p as f64);
|
|
let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
|
|
let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
|
|
xs * mask
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Dropout {
|
|
drop_p: f32,
|
|
}
|
|
|
|
impl Dropout {
|
|
pub fn new(drop_p: f32) -> Dropout {
|
|
Self { drop_p }
|
|
}
|
|
|
|
pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
|
if train {
|
|
dropout(xs, self.drop_p)
|
|
} else {
|
|
Ok(xs.clone())
|
|
}
|
|
}
|
|
}
|
|
|
|
impl candle::ModuleT for Dropout {
|
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
|
self.forward(xs, train)
|
|
}
|
|
}
|
|
|
|
struct SoftmaxLastDim;
|
|
|
|
impl candle::CustomOp1 for SoftmaxLastDim {
|
|
fn name(&self) -> &'static str {
|
|
"softmax-last-dim"
|
|
}
|
|
|
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
|
fn softmax<T: candle::WithDType + num_traits::Float>(
|
|
src: &[T],
|
|
layout: &Layout,
|
|
) -> Result<(CpuStorage, Shape)> {
|
|
let src = match layout.contiguous_offsets() {
|
|
None => candle::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => &src[o1..o2],
|
|
};
|
|
let el_count = layout.shape().elem_count();
|
|
let dims = layout.shape().dims();
|
|
let dim_m1 = dims[dims.len() - 1];
|
|
let mut dst = vec![T::zero(); el_count];
|
|
src.par_chunks(dim_m1)
|
|
.zip(dst.par_chunks_mut(dim_m1))
|
|
.for_each(|(src, dst)| {
|
|
let mut max = T::neg_infinity();
|
|
unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
|
|
for (s, d) in src.iter().zip(dst.iter_mut()) {
|
|
*d = (*s - max).exp();
|
|
}
|
|
let mut sum_exp = T::zero();
|
|
unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
|
|
for d in dst.iter_mut() {
|
|
*d /= sum_exp
|
|
}
|
|
});
|
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
|
Ok((storage, Shape::from_dims(dims)))
|
|
}
|
|
|
|
match storage {
|
|
CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
|
|
CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
|
|
CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
|
|
CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
|
|
_ => candle::bail!("unsupported dtype for softmax {:?}", storage),
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
fn cuda_fwd(
|
|
&self,
|
|
storage: &candle::CudaStorage,
|
|
layout: &Layout,
|
|
) -> Result<(candle::CudaStorage, Shape)> {
|
|
use candle::cuda_backend::cudarc::driver::{
|
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
|
};
|
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
|
use candle::{CudaDevice, WithDType};
|
|
|
|
struct S;
|
|
impl Map1 for S {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let src = match layout.contiguous_offsets() {
|
|
None => candle::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
};
|
|
let el = layout.shape().elem_count();
|
|
let dims = layout.shape().dims();
|
|
let dim_m1 = dims[dims.len() - 1];
|
|
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
|
|
|
let cfg = LaunchConfig {
|
|
grid_dim: (n_rows as u32, 1, 1),
|
|
block_dim: (1, 32, 1),
|
|
shared_mem_bytes: 0,
|
|
};
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (&src, &dst, n_cols as i32);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
use candle::backend::BackendStorage;
|
|
let dev = storage.device();
|
|
let slice = S.map(&storage.slice, dev, layout)?;
|
|
let dst = candle::cuda_backend::CudaStorage {
|
|
slice,
|
|
device: dev.clone(),
|
|
};
|
|
Ok((dst, layout.shape().clone()))
|
|
}
|
|
|
|
#[cfg(feature = "metal")]
|
|
fn metal_fwd(
|
|
&self,
|
|
storage: &candle::MetalStorage,
|
|
layout: &Layout,
|
|
) -> Result<(candle::MetalStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
let device = storage.device();
|
|
let command_buffer = device.command_buffer()?;
|
|
let kernels = device.kernels();
|
|
let name = match storage.dtype() {
|
|
DType::F32 => "softmax_f32",
|
|
DType::F16 => "softmax_f16",
|
|
DType::BF16 => "softmax_bf16",
|
|
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
|
};
|
|
|
|
let n = layout.stride().len();
|
|
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
|
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
|
}
|
|
|
|
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
|
let elem_count = layout.shape().elem_count();
|
|
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
|
candle_metal_kernels::call_last_softmax(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
kernels,
|
|
name,
|
|
elem_count,
|
|
last_dim,
|
|
storage.buffer(),
|
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
&output,
|
|
)
|
|
.map_err(candle::Error::wrap)?;
|
|
let newstorage =
|
|
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
|
|
Ok((newstorage, layout.shape().clone()))
|
|
}
|
|
}
|
|
|
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
|
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct RmsNorm {
|
|
eps: f32,
|
|
}
|
|
|
|
impl candle::CustomOp2 for RmsNorm {
|
|
fn name(&self) -> &'static str {
|
|
"rms-norm"
|
|
}
|
|
|
|
fn cpu_fwd(
|
|
&self,
|
|
s1: &CpuStorage,
|
|
l1: &Layout,
|
|
s2: &CpuStorage,
|
|
l2: &Layout,
|
|
) -> Result<(CpuStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
|
|
let eps = self.eps;
|
|
fn inner<
|
|
T: candle::WithDType
|
|
+ num_traits::Float
|
|
+ num_traits::AsPrimitive<f32>
|
|
+ num_traits::FromPrimitive,
|
|
>(
|
|
src: &[T],
|
|
layout: &Layout,
|
|
alpha: &[T],
|
|
alpha_layout: &Layout,
|
|
eps: f32,
|
|
) -> Result<(CpuStorage, Shape)> {
|
|
let src = match layout.contiguous_offsets() {
|
|
None => candle::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => &src[o1..o2],
|
|
};
|
|
let alpha = match alpha_layout.contiguous_offsets() {
|
|
None => candle::bail!("alpha has to be contiguous"),
|
|
Some((o1, o2)) => &alpha[o1..o2],
|
|
};
|
|
let el_count = layout.shape().elem_count();
|
|
let dims = layout.shape().dims();
|
|
let dim_m1 = dims[dims.len() - 1];
|
|
let mut dst = vec![T::zero(); el_count];
|
|
src.par_chunks(dim_m1)
|
|
.zip(dst.par_chunks_mut(dim_m1))
|
|
.for_each(|(src, dst)| {
|
|
let sum2 = src
|
|
.iter()
|
|
.map(|&v| {
|
|
let v = v.as_();
|
|
v * v
|
|
})
|
|
.sum::<f32>();
|
|
let m = (sum2 / dim_m1 as f32 + eps).sqrt();
|
|
let m = T::from_f32(m).unwrap_or_else(T::nan);
|
|
for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
|
|
*d = *s / m * *alpha
|
|
}
|
|
});
|
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
|
Ok((storage, Shape::from_dims(dims)))
|
|
}
|
|
|
|
use CpuStorage as C;
|
|
match (s1, s2) {
|
|
(C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
|
|
(C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
|
|
(C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
|
|
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
fn cuda_fwd(
|
|
&self,
|
|
s1: &candle::CudaStorage,
|
|
l1: &Layout,
|
|
s2: &candle::CudaStorage,
|
|
l2: &Layout,
|
|
) -> Result<(candle::CudaStorage, Shape)> {
|
|
use candle::cuda_backend::cudarc::driver::{
|
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
|
};
|
|
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
|
use candle::{CudaDevice, WithDType};
|
|
|
|
struct S {
|
|
eps: f32,
|
|
}
|
|
impl Map2 for S {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
layout: &Layout,
|
|
alpha: &CudaSlice<T>,
|
|
alpha_layout: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
let src = match layout.contiguous_offsets() {
|
|
None => candle::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
};
|
|
let alpha = match alpha_layout.contiguous_offsets() {
|
|
None => candle::bail!("alpha has to be contiguous"),
|
|
Some((o1, o2)) => alpha.slice(o1..o2),
|
|
};
|
|
let el = layout.shape().elem_count();
|
|
let dims = layout.shape().dims();
|
|
let dim_m1 = dims[dims.len() - 1];
|
|
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
|
|
|
let block_size = if n_cols < 1024 { 32 } else { 1024 };
|
|
let cfg = LaunchConfig {
|
|
grid_dim: (n_rows as u32, 1, 1),
|
|
block_dim: (block_size, 1, 1),
|
|
shared_mem_bytes: 0,
|
|
};
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (
|
|
&src,
|
|
&dst,
|
|
&alpha,
|
|
n_cols as i32,
|
|
block_size as i32,
|
|
self.eps,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
use candle::backend::BackendStorage;
|
|
let dev = s1.device();
|
|
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
|
|
let dst = candle::cuda_backend::CudaStorage {
|
|
slice,
|
|
device: dev.clone(),
|
|
};
|
|
Ok((dst, l1.shape().clone()))
|
|
}
|
|
|
|
#[cfg(feature = "metal")]
|
|
fn metal_fwd(
|
|
&self,
|
|
s1: &candle::MetalStorage,
|
|
l1: &Layout,
|
|
s2: &candle::MetalStorage,
|
|
l2: &Layout,
|
|
) -> Result<(candle::MetalStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
let device = s1.device();
|
|
let command_buffer = device.command_buffer()?;
|
|
let kernels = device.kernels();
|
|
let name = match (s1.dtype(), s2.dtype()) {
|
|
(DType::F32, DType::F32) => "rmsnorm_f32",
|
|
(DType::F16, DType::F16) => "rmsnorm_f16",
|
|
(DType::BF16, DType::BF16) => "rmsnorm_bf16",
|
|
(dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
|
|
};
|
|
|
|
if !(l1.is_contiguous() && l2.is_contiguous()) {
|
|
candle::bail!("Non contiguous rmsnorm is not implemented");
|
|
}
|
|
|
|
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
|
let elem_count = l1.shape().elem_count();
|
|
let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
|
|
candle_metal_kernels::call_rms_norm(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
kernels,
|
|
name,
|
|
elem_count,
|
|
last_dim,
|
|
self.eps,
|
|
s1.buffer(),
|
|
l1.start_offset() * s1.dtype().size_in_bytes(),
|
|
s2.buffer(),
|
|
l2.start_offset() * s2.dtype().size_in_bytes(),
|
|
&output,
|
|
)
|
|
.map_err(candle::Error::wrap)?;
|
|
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
|
Ok((newstorage, l1.shape().clone()))
|
|
}
|
|
}
|
|
|
|
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
|
let x_dtype = x.dtype();
|
|
let internal_dtype = match x_dtype {
|
|
DType::F16 | DType::BF16 => DType::F32,
|
|
d => d,
|
|
};
|
|
let hidden_size = x.dim(D::Minus1)?;
|
|
let x = x.to_dtype(internal_dtype)?;
|
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
|
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
|
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
|
|
}
|
|
|
|
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
|
let hidden_size_xs = xs.dim(D::Minus1)?;
|
|
let hidden_size_alpha = alpha.dims1()?;
|
|
if hidden_size_xs != hidden_size_alpha {
|
|
candle::bail!(
|
|
"shape mismatch in rms-norm {:?} {:?}",
|
|
xs.shape(),
|
|
alpha.shape()
|
|
)
|
|
}
|
|
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct LayerNorm {
|
|
eps: f32,
|
|
}
|
|
|
|
impl candle::CustomOp3 for LayerNorm {
|
|
fn name(&self) -> &'static str {
|
|
"layer-norm"
|
|
}
|
|
|
|
fn cpu_fwd(
|
|
&self,
|
|
s1: &CpuStorage,
|
|
l1: &Layout,
|
|
s2: &CpuStorage,
|
|
l2: &Layout,
|
|
s3: &CpuStorage,
|
|
l3: &Layout,
|
|
) -> Result<(CpuStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
|
|
let eps = self.eps;
|
|
fn inner<
|
|
T: candle::WithDType
|
|
+ num_traits::Float
|
|
+ num_traits::AsPrimitive<f32>
|
|
+ num_traits::FromPrimitive,
|
|
>(
|
|
src: &[T],
|
|
layout: &Layout,
|
|
alpha: &[T],
|
|
alpha_layout: &Layout,
|
|
beta: &[T],
|
|
beta_layout: &Layout,
|
|
eps: f32,
|
|
) -> Result<(CpuStorage, Shape)> {
|
|
let src = match layout.contiguous_offsets() {
|
|
None => candle::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => &src[o1..o2],
|
|
};
|
|
let alpha = match alpha_layout.contiguous_offsets() {
|
|
None => candle::bail!("alpha has to be contiguous"),
|
|
Some((o1, o2)) => &alpha[o1..o2],
|
|
};
|
|
let beta = match beta_layout.contiguous_offsets() {
|
|
None => candle::bail!("beta has to be contiguous"),
|
|
Some((o1, o2)) => &beta[o1..o2],
|
|
};
|
|
let el_count = layout.shape().elem_count();
|
|
let dims = layout.shape().dims();
|
|
let dim_m1 = dims[dims.len() - 1];
|
|
let mut dst = vec![T::zero(); el_count];
|
|
src.par_chunks(dim_m1)
|
|
.zip(dst.par_chunks_mut(dim_m1))
|
|
.for_each(|(src, dst)| {
|
|
let mut sum = 0f32;
|
|
let mut sum2 = 0f32;
|
|
for v in src {
|
|
let v = v.as_();
|
|
sum += v;
|
|
sum2 += v * v;
|
|
}
|
|
let mean = sum / dim_m1 as f32;
|
|
let var = sum2 / dim_m1 as f32 - mean * mean;
|
|
let inv_std = (var + eps).sqrt().recip();
|
|
for ((d, s), (alpha, beta)) in
|
|
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
|
|
{
|
|
let alpha = alpha.as_();
|
|
let beta = beta.as_();
|
|
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
|
|
*d = T::from_f32(d_).unwrap_or_else(T::nan);
|
|
}
|
|
});
|
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
|
Ok((storage, Shape::from_dims(dims)))
|
|
}
|
|
|
|
use CpuStorage as C;
|
|
match (s1, s2, s3) {
|
|
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
|
|
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
|
|
}
|
|
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
|
|
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
|
|
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
fn cuda_fwd(
|
|
&self,
|
|
s1: &candle::CudaStorage,
|
|
l1: &Layout,
|
|
s2: &candle::CudaStorage,
|
|
l2: &Layout,
|
|
s3: &candle::CudaStorage,
|
|
l3: &Layout,
|
|
) -> Result<(candle::CudaStorage, Shape)> {
|
|
use candle::cuda_backend::cudarc::driver::{
|
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
|
};
|
|
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
|
use candle::{CudaDevice, WithDType};
|
|
|
|
struct S {
|
|
eps: f32,
|
|
}
|
|
impl Map3 for S {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
layout: &Layout,
|
|
alpha: &CudaSlice<T>,
|
|
alpha_layout: &Layout,
|
|
beta: &CudaSlice<T>,
|
|
beta_layout: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
let src = match layout.contiguous_offsets() {
|
|
None => candle::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
};
|
|
let alpha = match alpha_layout.contiguous_offsets() {
|
|
None => candle::bail!("alpha has to be contiguous"),
|
|
Some((o1, o2)) => alpha.slice(o1..o2),
|
|
};
|
|
let beta = match beta_layout.contiguous_offsets() {
|
|
None => candle::bail!("beta has to be contiguous"),
|
|
Some((o1, o2)) => beta.slice(o1..o2),
|
|
};
|
|
let el = layout.shape().elem_count();
|
|
let dims = layout.shape().dims();
|
|
let dim_m1 = dims[dims.len() - 1];
|
|
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
|
|
|
let block_size = if n_cols < 1024 { 32 } else { 1024 };
|
|
let cfg = LaunchConfig {
|
|
grid_dim: (n_rows as u32, 1, 1),
|
|
block_dim: (block_size, 1, 1),
|
|
shared_mem_bytes: 0,
|
|
};
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (
|
|
&src,
|
|
&dst,
|
|
&alpha,
|
|
&beta,
|
|
n_cols as i32,
|
|
block_size as i32,
|
|
self.eps,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
use candle::backend::BackendStorage;
|
|
let dev = s1.device();
|
|
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
|
|
let dst = candle::cuda_backend::CudaStorage {
|
|
slice,
|
|
device: dev.clone(),
|
|
};
|
|
Ok((dst, l1.shape().clone()))
|
|
}
|
|
|
|
#[cfg(feature = "metal")]
|
|
fn metal_fwd(
|
|
&self,
|
|
s1: &candle::MetalStorage,
|
|
l1: &Layout,
|
|
s2: &candle::MetalStorage,
|
|
l2: &Layout,
|
|
s3: &candle::MetalStorage,
|
|
l3: &Layout,
|
|
) -> Result<(candle::MetalStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
let device = s1.device();
|
|
let command_buffer = device.command_buffer()?;
|
|
let kernels = device.kernels();
|
|
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
|
|
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
|
|
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
|
|
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
|
|
(dt1, dt2, dt3) => {
|
|
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
|
|
}
|
|
};
|
|
|
|
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
|
|
candle::bail!("Non contiguous layernorm is not implemented");
|
|
}
|
|
|
|
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
|
let elem_count = l1.shape().elem_count();
|
|
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
|
|
candle_metal_kernels::call_layer_norm(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
kernels,
|
|
name,
|
|
elem_count,
|
|
last_dim,
|
|
self.eps,
|
|
s1.buffer(),
|
|
l1.start_offset() * s1.dtype().size_in_bytes(),
|
|
s2.buffer(),
|
|
l2.start_offset() * s2.dtype().size_in_bytes(),
|
|
s3.buffer(),
|
|
l3.start_offset() * s3.dtype().size_in_bytes(),
|
|
&output,
|
|
)
|
|
.map_err(candle::Error::wrap)?;
|
|
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
|
Ok((newstorage, l1.shape().clone()))
|
|
}
|
|
}
|
|
|
|
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
|
let x_dtype = x.dtype();
|
|
let internal_dtype = match x_dtype {
|
|
DType::F16 | DType::BF16 => DType::F32,
|
|
d => d,
|
|
};
|
|
let hidden_size = x.dim(D::Minus1)?;
|
|
let x = x.to_dtype(internal_dtype)?;
|
|
let x = {
|
|
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
|
x.broadcast_sub(&mean_x)?
|
|
};
|
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
|
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
|
x_normed
|
|
.to_dtype(x_dtype)?
|
|
.broadcast_mul(alpha)?
|
|
.broadcast_add(beta)
|
|
}
|
|
|
|
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
|
let hidden_size_xs = xs.dim(D::Minus1)?;
|
|
let hidden_size_alpha = alpha.dims1()?;
|
|
let hidden_size_beta = beta.dims1()?;
|
|
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
|
|
candle::bail!(
|
|
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
|
|
xs.shape(),
|
|
alpha.shape(),
|
|
beta.shape()
|
|
)
|
|
}
|
|
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
|
|
}
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
|
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
|
let (b_size, c, h, w) = xs.dims4()?;
|
|
let out_c = c / upscale_factor / upscale_factor;
|
|
xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
|
|
.permute((0, 1, 4, 2, 5, 3))?
|
|
.reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
|
|
}
|
|
|
|
pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
|
|
let (b_size, c, h, w) = xs.dims4()?;
|
|
let out_c = c * downscale_factor * downscale_factor;
|
|
xs.reshape((
|
|
b_size,
|
|
c,
|
|
h / downscale_factor,
|
|
downscale_factor,
|
|
w / downscale_factor,
|
|
downscale_factor,
|
|
))?
|
|
.permute((0, 1, 3, 5, 2, 4))?
|
|
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
|
|
}
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
|
|
pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
|
|
match pad {
|
|
0 => Ok(xs.clone()),
|
|
1 => {
|
|
let (_b_size, _c, h, w) = xs.dims4()?;
|
|
let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
|
|
let xs = Tensor::cat(&[&first, xs, &last], 3)?;
|
|
let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
|
|
Tensor::cat(&[&first, &xs, &last], 2)
|
|
}
|
|
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Identity;
|
|
|
|
impl Identity {
|
|
pub fn new() -> Identity {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Default for Identity {
|
|
fn default() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Module for Identity {
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
Ok(xs.clone())
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
struct Sdpa {
|
|
scale: f32,
|
|
softcapping: f32,
|
|
}
|
|
|
|
impl candle::CustomOp3 for Sdpa {
|
|
fn name(&self) -> &'static str {
|
|
"metal-sdpa"
|
|
}
|
|
|
|
fn cpu_fwd(
|
|
&self,
|
|
_s1: &CpuStorage,
|
|
_l1: &Layout,
|
|
_s2: &CpuStorage,
|
|
_l2: &Layout,
|
|
_s3: &CpuStorage,
|
|
_l3: &Layout,
|
|
) -> Result<(CpuStorage, Shape)> {
|
|
candle::bail!("SDPA has no cpu impl")
|
|
}
|
|
|
|
#[cfg(feature = "metal")]
|
|
fn metal_fwd(
|
|
&self,
|
|
q: &candle::MetalStorage,
|
|
q_l: &Layout,
|
|
k: &candle::MetalStorage,
|
|
k_l: &Layout,
|
|
v: &candle::MetalStorage,
|
|
v_l: &Layout,
|
|
) -> Result<(candle::MetalStorage, Shape)> {
|
|
use candle::backend::BackendStorage;
|
|
use candle_metal_kernels::SdpaDType;
|
|
|
|
let device = q.device();
|
|
|
|
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
|
|
let elem_count: usize = out_dims.iter().product();
|
|
|
|
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
|
|
|
|
// q,k must have matching emb dim
|
|
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
|
|
candle::bail!("`q` and `k` last dims must match");
|
|
}
|
|
|
|
// k,v must have matching n kv heads
|
|
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
|
|
candle::bail!("`k` and `v` head dims must match");
|
|
}
|
|
|
|
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
|
|
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
|
|
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
|
|
}
|
|
|
|
let k_head = k_l.dim(D::Minus1)?;
|
|
let q_head = q_l.dim(D::Minus1)?;
|
|
let q_seq = q_l.dim(2)?;
|
|
|
|
let mut implementation_supports_use_case = q_head == k_head;
|
|
let supported_head_dim =
|
|
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
|
|
|
|
const SDPA_FULL_THRESHOLD: usize = 2;
|
|
|
|
let supports_sdpa_full =
|
|
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
|
|
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
|
|
|
|
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
|
|
|
|
if !supported_head_dim {
|
|
candle::bail!(
|
|
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
|
|
q_l.dims(),
|
|
k_l.dims(),
|
|
v_l.dims()
|
|
);
|
|
}
|
|
if !implementation_supports_use_case {
|
|
candle::bail!(
|
|
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
|
|
q_l.dims(),
|
|
k_l.dims(),
|
|
v_l.dims()
|
|
);
|
|
}
|
|
|
|
for t in [k.dtype(), v.dtype()] {
|
|
if q.dtype() != t {
|
|
candle::bail!("all q, k, v dtypes must match.");
|
|
}
|
|
}
|
|
|
|
let itype = match q.dtype() {
|
|
DType::BF16 => SdpaDType::BF16,
|
|
DType::F16 => SdpaDType::F16,
|
|
DType::F32 => SdpaDType::F32,
|
|
other => candle::bail!("unsupported sdpa type {other:?}"),
|
|
};
|
|
|
|
let command_buffer = q.device().command_buffer()?;
|
|
if supports_sdpa_vector {
|
|
command_buffer.set_label("vector_attention");
|
|
candle_metal_kernels::call_sdpa_vector(
|
|
q.device().device(),
|
|
&command_buffer,
|
|
q.device().kernels(),
|
|
q_l.start_offset(),
|
|
q_l.dims(),
|
|
q.buffer(),
|
|
k_l.start_offset(),
|
|
k_l.dims(),
|
|
k_l.stride(),
|
|
k.buffer(),
|
|
v_l.start_offset(),
|
|
v_l.stride(),
|
|
v.buffer(),
|
|
&output,
|
|
self.scale,
|
|
self.softcapping,
|
|
itype,
|
|
)
|
|
.map_err(candle::Error::wrap)?;
|
|
} else if supports_sdpa_full {
|
|
if q_l.dim(2)? != k_l.dim(2)? {
|
|
candle::bail!(
|
|
"query and key sequence length must be equal if using full metal sdpa"
|
|
)
|
|
}
|
|
|
|
command_buffer.set_label("full_attention");
|
|
candle_metal_kernels::call_sdpa_full(
|
|
q.device().device(),
|
|
&command_buffer,
|
|
q.device().kernels(),
|
|
q_l.start_offset(),
|
|
q_l.dims(),
|
|
q.buffer(),
|
|
k_l.start_offset(),
|
|
k.buffer(),
|
|
v_l.start_offset(),
|
|
v.buffer(),
|
|
&output,
|
|
self.scale,
|
|
self.softcapping,
|
|
itype,
|
|
)
|
|
.map_err(candle::Error::wrap)?;
|
|
} else {
|
|
candle::bail!("must be vector or full sdpa kernel");
|
|
}
|
|
|
|
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
|
|
Ok((newstorage, Shape::from_dims(&out_dims)))
|
|
}
|
|
}
|
|
|
|
/// Scaled dot product attention with a fused kernel.
|
|
///
|
|
/// Computes softmax(qk^T*scale)v.
|
|
///
|
|
/// **Inputs shapes:**
|
|
/// - `q`: (bs, qhead, seq, hidden)
|
|
/// - `k`: (bs, kv_head, kv_seq, hidden)
|
|
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
|
|
/// - `scale` is applied before softmax.
|
|
/// - If `softcapping` != 1.0:
|
|
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
|
|
///
|
|
/// **Output shape:** (bs, qhead, seq, v_hidden)
|
|
///
|
|
/// **Supported head dims:** 32, 64, 96, 128, 256.
|
|
///
|
|
/// ## On Metal:
|
|
/// - If `seq` == 1:
|
|
/// - Use a vectorized kernel
|
|
/// - Supports `seq` != `kv_seq` (cross attn. support)
|
|
/// - Supports GQA when `qhead` is a multiple of `kv_head`
|
|
/// - Otherwise:
|
|
/// - Use an alternate kernel
|
|
/// - Requires `seq` == `kv_seq`
|
|
/// - GQA is not supported (requires `qhead` == `kv_head`)
|
|
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
|
|
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
|
|
}
|