From 1c0916402179f8d5da849065219cfcf657bf2714 Mon Sep 17 00:00:00 2001 From: Charles Lew Date: Wed, 13 Sep 2023 18:39:22 +0800 Subject: [PATCH] Add `CANDLE_NVCC_CCBIN` support for `candle-kernels`, and eliminate warning. (#836) --- candle-core/src/cuda_backend.rs | 1 + candle-kernels/build.rs | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 55443068..07c5dfa8 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -12,6 +12,7 @@ use half::{bf16, f16}; use std::sync::{Arc, Mutex}; const USE_IM2COL_CONV1D: bool = true; +#[cfg(not(feature = "cudnn"))] const USE_IM2COL_CONV2D: bool = true; /// cudarc related errors diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 3c8e96a9..ad084671 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -164,6 +164,8 @@ mod cuda { println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let children = kernel_paths .par_iter() .flat_map(|p| { @@ -188,8 +190,13 @@ mod cuda { .args(["--output-directory", &out_dir]) // Flash attention only // .arg("--expt-relaxed-constexpr") - .args(&include_options) - .arg(p); + .args(&include_options); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } + command.arg(p); Some((p, command.spawn() .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output())) }})