Fix detail in new RoPE implementation (#1935)

This commit is contained in:
Hugo Abonizio
2024-03-25 14:20:09 -03:00
committed by GitHub
parent d3a8d291d5
commit 60676780a9

View File

@ -455,7 +455,7 @@ impl candle::CustomOp3 for RotaryEmb {
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> { pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?; let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?; let (sin_seq_len, sin_n_embd) = sin.dims2()?;
if cos_n_embd * 2 != n_embd if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd || sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len || seq_len > cos_seq_len