From 347e31c9ff08f52574c0158e2a48cab52e224b4d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 15 Nov 2023 20:34:37 +0000 Subject: [PATCH] Add the tril/triu/eye ops. (#1333) * Add tril/triu/eye. * Revert the metal crate tweak. --- candle-core/src/tensor.rs | 24 +++++++++++++++++++++ candle-core/tests/tensor_tests.rs | 35 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d51a3db7..f6b1698c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2450,6 +2450,30 @@ impl Tensor { Ok(naxis as usize) } } + + /// Returns a lower triangular matrix of ones of size n by n. + pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.le(&t2)?.to_dtype(dtype) + } + + /// Returns an upper triangular matrix of ones of size n by n. + pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.ge(&t2)?.to_dtype(dtype) + } + + /// Returns a matrix with a diagonal of ones of size n by n. + pub fn eye(n: usize, dtype: DType, device: &Device) -> Result { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.eq(&t2)?.to_dtype(dtype) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cc44ce94..c8b255dd 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1134,3 +1134,38 @@ fn i64_abs() -> Result<()> { assert_eq!(t.to_vec1::()?, [42, 1337]); Ok(()) } + +#[test] +fn tril_triu_eye() -> Result<()> { + let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::()?, + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ], + ); + let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::()?, + [ + [1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0] + ] + ); + let t = Tensor::eye(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::()?, + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ] + ); + Ok(()) +}