mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Moving to a proper build crate bindgen_cuda
. (#1531)
* Moving to a proper build crate `bindgen_cuda`. * Fmt.
This commit is contained in:
@ -15,9 +15,9 @@ candle = { path = "../candle-core", features = ["cuda"], package = "candle-core"
|
|||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
bindgen_cuda = "0.1.1"
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
num_cpus = "1.15.0"
|
|
||||||
rayon = "1.7.0"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
|
@ -2,44 +2,32 @@
|
|||||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use rayon::prelude::*;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
const KERNEL_FILES: [&str; 17] = [
|
const KERNEL_FILES: [&str; 17] = [
|
||||||
"flash_api.cu",
|
"kernels/flash_api.cu",
|
||||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim192_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim224_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim256_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
||||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
|
||||||
|_| num_cpus::get_physical(),
|
|
||||||
|s| usize::from_str(&s).unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
rayon::ThreadPoolBuilder::new()
|
|
||||||
.num_threads(num_cpus)
|
|
||||||
.build_global()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
for kernel_file in KERNEL_FILES.iter() {
|
for kernel_file in KERNEL_FILES.iter() {
|
||||||
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
println!("cargo:rerun-if-changed={kernel_file}");
|
||||||
}
|
}
|
||||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||||
@ -66,223 +54,30 @@ fn main() -> Result<()> {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
set_cuda_include_dir()?;
|
|
||||||
|
|
||||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
let kernels = KERNEL_FILES.iter().collect();
|
||||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
let builder = bindgen_cuda::Builder::default()
|
||||||
|
.kernel_paths(kernels)
|
||||||
let compute_cap = compute_cap()?;
|
.out_dir(build_dir.clone())
|
||||||
|
.arg("-std=c++17")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||||
|
.arg("-Icutlass/include")
|
||||||
|
.arg("--expt-relaxed-constexpr")
|
||||||
|
.arg("--expt-extended-lambda")
|
||||||
|
.arg("--use_fast_math")
|
||||||
|
.arg("--verbose");
|
||||||
|
|
||||||
let out_file = build_dir.join("libflashattention.a");
|
let out_file = build_dir.join("libflashattention.a");
|
||||||
|
builder.build_lib(out_file);
|
||||||
|
|
||||||
let kernel_dir = PathBuf::from("kernels");
|
|
||||||
let cu_files: Vec<_> = KERNEL_FILES
|
|
||||||
.iter()
|
|
||||||
.map(|f| {
|
|
||||||
let mut obj_file = out_dir.join(f);
|
|
||||||
obj_file.set_extension("o");
|
|
||||||
(kernel_dir.join(f), obj_file)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
|
||||||
let should_compile = if out_file.exists() {
|
|
||||||
kernel_dir
|
|
||||||
.read_dir()
|
|
||||||
.expect("kernels folder should exist")
|
|
||||||
.any(|entry| {
|
|
||||||
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
|
||||||
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
|
||||||
in_modified.duration_since(*out_modified).is_ok()
|
|
||||||
} else {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
true
|
|
||||||
};
|
|
||||||
if should_compile {
|
|
||||||
cu_files
|
|
||||||
.par_iter()
|
|
||||||
.map(|(cu_file, obj_file)| {
|
|
||||||
let mut command = std::process::Command::new("nvcc");
|
|
||||||
command
|
|
||||||
.arg("-std=c++17")
|
|
||||||
.arg("-O3")
|
|
||||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
|
||||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
|
||||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
|
||||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
|
||||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
|
||||||
.arg("-c")
|
|
||||||
.args(["-o", obj_file.to_str().unwrap()])
|
|
||||||
.args(["--default-stream", "per-thread"])
|
|
||||||
.arg("-Icutlass/include")
|
|
||||||
.arg("--expt-relaxed-constexpr")
|
|
||||||
.arg("--expt-extended-lambda")
|
|
||||||
.arg("--use_fast_math")
|
|
||||||
.arg("--verbose");
|
|
||||||
if let Ok(ccbin_path) = &ccbin_env {
|
|
||||||
command
|
|
||||||
.arg("-allow-unsupported-compiler")
|
|
||||||
.args(["-ccbin", ccbin_path]);
|
|
||||||
}
|
|
||||||
command.arg(cu_file);
|
|
||||||
let output = command
|
|
||||||
.spawn()
|
|
||||||
.context("failed spawning nvcc")?
|
|
||||||
.wait_with_output()?;
|
|
||||||
if !output.status.success() {
|
|
||||||
anyhow::bail!(
|
|
||||||
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
|
||||||
&command,
|
|
||||||
String::from_utf8_lossy(&output.stdout),
|
|
||||||
String::from_utf8_lossy(&output.stderr)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
.collect::<Result<()>>()?;
|
|
||||||
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
|
||||||
let mut command = std::process::Command::new("nvcc");
|
|
||||||
command
|
|
||||||
.arg("--lib")
|
|
||||||
.args(["-o", out_file.to_str().unwrap()])
|
|
||||||
.args(obj_files);
|
|
||||||
let output = command
|
|
||||||
.spawn()
|
|
||||||
.context("failed spawning nvcc")?
|
|
||||||
.wait_with_output()?;
|
|
||||||
if !output.status.success() {
|
|
||||||
anyhow::bail!(
|
|
||||||
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
|
||||||
&command,
|
|
||||||
String::from_utf8_lossy(&output.stdout),
|
|
||||||
String::from_utf8_lossy(&output.stderr)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
println!("cargo:rustc-link-lib=flashattention");
|
println!("cargo:rustc-link-lib=flashattention");
|
||||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
|
||||||
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
|
||||||
finishing to run for some reason. Calling nvcc manually worked fine.
|
|
||||||
cc::Build::new()
|
|
||||||
.cuda(true)
|
|
||||||
.include("cutlass/include")
|
|
||||||
.flag("--expt-relaxed-constexpr")
|
|
||||||
.flag("--default-stream")
|
|
||||||
.flag("per-thread")
|
|
||||||
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
|
||||||
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
|
||||||
.compile("flashattn");
|
|
||||||
*/
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_cuda_include_dir() -> Result<()> {
|
|
||||||
// NOTE: copied from cudarc build.rs.
|
|
||||||
let env_vars = [
|
|
||||||
"CUDA_PATH",
|
|
||||||
"CUDA_ROOT",
|
|
||||||
"CUDA_TOOLKIT_ROOT_DIR",
|
|
||||||
"CUDNN_LIB",
|
|
||||||
];
|
|
||||||
let env_vars = env_vars
|
|
||||||
.into_iter()
|
|
||||||
.map(std::env::var)
|
|
||||||
.filter_map(Result::ok)
|
|
||||||
.map(Into::<PathBuf>::into);
|
|
||||||
|
|
||||||
let roots = [
|
|
||||||
"/usr",
|
|
||||||
"/usr/local/cuda",
|
|
||||||
"/opt/cuda",
|
|
||||||
"/usr/lib/cuda",
|
|
||||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
|
||||||
"C:/CUDA",
|
|
||||||
];
|
|
||||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
|
||||||
let root = env_vars
|
|
||||||
.chain(roots)
|
|
||||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
|
||||||
.context("cannot find include/cuda.h")?;
|
|
||||||
println!(
|
|
||||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
|
||||||
root.join("include").display()
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
fn compute_cap() -> Result<usize> {
|
|
||||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
|
||||||
|
|
||||||
// Try to parse compute caps from env
|
|
||||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
|
||||||
compute_cap_str
|
|
||||||
.parse::<usize>()
|
|
||||||
.context("Could not parse compute cap")?
|
|
||||||
} else {
|
|
||||||
// Use nvidia-smi to get the current compute cap
|
|
||||||
let out = std::process::Command::new("nvidia-smi")
|
|
||||||
.arg("--query-gpu=compute_cap")
|
|
||||||
.arg("--format=csv")
|
|
||||||
.output()
|
|
||||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
|
||||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
|
||||||
let mut lines = out.lines();
|
|
||||||
assert_eq!(
|
|
||||||
lines.next().context("missing line in stdout")?,
|
|
||||||
"compute_cap"
|
|
||||||
);
|
|
||||||
let cap = lines
|
|
||||||
.next()
|
|
||||||
.context("missing line in stdout")?
|
|
||||||
.replace('.', "");
|
|
||||||
let cap = cap
|
|
||||||
.parse::<usize>()
|
|
||||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
|
||||||
cap
|
|
||||||
};
|
|
||||||
|
|
||||||
// Grab available GPU codes from nvcc and select the highest one
|
|
||||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
|
||||||
let out = std::process::Command::new("nvcc")
|
|
||||||
.arg("--list-gpu-code")
|
|
||||||
.output()
|
|
||||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
|
||||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
|
||||||
|
|
||||||
let out = out.lines().collect::<Vec<&str>>();
|
|
||||||
let mut codes = Vec::with_capacity(out.len());
|
|
||||||
for code in out {
|
|
||||||
let code = code.split('_').collect::<Vec<&str>>();
|
|
||||||
if !code.is_empty() && code.contains(&"sm") {
|
|
||||||
if let Ok(num) = code[1].parse::<usize>() {
|
|
||||||
codes.push(num);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
codes.sort();
|
|
||||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
|
||||||
(codes, max_nvcc_code)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check that nvcc supports the asked compute caps
|
|
||||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
|
||||||
anyhow::bail!(
|
|
||||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if compute_cap > max_nvcc_code {
|
|
||||||
anyhow::bail!(
|
|
||||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(compute_cap)
|
|
||||||
}
|
|
||||||
|
@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
bindgen_cuda = "0.1.1"
|
||||||
glob = "0.3.1"
|
|
||||||
rayon = "1.7.0"
|
|
||||||
|
@ -1,243 +1,8 @@
|
|||||||
use std::io::Write;
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
|
||||||
cuda::set_include_dir();
|
let builder = bindgen_cuda::Builder::default();
|
||||||
let (write, kernel_paths) = cuda::build_ptx();
|
println!("cargo:info={builder:?}");
|
||||||
if write {
|
let bindings = builder.build_ptx().unwrap();
|
||||||
let mut file = std::fs::File::create("src/lib.rs").unwrap();
|
bindings.write("src/lib.rs").unwrap();
|
||||||
for kernel_path in kernel_paths {
|
|
||||||
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
|
|
||||||
file.write_all(
|
|
||||||
format!(
|
|
||||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
|
|
||||||
name.to_uppercase().replace('.', "_"),
|
|
||||||
name
|
|
||||||
)
|
|
||||||
.as_bytes(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
file.write_all(&[b'\n']).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mod cuda {
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
|
|
||||||
pub fn set_include_dir() {
|
|
||||||
use std::path::PathBuf;
|
|
||||||
// NOTE: copied from cudarc build.rs.
|
|
||||||
// We can't actually set a env!() value from another crate,
|
|
||||||
// so we have to do that here.
|
|
||||||
|
|
||||||
// use PathBuf;
|
|
||||||
|
|
||||||
let env_vars = [
|
|
||||||
"CUDA_PATH",
|
|
||||||
"CUDA_ROOT",
|
|
||||||
"CUDA_TOOLKIT_ROOT_DIR",
|
|
||||||
"CUDNN_LIB",
|
|
||||||
];
|
|
||||||
#[allow(unused)]
|
|
||||||
let env_vars = env_vars
|
|
||||||
.into_iter()
|
|
||||||
.map(std::env::var)
|
|
||||||
.filter_map(Result::ok)
|
|
||||||
.map(Into::<PathBuf>::into);
|
|
||||||
|
|
||||||
let roots = [
|
|
||||||
"/usr",
|
|
||||||
"/usr/local/cuda",
|
|
||||||
"/opt/cuda",
|
|
||||||
"/usr/lib/cuda",
|
|
||||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
|
||||||
"C:/CUDA",
|
|
||||||
];
|
|
||||||
#[allow(unused)]
|
|
||||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
|
||||||
|
|
||||||
#[cfg(feature = "ci-check")]
|
|
||||||
let root: PathBuf = "ci".into();
|
|
||||||
|
|
||||||
#[cfg(not(feature = "ci-check"))]
|
|
||||||
let root = env_vars
|
|
||||||
.chain(roots)
|
|
||||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
|
||||||
root.join("include").display()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
|
|
||||||
use rayon::prelude::*;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
|
||||||
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
|
|
||||||
.unwrap()
|
|
||||||
.map(|p| p.unwrap())
|
|
||||||
.collect();
|
|
||||||
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
|
|
||||||
.unwrap()
|
|
||||||
.map(|p| p.unwrap())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=src/");
|
|
||||||
// for path in &kernel_paths {
|
|
||||||
// println!("cargo:rerun-if-changed={}", path.display());
|
|
||||||
// }
|
|
||||||
|
|
||||||
for path in &mut include_directories {
|
|
||||||
// println!("cargo:rerun-if-changed={}", path.display());
|
|
||||||
let destination =
|
|
||||||
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
|
|
||||||
std::fs::copy(path.clone(), destination).unwrap();
|
|
||||||
// remove the filename from the path so it's just the directory
|
|
||||||
path.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
include_directories.sort();
|
|
||||||
include_directories.dedup();
|
|
||||||
|
|
||||||
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
let include_options: Vec<String> = include_directories
|
|
||||||
.into_iter()
|
|
||||||
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
|
||||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
|
||||||
let children = kernel_paths
|
|
||||||
.par_iter()
|
|
||||||
.flat_map(|p| {
|
|
||||||
let mut output = p.clone();
|
|
||||||
output.set_extension("ptx");
|
|
||||||
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
|
|
||||||
|
|
||||||
let ignore = if output_filename.exists() {
|
|
||||||
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
|
|
||||||
let in_modified = p.metadata().unwrap().modified().unwrap();
|
|
||||||
out_modified.duration_since(in_modified).is_ok()
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
if ignore {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mut command = std::process::Command::new("nvcc");
|
|
||||||
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
|
||||||
.arg("--ptx")
|
|
||||||
.args(["--default-stream", "per-thread"])
|
|
||||||
.args(["--output-directory", &out_dir])
|
|
||||||
// Flash attention only
|
|
||||||
// .arg("--expt-relaxed-constexpr")
|
|
||||||
.args(&include_options);
|
|
||||||
if let Ok(ccbin_path) = &ccbin_env {
|
|
||||||
command
|
|
||||||
.arg("-allow-unsupported-compiler")
|
|
||||||
.args(["-ccbin", ccbin_path]);
|
|
||||||
}
|
|
||||||
command.arg(p);
|
|
||||||
Some((p, command.spawn()
|
|
||||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
|
|
||||||
.unwrap()
|
|
||||||
.map(|p| p.unwrap())
|
|
||||||
.collect();
|
|
||||||
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
|
|
||||||
// some old ones
|
|
||||||
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
|
|
||||||
for (kernel_path, child) in children {
|
|
||||||
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
|
||||||
assert!(
|
|
||||||
output.status.success(),
|
|
||||||
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
|
||||||
String::from_utf8_lossy(&output.stdout),
|
|
||||||
String::from_utf8_lossy(&output.stderr)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
(write, kernel_paths)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
fn compute_cap() -> Result<usize> {
|
|
||||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
|
||||||
|
|
||||||
// Try to parse compute caps from env
|
|
||||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
|
||||||
compute_cap_str
|
|
||||||
.parse::<usize>()
|
|
||||||
.context("Could not parse code")?
|
|
||||||
} else {
|
|
||||||
// Use nvidia-smi to get the current compute cap
|
|
||||||
let out = std::process::Command::new("nvidia-smi")
|
|
||||||
.arg("--query-gpu=compute_cap")
|
|
||||||
.arg("--format=csv")
|
|
||||||
.output()
|
|
||||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
|
||||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
|
||||||
let mut lines = out.lines();
|
|
||||||
assert_eq!(
|
|
||||||
lines.next().context("missing line in stdout")?,
|
|
||||||
"compute_cap"
|
|
||||||
);
|
|
||||||
let cap = lines
|
|
||||||
.next()
|
|
||||||
.context("missing line in stdout")?
|
|
||||||
.replace('.', "");
|
|
||||||
let cap = cap
|
|
||||||
.parse::<usize>()
|
|
||||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
|
||||||
cap
|
|
||||||
};
|
|
||||||
|
|
||||||
// Grab available GPU codes from nvcc and select the highest one
|
|
||||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
|
||||||
let out = std::process::Command::new("nvcc")
|
|
||||||
.arg("--list-gpu-code")
|
|
||||||
.output()
|
|
||||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
|
||||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
|
||||||
|
|
||||||
let out = out.lines().collect::<Vec<&str>>();
|
|
||||||
let mut codes = Vec::with_capacity(out.len());
|
|
||||||
for code in out {
|
|
||||||
let code = code.split('_').collect::<Vec<&str>>();
|
|
||||||
if !code.is_empty() && code.contains(&"sm") {
|
|
||||||
if let Ok(num) = code[1].parse::<usize>() {
|
|
||||||
codes.push(num);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
codes.sort();
|
|
||||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
|
||||||
(codes, max_nvcc_code)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check that nvcc supports the asked compute caps
|
|
||||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
|
||||||
anyhow::bail!(
|
|
||||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if compute_cap > max_nvcc_code {
|
|
||||||
anyhow::bail!(
|
|
||||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(compute_cap)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user