feat: enhance linear algebra operations (#2972)

- Add `dot()` for vector/matrix products
- Implement the `Frobenius` norm
- Add `mv()` for matrix-vector multiply
This commit is contained in:
飘尘
2025-05-29 15:41:01 +08:00
committed by GitHub
parent 1a183c988a
commit 5aed817f1b
3 changed files with 105 additions and 0 deletions

View File

@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn tensor_dot() -> Result<()> {
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
let expected = Tensor::new(32., &Device::Cpu)?;
let dot_ret = lhs.dot(&rhs)?;
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
Ok(())
}
#[test]
fn tensor_mv() -> Result<()> {
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
let mv_ret = mat.mv(&vec)?;
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;

View File

@ -1880,3 +1880,11 @@ fn tensor_new() -> Result<()> {
);
Ok(())
}
#[test]
fn tensor_norm() -> Result<()> {
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
let norm = t.norm()?;
assert_eq!(norm.to_scalar::<f64>()?, 5.);
Ok(())
}