diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 494c41ee..1b618551 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -13,12 +13,12 @@ pub enum CpuStorage { F64(Vec), } -fn unary_map T>( +fn unary_map U>( shape: &Shape, stride: &[usize], vs: &[T], mut f: F, -) -> Vec { +) -> Vec { if shape.is_contiguous(stride) { vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() } else { @@ -105,6 +105,48 @@ impl CpuStorage { D::cpu_storage_as_mut_slice(self) } + pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + // TODO: find a way around the quadratic number of cases below. + match (self, dtype) { + (Self::U32(storage), DType::F32) => { + let data = unary_map(shape, stride, storage, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::F32(storage), DType::F32) => { + let data = unary_map(shape, stride, storage, |v| v); + Ok(Self::F32(data)) + } + (Self::F64(storage), DType::F32) => { + let data = unary_map(shape, stride, storage, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::U32(storage), DType::U32) => { + let data = unary_map(shape, stride, storage, |v| v); + Ok(Self::U32(data)) + } + (Self::F32(storage), DType::U32) => { + let data = unary_map(shape, stride, storage, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::F64(storage), DType::U32) => { + let data = unary_map(shape, stride, storage, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::U32(storage), DType::F64) => { + let data = unary_map(shape, stride, storage, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::F32(storage), DType::F64) => { + let data = unary_map(shape, stride, storage, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::F64(storage), DType::F64) => { + let data = unary_map(shape, stride, storage, |v| v); + Ok(Self::F64(data)) + } + } + } + pub(crate) fn affine_impl( &self, shape: &Shape, diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 426d4387..345b84d8 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -242,6 +242,10 @@ impl CudaStorage { &self.device } + pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { + Err(CudaError::InternalError("TODO: implement embedding")) + } + pub(crate) fn affine_impl( &self, shape: &Shape, @@ -400,7 +404,7 @@ impl CudaStorage { _hidden_size: usize, _vocab_size: usize, ) -> Result { - todo!("Implement embedding for gpu"); + Err(CudaError::InternalError("TODO: implement embedding")) } pub(crate) fn matmul_impl( diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 5a98b015..f2e0d36c 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -62,6 +62,10 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn unary_impl(&self, _: &Shape, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/src/op.rs b/src/op.rs index c4283d34..7c2539f9 100644 --- a/src/op.rs +++ b/src/op.rs @@ -21,6 +21,7 @@ pub(crate) enum Op { mul: f64, add: f64, }, + ToDType(Tensor), Exp(Tensor), Log(Tensor), Sin(Tensor), diff --git a/src/storage.rs b/src/storage.rs index 2a590d4e..ebbdcbb2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -60,7 +60,6 @@ impl Storage { mul: f64, add: f64, ) -> Result { - // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { let storage = storage.affine_impl(shape, stride, mul, add)?; @@ -73,6 +72,19 @@ impl Storage { } } + pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.to_dtype(shape, stride, dtype)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.to_dtype(shape, stride, dtype)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn unary_impl( &self, shape: &Shape, diff --git a/src/tensor.rs b/src/tensor.rs index c1ebaae0..264df5f6 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -602,6 +602,17 @@ impl Tensor { } } + pub fn to_dtype(&self, dtype: DType) -> Result { + let shape = self.shape(); + let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; + let op = if self.track_op() { + Some(Op::ToDType(self.clone())) + } else { + None + }; + Ok(from_storage(storage, shape.clone(), op, false)) + } + pub fn contiguous(&self) -> Result { if self.is_contiguous() { Ok(self.clone()) @@ -773,6 +784,7 @@ impl Tensor { } } Op::Reshape(node) + | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) | Op::Softmax(node, _) @@ -892,6 +904,10 @@ impl Tensor { *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }), + Op::ToDType(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? + } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.or_insert(arg)?;