mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Shuffle the shape bits around.
This commit is contained in:
12
src/dtype.rs
12
src/dtype.rs
@ -12,3 +12,15 @@ impl DType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait WithDType {
|
||||||
|
const DTYPE: DType;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WithDType for f32 {
|
||||||
|
const DTYPE: DType = DType::F32;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WithDType for f64 {
|
||||||
|
const DTYPE: DType = DType::F64;
|
||||||
|
}
|
||||||
|
@ -2,10 +2,11 @@ mod device;
|
|||||||
mod dtype;
|
mod dtype;
|
||||||
mod error;
|
mod error;
|
||||||
mod op;
|
mod op;
|
||||||
|
mod shape;
|
||||||
mod storage;
|
mod storage;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
pub use device::Device;
|
pub use device::Device;
|
||||||
pub use dtype::DType;
|
pub use dtype::{DType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use tensor::Tensor;
|
pub use tensor::Tensor;
|
||||||
|
129
src/shape.rs
Normal file
129
src/shape.rs
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
use crate::{Error, Result};
|
||||||
|
pub struct Shape(pub(crate) Vec<usize>);
|
||||||
|
|
||||||
|
impl From<&[usize; 1]> for Shape {
|
||||||
|
fn from(dims: &[usize; 1]) -> Self {
|
||||||
|
Self(dims.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&[usize; 2]> for Shape {
|
||||||
|
fn from(dims: &[usize; 2]) -> Self {
|
||||||
|
Self(dims.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&[usize; 3]> for Shape {
|
||||||
|
fn from(dims: &[usize; 3]) -> Self {
|
||||||
|
Self(dims.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&[usize]> for Shape {
|
||||||
|
fn from(dims: &[usize]) -> Self {
|
||||||
|
Self(dims.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<()> for Shape {
|
||||||
|
fn from(_: ()) -> Self {
|
||||||
|
Self(vec![])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<usize> for Shape {
|
||||||
|
fn from(d1: usize) -> Self {
|
||||||
|
Self(vec![d1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<(usize, usize)> for Shape {
|
||||||
|
fn from(d12: (usize, usize)) -> Self {
|
||||||
|
Self(vec![d12.0, d12.1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<(usize, usize, usize)> for Shape {
|
||||||
|
fn from(d123: (usize, usize, usize)) -> Self {
|
||||||
|
Self(vec![d123.0, d123.1, d123.2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Shape {
|
||||||
|
pub fn rank(&self) -> usize {
|
||||||
|
self.0.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dims(&self) -> &[usize] {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn elem_count(&self) -> usize {
|
||||||
|
self.0.iter().product()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r0(&self) -> Result<()> {
|
||||||
|
let shape = &self.0;
|
||||||
|
if shape.is_empty() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 0,
|
||||||
|
got: shape.len(),
|
||||||
|
shape: shape.to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r1(&self) -> Result<usize> {
|
||||||
|
let shape = &self.0;
|
||||||
|
if shape.len() == 1 {
|
||||||
|
Ok(shape[0])
|
||||||
|
} else {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 1,
|
||||||
|
got: shape.len(),
|
||||||
|
shape: shape.to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r2(&self) -> Result<(usize, usize)> {
|
||||||
|
let shape = &self.0;
|
||||||
|
if shape.len() == 2 {
|
||||||
|
Ok((shape[0], shape[1]))
|
||||||
|
} else {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 2,
|
||||||
|
got: shape.len(),
|
||||||
|
shape: shape.to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r3(&self) -> Result<(usize, usize, usize)> {
|
||||||
|
let shape = &self.0;
|
||||||
|
if shape.len() == 3 {
|
||||||
|
Ok((shape[0], shape[1], shape[2]))
|
||||||
|
} else {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 3,
|
||||||
|
got: shape.len(),
|
||||||
|
shape: shape.to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r4(&self) -> Result<(usize, usize, usize, usize)> {
|
||||||
|
let shape = &self.0;
|
||||||
|
if shape.len() == 4 {
|
||||||
|
Ok((shape[0], shape[1], shape[2], shape[4]))
|
||||||
|
} else {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 4,
|
||||||
|
got: shape.len(),
|
||||||
|
shape: shape.to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result};
|
use crate::{op::Op, shape, storage::Storage, DType, Device};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) struct Tensor_ {
|
pub(crate) struct Tensor_ {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: Vec<usize>,
|
shape: shape::Shape,
|
||||||
stride: Vec<usize>,
|
stride: Vec<usize>,
|
||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
}
|
}
|
||||||
@ -12,12 +12,14 @@ pub(crate) struct Tensor_ {
|
|||||||
pub struct Tensor(Arc<Tensor_>);
|
pub struct Tensor(Arc<Tensor_>);
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
let storage = device.zeros(shape, dtype);
|
let shape = shape.into();
|
||||||
|
let storage = device.zeros(&shape.0, dtype);
|
||||||
|
let rank = shape.0.len();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
storage,
|
storage,
|
||||||
shape: shape.to_vec(),
|
shape,
|
||||||
stride: vec![1; shape.len()],
|
stride: vec![1; rank],
|
||||||
op: None,
|
op: None,
|
||||||
};
|
};
|
||||||
Tensor(Arc::new(tensor_))
|
Tensor(Arc::new(tensor_))
|
||||||
@ -31,71 +33,23 @@ impl Tensor {
|
|||||||
self.0.storage.device()
|
self.0.storage.device()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &[usize] {
|
pub fn shape(&self) -> &shape::Shape {
|
||||||
&self.0.shape
|
&self.0.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dims(&self) -> &[usize] {
|
||||||
|
&self.shape().dims()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn stride(&self) -> &[usize] {
|
pub fn stride(&self) -> &[usize] {
|
||||||
&self.0.stride
|
&self.0.stride
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
self.0.shape.len()
|
self.shape().rank()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn elem_count(&self) -> usize {
|
pub fn elem_count(&self) -> usize {
|
||||||
self.0.shape.iter().product()
|
self.shape().elem_count()
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shape1(&self) -> Result<usize> {
|
|
||||||
let shape = self.shape();
|
|
||||||
if shape.len() == 1 {
|
|
||||||
Ok(shape[0])
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 1,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: shape.to_vec(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shape2(&self) -> Result<(usize, usize)> {
|
|
||||||
let shape = self.shape();
|
|
||||||
if shape.len() == 2 {
|
|
||||||
Ok((shape[0], shape[1]))
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 2,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: shape.to_vec(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shape3(&self) -> Result<(usize, usize, usize)> {
|
|
||||||
let shape = self.shape();
|
|
||||||
if shape.len() == 3 {
|
|
||||||
Ok((shape[0], shape[1], shape[2]))
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 3,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: shape.to_vec(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shape4(&self) -> Result<(usize, usize, usize, usize)> {
|
|
||||||
let shape = self.shape();
|
|
||||||
if shape.len() == 4 {
|
|
||||||
Ok((shape[0], shape[1], shape[2], shape[4]))
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 4,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: shape.to_vec(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ use candle::{DType, Device, Result, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn add() -> Result<()> {
|
fn add() -> Result<()> {
|
||||||
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
|
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
|
||||||
let (dim1, dim2) = tensor.shape2()?;
|
let (dim1, dim2) = tensor.shape().r2()?;
|
||||||
assert_eq!(dim1, 5);
|
assert_eq!(dim1, 5);
|
||||||
assert_eq!(dim2, 2);
|
assert_eq!(dim2, 2);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user