Add an easy way to create tensor objects.

This commit is contained in:
laurent
2023-06-19 20:59:26 +01:00
parent 01eeb0e72f
commit 26d6288eb6
6 changed files with 77 additions and 14 deletions

View File

@ -1,4 +1,4 @@
use crate::{op::Op, shape, storage::Storage, DType, Device};
use crate::{op::Op, shape, storage::Storage, DType, Device, Result};
use std::sync::Arc;
#[allow(dead_code)]
@ -14,15 +14,28 @@ pub struct Tensor(Arc<Tensor_>);
impl Tensor {
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
let shape = shape.into();
let storage = device.zeros(&shape.0, dtype);
let rank = shape.0.len();
let storage = device.zeros(&shape, dtype);
let rank = shape.rank();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
op: None,
};
Tensor(Arc::new(tensor_))
Self(Arc::new(tensor_))
}
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
let shape = array.shape()?;
let storage = device.tensor(array);
let rank = shape.rank();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
op: None,
};
Ok(Self(Arc::new(tensor_)))
}
pub fn dtype(&self) -> DType {
@ -38,7 +51,7 @@ impl Tensor {
}
pub fn dims(&self) -> &[usize] {
&self.shape().dims()
self.shape().dims()
}
pub fn stride(&self) -> &[usize] {