diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index bf372589..ca2848cf 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -8,6 +8,7 @@ use gemm::{gemm, Parallelism}; // intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] pub enum CpuStorage { + U32(Vec), F32(Vec), F64(Vec), } @@ -15,6 +16,7 @@ pub enum CpuStorage { impl CpuStorage { pub fn dtype(&self) -> DType { match self { + Self::U32(_) => DType::U32, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, } @@ -36,6 +38,13 @@ impl CpuStorage { add: f64, ) -> Result { match self { + Self::U32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let mul = mul as u32; + let add = add as u32; + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Self::U32(data)) + } Self::F32(storage) => { let index = StridedIndex::new(shape.dims(), stride); let mul = mul as f32; @@ -64,6 +73,9 @@ impl CpuStorage { let data = index.map(|i| B::f64(storage[i])).collect(); Ok(Self::F64(data)) } + Self::U32(_storage) => { + todo!("No unary for u32 because of neg, sqrt") + } } } @@ -138,6 +150,57 @@ impl CpuStorage { } } Ok(()) + + pub(crate) fn embedding_impl( + &self, + rhs: &Self, + hidden_size: usize, + vocab_size: usize, + ) -> Result { + match self { + CpuStorage::U32(lhs) => match rhs { + CpuStorage::F32(rhs) => { + let mut weights = Vec::with_capacity(lhs.len() * hidden_size); + for &index in lhs { + let index: usize = index.try_into()?; + if index >= vocab_size { + return Err(Error::InvalidIndex { + index, + vocab_size, + op: "embedding", + }); + } else { + weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]); + } + } + Ok(CpuStorage::F32(weights)) + } + CpuStorage::F64(rhs) => { + let mut weights = Vec::with_capacity(lhs.len() * hidden_size); + for &index in lhs { + let index: usize = index.try_into()?; + if index >= vocab_size { + return Err(Error::InvalidIndex { + index, + vocab_size, + op: "embedding", + }); + } else { + weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]); + } + } + Ok(CpuStorage::F64(weights)) + } + rhs => Err(Error::UnexpectedDType { + expected: DType::F32, + got: rhs.dtype(), + }), + }, + lhs => Err(Error::UnexpectedDType { + expected: DType::U32, + got: lhs.dtype(), + }), + } } pub(crate) fn matmul_impl( @@ -230,6 +293,10 @@ impl CpuStorage { pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { + DType::U32 => { + let data = vec![1u32; elem_count]; + Self::U32(data) + } DType::F32 => { let data = vec![1f32; elem_count]; Self::F32(data) @@ -244,6 +311,10 @@ impl CpuStorage { pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { + DType::U32 => { + let data = vec![0u32; elem_count]; + Self::U32(data) + } DType::F32 => { let data = vec![0f32; elem_count]; Self::F32(data) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index edbb7700..50b8c7ff 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -346,6 +346,15 @@ impl CudaStorage { } } + pub(crate) fn embedding_impl( + &self, + rhs: &Self, + hidden_size: usize, + vocab_size: usize, + ) -> Result { + todo!("Implement embedding for gpu"); + } + pub(crate) fn matmul_impl( &self, rhs: &Self, diff --git a/src/dtype.rs b/src/dtype.rs index 1b348175..a11ab350 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -2,6 +2,7 @@ use crate::{CpuStorage, Error, Result}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + U32, F32, F64, } @@ -9,6 +10,7 @@ pub enum DType { impl DType { pub fn size_in_bytes(&self) -> usize { match self { + Self::U32 => 4, Self::F32 => 4, Self::F64 => 8, } @@ -70,5 +72,6 @@ macro_rules! with_dtype { } }; } +with_dtype!(u32, U32); with_dtype!(f32, F32); with_dtype!(f64, F64); diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index a12bafe3..5a98b015 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -76,6 +76,10 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn matmul_impl( &self, _: &Self, diff --git a/src/error.rs b/src/error.rs index c8c338ea..8a75cd38 100644 --- a/src/error.rs +++ b/src/error.rs @@ -15,6 +15,13 @@ pub enum Error { #[error("backward is not supported for {op}")] BackwardNotSupported { op: &'static str }, + #[error("{op} invalid index {index} with vocab {vocab_size}")] + InvalidIndex { + op: &'static str, + index: usize, + vocab_size: usize, + }, + #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, @@ -65,6 +72,9 @@ pub enum Error { #[error(transparent)] Cuda(#[from] crate::CudaError), + + #[error(transparent)] + TryFromIntError(#[from] core::num::TryFromIntError), } pub type Result = std::result::Result; diff --git a/src/op.rs b/src/op.rs index 45fe97a4..e495d60d 100644 --- a/src/op.rs +++ b/src/op.rs @@ -7,6 +7,7 @@ pub(crate) enum Op { Sub(Tensor, Tensor), Div(Tensor, Tensor), Matmul(Tensor, Tensor), + Embedding(Tensor, Tensor), Cat(Vec, usize), diff --git a/src/storage.rs b/src/storage.rs index bb2aad4c..c4938fa3 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -122,6 +122,31 @@ impl Storage { } } + pub(crate) fn embedding_impl( + &self, + rhs: &Self, + hidden_size: usize, + vocab_size: usize, + ) -> Result { + self.same_device(rhs, "matmul")?; + self.same_dtype(rhs, "matmul")?; + match (self, rhs) { + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; + Ok(Self::Cuda(storage)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "embedding", + }), + } + } + pub(crate) fn matmul_impl( &self, rhs: &Self, diff --git a/src/tensor.rs b/src/tensor.rs index 1e411dcc..65a90d7e 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -344,6 +344,33 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } + pub fn embedding(ids: &Self, rhs: &Self) -> Result { + if !rhs.is_contiguous() { + return Err(Error::RequiresContiguous { op: "embedding" }); + } else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 { + return Err(Error::ShapeMismatchBinaryOp { + lhs: ids.shape.clone(), + rhs: rhs.shape.clone(), + op: "embedding", + }); + } + let seq_len = ids.shape().r1()?; + let (vocab_size, hidden_size) = rhs.shape().r2()?; + let storage = ids + .storage + .embedding_impl(&rhs.storage, hidden_size, vocab_size)?; + let shape: Shape = (seq_len, hidden_size).into(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: Some(Op::Embedding(ids.clone(), rhs.clone())), + is_variable: false, + }; + Ok(Self(Arc::new(tensor_))) + } + pub(crate) fn strided_index(&self) -> crate::StridedIndex { crate::StridedIndex::new(self.dims(), self.stride()) } @@ -740,6 +767,7 @@ impl Tensor { | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) | Op::Div(lhs, rhs) + | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; @@ -830,6 +858,9 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } + Op::Embedding(_lhs, _rhs) => { + todo!("Backward for embedding not implemented"); + } Op::Matmul(lhs, rhs) => { // Skipping checks, the op went ok, we can skip // the matmul size checks for now.