Update to cudarc 0.14 (breaking change). (#2858)

* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
This commit is contained in:
Laurent Mazare
2025-04-03 09:12:19 +02:00
committed by GitHub
parent d6db305829
commit d9904a3baf
18 changed files with 924 additions and 590 deletions

View File

@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_device());
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv2d.launch::<CudaSlice<u8>, _, _, _>(
alg,

View File

@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
@ -24,10 +25,17 @@ impl DeviceId {
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
pub struct ModuleStore {
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
context: Arc<cudarc::driver::CudaContext>,
modules: Arc<std::sync::RwLock<ModuleStore>>,
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
stream: Arc<cudarc::driver::CudaStream>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
type Target = Arc<cudarc::driver::CudaStream>;
fn deref(&self) -> &Self::Target {
&self.device
&self.stream
}
}
pub struct CudaFunc {
func: CudaFunction,
stream: Arc<cudarc::driver::CudaStream>,
}
impl std::ops::Deref for CudaFunc {
type Target = CudaFunction;
fn deref(&self) -> &Self::Target {
&self.func
}
}
impl CudaFunc {
pub fn into_cuda_function(self) -> CudaFunction {
self.func
}
}
#[macro_export]
macro_rules! builder_arg {
($b:ident, $($arg:expr),*) => {
$(
let __arg = $arg;
$b.arg(&__arg);
)*
};
}
impl CudaFunc {
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
self.stream.launch_builder(&self.func)
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
self.stream.clone()
}
#[cfg(not(target_arch = "wasm32"))]
@ -56,7 +99,7 @@ impl CudaDevice {
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
) -> Result<CudaFunc> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
@ -65,12 +108,12 @@ impl CudaDevice {
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(func_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn id(&self) -> DeviceId {
@ -84,57 +127,84 @@ impl CudaDevice {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F64(data)
}
};
@ -144,38 +214,69 @@ impl CudaDevice {
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
module_name: &str,
ptx: &str,
) -> Result<CudaFunc> {
let ms = self.custom_modules.read().unwrap();
if let Some(mdl) = ms.get(module_name).as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
drop(ms);
let mut ms = self.custom_modules.write().unwrap();
let cuda_module = self.context.load_module(ptx.into()).w()?;
ms.insert(module_name.to_string(), cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
let ms = self.modules.read().unwrap();
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.modules.write().unwrap();
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
ms.mdls[mdl.index()] = Some(cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.new_stream().w()?;
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
device,
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
}
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.default_stream();
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
device,
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
gpu_id: self.context.ordinal(),
}
}
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
self.stream.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff