Revert candle-transformers.

This commit is contained in:
Nicolas Patry
2023-12-15 11:15:21 +01:00
parent 243e83f2b9
commit 916a8c5464

View File

@ -142,9 +142,10 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?;
let cos = freqs.cos()?;
Ok(Self { sin, cos })
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
@ -407,38 +408,3 @@ impl MixFormerSequentialForCausalLM {
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rotary() {
let dev = Device::new_metal(0).unwrap();
for i in 0..10000 {
let dim = 8;
let max_seq_len = 12;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap();
let t = Tensor::arange(0u32, max_seq_len as u32, &dev)
.unwrap()
.to_dtype(DType::F32)
.unwrap()
.reshape((max_seq_len, 1))
.unwrap();
let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap();
assert_eq!(x, 1.0);
let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap();
assert_eq!(x, 0.1);
let freqs = t.matmul(&inv_freq).unwrap();
let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap();
assert_eq!(x, 0.1);
let sin = freqs.sin().unwrap().contiguous().unwrap();
let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap();
assert_eq!(x, 0.099833414);
}
}
}