Add an easy way to create tensor objects.

This commit is contained in:
laurent
2023-06-19 20:59:26 +01:00
parent 01eeb0e72f
commit 26d6288eb6
6 changed files with 77 additions and 14 deletions

View File

@ -1,6 +1,6 @@
use crate::{
storage::{CpuStorage, Storage},
DType,
DType, Result, Shape,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@ -8,11 +8,38 @@ pub enum Device {
Cpu,
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
fn to_cpu_storage(&self) -> CpuStorage;
}
impl<S: crate::WithDType> NdArray for S {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(&[*self])
}
}
impl<S: crate::WithDType> NdArray for &[S] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self)
}
}
impl Device {
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
match self {
Device::Cpu => {
let elem_count: usize = shape.iter().product();
let elem_count = shape.elem_count();
let storage = match dtype {
DType::F32 => {
let data = vec![0f32; elem_count];
@ -27,4 +54,10 @@ impl Device {
}
}
}
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
match self {
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
}
}
}

View File

@ -1,3 +1,5 @@
use crate::CpuStorage;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
F32,
@ -13,14 +15,24 @@ impl DType {
}
}
pub trait WithDType {
pub trait WithDType: Sized + Copy {
const DTYPE: DType;
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
}
impl WithDType for f32 {
const DTYPE: DType = DType::F32;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
CpuStorage::F32(data.to_vec())
}
}
impl WithDType for f64 {
const DTYPE: DType = DType::F64;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
CpuStorage::F64(data.to_vec())
}
}

View File

@ -9,4 +9,6 @@ mod tensor;
pub use device::Device;
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
pub use storage::{CpuStorage, Storage};
pub use tensor::Tensor;

View File

@ -2,7 +2,7 @@ use crate::{DType, Device};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
pub(crate) enum CpuStorage {
pub enum CpuStorage {
F32(Vec<f32>),
F64(Vec<f64>),
}
@ -17,18 +17,18 @@ impl CpuStorage {
}
#[allow(dead_code)]
pub(crate) enum Storage {
pub enum Storage {
Cpu(CpuStorage),
}
impl Storage {
pub(crate) fn device(&self) -> Device {
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
}
}
pub(crate) fn dtype(&self) -> DType {
pub fn dtype(&self) -> DType {
match self {
Self::Cpu(storage) => storage.dtype(),
}

View File

@ -1,4 +1,4 @@
use crate::{op::Op, shape, storage::Storage, DType, Device};
use crate::{op::Op, shape, storage::Storage, DType, Device, Result};
use std::sync::Arc;
#[allow(dead_code)]
@ -14,15 +14,28 @@ pub struct Tensor(Arc<Tensor_>);
impl Tensor {
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
let shape = shape.into();
let storage = device.zeros(&shape.0, dtype);
let rank = shape.0.len();
let storage = device.zeros(&shape, dtype);
let rank = shape.rank();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
op: None,
};
Tensor(Arc::new(tensor_))
Self(Arc::new(tensor_))
}
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
let shape = array.shape()?;
let storage = device.tensor(array);
let rank = shape.rank();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
op: None,
};
Ok(Self(Arc::new(tensor_)))
}
pub fn dtype(&self) -> DType {
@ -38,7 +51,7 @@ impl Tensor {
}
pub fn dims(&self) -> &[usize] {
&self.shape().dims()
self.shape().dims()
}
pub fn stride(&self) -> &[usize] {

View File

@ -2,9 +2,12 @@ use candle::{DType, Device, Result, Tensor};
#[test]
fn add() -> Result<()> {
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu);
let (dim1, dim2) = tensor.shape().r2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?;
let dim1 = tensor.shape().r1()?;
assert_eq!(dim1, 3);
Ok(())
}