use crate::op::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; use gemm::{gemm, Parallelism}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] pub enum CpuStorage { U32(Vec), F32(Vec), F64(Vec), } fn unary_map T>(shape: &Shape, stride: &[usize], f: F) -> Vec { if shape.is_contiguous(stride) { (0..shape.elem_count()).map(f).collect() } else { StridedIndex::new(shape.dims(), stride).map(f).collect() } } // This function maps over two strided index sequences. It supports broadcasting in case // `lhs_stride` or `rhs_stride` has a length shorter than `shape`. fn binary_map T>( shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { let dims = shape.dims(); let broadcast_ldims = dims.len() - lhs_stride.len(); let broadcast_rdims = dims.len() - rhs_stride.len(); let elem_count = shape.elem_count(); if broadcast_ldims == 0 && broadcast_rdims == 0 { if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { let lhs_index = StridedIndex::new(dims, lhs_stride); let rhs_index = StridedIndex::new(dims, rhs_stride); lhs_index .zip(rhs_index) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect() } } else if broadcast_rdims == 0 { let mut res = Vec::new(); res.reserve(elem_count); let lhs_v: Vec = StridedIndex::new(dims, lhs_stride) .map(|i| lhs[i]) .collect(); let mut i = 0; for rhs_i in StridedIndex::new(dims, rhs_stride) { res.push(f(lhs_v[i], rhs[rhs_i])); i += 1; if i >= lhs_v.len() { i = 0 } } res } else if broadcast_ldims == 0 { let mut res = Vec::new(); res.reserve(elem_count); let rhs_v: Vec = StridedIndex::new(dims, rhs_stride) .map(|i| rhs[i]) .collect(); let mut i = 0; for lhs_i in StridedIndex::new(dims, lhs_stride) { res.push(f(lhs[lhs_i], rhs_v[i])); i += 1; if i >= rhs_v.len() { i = 0 } } res } else { panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}") } } impl CpuStorage { pub fn dtype(&self) -> DType { match self { Self::U32(_) => DType::U32, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, } } pub fn as_slice(&self) -> Result<&[D]> { D::cpu_storage_as_slice(self) } pub fn as_mut_slice(&mut self) -> Result<&mut [D]> { D::cpu_storage_as_mut_slice(self) } pub(crate) fn affine_impl( &self, shape: &Shape, stride: &[usize], mul: f64, add: f64, ) -> Result { match self { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; let data = unary_map(shape, stride, |i| storage[i] * mul + add); Ok(Self::U32(data)) } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; let data = unary_map(shape, stride, |i| storage[i] * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { let data = unary_map(shape, stride, |i| storage[i] * mul + add); Ok(Self::F64(data)) } } } pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { // TODO: Different code path for the contiguous case? match self { Self::F32(storage) => { let index = StridedIndex::new(shape.dims(), stride); let data = index.map(|i| B::f32(storage[i])).collect(); Ok(Self::F32(data)) } Self::F64(storage) => { let index = StridedIndex::new(shape.dims(), stride); let data = index.map(|i| B::f64(storage[i])).collect(); Ok(Self::F64(data)) } Self::U32(_storage) => { todo!("No unary for u32 because of neg, sqrt") } } } pub(crate) fn binary_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { match (self, rhs) { (Self::F32(lhs), Self::F32(rhs)) => { let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { // This should be covered by the dtype check above. Err(Error::DTypeMismatchBinaryOp { lhs: self.dtype(), rhs: rhs.dtype(), op: B::NAME, }) } } } pub(crate) fn copy_strided_src( &self, dst: &mut Self, src_shape: &Shape, src_stride: &[usize], dst_offset: usize, ) -> Result<()> { if src_shape.rank() != src_stride.len() { panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") } match (self, dst) { (Self::F32(src), Self::F32(dst)) => { if src_shape.is_contiguous(src_stride) { dst[dst_offset..].copy_from_slice(src) } else { let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); for (dst_index, src_index) in src_indexes.enumerate() { dst[dst_index + dst_offset] = src[src_index] } } } (Self::F64(src), Self::F64(dst)) => { if src_shape.is_contiguous(src_stride) { dst[dst_offset..].copy_from_slice(src) } else { let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); for (dst_index, src_index) in src_indexes.enumerate() { dst[dst_index + dst_offset] = src[src_index] } } } (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { lhs: self.dtype(), rhs: dst.dtype(), op: "copy_strided", }); } } Ok(()) } pub(crate) fn embedding_impl( &self, rhs: &Self, hidden_size: usize, vocab_size: usize, ) -> Result { match self { CpuStorage::U32(lhs) => match rhs { CpuStorage::F32(rhs) => { let mut weights = Vec::with_capacity(lhs.len() * hidden_size); for &index in lhs { let index: usize = index.try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { index, vocab_size, op: "embedding", }); } else { weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]); } } Ok(CpuStorage::F32(weights)) } CpuStorage::F64(rhs) => { let mut weights = Vec::with_capacity(lhs.len() * hidden_size); for &index in lhs { let index: usize = index.try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { index, vocab_size, op: "embedding", }); } else { weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]); } } Ok(CpuStorage::F64(weights)) } rhs => Err(Error::UnexpectedDType { expected: DType::F32, got: rhs.dtype(), }), }, lhs => Err(Error::UnexpectedDType { expected: DType::U32, got: lhs.dtype(), }), } } pub(crate) fn matmul_impl( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; let rank = lhs_stride.len(); let lhs_cs = lhs_stride[rank - 1]; let lhs_rs = lhs_stride[rank - 2]; let rhs_cs = rhs_stride[rank - 1]; let rhs_rs = rhs_stride[rank - 2]; if lhs_stride.len() > 2 { let lhs_batch_stride = &lhs_stride[..rank - 2]; let rhs_batch_stride = &rhs_stride[..rank - 2]; if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { // Temporary error before we support abitrary striding. return Err(Error::UnexpectedStriding); } } let mut dst = vec![0.0; b * m * n]; let dst_shape: Shape = (m, n).into(); let dst_strides = dst_shape.stride_contiguous(); let dst_rs = dst_strides[0]; let dst_cs = dst_strides[1]; for step in 0..b { let lhs_p = &self.as_slice::()?[step * a_skip..]; let rhs_p = &rhs.as_slice::()?[step * b_skip..]; let dst_p = &mut dst[step * c_skip..]; unsafe { gemm( // m: usize, m, // n: usize, n, // k: usize, k, // dst: *mut T, dst_p.as_mut_ptr(), // dst_cs: isize, dst_cs as isize, // dst_rs: isize, dst_rs as isize, // read_dst: bool, false, // lhs: *const T, lhs_p.as_ptr(), // lhs_cs: isize, lhs_cs as isize, // lhs_rs: isize, lhs_rs as isize, // rhs: *const T, rhs_p.as_ptr(), // rhs_cs: isize, rhs_cs as isize, // rhs_rs: isize, rhs_rs as isize, // alpha: T, 1.0, // beta: T, 1.0, // conj_dst: bool, false, // conj_lhs: bool, false, // conj_rhs: bool, true, // parallelism: Parallelism Parallelism::None, ) } } let c = Self::F32(dst); Ok(c) } pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { DType::U32 => { let data = vec![1u32; elem_count]; Self::U32(data) } DType::F32 => { let data = vec![1f32; elem_count]; Self::F32(data) } DType::F64 => { let data = vec![1f64; elem_count]; Self::F64(data) } } } pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { DType::U32 => { let data = vec![0u32; elem_count]; Self::U32(data) } DType::F32 => { let data = vec![0f32; elem_count]; Self::F32(data) } DType::F64 => { let data = vec![0f64; elem_count]; Self::F64(data) } } } } #[cfg(test)] mod tests { use super::*; use crate::{Device, Tensor}; #[test] fn simple_matmul() -> Result<()> { let data = vec![1.0f32, 2.0, 3.0, 4.0]; let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; let data = vec![1.0f32, 2.0, 3.0, 4.0]; let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]); let data = vec![1.0f32, 2.0]; let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?; let data = vec![3.0f32, 4.0]; let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); let data: Vec<_> = (0..6).map(|i| i as f32).collect(); let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?; let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); let data: Vec<_> = (0..12).map(|i| i as f32).collect(); let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?; let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!( c.to_vec3::()?, &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]] ); Ok(()) } }