mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Lots of updates including some stack of command buffers.
This commit is contained in:
@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
@ -142,10 +142,9 @@ impl RotaryEmbedding {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
@ -273,6 +272,10 @@ impl MHA {
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let view = xs.to_string();
|
||||
if view.contains("NaN") {
|
||||
panic!("NaN");
|
||||
}
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self
|
||||
@ -408,3 +411,38 @@ impl MixFormerSequentialForCausalLM {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_rotary() {
|
||||
let dev = Device::new_metal(0).unwrap();
|
||||
for i in 0..10000 {
|
||||
let dim = 8;
|
||||
let max_seq_len = 12;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap();
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, &dev)
|
||||
.unwrap()
|
||||
.to_dtype(DType::F32)
|
||||
.unwrap()
|
||||
.reshape((max_seq_len, 1))
|
||||
.unwrap();
|
||||
let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 1.0);
|
||||
let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.1);
|
||||
let freqs = t.matmul(&inv_freq).unwrap();
|
||||
let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.1);
|
||||
let sin = freqs.sin().unwrap().contiguous().unwrap();
|
||||
let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.099833414);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user