mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
add where_cond f32 for metal (#2236)
This commit is contained in:
@ -718,6 +718,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U32, DType::F32) => "where_u32_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
|
@ -1023,6 +1023,27 @@ fn where_cond() {
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
#[test]
|
||||
fn where_cond_u32_f32() {
|
||||
let shape = vec![6];
|
||||
let cond = vec![0u32, 1, 0, 0, 1, 1];
|
||||
let cond_l = (vec![1], 0);
|
||||
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let left_l = (vec![1], 0);
|
||||
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
||||
let right_l = (vec![1], 0);
|
||||
let results = run_where_cond(
|
||||
&shape,
|
||||
&cond,
|
||||
cond_l,
|
||||
&left_true,
|
||||
left_l,
|
||||
&right_false,
|
||||
right_l,
|
||||
"where_u32_f32",
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
fn run_gemm<T: Clone>(
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
|
@ -5,7 +5,7 @@ use criterion::{black_box, criterion_group, Criterion};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) {
|
||||
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input);
|
||||
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(input);
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
|
Reference in New Issue
Block a user