diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5cebe498..952374c2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1235,6 +1235,83 @@ impl Tensor { Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } + /// Computes the dot product of two 1D tensors. + /// + /// - If inputs are 1D vectors (`[n]`), returns their scalar dot product. + /// - Panics if shapes are not compatible + /// - Not supported for integer dtypes + /// + /// # Example (vectors) + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?; + /// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?; + /// let res = t1.dot(&t2)?; + /// assert_eq!(res.to_scalar::()?, 32.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn dot(&self, rhs: &Self) -> Result { + if self.dims().len() != 1 || rhs.dims().len() != 1 { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "dot", + }); + } + + (self * rhs).and_then(|ret| ret.sum_all()) + } + + /// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor. + /// - Output is `sqrt(sum(x^2))`. + /// - Always returns a scalar (`[]` shape). + /// + /// # Example + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?; + /// let norm = t.norm()?; + /// assert_eq!(norm.to_scalar::()?, 5.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn norm(&self) -> Result { + if self.dtype().is_int() { + bail!("norm not supported for integer dtypes"); + } + + self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt()) + } + + /// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`). + /// + /// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`). + /// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size. + /// + /// # Example + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + /// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?; + /// let res = mat.mv(&vec)?; + /// assert_eq!(res.to_vec1::()?, [6., 15.]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn mv(&self, rhs: &Self) -> Result { + // Strict shape checks + let lhs_dims = self.dims(); + let rhs_dims = rhs.dims(); + if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "mv", + }); + } + + // Direct matmul after ensuring rhs is column vector + self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index c1c16401..aa2189f7 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> { Ok(()) } +#[test] +fn tensor_dot() -> Result<()> { + let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?; + let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?; + let expected = Tensor::new(32., &Device::Cpu)?; + let dot_ret = lhs.dot(&rhs)?; + candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?; + Ok(()) +} + +#[test] +fn tensor_mv() -> Result<()> { + let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?; + let expected = Tensor::new(&[6., 15.], &Device::Cpu)?; + let mv_ret = mat.mv(&vec)?; + candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?; + Ok(()) +} + // https://github.com/huggingface/candle/issues/1948 fn squeeze_mm(device: &Device) -> Result<()> { let seq_len = 8_usize; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c443ad2a..85c524f0 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1880,3 +1880,11 @@ fn tensor_new() -> Result<()> { ); Ok(()) } + +#[test] +fn tensor_norm() -> Result<()> { + let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?; + let norm = t.norm()?; + assert_eq!(norm.to_scalar::()?, 5.); + Ok(()) +}