mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Contiguous variant of the rope kernel. (#1929)
* Contiguous variant of the rope kernel. * Add the cuda kernel. * Metal kernel.
This commit is contained in:
@ -86,7 +86,7 @@ fn softmax_numerical_stability() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rope(device: &Device) -> Result<()> {
|
||||
fn ropei(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
@ -107,12 +107,40 @@ fn rope(device: &Device) -> Result<()> {
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else if device.is_cuda() {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rope(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
|
Reference in New Issue
Block a user