Files
candle/src/cpu_backend.rs

438 lines
15 KiB
Rust

use crate::op::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
use gemm::{gemm, Parallelism};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)]
pub enum CpuStorage {
U32(Vec<u32>),
F32(Vec<f32>),
F64(Vec<f64>),
}
fn unary_map<T, F: FnMut(usize) -> T>(shape: &Shape, stride: &[usize], f: F) -> Vec<T> {
if shape.is_contiguous(stride) {
(0..shape.elem_count()).map(f).collect()
} else {
StridedIndex::new(shape.dims(), stride).map(f).collect()
}
}
// This function maps over two strided index sequences. It supports broadcasting in case
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<T> {
let dims = shape.dims();
let broadcast_ldims = dims.len() - lhs_stride.len();
let broadcast_rdims = dims.len() - rhs_stride.len();
let elem_count = shape.elem_count();
if broadcast_ldims == 0 && broadcast_rdims == 0 {
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
} else {
let lhs_index = StridedIndex::new(dims, lhs_stride);
let rhs_index = StridedIndex::new(dims, rhs_stride);
lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect()
}
} else if broadcast_rdims == 0 {
let mut res = Vec::new();
res.reserve(elem_count);
let lhs_v: Vec<T> = StridedIndex::new(dims, lhs_stride)
.map(|i| lhs[i])
.collect();
let mut i = 0;
for rhs_i in StridedIndex::new(dims, rhs_stride) {
res.push(f(lhs_v[i], rhs[rhs_i]));
i += 1;
if i >= lhs_v.len() {
i = 0
}
}
res
} else if broadcast_ldims == 0 {
let mut res = Vec::new();
res.reserve(elem_count);
let rhs_v: Vec<T> = StridedIndex::new(dims, rhs_stride)
.map(|i| rhs[i])
.collect();
let mut i = 0;
for lhs_i in StridedIndex::new(dims, lhs_stride) {
res.push(f(lhs[lhs_i], rhs_v[i]));
i += 1;
if i >= rhs_v.len() {
i = 0
}
}
res
} else {
panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}")
}
}
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
Self::U32(_) => DType::U32,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
}
}
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
D::cpu_storage_as_mut_slice(self)
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Self::U32(storage) => {
let mul = mul as u32;
let add = add as u32;
let data = unary_map(shape, stride, |i| storage[i] * mul + add);
Ok(Self::U32(data))
}
Self::F32(storage) => {
let mul = mul as f32;
let add = add as f32;
let data = unary_map(shape, stride, |i| storage[i] * mul + add);
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(shape, stride, |i| storage[i] * mul + add);
Ok(Self::F64(data))
}
}
}
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Self::F32(storage) => {
let index = StridedIndex::new(shape.dims(), stride);
let data = index.map(|i| B::f32(storage[i])).collect();
Ok(Self::F32(data))
}
Self::F64(storage) => {
let index = StridedIndex::new(shape.dims(), stride);
let data = index.map(|i| B::f64(storage[i])).collect();
Ok(Self::F64(data))
}
Self::U32(_storage) => {
todo!("No unary for u32 because of neg, sqrt")
}
}
}
pub(crate) fn binary_impl<B: BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
match (self, rhs) {
(Self::F32(lhs), Self::F32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
Ok(Self::F32(data))
}
(Self::F64(lhs), Self::F64(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64);
Ok(Self::F64(data))
}
(Self::U32(lhs), Self::U32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32);
Ok(Self::U32(data))
}
_ => {
// This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: rhs.dtype(),
op: B::NAME,
})
}
}
}
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
src_shape: &Shape,
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
if src_shape.rank() != src_stride.len() {
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
}
match (self, dst) {
(Self::F32(src), Self::F32(dst)) => {
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
}
(Self::F64(src), Self::F64(dst)) => {
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
}
(_, dst) => {
// This should be covered by the dtype check above.
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy_strided",
});
}
}
Ok(())
}
pub(crate) fn embedding_impl(
&self,
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
match self {
CpuStorage::U32(lhs) => match rhs {
CpuStorage::F32(rhs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
for &index in lhs {
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F32(weights))
}
CpuStorage::F64(rhs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
for &index in lhs {
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F64(weights))
}
rhs => Err(Error::UnexpectedDType {
expected: DType::F32,
got: rhs.dtype(),
}),
},
lhs => Err(Error::UnexpectedDType {
expected: DType::U32,
got: lhs.dtype(),
}),
}
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
let rank = lhs_stride.len();
let lhs_cs = lhs_stride[rank - 1];
let lhs_rs = lhs_stride[rank - 2];
let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];
if lhs_stride.len() > 2 {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
}
let mut dst = vec![0.0; b * m * n];
let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous();
let dst_rs = dst_strides[0];
let dst_cs = dst_strides[1];
for step in 0..b {
let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1.0,
// beta: T,
1.0,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
let c = Self::F32(dst);
Ok(c)
}
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
DType::U32 => {
let data = vec![1u32; elem_count];
Self::U32(data)
}
DType::F32 => {
let data = vec![1f32; elem_count];
Self::F32(data)
}
DType::F64 => {
let data = vec![1f64; elem_count];
Self::F64(data)
}
}
}
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
DType::U32 => {
let data = vec![0u32; elem_count];
Self::U32(data)
}
DType::F32 => {
let data = vec![0f32; elem_count];
Self::F32(data)
}
DType::F64 => {
let data = vec![0f64; elem_count];
Self::F64(data)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Device, Tensor};
#[test]
fn simple_matmul() -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(
c.to_vec3::<f32>()?,
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
);
Ok(())
}
}