From 916a8c54646fab67f3d886717a48abfe55d89e39 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:15:21 +0100 Subject: [PATCH] Revert candle-transformers. --- candle-transformers/src/models/mixformer.rs | 42 ++------------------- 1 file changed, 4 insertions(+), 38 deletions(-) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e4e4f619..e822ca14 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -142,9 +142,10 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let sin = freqs.sin()?; - let cos = freqs.cos()?; - Ok(Self { sin, cos }) + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) } fn apply_rotary_emb_qkv( @@ -407,38 +408,3 @@ impl MixFormerSequentialForCausalLM { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } } - -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_rotary() { - let dev = Device::new_metal(0).unwrap(); - for i in 0..10000 { - let dim = 8; - let max_seq_len = 12; - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap(); - let t = Tensor::arange(0u32, max_seq_len as u32, &dev) - .unwrap() - .to_dtype(DType::F32) - .unwrap() - .reshape((max_seq_len, 1)) - .unwrap(); - let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 1.0); - let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 0.1); - let freqs = t.matmul(&inv_freq).unwrap(); - let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 0.1); - let sin = freqs.sin().unwrap().contiguous().unwrap(); - let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap(); - assert_eq!(x, 0.099833414); - } - } -}