mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add an easy way to create tensor objects.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, shape, storage::Storage, DType, Device};
|
||||
use crate::{op::Op, shape, storage::Storage, DType, Device, Result};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -14,15 +14,28 @@ pub struct Tensor(Arc<Tensor_>);
|
||||
impl Tensor {
|
||||
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape.0, dtype);
|
||||
let rank = shape.0.len();
|
||||
let storage = device.zeros(&shape, dtype);
|
||||
let rank = shape.rank();
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape,
|
||||
stride: vec![1; rank],
|
||||
op: None,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
Self(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||
let shape = array.shape()?;
|
||||
let storage = device.tensor(array);
|
||||
let rank = shape.rank();
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape,
|
||||
stride: vec![1; rank],
|
||||
op: None,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
@ -38,7 +51,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.shape().dims()
|
||||
self.shape().dims()
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
|
Reference in New Issue
Block a user