From bcae61b7f24ca37f7d76c2943371d8ce55559558 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 19 Jun 2023 21:30:03 +0100 Subject: [PATCH] Cosmetic changes. --- src/shape.rs | 6 ++++++ src/storage.rs | 2 ++ src/tensor.rs | 29 +++++++++++++++++++++++++---- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/shape.rs b/src/shape.rs index d9112aa7..b1965872 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -1,6 +1,12 @@ use crate::{Error, Result}; pub struct Shape(pub(crate) Vec); +impl std::fmt::Debug for Shape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", &self.dims()) + } +} + impl From<&[usize; 1]> for Shape { fn from(dims: &[usize; 1]) -> Self { Self(dims.to_vec()) diff --git a/src/storage.rs b/src/storage.rs index 10502d43..bcd65ba3 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,6 +2,7 @@ use crate::{DType, Device}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. +#[derive(Debug, Clone)] pub enum CpuStorage { F32(Vec), F64(Vec), @@ -17,6 +18,7 @@ impl CpuStorage { } #[allow(dead_code)] +#[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), } diff --git a/src/tensor.rs b/src/tensor.rs index 64438ce6..37f010da 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,18 +1,24 @@ -use crate::{op::Op, shape, storage::Storage, DType, Device, Result}; +use crate::{op::Op, storage::Storage, DType, Device, Result, Shape}; use std::sync::Arc; #[allow(dead_code)] pub(crate) struct Tensor_ { storage: Storage, - shape: shape::Shape, + shape: Shape, stride: Vec, op: Option, } pub struct Tensor(Arc); +impl std::fmt::Debug for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device()) + } +} + impl Tensor { - pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { + pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { let shape = shape.into(); let storage = device.zeros(&shape, dtype); let rank = shape.rank(); @@ -38,6 +44,21 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } + pub fn to_scalar(&self) -> Result { + // TODO: properly use the strides here. + todo!() + } + + pub fn to_vec1(&self) -> Result> { + // TODO: properly use the strides here. + todo!() + } + + pub fn to_vec2(&self) -> Result>> { + // TODO: properly use the strides here. + todo!() + } + pub fn dtype(&self) -> DType { self.0.storage.dtype() } @@ -46,7 +67,7 @@ impl Tensor { self.0.storage.device() } - pub fn shape(&self) -> &shape::Shape { + pub fn shape(&self) -> &Shape { &self.0.shape }