mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Merge pull request #15 from LaurentMazare/num-cpus
Use num-cpus to enable parallelism in matmul's cpu version.
This commit is contained in:
@ -20,6 +20,7 @@ zip = { version = "0.6.6", default-features=false }
|
|||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
|
num_cpus = "1.15.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
|
@ -669,7 +669,7 @@ impl CpuStorage {
|
|||||||
// conj_rhs: bool,
|
// conj_rhs: bool,
|
||||||
true,
|
true,
|
||||||
// parallelism: Parallelism
|
// parallelism: Parallelism
|
||||||
Parallelism::None,
|
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -721,11 +721,10 @@ impl CpuStorage {
|
|||||||
// conj_rhs: bool,
|
// conj_rhs: bool,
|
||||||
true,
|
true,
|
||||||
// parallelism: Parallelism
|
// parallelism: Parallelism
|
||||||
Parallelism::None,
|
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self::F32(dst))
|
Ok(Self::F32(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||||
@ -773,7 +772,7 @@ impl CpuStorage {
|
|||||||
// conj_rhs: bool,
|
// conj_rhs: bool,
|
||||||
true,
|
true,
|
||||||
// parallelism: Parallelism
|
// parallelism: Parallelism
|
||||||
Parallelism::None,
|
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ mod shape;
|
|||||||
mod storage;
|
mod storage;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use device::{Device, DeviceLocation};
|
pub use device::{Device, DeviceLocation};
|
||||||
|
12
candle-core/src/utils.rs
Normal file
12
candle-core/src/utils.rs
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
pub(crate) fn get_num_threads() -> usize {
|
||||||
|
// Respond to the same environment variable as rayon.
|
||||||
|
match std::env::var("RAYON_NUM_THREADS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| usize::from_str(&s).ok())
|
||||||
|
{
|
||||||
|
Some(x) if x > 0 => x,
|
||||||
|
Some(_) | None => num_cpus::get(),
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user