mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 01:48:08 +00:00
feat: enhance linear algebra operations (#2972)
- Add `dot()` for vector/matrix products - Implement the `Frobenius` norm - Add `mv()` for matrix-vector multiply
This commit is contained in:
@ -1235,6 +1235,83 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Computes the dot product of two 1D tensors.
|
||||
///
|
||||
/// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
|
||||
/// - Panics if shapes are not compatible
|
||||
/// - Not supported for integer dtypes
|
||||
///
|
||||
/// # Example (vectors)
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
|
||||
/// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
|
||||
/// let res = t1.dot(&t2)?;
|
||||
/// assert_eq!(res.to_scalar::<f64>()?, 32.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn dot(&self, rhs: &Self) -> Result<Self> {
|
||||
if self.dims().len() != 1 || rhs.dims().len() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "dot",
|
||||
});
|
||||
}
|
||||
|
||||
(self * rhs).and_then(|ret| ret.sum_all())
|
||||
}
|
||||
|
||||
/// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
|
||||
/// - Output is `sqrt(sum(x^2))`.
|
||||
/// - Always returns a scalar (`[]` shape).
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
/// let norm = t.norm()?;
|
||||
/// assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn norm(&self) -> Result<Self> {
|
||||
if self.dtype().is_int() {
|
||||
bail!("norm not supported for integer dtypes");
|
||||
}
|
||||
|
||||
self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
|
||||
}
|
||||
|
||||
/// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
|
||||
///
|
||||
/// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
|
||||
/// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
/// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
/// let res = mat.mv(&vec)?;
|
||||
/// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn mv(&self, rhs: &Self) -> Result<Self> {
|
||||
// Strict shape checks
|
||||
let lhs_dims = self.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "mv",
|
||||
});
|
||||
}
|
||||
|
||||
// Direct matmul after ensuring rhs is column vector
|
||||
self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_dot() -> Result<()> {
|
||||
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(32., &Device::Cpu)?;
|
||||
let dot_ret = lhs.dot(&rhs)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_mv() -> Result<()> {
|
||||
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
|
||||
let mv_ret = mat.mv(&vec)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
|
@ -1880,3 +1880,11 @@ fn tensor_new() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_norm() -> Result<()> {
|
||||
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
let norm = t.norm()?;
|
||||
assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user