mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Shuffle the shape bits around.
This commit is contained in:
@ -1,10 +1,10 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result};
|
||||
use crate::{op::Op, shape, storage::Storage, DType, Device};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct Tensor_ {
|
||||
storage: Storage,
|
||||
shape: Vec<usize>,
|
||||
shape: shape::Shape,
|
||||
stride: Vec<usize>,
|
||||
op: Option<Op>,
|
||||
}
|
||||
@ -12,12 +12,14 @@ pub(crate) struct Tensor_ {
|
||||
pub struct Tensor(Arc<Tensor_>);
|
||||
|
||||
impl Tensor {
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let storage = device.zeros(shape, dtype);
|
||||
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape.0, dtype);
|
||||
let rank = shape.0.len();
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.to_vec(),
|
||||
stride: vec![1; shape.len()],
|
||||
shape,
|
||||
stride: vec![1; rank],
|
||||
op: None,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
@ -31,71 +33,23 @@ impl Tensor {
|
||||
self.0.storage.device()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
pub fn shape(&self) -> &shape::Shape {
|
||||
&self.0.shape
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.shape().dims()
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.0.stride
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.shape.len()
|
||||
self.shape().rank()
|
||||
}
|
||||
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.shape.iter().product()
|
||||
}
|
||||
|
||||
pub fn shape1(&self) -> Result<usize> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 1 {
|
||||
Ok(shape[0])
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape2(&self) -> Result<(usize, usize)> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 2 {
|
||||
Ok((shape[0], shape[1]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape3(&self) -> Result<(usize, usize, usize)> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 3 {
|
||||
Ok((shape[0], shape[1], shape[2]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 3,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape4(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 4 {
|
||||
Ok((shape[0], shape[1], shape[2], shape[4]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 4,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
self.shape().elem_count()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user