mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix detail in new RoPE implementation (#1935)
This commit is contained in:
@ -455,7 +455,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||||
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
|
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||||
if cos_n_embd * 2 != n_embd
|
if cos_n_embd * 2 != n_embd
|
||||||
|| sin_n_embd * 2 != n_embd
|
|| sin_n_embd * 2 != n_embd
|
||||||
|| seq_len > cos_seq_len
|
|| seq_len > cos_seq_len
|
||||||
|
Reference in New Issue
Block a user