mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add the rope THD kernel. (#2014)
* Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
This commit is contained in:
@ -497,3 +497,234 @@ pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
|
||||
x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
|
||||
}
|
||||
|
||||
/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbThd;
|
||||
|
||||
impl candle::CustomOp3 for RotaryEmbThd {
|
||||
fn name(&self) -> &'static str {
|
||||
"rotary-emb"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
fn inner<T: candle::WithDType + num_traits::Float>(
|
||||
src: &[T],
|
||||
l_src: &Layout,
|
||||
cos: &[T],
|
||||
l_cos: &Layout,
|
||||
sin: &[T],
|
||||
l_sin: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("input src has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("input cos has to be contiguous"),
|
||||
Some((o1, o2)) => &cos[o1..o2],
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("input sin has to be contiguous"),
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * h * d)
|
||||
.zip(dst.par_chunks_mut(t * h * d))
|
||||
.for_each(|(src, dst)| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
for i_h in 0..h {
|
||||
let i1 = i_t * h * d + i_h * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
||||
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, t, h, d).into()))
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use CpuStorage::{BF16, F16, F32, F64};
|
||||
match (s1, s2, s3) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
fn inner<T: DeviceRepr + WithDType>(
|
||||
src: &CudaSlice<T>,
|
||||
l_src: &Layout,
|
||||
cos: &CudaSlice<T>,
|
||||
l_cos: &Layout,
|
||||
sin: &CudaSlice<T>,
|
||||
l_sin: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("src input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("cos input has to be contiguous"),
|
||||
Some((o1, o2)) => cos.slice(o1..o2),
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("sin input has to be contiguous"),
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
|
||||
let dev = s1.device();
|
||||
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
};
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
src: &candle::MetalStorage,
|
||||
l_src: &Layout,
|
||||
cos: &candle::MetalStorage,
|
||||
l_cos: &Layout,
|
||||
sin: &candle::MetalStorage,
|
||||
l_sin: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = src.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
|
||||
candle::bail!(
|
||||
"dtype mismatch in rope {:?} {:?} {:?}",
|
||||
src.dtype(),
|
||||
cos.dtype(),
|
||||
sin.dtype()
|
||||
)
|
||||
}
|
||||
let name = match src.dtype() {
|
||||
candle::DType::F32 => "rope_thd_f32",
|
||||
candle::DType::F16 => "rope_thd_f16",
|
||||
candle::DType::BF16 => "rope_thd_bf16",
|
||||
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
|
||||
candle_metal_kernels::call_rope_thd(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
b,
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
l_cos.start_offset() * cos.dtype().size_in_bytes(),
|
||||
sin.buffer(),
|
||||
l_sin.start_offset() * sin.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
|
||||
Ok((out, l_src.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|| seq_len > sin_seq_len
|
||||
{
|
||||
candle::bail!(
|
||||
"inconsistent last dim size in rope {:?} {:?} {:?}",
|
||||
xs.shape(),
|
||||
cos.shape(),
|
||||
sin.shape()
|
||||
)
|
||||
}
|
||||
if !xs.is_contiguous() {
|
||||
candle::bail!("xs has to be contiguous in rope")
|
||||
}
|
||||
if !cos.is_contiguous() {
|
||||
candle::bail!("cos has to be contiguous in rope")
|
||||
}
|
||||
if !sin.is_contiguous() {
|
||||
candle::bail!("sin has to be contiguous in rope")
|
||||
}
|
||||
xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
|
||||
}
|
||||
|
@ -140,7 +140,38 @@ fn rope(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rope_thd(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)?
|
||||
};
|
||||
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
|
Reference in New Issue
Block a user