mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Modular backends (#138)
* Add some trait to formalize backends. * Use the generic backend trait.
This commit is contained in:
71
candle-core/src/backend.rs
Normal file
71
candle-core/src/backend.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
|
||||
pub(crate) trait BackendStorage: Sized {
|
||||
type Device: BackendDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn dtype(&self) -> DType;
|
||||
|
||||
fn device(&self) -> &Self::Device;
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage>;
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
||||
|
||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
|
||||
-> Result<Self>;
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
type Storage: BackendStorage;
|
||||
|
||||
// TODO: Make the usize generic and part of a generic DeviceLocation.
|
||||
fn new(_: usize) -> Result<Self>;
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation;
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool;
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct ParamsConv1D {
|
||||
pub struct ParamsConv1D {
|
||||
pub(crate) b_size: Option<usize>,
|
||||
// Maybe we should have a version without l_in as this bit depends on the input and not only on
|
||||
// the weights.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,4 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
@ -85,10 +86,10 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn same_id(&self, rhs: &Self) -> bool {
|
||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@ -96,9 +97,7 @@ impl Device {
|
||||
pub fn location(&self) -> DeviceLocation {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => DeviceLocation::Cuda {
|
||||
gpu_id: device.ordinal(),
|
||||
},
|
||||
Self::Cuda(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,7 +177,7 @@ impl Device {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
@ -189,7 +188,7 @@ impl Device {
|
||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||
Device::Cuda(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
|
@ -1,98 +1,62 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum DummyError {}
|
||||
pub type CudaError = DummyError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CudaDevice;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage;
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("cuda support has not been enabled")
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub(crate) fn new(_: usize) -> Result<Self> {
|
||||
impl crate::backend::BackendStorage for CudaStorage {
|
||||
type Device = CudaDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn same_id(&self, _: &Self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn ordinal(&self) -> usize {
|
||||
fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage;
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
fn device(&self) -> &Self::Device {
|
||||
fail!()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
fail!()
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
@ -101,32 +65,25 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul(
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
@ -136,7 +93,42 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ pub enum Error {
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] crate::CudaError),
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
@ -33,6 +33,7 @@
|
||||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
|
||||
mod backend;
|
||||
mod backprop;
|
||||
mod conv;
|
||||
mod cpu_backend;
|
||||
@ -68,10 +69,10 @@ use strided_index::StridedIndex;
|
||||
pub use tensor::{Tensor, TensorId};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
|
||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
@ -963,19 +964,19 @@ impl Tensor {
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||
if self.device().same_id(device) {
|
||||
if self.device().same_device(device) {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let storage = match (self.storage.as_ref(), device) {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?)
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
let cpu_storage = storage.to_cpu_storage()?;
|
||||
Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?)
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
};
|
||||
|
Reference in New Issue
Block a user