mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add the op.
This commit is contained in:
@ -1,40 +1,49 @@
|
||||
use crate::{DType, Device, Storage};
|
||||
use crate::{op::Op, storage::Storage, DType, Device};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct Tensor {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct Tensor_ {
|
||||
storage: Storage,
|
||||
shape: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
op: Option<Op>,
|
||||
}
|
||||
|
||||
pub struct Tensor(Arc<Tensor_>);
|
||||
|
||||
impl Tensor {
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let storage = device.zeros(shape, dtype);
|
||||
Self {
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.to_vec(),
|
||||
stride: vec![1; shape.len()],
|
||||
}
|
||||
op: None,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.0.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self.storage {
|
||||
Storage::Cpu { .. } => Device::Cpu,
|
||||
}
|
||||
self.0.storage.device()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
&self.0.shape
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.stride
|
||||
&self.0.stride
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.len()
|
||||
self.0.shape.len()
|
||||
}
|
||||
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.shape.iter().product()
|
||||
self.0.shape.iter().product()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user