mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
467 lines
17 KiB
Rust
467 lines
17 KiB
Rust
use crate::backend::BackendDevice;
|
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
|
pub use candle_kernels as kernels;
|
|
pub use cudarc;
|
|
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
|
use half::{bf16, f16};
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
|
|
|
/// Unique identifier for cuda devices.
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
pub struct DeviceId(usize);
|
|
|
|
impl DeviceId {
|
|
fn new() -> Self {
|
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
use std::sync::atomic;
|
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
}
|
|
}
|
|
|
|
struct CudaRng(cudarc::curand::CudaRng);
|
|
unsafe impl Send for CudaRng {}
|
|
|
|
#[derive(Clone)]
|
|
pub struct CudaDevice {
|
|
id: DeviceId,
|
|
device: Arc<cudarc::driver::CudaDevice>,
|
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
|
curand: Arc<Mutex<CudaRng>>,
|
|
}
|
|
|
|
impl std::fmt::Debug for CudaDevice {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "CudaDevice({:?})", self.id)
|
|
}
|
|
}
|
|
|
|
impl std::ops::Deref for CudaDevice {
|
|
type Target = Arc<cudarc::driver::CudaDevice>;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.device
|
|
}
|
|
}
|
|
|
|
impl CudaDevice {
|
|
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
|
self.device.clone()
|
|
}
|
|
|
|
pub fn id(&self) -> DeviceId {
|
|
self.id
|
|
}
|
|
|
|
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
let elem_count = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
let slice = match dtype {
|
|
DType::U8 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
|
let params = (&data, v as u8, elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::U8(data)
|
|
}
|
|
DType::U32 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
|
let params = (&data, v as u32, elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::U32(data)
|
|
}
|
|
DType::I64 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
|
let params = (&data, v as i64, elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::I64(data)
|
|
}
|
|
DType::BF16 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
|
let params = (&data, bf16::from_f64(v), elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::BF16(data)
|
|
}
|
|
DType::F16 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
|
let params = (&data, f16::from_f64(v), elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::F16(data)
|
|
}
|
|
DType::F32 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
|
let params = (&data, v as f32, elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
DType::F64 => {
|
|
// SAFETY: Set later by running the fill kernel.
|
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
|
let params = (&data, v, elem_count);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
|
if !self.has_func(module_name, module_name) {
|
|
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
|
// done once per kernel name.
|
|
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
|
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
|
.map_err(|cuda| CudaError::Load {
|
|
cuda,
|
|
module_name: module_name.to_string(),
|
|
})
|
|
.w()?;
|
|
}
|
|
self.get_func(module_name, module_name)
|
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
|
// able to only build the error value if needed.
|
|
.ok_or(CudaError::MissingKernel {
|
|
module_name: module_name.to_string(),
|
|
})
|
|
.w()
|
|
}
|
|
}
|
|
|
|
impl CudaDevice {
|
|
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
|
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
|
Ok(Self {
|
|
id: DeviceId::new(),
|
|
device,
|
|
blas: Arc::new(blas),
|
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl BackendDevice for CudaDevice {
|
|
type Storage = CudaStorage;
|
|
|
|
fn new(ordinal: usize) -> Result<Self> {
|
|
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
|
Ok(Self {
|
|
id: DeviceId::new(),
|
|
device,
|
|
blas: Arc::new(blas),
|
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
|
})
|
|
}
|
|
|
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
|
// state will be identical and the same random numbers will be generated.
|
|
let mut curand = self.curand.lock().unwrap();
|
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
|
Ok(())
|
|
}
|
|
|
|
fn location(&self) -> crate::DeviceLocation {
|
|
crate::DeviceLocation::Cuda {
|
|
gpu_id: self.device.ordinal(),
|
|
}
|
|
}
|
|
|
|
fn same_device(&self, rhs: &Self) -> bool {
|
|
self.id == rhs.id
|
|
}
|
|
|
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
let elem_count = shape.elem_count();
|
|
let slice = match dtype {
|
|
DType::U8 => {
|
|
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
|
CudaStorageSlice::U8(data)
|
|
}
|
|
DType::U32 => {
|
|
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
|
CudaStorageSlice::U32(data)
|
|
}
|
|
DType::I64 => {
|
|
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
|
CudaStorageSlice::I64(data)
|
|
}
|
|
DType::BF16 => {
|
|
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
|
CudaStorageSlice::BF16(data)
|
|
}
|
|
DType::F16 => {
|
|
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
|
CudaStorageSlice::F16(data)
|
|
}
|
|
DType::F32 => {
|
|
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
DType::F64 => {
|
|
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
|
let elem_count = shape.elem_count();
|
|
let curand = self.curand.lock().unwrap();
|
|
let slice = match dtype {
|
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
// cudarc changes.
|
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
Err(CudaError::UnsupportedDtype {
|
|
dtype,
|
|
op: "rand_uniform",
|
|
})
|
|
.w()?
|
|
}
|
|
DType::F32 => {
|
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
curand.0.fill_with_uniform(&mut data).w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
DType::F64 => {
|
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
curand.0.fill_with_uniform(&mut data).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
let slice = if lo == 0. && up == 1.0 {
|
|
slice
|
|
} else {
|
|
use super::utils::Map1;
|
|
let layout = Layout::contiguous(shape);
|
|
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
// cudarc changes.
|
|
let elem_count = shape.elem_count();
|
|
let curand = self.curand.lock().unwrap();
|
|
// curand can only generate an odd number of values.
|
|
// https://github.com/huggingface/candle/issues/734
|
|
let elem_count_round = if elem_count % 2 == 1 {
|
|
elem_count + 1
|
|
} else {
|
|
elem_count
|
|
};
|
|
let slice = match dtype {
|
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
Err(CudaError::UnsupportedDtype {
|
|
dtype,
|
|
op: "rand_normal",
|
|
})
|
|
.w()?
|
|
}
|
|
DType::F32 => {
|
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
|
curand
|
|
.0
|
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
|
.w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
DType::F64 => {
|
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
self.const_impl(1., shape, dtype)
|
|
}
|
|
|
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
|
let elem_count = shape.elem_count();
|
|
let slice = match dtype {
|
|
DType::U8 => {
|
|
let data = self.alloc::<u8>(elem_count).w()?;
|
|
CudaStorageSlice::U8(data)
|
|
}
|
|
DType::U32 => {
|
|
let data = self.alloc::<u32>(elem_count).w()?;
|
|
CudaStorageSlice::U32(data)
|
|
}
|
|
DType::I64 => {
|
|
let data = self.alloc::<i64>(elem_count).w()?;
|
|
CudaStorageSlice::I64(data)
|
|
}
|
|
DType::BF16 => {
|
|
let data = self.alloc::<bf16>(elem_count).w()?;
|
|
CudaStorageSlice::BF16(data)
|
|
}
|
|
DType::F16 => {
|
|
let data = self.alloc::<f16>(elem_count).w()?;
|
|
CudaStorageSlice::F16(data)
|
|
}
|
|
DType::F32 => {
|
|
let data = self.alloc::<f32>(elem_count).w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
DType::F64 => {
|
|
let data = self.alloc::<f64>(elem_count).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
|
let slice = match T::cpu_storage_ref(s) {
|
|
CpuStorageRef::U8(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::U8(data)
|
|
}
|
|
CpuStorageRef::U32(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::U32(data)
|
|
}
|
|
CpuStorageRef::I64(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::I64(data)
|
|
}
|
|
CpuStorageRef::BF16(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::BF16(data)
|
|
}
|
|
CpuStorageRef::F16(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::F16(data)
|
|
}
|
|
CpuStorageRef::F32(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
CpuStorageRef::F64(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
|
let slice = match storage {
|
|
CpuStorage::U8(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::U8(data)
|
|
}
|
|
CpuStorage::U32(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::U32(data)
|
|
}
|
|
CpuStorage::I64(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::I64(data)
|
|
}
|
|
CpuStorage::BF16(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::BF16(data)
|
|
}
|
|
CpuStorage::F16(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::F16(data)
|
|
}
|
|
CpuStorage::F32(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
CpuStorage::F64(storage) => {
|
|
let data = self.htod_sync_copy(storage).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
|
let slice = match storage {
|
|
CpuStorage::U8(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::U8(data)
|
|
}
|
|
CpuStorage::U32(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::U32(data)
|
|
}
|
|
CpuStorage::I64(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::I64(data)
|
|
}
|
|
CpuStorage::BF16(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::BF16(data)
|
|
}
|
|
CpuStorage::F16(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::F16(data)
|
|
}
|
|
CpuStorage::F32(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::F32(data)
|
|
}
|
|
CpuStorage::F64(storage) => {
|
|
let data = self.htod_copy(storage).w()?;
|
|
CudaStorageSlice::F64(data)
|
|
}
|
|
};
|
|
Ok(CudaStorage {
|
|
slice,
|
|
device: self.clone(),
|
|
})
|
|
}
|
|
|
|
fn synchronize(&self) -> Result<()> {
|
|
self.device.synchronize().map_err(crate::Error::wrap)?;
|
|
Ok(())
|
|
}
|
|
}
|