Support for "unbatched" rope. (#2926)

* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
This commit is contained in:
Laurent Mazare
2025-04-27 15:12:02 +02:00
committed by GitHub
parent 6e0646c208
commit e3db30021f
6 changed files with 217 additions and 33 deletions

View File

@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_over_2 in 0..t * d / 2 {
let i = 2 * i_over_2;
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
let rope_i = if unbatched_rope {
let b_i = bh_i / h;
i_over_2 + b_i * t * d / 2
} else {
i_over_2
};
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope_i(
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
name,
b * h,
t * d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
}
}
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
match *cs.dims() {
[t, d] => Ok((t, d)),
[b, t, d] => {
if b != b_sz {
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
}
Ok((t, d))
}
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
}
}
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i1 = i_t * d + i_d;
let i2 = i1 + d / 2;
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
let b_i = bh_i / h;
i_cs + b_i * t * d / 2
} else {
i_cs
};
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope(
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
b * h,
t * d,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
}
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, t, h, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * h * d)
.zip(dst.par_chunks_mut(t * h * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(b_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
i_cs + b_i * t * d / 2
} else {
i_cs
};
for i_h in 0..h {
let i1 = i_t * h * d + i_h * d + i_d;
let i2 = i1 + d / 2;
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
candle_metal_kernels::call_rope_thd(
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
t,
h,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
}
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
};
let rope2 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
};
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
};
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}