mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Specific cache dir for the flash attn build artifacts. (#242)
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
#![allow(unused)]
|
// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
|
||||||
|
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||||
|
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use std::io::Write;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -16,17 +17,16 @@ fn main() -> Result<()> {
|
|||||||
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
||||||
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
||||||
|
|
||||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
|
||||||
let mut out_dir = PathBuf::from(out_dir);
|
Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?,
|
||||||
// TODO: Getting up two levels avoid having to recompile this too often, however it's likely
|
Ok(build_dir) => build_dir,
|
||||||
// not a safe assumption.
|
};
|
||||||
out_dir.pop();
|
let build_dir = PathBuf::from(build_dir);
|
||||||
out_dir.pop();
|
|
||||||
set_cuda_include_dir()?;
|
set_cuda_include_dir()?;
|
||||||
let compute_cap = compute_cap()?;
|
let compute_cap = compute_cap()?;
|
||||||
|
|
||||||
let mut command = std::process::Command::new("nvcc");
|
let mut command = std::process::Command::new("nvcc");
|
||||||
let out_file = out_dir.join("libflashattention.a");
|
let out_file = build_dir.join("libflashattention.a");
|
||||||
|
|
||||||
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
|
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
|
||||||
let should_compile = if out_file.exists() {
|
let should_compile = if out_file.exists() {
|
||||||
@ -57,7 +57,7 @@ fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("cargo:rustc-link-search={}", out_dir.display());
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
println!("cargo:rustc-link-lib=flashattention");
|
println!("cargo:rustc-link-lib=flashattention");
|
||||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
Reference in New Issue
Block a user