mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Move the data between the host and the device.
This commit is contained in:
80
src/cuda_backend.rs
Normal file
80
src/cuda_backend.rs
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
use crate::{CpuStorage, DType, Result, Shape};
|
||||||
|
use cudarc::driver::CudaSlice;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||||
|
Ok(Self(device))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ordinal(&self) -> usize {
|
||||||
|
self.0.ordinal()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
match dtype {
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.0.alloc_zeros::<f32>(elem_count)?;
|
||||||
|
Ok(CudaStorage::F32(data))
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.0.alloc_zeros::<f64>(elem_count)?;
|
||||||
|
Ok(CudaStorage::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
|
match storage {
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.0.htod_sync_copy(storage)?;
|
||||||
|
Ok(CudaStorage::F32(data))
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.0.htod_sync_copy(storage)?;
|
||||||
|
Ok(CudaStorage::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CudaStorage {
|
||||||
|
F32(CudaSlice<f32>),
|
||||||
|
F64(CudaSlice<f64>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaStorage {
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
match self {
|
||||||
|
Self::F32(_) => DType::F32,
|
||||||
|
Self::F64(_) => DType::F64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> CudaDevice {
|
||||||
|
match self {
|
||||||
|
Self::F32(slice) => CudaDevice(slice.device()),
|
||||||
|
Self::F64(slice) => CudaDevice(slice.device()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
match self {
|
||||||
|
Self::F32(slice) => {
|
||||||
|
let dev = slice.device();
|
||||||
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
|
Ok(CpuStorage::F32(cpu_storage))
|
||||||
|
}
|
||||||
|
Self::F64(slice) => {
|
||||||
|
let dev = slice.device();
|
||||||
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
|
Ok(CpuStorage::F64(cpu_storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -11,7 +11,7 @@ pub enum DeviceLocation {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(std::sync::Arc<cudarc::driver::CudaDevice>),
|
Cuda(crate::CudaDevice),
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||||
@ -63,8 +63,7 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
|||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
Ok(Self::Cuda(device))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn location(&self) -> DeviceLocation {
|
pub fn location(&self) -> DeviceLocation {
|
||||||
@ -85,7 +84,8 @@ impl Device {
|
|||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Instead of allocating memory on the host and transfering it,
|
// TODO: Instead of allocating memory on the host and transfering it,
|
||||||
// allocate some zeros on the device and use a shader to set them to 1.
|
// allocate some zeros on the device and use a shader to set them to 1.
|
||||||
let storage = device.htod_copy(vec![1f32; shape.elem_count()])?;
|
let storage = CpuStorage::ones_impl(shape, dtype);
|
||||||
|
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,7 +98,7 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = device.alloc_zeros::<f32>(shape.elem_count())?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,16 +108,9 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Avoid making a copy through the cpu.
|
let storage = array.to_cpu_storage();
|
||||||
match array.to_cpu_storage() {
|
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||||
CpuStorage::F64(_) => {
|
Ok(Storage::Cuda(storage))
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
CpuStorage::F32(data) => {
|
|
||||||
let storage = device.htod_copy(data)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
mod cpu_backend;
|
mod cpu_backend;
|
||||||
|
mod cuda_backend;
|
||||||
mod device;
|
mod device;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod error;
|
mod error;
|
||||||
@ -9,6 +10,7 @@ mod strided_index;
|
|||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
|
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||||
pub use device::{Device, DeviceLocation};
|
pub use device::{Device, DeviceLocation};
|
||||||
pub use dtype::{DType, WithDType};
|
pub use dtype::{DType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use crate::{CpuStorage, DType, Device, Error, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(cudarc::driver::CudaSlice<f32>),
|
Cuda(CudaStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait UnaryOp {
|
pub(crate) trait UnaryOp {
|
||||||
@ -100,14 +100,14 @@ impl Storage {
|
|||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(slice) => Device::Cuda(slice.device()),
|
Self::Cuda(storage) => Device::Cuda(storage.device()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
Self::Cuda { .. } => todo!(),
|
Self::Cuda(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +250,12 @@ impl Tensor {
|
|||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||||
}
|
}
|
||||||
Storage::Cuda(_) => todo!(),
|
Storage::Cuda(slice) => {
|
||||||
|
// TODO: Would it be possible to only fetch the necessary data?
|
||||||
|
let cpu_storage = slice.to_cpu_storage()?;
|
||||||
|
let data = S::cpu_storage_as_slice(&cpu_storage)?;
|
||||||
|
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user