mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Fixed all bugs. Improved code quality. Added tests.
This commit is contained in:
@ -61,13 +61,21 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up));
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up));
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,6 +97,7 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
||||
DType::BF16 => "softmax_bf16",
|
||||
_ => "softmax",
|
||||
};
|
||||
softmax(&a).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
@ -105,19 +114,49 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "reduce_f32",
|
||||
DType::F16 => "reduce_f16",
|
||||
DType::BF16 => "reduce_bf16",
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
@ -140,20 +179,46 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "arg_reduce_f32",
|
||||
DType::F16 => "arg_reduce_f16",
|
||||
DType::BF16 => "arg_reduce_bf16",
|
||||
_ => "reduce",
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
|
Reference in New Issue
Block a user