mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
feat: enhance linear algebra operations (#2972)
- Add `dot()` for vector/matrix products - Implement the `Frobenius` norm - Add `mv()` for matrix-vector multiply
This commit is contained in:
@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_dot() -> Result<()> {
|
||||
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(32., &Device::Cpu)?;
|
||||
let dot_ret = lhs.dot(&rhs)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_mv() -> Result<()> {
|
||||
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
|
||||
let mv_ret = mat.mv(&vec)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
|
@ -1880,3 +1880,11 @@ fn tensor_new() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_norm() -> Result<()> {
|
||||
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
let norm = t.norm()?;
|
||||
assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user