Support for "unbatched" rope. (#2926)

* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
This commit is contained in:
Laurent Mazare
2025-04-27 15:12:02 +02:00
committed by GitHub
parent 6e0646c208
commit e3db30021f
6 changed files with 217 additions and 33 deletions

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
};
let rope2 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
};
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
};
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}