Specific cache dir for the flash attn build artifacts. (#242)

This commit is contained in:
Laurent Mazare
2023-07-26 08:04:02 +01:00
committed by GitHub
parent d9f9c859af
commit 471855e2ee

View File

@ -1,6 +1,7 @@
#![allow(unused)] // Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
fn main() -> Result<()> { fn main() -> Result<()> {
@ -16,17 +17,16 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/block_info.h"); println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h"); println!("cargo:rerun-if-changed=kernels/static_switch.h");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
let mut out_dir = PathBuf::from(out_dir); Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?,
// TODO: Getting up two levels avoid having to recompile this too often, however it's likely Ok(build_dir) => build_dir,
// not a safe assumption. };
out_dir.pop(); let build_dir = PathBuf::from(build_dir);
out_dir.pop();
set_cuda_include_dir()?; set_cuda_include_dir()?;
let compute_cap = compute_cap()?; let compute_cap = compute_cap()?;
let mut command = std::process::Command::new("nvcc"); let mut command = std::process::Command::new("nvcc");
let out_file = out_dir.join("libflashattention.a"); let out_file = build_dir.join("libflashattention.a");
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu"); let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
let should_compile = if out_file.exists() { let should_compile = if out_file.exists() {
@ -57,7 +57,7 @@ fn main() -> Result<()> {
) )
} }
} }
println!("cargo:rustc-link-search={}", out_dir.display()); println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++"); println!("cargo:rustc-link-lib=dylib=stdc++");