diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 079c3708..5627c0c1 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) { } template -__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) { +__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (2 * idx >= bh * td) return; uint32_t rope_idx = idx % (td / 2); + if (stride_b > 0) { + uint32_t b_idx = (2 * idx) / stride_b; + rope_idx += b_idx * (td / 2); + } T c = cos[rope_idx]; T s = sin[rope_idx]; @@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons } template -__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) { +__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (2 * idx >= bh * td) return; @@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t i1 = i_bh * td + i_t * d + i_d; uint32_t i2 = i1 + d / 2; uint32_t i_cs = i_t * (d / 2) + i_d; + if (stride_b > 0) { + uint32_t b_idx = (2 * idx) / stride_b; + i_cs += b_idx * (td / 2); + } T c = cos[i_cs]; T s = sin[i_cs]; @@ -259,7 +267,8 @@ __device__ void rope_thd( const uint32_t b, const uint32_t t, const uint32_t h, - const uint32_t d + const uint32_t d, + const uint32_t stride_b ) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (2 * idx >= b * t * h * d) return; @@ -270,6 +279,10 @@ __device__ void rope_thd( uint32_t i1 = i_bth * d + i_d; uint32_t i2 = i1 + d / 2; uint32_t i_cs = i_t * (d / 2) + i_d; + if (stride_b > 0) { + uint32_t b_idx = (2 * idx) / stride_b; + i_cs += b_idx * ((t * d) / 2); + } T c = cos[i_cs]; T s = sin[i_cs]; @@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, const TYPENAME *sin, \ TYPENAME *dst, \ const uint32_t bh, \ - const uint32_t td) { \ - ropei(src, cos, sin, dst, bh, td); \ + const uint32_t td, \ + const uint32_t stride_b) { \ + ropei(src, cos, sin, dst, bh, td, stride_b); \ } \ extern "C" __global__ void FN_NAME( \ const TYPENAME *src, \ @@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, TYPENAME *dst, \ const uint32_t bh, \ const uint32_t td, \ - const uint32_t d) { \ - rope(src, cos, sin, dst, bh, td, d); \ + const uint32_t d, \ + const uint32_t stride_b) { \ + rope(src, cos, sin, dst, bh, td, d, stride_b); \ } \ extern "C" __global__ void FN_NAME_THD( \ const TYPENAME *src, \ @@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, const uint32_t b, \ const uint32_t t, \ const uint32_t h, \ - const uint32_t d) { \ - rope_thd(src, cos, sin, dst, b, t, h, d); \ + const uint32_t d, \ + const uint32_t stride_b) { \ + rope_thd(src, cos, sin, dst, b, t, h, d, stride_b); \ } \ #if __CUDA_ARCH__ >= 800 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index de1b1053..939990da 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -991,6 +991,7 @@ pub fn call_rope_i( kernel_name: &'static str, bh: usize, td: usize, + stride_b: usize, src: &Buffer, src_offset: usize, cos: &Buffer, @@ -1009,6 +1010,7 @@ pub fn call_rope_i( ( bh, td, + stride_b, (src, src_offset), (cos, cos_offset), (sin, sin_offset), @@ -1034,6 +1036,7 @@ pub fn call_rope_thd( t: usize, h: usize, d: usize, + stride_b: usize, src: &Buffer, src_offset: usize, cos: &Buffer, @@ -1054,6 +1057,7 @@ pub fn call_rope_thd( t, h, d, + stride_b, (src, src_offset), (cos, cos_offset), (sin, sin_offset), @@ -1078,6 +1082,7 @@ pub fn call_rope( bh: usize, td: usize, d: usize, + stride_b: usize, src: &Buffer, src_offset: usize, cos: &Buffer, @@ -1097,6 +1102,7 @@ pub fn call_rope( bh, td, d, + stride_b, (src, src_offset), (cos, cos_offset), (sin, sin_offset), diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 291c81e6..c134218c 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1097,6 +1097,7 @@ template METAL_FUNC void ropei( constant size_t &bh, constant size_t &td, + constant size_t &stride_b, device const T *src, device const T *cos, device const T *sin, @@ -1107,6 +1108,10 @@ METAL_FUNC void ropei( return; } size_t rope_idx = tid % (td / 2); + if (stride_b > 0) { + size_t b_idx = (2 * tid) / stride_b; + rope_idx += b_idx * (td / 2); + } T c = cos[rope_idx]; T s = sin[rope_idx]; dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; @@ -1118,6 +1123,7 @@ METAL_FUNC void rope( constant size_t &bh, constant size_t &td, constant size_t &d, + constant size_t &stride_b, device const T *src, device const T *cos, device const T *sin, @@ -1134,6 +1140,10 @@ METAL_FUNC void rope( size_t i1 = i_bh * td + i_t * d + i_d; size_t i2 = i1 + d / 2; size_t i_cs = i_t * (d / 2) + i_d; + if (stride_b > 0) { + size_t b_idx = (2 * idx) / stride_b; + i_cs += b_idx * (td / 2); + } T c = cos[i_cs]; T s = sin[i_cs]; dst[i1] = src[i1] * c - src[i2] * s; @@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd( constant size_t &t, constant size_t &h, constant size_t &d, + constant size_t &stride_b, device const T *src, device const T *cos, device const T *sin, @@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd( const size_t i_t = (i_bth / h) % t; const size_t i1 = i_bth * d + i_d; const size_t i2 = i1 + d / 2; - const size_t i_cs = i_t * (d / 2) + i_d; - T c = cos[i_cs]; + size_t i_cs = i_t * (d / 2) + i_d; + if (stride_b > 0) { + const size_t b_idx = (2 * idx) / stride_b; + i_cs += b_idx * ((t * d) / 2); + } + T c = cos[i_cs]; T s = sin[i_cs]; dst[i1] = src[i1] * c - src[i2] * s; dst[i2] = src[i1] * s + src[i2] * c; @@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd( kernel void FN_NAME_I( \ constant size_t &bh, \ constant size_t &td, \ + constant size_t &stride_b, \ device const TYPENAME *src, \ device const TYPENAME *cos, \ device const TYPENAME *sin, \ device TYPENAME *dst, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - ropei(bh, td, src, cos, sin, dst, tid); \ + ropei(bh, td, stride_b, src, cos, sin, dst, tid); \ }\ kernel void FN_NAME( \ constant size_t &bh, \ constant size_t &td, \ constant size_t &d, \ + constant size_t &stride_b, \ device const TYPENAME *src, \ device const TYPENAME *cos, \ device const TYPENAME *sin, \ device TYPENAME *dst, \ uint idx [[ thread_position_in_grid ]] \ ) { \ - rope(bh, td, d, src, cos, sin, dst, idx); \ + rope(bh, td, d, stride_b, src, cos, sin, dst, idx); \ }\ kernel void FN_NAME_THD( \ constant size_t &b, \ constant size_t &t, \ constant size_t &h, \ constant size_t &d, \ + constant size_t &stride_b, \ device const TYPENAME *src, \ device const TYPENAME *cos, \ device const TYPENAME *sin, \ device TYPENAME *dst, \ uint idx [[ thread_position_in_grid ]] \ ) { \ - rope_thd(b, t, h, d, src, cos, sin, dst, idx); \ + rope_thd(b, t, h, d, stride_b, src, cos, sin, dst, idx); \ }\ RMSNORM(rmsnorm_f32, float) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index ee130d6b..5934cffb 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1584,7 +1584,7 @@ fn run_scatter_add( dim, BufferOffset::zero_offset(&input_buffer), BufferOffset::zero_offset(&ids_buffer), - &output, + BufferOffset::zero_offset(&output), ) .unwrap(); command_buffer.commit(); diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index e9fa24ce..bfb541f0 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI { Some((o1, o2)) => &sin[o1..o2], }; let (b, h, t, d) = l_src.shape().dims4()?; + let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3; let el_count = b * h * t * d; let mut dst = vec![T::zero(); el_count]; src.par_chunks(t * d) .zip(dst.par_chunks_mut(t * d)) - .for_each(|(src, dst)| { + .enumerate() + .for_each(|(bh_i, (src, dst))| { for i_over_2 in 0..t * d / 2 { let i = 2 * i_over_2; - dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2]; - dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2]; + let rope_i = if unbatched_rope { + let b_i = bh_i / h; + i_over_2 + b_i * t * d / 2 + } else { + i_over_2 + }; + dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i]; + dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i]; } }); let storage = candle::WithDType::to_cpu_storage_owned(dst); @@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI { Some((o1, o2)) => sin.slice(o1..o2), }; let (b, h, t, d) = l_src.shape().dims4()?; + let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 { + (h * t * d) as u32 + } else { + 0u32 + }; let el = b * h * t * d; let cfg = LaunchConfig::for_num_elems((el / 2) as u32); let func = dev.get_or_load_func(&kernel_name::("rope_i"), &kernels::REDUCE)?; @@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI { builder.arg(&cos); builder.arg(&sin); builder.arg(&dst); - candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32); + candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b); // SAFETY: ffi. unsafe { builder.launch(cfg) }.w()?; Ok(dst) @@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI { dtype => candle::bail!("rope-i is not implemented for {dtype:?}"), }; let (b, h, t, d) = l_src.shape().dims4()?; + let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 { + h * t * d + } else { + 0usize + }; let el = b * h * t * d; let output = device.new_buffer(el, src.dtype(), "rope-i")?; candle_metal_kernels::call_rope_i( @@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI { name, b * h, t * d, + stride_b, src.buffer(), l_src.start_offset() * src.dtype().size_in_bytes(), cos.buffer(), @@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI { } } +fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> { + match *cs.dims() { + [t, d] => Ok((t, d)), + [b, t, d] => { + if b != b_sz { + candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",) + } + Ok((t, d)) + } + _ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"), + } +} + pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; - let (cos_seq_len, cos_n_embd) = cos.dims2()?; - let (sin_seq_len, sin_n_embd) = cos.dims2()?; + let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?; + let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?; if cos_n_embd * 2 != n_embd || sin_n_embd * 2 != n_embd || seq_len > cos_seq_len @@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb { Some((o1, o2)) => &sin[o1..o2], }; let (b, h, t, d) = l_src.shape().dims4()?; + let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3; let el_count = b * h * t * d; let mut dst = vec![T::zero(); el_count]; src.par_chunks(t * d) .zip(dst.par_chunks_mut(t * d)) - .for_each(|(src, dst)| { + .enumerate() + .for_each(|(bh_i, (src, dst))| { for i_t in 0..t { for i_d in 0..d / 2 { let i1 = i_t * d + i_d; let i2 = i1 + d / 2; let i_cs = i_t * (d / 2) + i_d; + let i_cs = if unbatched_rope { + let b_i = bh_i / h; + i_cs + b_i * t * d / 2 + } else { + i_cs + }; dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; } @@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb { Some((o1, o2)) => sin.slice(o1..o2), }; let (b, h, t, d) = l_src.shape().dims4()?; + let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 { + (h * t * d) as u32 + } else { + 0u32 + }; let el = b * h * t * d; let cfg = LaunchConfig::for_num_elems((el / 2) as u32); let func = dev.get_or_load_func(&kernel_name::("rope"), &kernels::REDUCE)?; @@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb { builder.arg(&cos); builder.arg(&sin); builder.arg(&dst); - candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32); + candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b); // SAFETY: ffi. unsafe { builder.launch(cfg) }.w()?; Ok(dst) @@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb { dtype => candle::bail!("rope is not implemented for {dtype:?}"), }; let (b, h, t, d) = l_src.shape().dims4()?; + let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 { + h * t * d + } else { + 0usize + }; let el = b * h * t * d; let output = device.new_buffer(el, src.dtype(), "rope-i")?; candle_metal_kernels::call_rope( @@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb { b * h, t * d, d, + stride_b, src.buffer(), l_src.start_offset() * src.dtype().size_in_bytes(), cos.buffer(), @@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb { } pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; - let (cos_seq_len, cos_n_embd) = cos.dims2()?; - let (sin_seq_len, sin_n_embd) = sin.dims2()?; + let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?; + let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?; if cos_n_embd * 2 != n_embd || sin_n_embd * 2 != n_embd || seq_len > cos_seq_len @@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd { Some((o1, o2)) => &sin[o1..o2], }; let (b, t, h, d) = l_src.shape().dims4()?; + let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3; let el_count = b * h * t * d; let mut dst = vec![T::zero(); el_count]; src.par_chunks(t * h * d) .zip(dst.par_chunks_mut(t * h * d)) - .for_each(|(src, dst)| { + .enumerate() + .for_each(|(b_i, (src, dst))| { for i_t in 0..t { for i_d in 0..d / 2 { let i_cs = i_t * (d / 2) + i_d; + let i_cs = if unbatched_rope { + i_cs + b_i * t * d / 2 + } else { + i_cs + }; for i_h in 0..h { let i1 = i_t * h * d + i_h * d + i_d; let i2 = i1 + d / 2; @@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd { Some((o1, o2)) => sin.slice(o1..o2), }; let (b, t, h, d) = l_src.shape().dims4()?; + let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 { + (h * t * d) as u32 + } else { + 0u32 + }; let el = b * h * t * d; let cfg = LaunchConfig::for_num_elems((el / 2) as u32); let func = dev.get_or_load_func(&kernel_name::("rope_thd"), &kernels::REDUCE)?; @@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd { builder.arg(&cos); builder.arg(&sin); builder.arg(&dst); - candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32); + candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b); // SAFETY: ffi. unsafe { builder.launch(cfg) }.w()?; Ok(dst) @@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd { dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"), }; let (b, t, h, d) = l_src.shape().dims4()?; + let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 { + h * t * d + } else { + 0usize + }; let el = b * h * t * d; let output = device.new_buffer(el, src.dtype(), "rope-thd")?; candle_metal_kernels::call_rope_thd( @@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd { t, h, d, + stride_b, src.buffer(), l_src.start_offset() * src.dtype().size_in_bytes(), cos.buffer(), @@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd { } pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?; - let (cos_seq_len, cos_n_embd) = cos.dims2()?; - let (sin_seq_len, sin_n_embd) = sin.dims2()?; + let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?; + let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?; if cos_n_embd * 2 != n_embd || sin_n_embd * 2 != n_embd || seq_len > cos_seq_len diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 6c66f39f..6287aa24 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; +use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor}; fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; @@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> { } else { assert!(sum_diff < 1e-4); } + + // Test with a 3d cos/sin + let cos2: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.random::()) + .collect(); + let sin2: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.random::()) + .collect(); + let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?; + let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?; + let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?; + let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?; + + let both_cos = Tensor::stack(&[cos, cos2], 0)?; + let both_sin = Tensor::stack(&[sin, sin2], 0)?; + let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?; + let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?; + let sum_diff = (both_rope - both_rope2)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(sum_diff, 0.); Ok(()) } @@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> { } else { assert!(sum_diff < 1e-4); } + + // Test with a 3d cos/sin + let cos2: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.random::()) + .collect(); + let sin2: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.random::()) + .collect(); + let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?; + let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?; + let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?; + let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?; + + let both_cos = Tensor::stack(&[cos, cos2], 0)?; + let both_sin = Tensor::stack(&[sin, sin2], 0)?; + let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?; + let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?; + let sum_diff = (both_rope - both_rope2)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(sum_diff, 0.); Ok(()) } @@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> { } else { assert!(sum_diff < 1e-4); } + + // Test with a 3d cos/sin + let cos2: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.random::()) + .collect(); + let sin2: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.random::()) + .collect(); + let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?; + let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?; + let rope1 = { + let src = src.transpose(1, 2)?.contiguous()?; + candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)? + }; + let rope2 = { + let src = src.transpose(1, 2)?.contiguous()?; + candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)? + }; + + let both_cos = Tensor::stack(&[cos, cos2], 0)?; + let both_sin = Tensor::stack(&[sin, sin2], 0)?; + let both_rope = { + let src = src.transpose(1, 2)?.contiguous()?; + candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)? + }; + let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?; + let sum_diff = (both_rope - both_rope2)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(sum_diff, 0.); Ok(()) }