mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Implement the backend trait for the cpu backend. (#143)
This commit is contained in:
@ -9,6 +9,7 @@ pub(crate) trait BackendStorage: Sized {
|
||||
|
||||
fn device(&self) -> &Self::Device;
|
||||
|
||||
// Maybe this should return a Cow instead so that no copy is done on the cpu case.
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage>;
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
@ -14,6 +15,9 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CpuDevice;
|
||||
|
||||
trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
@ -519,7 +523,15 @@ fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
}
|
||||
|
||||
impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendStorage for CpuStorage {
|
||||
type Device = CpuDevice;
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::U8(_) => DType::U8,
|
||||
Self::U32(_) => DType::U32,
|
||||
@ -530,11 +542,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
// TODO: find a way around the quadratic number of cases below.
|
||||
match (self, dtype) {
|
||||
(Self::U8(storage), DType::BF16) => {
|
||||
@ -684,7 +692,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = layout.dims();
|
||||
let mut dst_dims = src_dims.to_vec();
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
@ -706,7 +714,7 @@ impl CpuStorage {
|
||||
.map(self, layout)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
match self {
|
||||
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
@ -717,11 +725,11 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
||||
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
@ -745,7 +753,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, layout, B::bf16);
|
||||
@ -774,12 +782,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16);
|
||||
@ -816,12 +819,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -841,7 +839,7 @@ impl CpuStorage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
fn where_cond(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
@ -854,7 +852,7 @@ impl CpuStorage {
|
||||
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||
}
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
fn conv1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
@ -864,7 +862,7 @@ impl CpuStorage {
|
||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||
Embedding {
|
||||
@ -876,7 +874,7 @@ impl CpuStorage {
|
||||
.map(rhs, rhs_l)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul(
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
@ -886,7 +884,39 @@ impl CpuStorage {
|
||||
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<Self> {
|
||||
fn device(&self) -> &Self::Device {
|
||||
&CpuDevice
|
||||
}
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CpuDevice {
|
||||
type Storage = CpuStorage;
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cpu
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||
Ok(s.clone())
|
||||
}
|
||||
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Ok(Self)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||
use rand::prelude::*;
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
@ -902,7 +932,7 @@ impl CpuStorage {
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(uniform))
|
||||
}
|
||||
Ok(Self::F32(data))
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::new();
|
||||
@ -911,12 +941,12 @@ impl CpuStorage {
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(uniform))
|
||||
}
|
||||
Ok(Self::F64(data))
|
||||
Ok(CpuStorage::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<Self> {
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
|
||||
use rand::prelude::*;
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
@ -933,7 +963,7 @@ impl CpuStorage {
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
Ok(Self::F32(data))
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::new();
|
||||
@ -941,32 +971,34 @@ impl CpuStorage {
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
Ok(Self::F64(data))
|
||||
Ok(CpuStorage::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U8 => Self::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => Self::U32(vec![1u32; elem_count]),
|
||||
DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => Self::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => Self::F32(vec![1f32; elem_count]),
|
||||
DType::F64 => Self::F64(vec![1f64; elem_count]),
|
||||
}
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U8 => Self::U8(vec![0u8; elem_count]),
|
||||
DType::U32 => Self::U32(vec![0u32; elem_count]),
|
||||
DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]),
|
||||
DType::F16 => Self::F16(vec![f16::ZERO; elem_count]),
|
||||
DType::F32 => Self::F32(vec![0f32; elem_count]),
|
||||
DType::F64 => Self::F64(vec![0f64; elem_count]),
|
||||
}
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
|
||||
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::cpu_backend::CpuDevice;
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
@ -117,7 +118,7 @@ impl Device {
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::rand_uniform(shape, dtype, lo, up)?;
|
||||
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
@ -136,7 +137,7 @@ impl Device {
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::rand_normal(shape, dtype, mean, std)?;
|
||||
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
@ -149,7 +150,7 @@ impl Device {
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::ones_impl(shape, dtype);
|
||||
let storage = CpuDevice.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
@ -162,7 +163,7 @@ impl Device {
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::zeros_impl(shape, dtype);
|
||||
let storage = CpuDevice.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
|
Reference in New Issue
Block a user