diff --git a/src/error.rs b/src/error.rs index 7416ed76..0114a86c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use crate::{DType, Shape}; +use crate::{DType, Device, Shape}; /// Main library error type. #[derive(thiserror::Error, Debug)] @@ -13,6 +13,20 @@ pub enum Error { op: &'static str, }, + #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + DeviceMismatchBinaryOp { + lhs: Device, + rhs: Device, + op: &'static str, + }, + + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + DTypeMismatchBinaryOp { + lhs: DType, + rhs: DType, + op: &'static str, + }, + #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] UnexpectedNumberOfDims { expected: usize, diff --git a/src/shape.rs b/src/shape.rs index 4b186ca0..680460e8 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -144,6 +144,7 @@ impl Shape { pub(crate) fn stride_contiguous(&self) -> Vec { self.0 .iter() + .rev() .scan(1, |prod, u| { let prod_pre_mult = *prod; *prod *= u; diff --git a/src/storage.rs b/src/storage.rs index bcd65ba3..859b3f76 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{DType, Device}; +use crate::{DType, Device, Error, Result, Shape}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -35,4 +35,75 @@ impl Storage { Self::Cpu(storage) => storage.dtype(), } } + + pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { + let lhs = self.device(); + let rhs = rhs.device(); + if lhs != rhs { + Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) + } else { + Ok(()) + } + } + + pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> { + let lhs = self.dtype(); + let rhs = rhs.dtype(); + if lhs != rhs { + Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }) + } else { + Ok(()) + } + } + + pub(crate) fn add_impl( + &self, + rhs: &Self, + shape: &Shape, + _lhs_stride: &[usize], + _rhs_stride: &[usize], + ) -> Result { + self.same_device(rhs, "add")?; + self.same_dtype(rhs, "add")?; + // The ggml implementation has different paths based on whether the rhs is contiguous + // or not, for now we only consider the general case but we should benchmark and do the + // same if it helps. + // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895 + match (self, rhs) { + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) { + (CpuStorage::F32(_), CpuStorage::F32(_)) => { + let elem_count = shape.elem_count(); + let data = vec![0f32; elem_count]; + // TODO: properly fill data with the sum + Ok(Storage::Cpu(CpuStorage::F32(data))) + } + (CpuStorage::F64(_), CpuStorage::F64(_)) => { + let elem_count = shape.elem_count(); + let data = vec![0f64; elem_count]; + // TODO: properly fill data with the sum + Ok(Storage::Cpu(CpuStorage::F64(data))) + } + _ => { + // This should be covered by the dtype check above. + Err(Error::DTypeMismatchBinaryOp { + lhs: lhs.dtype(), + rhs: rhs.dtype(), + op: "add", + }) + } + }, + } + } + + pub(crate) fn mul_impl( + &self, + rhs: &Self, + _shape: &Shape, + _lhs_stride: &[usize], + _rhs_stride: &[usize], + ) -> Result { + self.same_device(rhs, "mul")?; + self.same_dtype(rhs, "mul")?; + todo!() + } } diff --git a/src/tensor.rs b/src/tensor.rs index 83aa00d7..8881cad0 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -2,7 +2,7 @@ use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; use std::sync::Arc; #[allow(dead_code)] -pub(crate) struct Tensor_ { +pub struct Tensor_ { storage: Storage, shape: Shape, // The strides are given in number of elements and not in bytes. @@ -10,8 +10,17 @@ pub(crate) struct Tensor_ { op: Option, } +#[derive(Clone)] pub struct Tensor(Arc); +impl std::ops::Deref for Tensor { + type Target = Tensor_; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + impl std::fmt::Debug for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device()) @@ -45,7 +54,7 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> { + pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { let lhs = self.shape(); let rhs = rhs.shape(); if lhs != rhs { @@ -55,18 +64,38 @@ impl Tensor { op, }) } else { - Ok(()) + Ok(lhs) } } + // TODO: Also make an inplace version or a pre-allocated? This could be tricky + // if this can create cycles in the compute graph. pub fn add(&self, rhs: &Self) -> Result { - self.same_shape_binary_op(rhs, "add")?; - todo!() + let shape = self.same_shape_binary_op(rhs, "add")?; + let storage = self + .storage + .add_impl(&rhs.storage, shape, self.stride(), rhs.stride())?; + let tensor_ = Tensor_ { + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: Some(Op::Add(self.clone(), rhs.clone())), + }; + Ok(Self(Arc::new(tensor_))) } pub fn mul(&self, rhs: &Self) -> Result { - self.same_shape_binary_op(rhs, "mul")?; - todo!() + let shape = self.same_shape_binary_op(rhs, "mul")?; + let storage = self + .storage + .mul_impl(&rhs.storage, shape, self.stride(), rhs.stride())?; + let tensor_ = Tensor_ { + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: Some(Op::Mul(self.clone(), rhs.clone())), + }; + Ok(Self(Arc::new(tensor_))) } pub fn to_scalar(&self) -> Result { @@ -77,7 +106,7 @@ impl Tensor { shape: self.shape().clone(), }); } - match &self.0.storage { + match &self.storage { Storage::Cpu(cpu_storage) => { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(data[0]) @@ -96,15 +125,15 @@ impl Tensor { } pub fn dtype(&self) -> DType { - self.0.storage.dtype() + self.storage.dtype() } pub fn device(&self) -> Device { - self.0.storage.device() + self.storage.device() } pub fn shape(&self) -> &Shape { - &self.0.shape + &self.shape } pub fn dims(&self) -> &[usize] { @@ -112,7 +141,7 @@ impl Tensor { } pub fn stride(&self) -> &[usize] { - &self.0.stride + &self.stride } pub fn rank(&self) -> usize {