mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add some unique identifier.
This commit is contained in:
@ -1,8 +1,22 @@
|
|||||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Unique identifier for tensors.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
struct Id(usize);
|
||||||
|
|
||||||
|
impl Id {
|
||||||
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub struct Tensor_ {
|
pub struct Tensor_ {
|
||||||
|
id: Id,
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
// The strides are given in number of elements and not in bytes.
|
// The strides are given in number of elements and not in bytes.
|
||||||
@ -33,6 +47,7 @@ macro_rules! unary_op {
|
|||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
id: Id::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
@ -51,6 +66,7 @@ macro_rules! binary_op {
|
|||||||
self.storage
|
self.storage
|
||||||
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
id: Id::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
@ -67,6 +83,7 @@ impl Tensor {
|
|||||||
let storage = device.zeros(&shape, dtype);
|
let storage = device.zeros(&shape, dtype);
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
id: Id::new(),
|
||||||
storage,
|
storage,
|
||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
@ -80,6 +97,7 @@ impl Tensor {
|
|||||||
let storage = device.tensor(array);
|
let storage = device.tensor(array);
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
id: Id::new(),
|
||||||
storage,
|
storage,
|
||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
|
Reference in New Issue
Block a user