mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add a matmul benchmark. (#429)
This commit is contained in:
@ -55,6 +55,23 @@ impl Benchmark for Conv2d {
|
||||
const ITERS: usize = 1;
|
||||
}
|
||||
|
||||
struct Matmul;
|
||||
impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.matmul(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
use std::hint::black_box;
|
||||
|
||||
@ -72,6 +89,7 @@ fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
enum Task {
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
Matmul,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -90,6 +108,7 @@ fn main() -> Result<()> {
|
||||
match args.task {
|
||||
Task::Conv1d => run::<Conv1d>(args.iters)?,
|
||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user