From ca6aa8ff12d4f263ea7869f8d1da6e6b32fa4e5a Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 14:42:26 +0100 Subject: [PATCH] Use num-cpus to enable parallelism. --- candle-core/Cargo.toml | 1 + candle-core/src/cpu_backend.rs | 7 +++---- candle-core/src/lib.rs | 1 + candle-core/src/utils.rs | 12 ++++++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) create mode 100644 candle-core/src/utils.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 97215953..75b48df8 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -20,6 +20,7 @@ zip = { version = "0.6.6", default-features=false } byteorder = "1.4.3" half = { version = "2.3.1", features = ["num-traits"] } num-traits = "0.2.15" +num_cpus = "1.15.0" [dev-dependencies] anyhow = "1" diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 72599afc..56ff08ae 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -669,7 +669,7 @@ impl CpuStorage { // conj_rhs: bool, true, // parallelism: Parallelism - Parallelism::None, + Parallelism::Rayon(crate::utils::get_num_threads()), ) } } @@ -721,11 +721,10 @@ impl CpuStorage { // conj_rhs: bool, true, // parallelism: Parallelism - Parallelism::None, + Parallelism::Rayon(crate::utils::get_num_threads()), ) } } - Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { @@ -773,7 +772,7 @@ impl CpuStorage { // conj_rhs: bool, true, // parallelism: Parallelism - Parallelism::None, + Parallelism::Rayon(crate::utils::get_num_threads()), ) } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index d34c5983..b220dfb9 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -12,6 +12,7 @@ mod shape; mod storage; mod strided_index; mod tensor; +mod utils; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs new file mode 100644 index 00000000..0be63c66 --- /dev/null +++ b/candle-core/src/utils.rs @@ -0,0 +1,12 @@ +use std::str::FromStr; + +pub(crate) fn get_num_threads() -> usize { + // Respond to the same environment variable as rayon. + match std::env::var("RAYON_NUM_THREADS") + .ok() + .and_then(|s| usize::from_str(&s).ok()) + { + Some(x) if x > 0 => x, + Some(_) | None => num_cpus::get(), + } +}