Expose the tensor ids.

This commit is contained in:
laurent
2023-06-20 14:22:04 +01:00
parent 98b423145a
commit 671bcf060e
2 changed files with 12 additions and 8 deletions

View File

@ -11,4 +11,4 @@ pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
pub use storage::{CpuStorage, Storage};
pub use tensor::Tensor;
pub use tensor::{Tensor, TensorId};

View File

@ -3,9 +3,9 @@ use std::sync::Arc;
/// Unique identifier for tensors.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
struct Id(usize);
pub struct TensorId(usize);
impl Id {
impl TensorId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
@ -16,7 +16,7 @@ impl Id {
#[allow(dead_code)]
pub struct Tensor_ {
id: Id,
id: TensorId,
storage: Storage,
shape: Shape,
// The strides are given in number of elements and not in bytes.
@ -47,7 +47,7 @@ macro_rules! unary_op {
let shape = self.shape();
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
let tensor_ = Tensor_ {
id: Id::new(),
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
@ -66,7 +66,7 @@ macro_rules! binary_op {
self.storage
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
let tensor_ = Tensor_ {
id: Id::new(),
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
@ -83,7 +83,7 @@ impl Tensor {
let storage = device.zeros(&shape, dtype);
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: Id::new(),
id: TensorId::new(),
storage,
shape,
stride,
@ -97,7 +97,7 @@ impl Tensor {
let storage = device.tensor(array);
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: Id::new(),
id: TensorId::new(),
storage,
shape,
stride,
@ -207,4 +207,8 @@ impl Tensor {
pub fn elem_count(&self) -> usize {
self.shape().elem_count()
}
pub fn id(&self) -> TensorId {
self.id
}
}