mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Efficient implementation of Tensor::ones()
for metal
(#2512)
* WIP: hopefully better const impl * with GPU * More tests on * Reverting primitive for * Incorporating review changes - added check elem count check in kerner, using for call strategy * rustfmt ran
This commit is contained in:

committed by
GitHub

parent
def4c6cdee
commit
a2bcc227df
@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||
[
|
||||
[
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0)
|
||||
],
|
||||
[
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0)
|
||||
]
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
|
||||
[
|
||||
[
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0)
|
||||
],
|
||||
[
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0)
|
||||
]
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user