mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Expose the tensor ids.
This commit is contained in:
@ -11,4 +11,4 @@ pub use dtype::{DType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use shape::Shape;
|
||||
pub use storage::{CpuStorage, Storage};
|
||||
pub use tensor::Tensor;
|
||||
pub use tensor::{Tensor, TensorId};
|
||||
|
@ -3,9 +3,9 @@ use std::sync::Arc;
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
struct Id(usize);
|
||||
pub struct TensorId(usize);
|
||||
|
||||
impl Id {
|
||||
impl TensorId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
@ -16,7 +16,7 @@ impl Id {
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct Tensor_ {
|
||||
id: Id,
|
||||
id: TensorId,
|
||||
storage: Storage,
|
||||
shape: Shape,
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
@ -47,7 +47,7 @@ macro_rules! unary_op {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
@ -66,7 +66,7 @@ macro_rules! binary_op {
|
||||
self.storage
|
||||
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
@ -83,7 +83,7 @@ impl Tensor {
|
||||
let storage = device.zeros(&shape, dtype);
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape,
|
||||
stride,
|
||||
@ -97,7 +97,7 @@ impl Tensor {
|
||||
let storage = device.tensor(array);
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape,
|
||||
stride,
|
||||
@ -207,4 +207,8 @@ impl Tensor {
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.shape().elem_count()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> TensorId {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user