mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the tril/triu/eye ops. (#1333)
* Add tril/triu/eye. * Revert the metal crate tweak.
This commit is contained in:
@ -1134,3 +1134,38 @@ fn i64_abs() -> Result<()> {
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tril_triu_eye() -> Result<()> {
|
||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0]
|
||||
],
|
||||
);
|
||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user