Add the casting operation.

This commit is contained in:
laurent
2023-06-23 21:22:07 +01:00
parent 8ed350dc94
commit 5d44e76e3f
6 changed files with 83 additions and 4 deletions

View File

@ -13,12 +13,12 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
fn unary_map<T: Copy, F: FnMut(T) -> T>(
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
shape: &Shape,
stride: &[usize],
vs: &[T],
mut f: F,
) -> Vec<T> {
) -> Vec<U> {
if shape.is_contiguous(stride) {
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
} else {
@ -105,6 +105,48 @@ impl CpuStorage {
D::cpu_storage_as_mut_slice(self)
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
// TODO: find a way around the quadratic number of cases below.
match (self, dtype) {
(Self::U32(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v as f32);
Ok(Self::F32(data))
}
(Self::F32(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v);
Ok(Self::F32(data))
}
(Self::F64(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v as f32);
Ok(Self::F32(data))
}
(Self::U32(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v);
Ok(Self::U32(data))
}
(Self::F32(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F64(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v as f64);
Ok(Self::F64(data))
}
(Self::F32(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v as f64);
Ok(Self::F64(data))
}
(Self::F64(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v);
Ok(Self::F64(data))
}
}
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,

View File

@ -242,6 +242,10 @@ impl CudaStorage {
&self.device
}
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement embedding"))
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
@ -400,7 +404,7 @@ impl CudaStorage {
_hidden_size: usize,
_vocab_size: usize,
) -> Result<Self> {
todo!("Implement embedding for gpu");
Err(CudaError::InternalError("TODO: implement embedding"))
}
pub(crate) fn matmul_impl(

View File

@ -62,6 +62,10 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -21,6 +21,7 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
ToDType(Tensor),
Exp(Tensor),
Log(Tensor),
Sin(Tensor),

View File

@ -60,7 +60,6 @@ impl Storage {
mul: f64,
add: f64,
) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
let storage = storage.affine_impl(shape, stride, mul, add)?;
@ -73,6 +72,19 @@ impl Storage {
}
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.to_dtype(shape, stride, dtype)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.to_dtype(shape, stride, dtype)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn unary_impl<B: op::UnaryOp>(
&self,
shape: &Shape,

View File

@ -602,6 +602,17 @@ impl Tensor {
}
}
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
let shape = self.shape();
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
let op = if self.track_op() {
Some(Op::ToDType(self.clone()))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false))
}
pub fn contiguous(&self) -> Result<Tensor> {
if self.is_contiguous() {
Ok(self.clone())
@ -773,6 +784,7 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Softmax(node, _)
@ -892,6 +904,10 @@ impl Tensor {
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;