feat: enhance linear algebra operations (#2972)

- Add `dot()` for vector/matrix products
- Implement the `Frobenius` norm
- Add `mv()` for matrix-vector multiply
This commit is contained in:
飘尘
2025-05-29 15:41:01 +08:00
committed by GitHub
parent 1a183c988a
commit 5aed817f1b
3 changed files with 105 additions and 0 deletions

View File

@ -1235,6 +1235,83 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// Computes the dot product of two 1D tensors.
///
/// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
/// - Panics if shapes are not compatible
/// - Not supported for integer dtypes
///
/// # Example (vectors)
/// ```rust
/// use candle_core::{Tensor, Device};
/// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
/// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
/// let res = t1.dot(&t2)?;
/// assert_eq!(res.to_scalar::<f64>()?, 32.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn dot(&self, rhs: &Self) -> Result<Self> {
if self.dims().len() != 1 || rhs.dims().len() != 1 {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op: "dot",
});
}
(self * rhs).and_then(|ret| ret.sum_all())
}
/// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
/// - Output is `sqrt(sum(x^2))`.
/// - Always returns a scalar (`[]` shape).
///
/// # Example
/// ```rust
/// use candle_core::{Tensor, Device};
/// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
/// let norm = t.norm()?;
/// assert_eq!(norm.to_scalar::<f64>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn norm(&self) -> Result<Self> {
if self.dtype().is_int() {
bail!("norm not supported for integer dtypes");
}
self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
}
/// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
///
/// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
/// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
///
/// # Example
/// ```rust
/// use candle_core::{Tensor, Device};
/// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
/// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
/// let res = mat.mv(&vec)?;
/// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn mv(&self, rhs: &Self) -> Result<Self> {
// Strict shape checks
let lhs_dims = self.dims();
let rhs_dims = rhs.dims();
if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op: "mv",
});
}
// Direct matmul after ensuring rhs is column vector
self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
}
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments

View File

@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn tensor_dot() -> Result<()> {
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
let expected = Tensor::new(32., &Device::Cpu)?;
let dot_ret = lhs.dot(&rhs)?;
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
Ok(())
}
#[test]
fn tensor_mv() -> Result<()> {
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
let mv_ret = mat.mv(&vec)?;
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;

View File

@ -1880,3 +1880,11 @@ fn tensor_new() -> Result<()> {
);
Ok(())
}
#[test]
fn tensor_norm() -> Result<()> {
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
let norm = t.norm()?;
assert_eq!(norm.to_scalar::<f64>()?, 5.);
Ok(())
}