mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add dtype size to benchmark throughput calculation
This commit is contained in:
@ -17,16 +17,17 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) {
|
|||||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
let flops = b * m * n * k;
|
||||||
|
let bytes = flops * dtype.size_in_bytes();
|
||||||
|
|
||||||
let name = match dtype {
|
let name = match dtype {
|
||||||
DType::F32 => "matmul_f32",
|
DType::F32 => "matmul_f32",
|
||||||
|
DType::U8 => "matmul_fp8",
|
||||||
DType::F16 => "matmul_f16",
|
DType::F16 => "matmul_f16",
|
||||||
DType::BF16 => "matmul_bf16",
|
DType::BF16 => "matmul_bf16",
|
||||||
DType::U8 => "matmul_fp8",
|
|
||||||
_ => unimplemented!("{dtype:?} matmul bench not implemented"),
|
_ => unimplemented!("{dtype:?} matmul bench not implemented"),
|
||||||
};
|
};
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
group.throughput(Throughput::Bytes(bytes as u64));
|
||||||
group.bench_function("iter", move |b| {
|
group.bench_function("iter", move |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@ -42,7 +43,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) {
|
|||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
let dtypes = vec![DType::F32, DType::F16, DType::BF16, DType::U8];
|
let dtypes = vec![DType::F32, DType::U8, DType::F16, DType::BF16];
|
||||||
for device in handler.devices {
|
for device in handler.devices {
|
||||||
for dtype in dtypes.clone() {
|
for dtype in dtypes.clone() {
|
||||||
run_bench(c, &device, dtype);
|
run_bench(c, &device, dtype);
|
||||||
|
Reference in New Issue
Block a user