mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the cublas handle to the cuda device.
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
use crate::{CpuStorage, DType, Shape};
|
use crate::{CpuStorage, DType, Shape};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -11,6 +12,9 @@ pub enum CudaError {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||||
|
|
||||||
#[error("{op} only supports contiguous tensors")]
|
#[error("{op} only supports contiguous tensors")]
|
||||||
RequiresContiguous { op: &'static str },
|
RequiresContiguous { op: &'static str },
|
||||||
|
|
||||||
@ -24,54 +28,77 @@ pub enum CudaError {
|
|||||||
type Result<T> = std::result::Result<T, CudaError>;
|
type Result<T> = std::result::Result<T, CudaError>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
|
pub struct CudaDevice {
|
||||||
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaDevice {
|
||||||
|
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||||
Ok(Self(device))
|
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||||
|
Ok(Self {
|
||||||
|
device,
|
||||||
|
blas: Arc::new(blas),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ordinal(&self) -> usize {
|
pub(crate) fn ordinal(&self) -> usize {
|
||||||
self.0.ordinal()
|
self.device.ordinal()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
let slice = match dtype {
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.0.alloc_zeros::<f32>(elem_count)?;
|
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||||
Ok(CudaStorage::F32(data))
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let data = self.0.alloc_zeros::<f64>(elem_count)?;
|
let data = self.alloc_zeros::<f64>(elem_count)?;
|
||||||
Ok(CudaStorage::F64(data))
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dev = &self.0;
|
let slice = match dtype {
|
||||||
match dtype {
|
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { dev.alloc::<f32>(elem_count) }?;
|
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||||
let params = (&data, v as f32, elem_count);
|
let params = (&data, v as f32, elem_count);
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(CudaStorage::F32(data))
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { dev.alloc::<f64>(elem_count) }?;
|
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||||
let params = (&data, v, elem_count);
|
let params = (&data, v, elem_count);
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(CudaStorage::F64(data))
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
@ -79,16 +106,20 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
match storage {
|
let slice = match storage {
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.0.htod_sync_copy(storage)?;
|
let data = self.htod_sync_copy(storage)?;
|
||||||
Ok(CudaStorage::F32(data))
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.0.htod_sync_copy(storage)?;
|
let data = self.htod_sync_copy(storage)?;
|
||||||
Ok(CudaStorage::F64(data))
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_load_func(
|
fn get_or_load_func(
|
||||||
@ -96,11 +127,10 @@ impl CudaDevice {
|
|||||||
module_name: &'static str,
|
module_name: &'static str,
|
||||||
ptx: &'static str,
|
ptx: &'static str,
|
||||||
) -> Result<CudaFunction> {
|
) -> Result<CudaFunction> {
|
||||||
let dev = &self.0;
|
if !self.has_func(module_name, module_name) {
|
||||||
if !dev.has_func(module_name, module_name) {
|
self.load_ptx(ptx.into(), module_name, &[module_name])?;
|
||||||
dev.load_ptx(ptx.into(), module_name, &[module_name])?;
|
|
||||||
}
|
}
|
||||||
dev.get_func(module_name, module_name)
|
self.get_func(module_name, module_name)
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
// able to only build the error value if needed.
|
// able to only build the error value if needed.
|
||||||
.ok_or(CudaError::MissingKernel { module_name })
|
.ok_or(CudaError::MissingKernel { module_name })
|
||||||
@ -108,31 +138,36 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum CudaStorage {
|
enum CudaStorageSlice {
|
||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CudaStorage {
|
||||||
|
slice: CudaStorageSlice,
|
||||||
|
device: CudaDevice,
|
||||||
|
}
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self) -> Result<Self> {
|
||||||
match self {
|
let slice = match &self.slice {
|
||||||
Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)),
|
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
||||||
Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)),
|
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
||||||
}
|
};
|
||||||
|
let device = self.device.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self.slice {
|
||||||
Self::F32(_) => DType::F32,
|
CudaStorageSlice::F32(_) => DType::F32,
|
||||||
Self::F64(_) => DType::F64,
|
CudaStorageSlice::F64(_) => DType::F64,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> CudaDevice {
|
pub fn device(&self) -> &CudaDevice {
|
||||||
match self {
|
&self.device
|
||||||
Self::F32(slice) => CudaDevice(slice.device()),
|
|
||||||
Self::F64(slice) => CudaDevice(slice.device()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
@ -146,27 +181,29 @@ impl CudaStorage {
|
|||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let ds = dev.0.htod_copy([dims, stride].concat())?;
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||||
match self {
|
let slice = match &self.slice {
|
||||||
Self::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.0.alloc::<f32>(el_count) }?;
|
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
|
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(Self::F32(out))
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
Self::F64(arg) => {
|
CudaStorageSlice::F64(arg) => {
|
||||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.0.alloc::<f64>(el_count) }?;
|
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
|
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(Self::F64(out))
|
CudaStorageSlice::F64(out)
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
let device = dev.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
|
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
|
||||||
@ -177,28 +214,30 @@ impl CudaStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let dev = self.device();
|
let dev = &self.device;
|
||||||
let ds = dev.0.htod_copy([dims, stride].concat())?;
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||||
match self {
|
let slice = match &self.slice {
|
||||||
Self::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.0.alloc::<f32>(el_count) }?;
|
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(Self::F32(out))
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
Self::F64(arg) => {
|
CudaStorageSlice::F64(arg) => {
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.0.alloc::<f64>(el_count) }?;
|
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(Self::F64(out))
|
CudaStorageSlice::F64(out)
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
let device = dev.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||||
@ -212,39 +251,41 @@ impl CudaStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let dims_and_strides = dev.0.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||||
match (self, rhs) {
|
let slice = match (&self.slice, &rhs.slice) {
|
||||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||||
// SAFETY: ffi
|
// SAFETY: ffi
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(Self::F32(out))
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
||||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||||
// SAFETY: ffi
|
// SAFETY: ffi
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
Ok(Self::F64(out))
|
CudaStorageSlice::F64(out)
|
||||||
}
|
}
|
||||||
// The dtypes should have been checked at this point so this is an internal error.
|
// The dtypes should have been checked at this point so this is an internal error.
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")),
|
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||||
}
|
};
|
||||||
|
let device = dev.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
match self {
|
match &self.slice {
|
||||||
Self::F32(slice) => {
|
CudaStorageSlice::F32(slice) => {
|
||||||
let dev = slice.device();
|
let dev = slice.device();
|
||||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
Ok(CpuStorage::F32(cpu_storage))
|
Ok(CpuStorage::F32(cpu_storage))
|
||||||
}
|
}
|
||||||
Self::F64(slice) => {
|
CudaStorageSlice::F64(slice) => {
|
||||||
let dev = slice.device();
|
let dev = slice.device();
|
||||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
Ok(CpuStorage::F64(cpu_storage))
|
Ok(CpuStorage::F64(cpu_storage))
|
||||||
|
@ -46,7 +46,7 @@ impl CudaStorage {
|
|||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> CudaDevice {
|
pub fn device(&self) -> &CudaDevice {
|
||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ impl Storage {
|
|||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(storage) => Device::Cuda(storage.device()),
|
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user