use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. #[derive(Debug)] pub enum Storage { Cpu(CpuStorage), Cuda(CudaStorage), } impl Storage { pub fn try_clone(&self) -> Result { match self { Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())), Self::Cuda(storage) => { let storage = storage.try_clone()?; Ok(Self::Cuda(storage)) } } } pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, Self::Cuda(storage) => Device::Cuda(storage.device().clone()), } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), Self::Cuda(storage) => storage.dtype(), } } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { let lhs = self.device().location(); let rhs = rhs.device().location(); if lhs != rhs { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) } else { Ok(()) } } pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> { let lhs = self.dtype(); let rhs = rhs.dtype(); if lhs != rhs { Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }) } else { Ok(()) } } pub(crate) fn affine_impl( &self, shape: &Shape, stride: &[usize], mul: f64, add: f64, ) -> Result { match self { Storage::Cpu(storage) => { let storage = storage.affine_impl(shape, stride, mul, add)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { let storage = storage.affine_impl(shape, stride, mul, add)?; Ok(Self::Cuda(storage)) } } } pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result { match self { Storage::Cpu(storage) => { let storage = storage.sum(shape, stride, s)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { let storage = storage.sum(shape, stride, s)?; Ok(Self::Cuda(storage)) } } } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { match self { Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?, } Ok(()) } pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { match self { Storage::Cpu(storage) => { let storage = storage.to_dtype(shape, stride, dtype)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { let storage = storage.to_dtype(shape, stride, dtype)?; Ok(Self::Cuda(storage)) } } } pub(crate) fn unary_impl( &self, shape: &Shape, stride: &[usize], ) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { let storage = storage.unary_impl::(shape, stride)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { let storage = storage.unary_impl::(shape, stride)?; Ok(Self::Cuda(storage)) } } } pub(crate) fn binary_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive // anyway. Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), op: B::NAME, }) } } } pub(crate) fn where_cond( &self, shape: &Shape, stride: &[usize], t: &Self, stride_t: &[usize], f: &Self, stride_f: &[usize], ) -> Result { self.same_device(t, "where")?; self.same_device(f, "where")?; t.same_dtype(f, "where")?; match (self, t, f) { (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; Ok(Self::Cpu(storage)) } (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; Ok(Self::Cuda(storage)) } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "embedding", }), } } pub(crate) fn embedding_impl( &self, shape: &Shape, stride: &[usize], rhs: &Self, hidden_size: usize, vocab_size: usize, ) -> Result { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "embedding", }), } } pub(crate) fn matmul_impl( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { (Self::Cpu(lhs), Self::Cpu(rhs)) => { let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "matmul", }), } } // self, the source can be strided whereas dst is contiguous. pub(crate) fn copy_strided_src( &self, dst: &mut Self, dst_offset: usize, src_shape: &Shape, src_stride: &[usize], src_offset: usize, ) -> Result<()> { match (self, dst) { (Self::Cpu(src), Self::Cpu(dst)) => { src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset) } (Self::Cuda(src), Self::Cuda(dst)) => { Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "copy", }), } } }