mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00

* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
125 lines
4.2 KiB
Rust
125 lines
4.2 KiB
Rust
use crate::WithDType;
|
|
use cudarc;
|
|
use cudarc::cudnn::safe::{ConvForward, Cudnn};
|
|
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
|
use std::cell::RefCell;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
|
|
// send nor sync.
|
|
thread_local! {
|
|
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
|
|
}
|
|
|
|
impl From<cudarc::cudnn::CudnnError> for crate::Error {
|
|
fn from(err: cudarc::cudnn::CudnnError) -> Self {
|
|
crate::Error::wrap(err)
|
|
}
|
|
}
|
|
|
|
impl From<cudarc::driver::DriverError> for crate::Error {
|
|
fn from(err: cudarc::driver::DriverError) -> Self {
|
|
crate::Error::wrap(err)
|
|
}
|
|
}
|
|
|
|
pub(crate) fn launch_conv2d<
|
|
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
|
Y: cudarc::cudnn::CudnnDataType,
|
|
>(
|
|
src: &CudaView<T>,
|
|
src_l: &crate::Layout,
|
|
filter: &CudaView<T>,
|
|
dst: &mut CudaSlice<T>,
|
|
params: &crate::conv::ParamsConv2D,
|
|
dev: &crate::cuda_backend::CudaDevice,
|
|
) -> crate::Result<()> {
|
|
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
|
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
|
|
|
let device_id = dev.id();
|
|
let cudnn = CUDNN.with(|cudnn| {
|
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
|
return Ok(cudnn.clone());
|
|
}
|
|
let c = Cudnn::new(dev.cuda_stream());
|
|
if let Ok(c) = &c {
|
|
cudnn.borrow_mut().insert(device_id, c.clone());
|
|
}
|
|
c
|
|
})?;
|
|
let conv = cudnn.create_conv2d::<Y>(
|
|
/* pad */ [params.padding as i32, params.padding as i32],
|
|
/* stride */ [params.stride as i32, params.stride as i32],
|
|
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
|
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
|
)?;
|
|
let x_shape = [
|
|
params.b_size as i32,
|
|
params.c_in as i32,
|
|
params.i_h as i32,
|
|
params.i_w as i32,
|
|
];
|
|
// Note that `src` already starts at the proper offset.
|
|
let x = if src_l.is_contiguous() {
|
|
cudnn.create_4d_tensor::<T>(
|
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
|
x_shape,
|
|
)?
|
|
} else {
|
|
let s = src_l.stride();
|
|
cudnn.create_4d_tensor_ex::<T>(
|
|
x_shape,
|
|
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
|
)?
|
|
};
|
|
let w = cudnn.create_4d_filter::<T>(
|
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
|
[
|
|
params.c_out as i32,
|
|
params.c_in as i32,
|
|
params.k_h as i32,
|
|
params.k_w as i32,
|
|
],
|
|
)?;
|
|
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
|
let y = cudnn.create_4d_tensor::<T>(
|
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
|
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
|
)?;
|
|
let conv2d = ConvForward {
|
|
conv: &conv,
|
|
x: &x,
|
|
w: &w,
|
|
y: &y,
|
|
};
|
|
let alg = match params.cudnn_fwd_algo {
|
|
None => conv2d.pick_algorithm()?,
|
|
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
|
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
|
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
|
}
|
|
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
|
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
|
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
|
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
|
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
|
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
|
};
|
|
let workspace_size = conv2d.get_workspace_size(alg)?;
|
|
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
|
unsafe {
|
|
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
|
alg,
|
|
Some(&mut workspace),
|
|
(T::one(), T::zero()),
|
|
src,
|
|
filter,
|
|
dst,
|
|
)?;
|
|
}
|
|
Ok(())
|
|
}
|