mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add some unique identifier.
This commit is contained in:
@ -1,8 +1,22 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
struct Id(usize);
|
||||
|
||||
impl Id {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct Tensor_ {
|
||||
id: Id,
|
||||
storage: Storage,
|
||||
shape: Shape,
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
@ -33,6 +47,7 @@ macro_rules! unary_op {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
@ -51,6 +66,7 @@ macro_rules! binary_op {
|
||||
self.storage
|
||||
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
@ -67,6 +83,7 @@ impl Tensor {
|
||||
let storage = device.zeros(&shape, dtype);
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
storage,
|
||||
shape,
|
||||
stride,
|
||||
@ -80,6 +97,7 @@ impl Tensor {
|
||||
let storage = device.tensor(array);
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: Id::new(),
|
||||
storage,
|
||||
shape,
|
||||
stride,
|
||||
|
Reference in New Issue
Block a user