mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope. * Use 3d rope in the rope/ropei/rope_thd functions. * Get the CPU versions to work. * Fix the cuda version. * Adapt the metal side. * Fix the metal tests.
This commit is contained in:
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
|
||||
};
|
||||
let rope2 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
|
||||
};
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
|
||||
};
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user