mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Cosmetic changes.
This commit is contained in:
@ -1,6 +1,12 @@
|
|||||||
use crate::{Error, Result};
|
use crate::{Error, Result};
|
||||||
pub struct Shape(pub(crate) Vec<usize>);
|
pub struct Shape(pub(crate) Vec<usize>);
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Shape {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{:?}", &self.dims())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<&[usize; 1]> for Shape {
|
impl From<&[usize; 1]> for Shape {
|
||||||
fn from(dims: &[usize; 1]) -> Self {
|
fn from(dims: &[usize; 1]) -> Self {
|
||||||
Self(dims.to_vec())
|
Self(dims.to_vec())
|
||||||
|
@ -2,6 +2,7 @@ use crate::{DType, Device};
|
|||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum CpuStorage {
|
pub enum CpuStorage {
|
||||||
F32(Vec<f32>),
|
F32(Vec<f32>),
|
||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
@ -17,6 +18,7 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
use crate::{op::Op, shape, storage::Storage, DType, Device, Result};
|
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) struct Tensor_ {
|
pub(crate) struct Tensor_ {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: shape::Shape,
|
shape: Shape,
|
||||||
stride: Vec<usize>,
|
stride: Vec<usize>,
|
||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Tensor(Arc<Tensor_>);
|
pub struct Tensor(Arc<Tensor_>);
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Tensor {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape, dtype);
|
let storage = device.zeros(&shape, dtype);
|
||||||
let rank = shape.rank();
|
let rank = shape.rank();
|
||||||
@ -38,6 +44,21 @@ impl Tensor {
|
|||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
|
// TODO: properly use the strides here.
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||||
|
// TODO: properly use the strides here.
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
||||||
|
// TODO: properly use the strides here.
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
self.0.storage.dtype()
|
self.0.storage.dtype()
|
||||||
}
|
}
|
||||||
@ -46,7 +67,7 @@ impl Tensor {
|
|||||||
self.0.storage.device()
|
self.0.storage.device()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &shape::Shape {
|
pub fn shape(&self) -> &Shape {
|
||||||
&self.0.shape
|
&self.0.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user