diff --git a/candle-core/benches/benchmarks/matmul.rs b/candle-core/benches/benchmarks/matmul.rs index b8073485..e967a8bd 100644 --- a/candle-core/benches/benchmarks/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -17,16 +17,17 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) { let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); let flops = b * m * n * k; + let bytes = flops * dtype.size_in_bytes(); let name = match dtype { DType::F32 => "matmul_f32", + DType::U8 => "matmul_fp8", DType::F16 => "matmul_f16", DType::BF16 => "matmul_bf16", - DType::U8 => "matmul_fp8", _ => unimplemented!("{dtype:?} matmul bench not implemented"), }; let mut group = c.benchmark_group(device.bench_name(name)); - group.throughput(Throughput::Bytes(flops as u64)); + group.throughput(Throughput::Bytes(bytes as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { let start = Instant::now(); @@ -42,7 +43,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: DType) { fn criterion_benchmark(c: &mut Criterion) { let handler = BenchDeviceHandler::new().unwrap(); - let dtypes = vec![DType::F32, DType::F16, DType::BF16, DType::U8]; + let dtypes = vec![DType::F32, DType::U8, DType::F16, DType::BF16]; for device in handler.devices { for dtype in dtypes.clone() { run_bench(c, &device, dtype);