mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
feat: parse Cuda compute cap from env (#1066)
* feat: add support for multiple compute caps * Revert to one compute cap * fmt * fix
This commit is contained in:
@ -84,12 +84,19 @@ fn main() -> Result<()> {
|
||||
(kernel_dir.join(f), obj_file)
|
||||
})
|
||||
.collect();
|
||||
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
||||
let should_compile = if out_file.exists() {
|
||||
cu_files.iter().any(|(cu_file, _)| {
|
||||
let out_modified = out_file.metadata().unwrap().modified().unwrap();
|
||||
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(out_modified).is_ok()
|
||||
})
|
||||
kernel_dir
|
||||
.read_dir()
|
||||
.expect("kernels folder should exist")
|
||||
.any(|entry| {
|
||||
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
||||
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(*out_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
@ -100,12 +107,19 @@ fn main() -> Result<()> {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
@ -203,13 +217,21 @@ fn set_cuda_include_dir() -> Result<()> {
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
// Grab compute code from nvidia-smi
|
||||
let mut compute_cap = {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse compute cap")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
@ -220,16 +242,19 @@ fn compute_cap() -> Result<usize> {
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
@ -243,30 +268,21 @@ fn compute_cap() -> Result<usize> {
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
Reference in New Issue
Block a user