#[cfg(feature = "metal")] mod metal_sdpa_tests { use candle::{DType, Device, Result, Shape, Tensor}; use rand::SeedableRng; use rand_distr::Distribution; use std::ops::{Div, Mul}; fn randn>( rng: &mut rand::rngs::StdRng, shape: S, dev: &Device, ) -> Result { let shape = shape.into(); let elem_count = shape.elem_count(); let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); let vs: Vec = (0..elem_count).map(|_| normal.sample(rng)).collect(); Tensor::from_vec(vs, &shape, dev) } #[test] fn sdpa_full() -> Result<()> { // Force seqlen = 100 const BS: usize = 4; const R: usize = 4; const L: usize = 4; const DK: usize = 64; const H: usize = 3; let scale: f64 = f64::from(DK as u32).sqrt().recip(); let device = Device::new_metal(0)?; let mut rng = rand::rngs::StdRng::seed_from_u64(42); let q = randn(&mut rng, (BS, H, R, DK), &device)?; let k = randn(&mut rng, (BS, H, L, DK), &device)?; let v = randn(&mut rng, (BS, H, L, DK), &device)?; let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; assert!(error <= 0.0004, "{}", error); Ok(()) } #[test] fn sdpa_vector() -> Result<()> { // Allow vectorized, seqlen = 1 const BS: usize = 4; const R: usize = 1; const L: usize = 1; const DK: usize = 64; const H: usize = 3; let scale: f64 = f64::from(DK as u32).sqrt().recip(); let device = Device::new_metal(0)?; let mut rng = rand::rngs::StdRng::seed_from_u64(4242); let q = randn(&mut rng, (BS, H, R, DK), &device)?; let k = randn(&mut rng, (BS, H, L, DK), &device)?; let v = randn(&mut rng, (BS, H, L, DK), &device)?; let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; assert!(error <= 0.000, "{}", error); Ok(()) } #[test] fn sdpa_full_softcapping() -> Result<()> { // Allow vectorized, seqlen = 1 const BS: usize = 4; const R: usize = 4; const L: usize = 4; const DK: usize = 64; const H: usize = 3; const SOFTCAP: f64 = 50.; let scale: f64 = f64::from(DK as u32).sqrt().recip(); let device = Device::new_metal(0)?; let mut rng = rand::rngs::StdRng::seed_from_u64(424242); let q = randn(&mut rng, (BS, H, R, DK), &device)?; let k = randn(&mut rng, (BS, H, L, DK), &device)?; let v = randn(&mut rng, (BS, H, L, DK), &device)?; let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim( &att.to_dtype(DType::F32)? .div(SOFTCAP)? .tanh()? .mul(SOFTCAP)?, )? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; assert!(error <= 0.0005, "{}", error); Ok(()) } #[test] fn sdpa_vector_softcapping() -> Result<()> { // Allow vectorized, seqlen = 1 const BS: usize = 4; const R: usize = 1; const L: usize = 1; const DK: usize = 64; const H: usize = 3; const SOFTCAP: f64 = 50.; let scale: f64 = f64::from(DK as u32).sqrt().recip(); let device = Device::new_metal(0)?; let mut rng = rand::rngs::StdRng::seed_from_u64(42424242); let q = randn(&mut rng, (BS, H, R, DK), &device)?; let k = randn(&mut rng, (BS, H, L, DK), &device)?; let v = randn(&mut rng, (BS, H, L, DK), &device)?; let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim( &att.to_dtype(DType::F32)? .div(SOFTCAP)? .tanh()? .mul(SOFTCAP)?, )? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; assert!(error <= 0.0001, "{}", error); Ok(()) } #[test] fn sdpa_vector_cross() -> Result<()> { // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1 const BS: usize = 4; const R: usize = 1; const L: usize = 24; const DK: usize = 64; const H: usize = 3; let scale: f64 = f64::from(DK as u32).sqrt().recip(); let device = Device::new_metal(0)?; let mut rng = rand::rngs::StdRng::seed_from_u64(4242424242); let q = randn(&mut rng, (BS, H, R, DK), &device)?; let k = randn(&mut rng, (BS, H, L, DK), &device)?; let v = randn(&mut rng, (BS, H, L, DK), &device)?; let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; assert!(error <= 0.0013, "{}", error); Ok(()) } }