Merge pull request #15 from LaurentMazare/num-cpus

Use num-cpus to enable parallelism in matmul's cpu version.
This commit is contained in:
Laurent Mazare
2023-06-27 14:45:08 +01:00
committed by GitHub
4 changed files with 17 additions and 4 deletions

View File

@ -20,6 +20,7 @@ zip = { version = "0.6.6", default-features=false }
byteorder = "1.4.3"
half = { version = "2.3.1", features = ["num-traits"] }
num-traits = "0.2.15"
num_cpus = "1.15.0"
[dev-dependencies]
anyhow = "1"

View File

@ -669,7 +669,7 @@ impl CpuStorage {
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
Parallelism::Rayon(crate::utils::get_num_threads()),
)
}
}
@ -721,11 +721,10 @@ impl CpuStorage {
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
Parallelism::Rayon(crate::utils::get_num_threads()),
)
}
}
Ok(Self::F32(dst))
}
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
@ -773,7 +772,7 @@ impl CpuStorage {
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
Parallelism::Rayon(crate::utils::get_num_threads()),
)
}
}

View File

@ -12,6 +12,7 @@ mod shape;
mod storage;
mod strided_index;
mod tensor;
mod utils;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};

12
candle-core/src/utils.rs Normal file
View File

@ -0,0 +1,12 @@
use std::str::FromStr;
pub(crate) fn get_num_threads() -> usize {
// Respond to the same environment variable as rayon.
match std::env::var("RAYON_NUM_THREADS")
.ok()
.and_then(|s| usize::from_str(&s).ok())
{
Some(x) if x > 0 => x,
Some(_) | None => num_cpus::get(),
}
}