mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params. * Use the appropriate kernel sizes. * Add all the kernel sizes. * Parallel compiling. * Reduce the amount of parallelism. * Add the missing kernel. * Fix a typo. * Remove bf16 support for now.
This commit is contained in:
@ -2,11 +2,45 @@
|
||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||
use anyhow::{Context, Result};
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&'static str; 9] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
// "flash_fwd_hdim128_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim160_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim192_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim224_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim256_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim32_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim64_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
||||
|_| num_cpus::get_physical(),
|
||||
|s| usize::from_str(&s).unwrap(),
|
||||
);
|
||||
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(num_cpus)
|
||||
.build_global()
|
||||
.unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu");
|
||||
for kernel_file in KERNEL_FILES.iter() {
|
||||
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
||||
}
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash.h");
|
||||
@ -16,42 +50,74 @@ fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
|
||||
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
||||
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
||||
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
|
||||
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
|
||||
Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?,
|
||||
Ok(build_dir) => build_dir,
|
||||
Err(_) => out_dir.clone(),
|
||||
Ok(build_dir) => PathBuf::from(build_dir),
|
||||
};
|
||||
let build_dir = PathBuf::from(build_dir);
|
||||
set_cuda_include_dir()?;
|
||||
let compute_cap = compute_cap()?;
|
||||
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
|
||||
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
|
||||
let kernel_dir = PathBuf::from("kernels");
|
||||
let cu_files: Vec<_> = KERNEL_FILES
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut obj_file = out_dir.join(f);
|
||||
obj_file.set_extension("o");
|
||||
(kernel_dir.join(f), obj_file)
|
||||
})
|
||||
.collect();
|
||||
let should_compile = if out_file.exists() {
|
||||
let out_modified = out_file.metadata()?.modified()?;
|
||||
let in_modified = cu_file.metadata()?.modified()?;
|
||||
in_modified.duration_since(out_modified).is_ok()
|
||||
cu_files.iter().any(|(cu_file, _)| {
|
||||
let out_modified = out_file.metadata().unwrap().modified().unwrap();
|
||||
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(out_modified).is_ok()
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
cu_files
|
||||
.par_iter()
|
||||
.map(|(cu_file, obj_file)| {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<()>>()?;
|
||||
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--lib")
|
||||
.args(["-o", out_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg(cu_file);
|
||||
.args(obj_files);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
"nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
|
Reference in New Issue
Block a user