mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Move the cpu backend specific bits apart.
This commit is contained in:
99
src/cpu_backend.rs
Normal file
99
src/cpu_backend.rs
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
use crate::storage::{BinaryOp, UnaryOp};
|
||||||
|
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||||
|
|
||||||
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
|
// a buffer as an owned slice of bytes.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CpuStorage {
|
||||||
|
F32(Vec<f32>),
|
||||||
|
F64(Vec<f64>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CpuStorage {
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
match self {
|
||||||
|
Self::F32(_) => DType::F32,
|
||||||
|
Self::F64(_) => DType::F64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn affine_impl(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
mul: f64,
|
||||||
|
add: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
Self::F32(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let mul = mul as f32;
|
||||||
|
let add = add as f32;
|
||||||
|
let data = index.map(|i| storage[i] * mul + add).collect();
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
Self::F64(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let data = index.map(|i| storage[i] * mul + add).collect();
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
|
// TODO: Different code path for the contiguous case?
|
||||||
|
match self {
|
||||||
|
Self::F32(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let data = index.map(|i| B::f32(storage[i])).collect();
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
Self::F64(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let data = index.map(|i| B::f64(storage[i])).collect();
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn binary_impl<B: BinaryOp>(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
shape: &Shape,
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
// The ggml implementation has different paths based on whether the rhs is contiguous
|
||||||
|
// or not, for now we only consider the general case but we should benchmark and do the
|
||||||
|
// same if it helps.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
|
||||||
|
match (self, rhs) {
|
||||||
|
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||||
|
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
||||||
|
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||||
|
let data = lhs_index
|
||||||
|
.zip(rhs_index)
|
||||||
|
.map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect();
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||||
|
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
||||||
|
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||||
|
let data = lhs_index
|
||||||
|
.zip(rhs_index)
|
||||||
|
.map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect();
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// This should be covered by the dtype check above.
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: rhs.dtype(),
|
||||||
|
op: B::NAME,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,7 +1,4 @@
|
|||||||
use crate::{
|
use crate::{CpuStorage, DType, Result, Shape, Storage};
|
||||||
storage::{CpuStorage, Storage},
|
|
||||||
DType, Result, Shape,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
mod cpu_backend;
|
||||||
mod device;
|
mod device;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod error;
|
mod error;
|
||||||
@ -7,10 +8,11 @@ mod storage;
|
|||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use device::Device;
|
pub use device::Device;
|
||||||
pub use dtype::{DType, WithDType};
|
pub use dtype::{DType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use shape::Shape;
|
pub use shape::Shape;
|
||||||
pub use storage::{CpuStorage, Storage};
|
pub use storage::Storage;
|
||||||
use strided_index::StridedIndex;
|
use strided_index::StridedIndex;
|
||||||
pub use tensor::{Tensor, TensorId};
|
pub use tensor::{Tensor, TensorId};
|
||||||
|
@ -1,21 +1,4 @@
|
|||||||
use crate::{DType, Device, Error, Result, Shape, StridedIndex};
|
use crate::{CpuStorage, DType, Device, Error, Result, Shape};
|
||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
|
||||||
// a buffer as an owned slice of bytes.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum CpuStorage {
|
|
||||||
F32(Vec<f32>),
|
|
||||||
F64(Vec<f64>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CpuStorage {
|
|
||||||
pub(crate) fn dtype(&self) -> DType {
|
|
||||||
match self {
|
|
||||||
Self::F32(_) => DType::F32,
|
|
||||||
Self::F64(_) => DType::F64,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
@ -23,13 +6,13 @@ pub enum Storage {
|
|||||||
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
|
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
|
||||||
}
|
}
|
||||||
|
|
||||||
trait UnaryOp {
|
pub(crate) trait UnaryOp {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
fn f32(v1: f32) -> f32;
|
fn f32(v1: f32) -> f32;
|
||||||
fn f64(v1: f64) -> f64;
|
fn f64(v1: f64) -> f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
trait BinaryOp {
|
pub(crate) trait BinaryOp {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
fn f32(v1: f32, v2: f32) -> f32;
|
fn f32(v1: f32, v2: f32) -> f32;
|
||||||
fn f64(v1: f64, v2: f64) -> f64;
|
fn f64(v1: f64, v2: f64) -> f64;
|
||||||
@ -157,20 +140,10 @@ impl Storage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// TODO: Different code path for the contiguous case?
|
// TODO: Different code path for the contiguous case?
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => match storage {
|
Storage::Cpu(storage) => {
|
||||||
CpuStorage::F32(storage) => {
|
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
Ok(Self::Cpu(storage))
|
||||||
let mul = mul as f32;
|
}
|
||||||
let add = add as f32;
|
|
||||||
let data = index.map(|i| storage[i] * mul + add).collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
|
||||||
let data = index.map(|i| storage[i] * mul + add).collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Self::Cuda { .. } => todo!(),
|
Self::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,18 +151,10 @@ impl Storage {
|
|||||||
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
// TODO: Different code path for the contiguous case?
|
// TODO: Different code path for the contiguous case?
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => match storage {
|
Storage::Cpu(storage) => {
|
||||||
CpuStorage::F32(storage) => {
|
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
Ok(Self::Cpu(storage))
|
||||||
let data = index.map(|i| B::f32(storage[i])).collect();
|
}
|
||||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
|
||||||
let data = index.map(|i| B::f64(storage[i])).collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Self::Cuda { .. } => todo!(),
|
Self::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -204,39 +169,11 @@ impl Storage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, B::NAME)?;
|
self.same_device(rhs, B::NAME)?;
|
||||||
self.same_dtype(rhs, B::NAME)?;
|
self.same_dtype(rhs, B::NAME)?;
|
||||||
// The ggml implementation has different paths based on whether the rhs is contiguous
|
|
||||||
// or not, for now we only consider the general case but we should benchmark and do the
|
|
||||||
// same if it helps.
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
|
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
Ok(Self::Cpu(storage))
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
}
|
||||||
let data = lhs_index
|
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
|
||||||
}
|
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
|
||||||
let data = lhs_index
|
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// This should be covered by the dtype check above.
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: lhs.dtype(),
|
|
||||||
rhs: rhs.dtype(),
|
|
||||||
op: B::NAME,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
|
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
|
Reference in New Issue
Block a user