mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Improve benchmarks layout
This commit is contained in:
55
candle-core/benches/benchmarks/mod.rs
Normal file
55
candle-core/benches/benchmarks/mod.rs
Normal file
@ -0,0 +1,55 @@
|
||||
pub(crate) mod matmul;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
fn sync(&self) -> Result<()> {
|
||||
match self {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
#[cfg(feature = "metal")]
|
||||
return Ok(device.wait_until_completed()?);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn device() -> Result<Device> {
|
||||
if cfg!(feature = "metal") {
|
||||
Device::new_metal(0)
|
||||
} else if cfg!(feature = "cuda") {
|
||||
Device::new_cuda(0)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
|
||||
format!("{}_{}", device_variant(), name.into())
|
||||
}
|
||||
|
||||
const fn device_variant() -> &'static str {
|
||||
if cfg!(feature = "metal") {
|
||||
"metal"
|
||||
} else if cfg!(feature = "cuda") {
|
||||
"cuda"
|
||||
} else if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user