mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add an easy way to create tensor objects.
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
storage::{CpuStorage, Storage},
|
storage::{CpuStorage, Storage},
|
||||||
DType,
|
DType, Result, Shape,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
@ -8,11 +8,38 @@ pub enum Device {
|
|||||||
Cpu,
|
Cpu,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||||
|
pub trait NdArray {
|
||||||
|
fn shape(&self) -> Result<Shape>;
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> CpuStorage;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: crate::WithDType> NdArray for S {
|
||||||
|
fn shape(&self) -> Result<Shape> {
|
||||||
|
Ok(Shape::from(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> CpuStorage {
|
||||||
|
S::to_cpu_storage(&[*self])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: crate::WithDType> NdArray for &[S] {
|
||||||
|
fn shape(&self) -> Result<Shape> {
|
||||||
|
Ok(Shape::from(self.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> CpuStorage {
|
||||||
|
S::to_cpu_storage(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
|
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
let elem_count: usize = shape.iter().product();
|
let elem_count = shape.elem_count();
|
||||||
let storage = match dtype {
|
let storage = match dtype {
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = vec![0f32; elem_count];
|
let data = vec![0f32; elem_count];
|
||||||
@ -27,4 +54,10 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
14
src/dtype.rs
14
src/dtype.rs
@ -1,3 +1,5 @@
|
|||||||
|
use crate::CpuStorage;
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum DType {
|
pub enum DType {
|
||||||
F32,
|
F32,
|
||||||
@ -13,14 +15,24 @@ impl DType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WithDType {
|
pub trait WithDType: Sized + Copy {
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WithDType for f32 {
|
impl WithDType for f32 {
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
|
CpuStorage::F32(data.to_vec())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WithDType for f64 {
|
impl WithDType for f64 {
|
||||||
const DTYPE: DType = DType::F64;
|
const DTYPE: DType = DType::F64;
|
||||||
|
|
||||||
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
|
CpuStorage::F64(data.to_vec())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,4 +9,6 @@ mod tensor;
|
|||||||
pub use device::Device;
|
pub use device::Device;
|
||||||
pub use dtype::{DType, WithDType};
|
pub use dtype::{DType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
|
pub use shape::Shape;
|
||||||
|
pub use storage::{CpuStorage, Storage};
|
||||||
pub use tensor::Tensor;
|
pub use tensor::Tensor;
|
||||||
|
@ -2,7 +2,7 @@ use crate::{DType, Device};
|
|||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
pub(crate) enum CpuStorage {
|
pub enum CpuStorage {
|
||||||
F32(Vec<f32>),
|
F32(Vec<f32>),
|
||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
@ -17,18 +17,18 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
pub(crate) fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{op::Op, shape, storage::Storage, DType, Device};
|
use crate::{op::Op, shape, storage::Storage, DType, Device, Result};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@ -14,15 +14,28 @@ pub struct Tensor(Arc<Tensor_>);
|
|||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape.0, dtype);
|
let storage = device.zeros(&shape, dtype);
|
||||||
let rank = shape.0.len();
|
let rank = shape.rank();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
storage,
|
storage,
|
||||||
shape,
|
shape,
|
||||||
stride: vec![1; rank],
|
stride: vec![1; rank],
|
||||||
op: None,
|
op: None,
|
||||||
};
|
};
|
||||||
Tensor(Arc::new(tensor_))
|
Self(Arc::new(tensor_))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||||
|
let shape = array.shape()?;
|
||||||
|
let storage = device.tensor(array);
|
||||||
|
let rank = shape.rank();
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
storage,
|
||||||
|
shape,
|
||||||
|
stride: vec![1; rank],
|
||||||
|
op: None,
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
@ -38,7 +51,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dims(&self) -> &[usize] {
|
pub fn dims(&self) -> &[usize] {
|
||||||
&self.shape().dims()
|
self.shape().dims()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stride(&self) -> &[usize] {
|
pub fn stride(&self) -> &[usize] {
|
||||||
|
@ -2,9 +2,12 @@ use candle::{DType, Device, Result, Tensor};
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn add() -> Result<()> {
|
fn add() -> Result<()> {
|
||||||
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
|
let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu);
|
||||||
let (dim1, dim2) = tensor.shape().r2()?;
|
let (dim1, dim2) = tensor.shape().r2()?;
|
||||||
assert_eq!(dim1, 5);
|
assert_eq!(dim1, 5);
|
||||||
assert_eq!(dim2, 2);
|
assert_eq!(dim2, 2);
|
||||||
|
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?;
|
||||||
|
let dim1 = tensor.shape().r1()?;
|
||||||
|
assert_eq!(dim1, 3);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user