mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the tril/triu/eye ops. (#1333)
* Add tril/triu/eye. * Revert the metal crate tweak.
This commit is contained in:
@ -2450,6 +2450,30 @@ impl Tensor {
|
|||||||
Ok(naxis as usize)
|
Ok(naxis as usize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a lower triangular matrix of ones of size n by n.
|
||||||
|
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.le(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an upper triangular matrix of ones of size n by n.
|
||||||
|
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.ge(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||||
|
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.eq(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -1134,3 +1134,38 @@ fn i64_abs() -> Result<()> {
|
|||||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tril_triu_eye() -> Result<()> {
|
||||||
|
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0]
|
||||||
|
],
|
||||||
|
);
|
||||||
|
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user