mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the tril/triu/eye ops. (#1333)
* Add tril/triu/eye. * Revert the metal crate tweak.
This commit is contained in:
@ -2450,6 +2450,30 @@ impl Tensor {
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a lower triangular matrix of ones of size n by n.
|
||||
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.le(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns an upper triangular matrix of ones of size n by n.
|
||||
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.ge(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.eq(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -1134,3 +1134,38 @@ fn i64_abs() -> Result<()> {
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tril_triu_eye() -> Result<()> {
|
||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0]
|
||||
],
|
||||
);
|
||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user