Merge pull request #15 from LaurentMazare/num-cpus

Use num-cpus to enable parallelism in matmul's cpu version.
This commit is contained in:
Laurent Mazare
2023-06-27 14:45:08 +01:00
committed by GitHub
4 changed files with 17 additions and 4 deletions

View File

@ -20,6 +20,7 @@ zip = { version = "0.6.6", default-features=false }
byteorder = "1.4.3" byteorder = "1.4.3"
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
num-traits = "0.2.15" num-traits = "0.2.15"
num_cpus = "1.15.0"
[dev-dependencies] [dev-dependencies]
anyhow = "1" anyhow = "1"

View File

@ -669,7 +669,7 @@ impl CpuStorage {
// conj_rhs: bool, // conj_rhs: bool,
true, true,
// parallelism: Parallelism // parallelism: Parallelism
Parallelism::None, Parallelism::Rayon(crate::utils::get_num_threads()),
) )
} }
} }
@ -721,11 +721,10 @@ impl CpuStorage {
// conj_rhs: bool, // conj_rhs: bool,
true, true,
// parallelism: Parallelism // parallelism: Parallelism
Parallelism::None, Parallelism::Rayon(crate::utils::get_num_threads()),
) )
} }
} }
Ok(Self::F32(dst)) Ok(Self::F32(dst))
} }
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
@ -773,7 +772,7 @@ impl CpuStorage {
// conj_rhs: bool, // conj_rhs: bool,
true, true,
// parallelism: Parallelism // parallelism: Parallelism
Parallelism::None, Parallelism::Rayon(crate::utils::get_num_threads()),
) )
} }
} }

View File

@ -12,6 +12,7 @@ mod shape;
mod storage; mod storage;
mod strided_index; mod strided_index;
mod tensor; mod tensor;
mod utils;
pub use cpu_backend::CpuStorage; pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation}; pub use device::{Device, DeviceLocation};

12
candle-core/src/utils.rs Normal file
View File

@ -0,0 +1,12 @@
use std::str::FromStr;
pub(crate) fn get_num_threads() -> usize {
// Respond to the same environment variable as rayon.
match std::env::var("RAYON_NUM_THREADS")
.ok()
.and_then(|s| usize::from_str(&s).ok())
{
Some(x) if x > 0 => x,
Some(_) | None => num_cpus::get(),
}
}