Add some unique identifier.

This commit is contained in:
laurent
2023-06-20 13:00:04 +01:00
parent d9cb1917ce
commit d922ff97f2

View File

@ -1,8 +1,22 @@
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
use std::sync::Arc;
/// Unique identifier for tensors.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
struct Id(usize);
impl Id {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
#[allow(dead_code)]
pub struct Tensor_ {
id: Id,
storage: Storage,
shape: Shape,
// The strides are given in number of elements and not in bytes.
@ -33,6 +47,7 @@ macro_rules! unary_op {
let shape = self.shape();
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
let tensor_ = Tensor_ {
id: Id::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
@ -51,6 +66,7 @@ macro_rules! binary_op {
self.storage
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
let tensor_ = Tensor_ {
id: Id::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
@ -67,6 +83,7 @@ impl Tensor {
let storage = device.zeros(&shape, dtype);
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: Id::new(),
storage,
shape,
stride,
@ -80,6 +97,7 @@ impl Tensor {
let storage = device.tensor(array);
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: Id::new(),
storage,
shape,
stride,