mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add CANDLE_NVCC_CCBIN
support for candle-kernels
, and eliminate warning. (#836)
This commit is contained in:
@ -12,6 +12,7 @@ use half::{bf16, f16};
|
|||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
|
#[cfg(not(feature = "cudnn"))]
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
|
@ -164,6 +164,8 @@ mod cuda {
|
|||||||
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||||
|
|
||||||
|
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||||
|
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||||
let children = kernel_paths
|
let children = kernel_paths
|
||||||
.par_iter()
|
.par_iter()
|
||||||
.flat_map(|p| {
|
.flat_map(|p| {
|
||||||
@ -188,8 +190,13 @@ mod cuda {
|
|||||||
.args(["--output-directory", &out_dir])
|
.args(["--output-directory", &out_dir])
|
||||||
// Flash attention only
|
// Flash attention only
|
||||||
// .arg("--expt-relaxed-constexpr")
|
// .arg("--expt-relaxed-constexpr")
|
||||||
.args(&include_options)
|
.args(&include_options);
|
||||||
.arg(p);
|
if let Ok(ccbin_path) = &ccbin_env {
|
||||||
|
command
|
||||||
|
.arg("-allow-unsupported-compiler")
|
||||||
|
.args(["-ccbin", ccbin_path]);
|
||||||
|
}
|
||||||
|
command.arg(p);
|
||||||
Some((p, command.spawn()
|
Some((p, command.spawn()
|
||||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||||
}})
|
}})
|
||||||
|
Reference in New Issue
Block a user