mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the op.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{DType, Storage};
|
||||
use crate::{storage::Storage, DType};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Device {
|
||||
|
@ -1,9 +1,9 @@
|
||||
mod device;
|
||||
mod dtype;
|
||||
mod op;
|
||||
mod storage;
|
||||
mod tensor;
|
||||
|
||||
pub use device::Device;
|
||||
pub use dtype::DType;
|
||||
use storage::Storage;
|
||||
pub use tensor::Tensor;
|
||||
|
7
src/op.rs
Normal file
7
src/op.rs
Normal file
@ -0,0 +1,7 @@
|
||||
use crate::Tensor;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum Op {
|
||||
Add(Tensor, Tensor),
|
||||
Mul(Tensor, Tensor),
|
||||
}
|
@ -1,3 +1,5 @@
|
||||
use crate::{DType, Device};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum Storage {
|
||||
Cpu {
|
||||
@ -5,3 +7,17 @@ pub(crate) enum Storage {
|
||||
buffer: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub(crate) fn device(&self) -> Device {
|
||||
match self {
|
||||
Self::Cpu { .. } => Device::Cpu,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::Cpu { dtype, .. } => *dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,40 +1,49 @@
|
||||
use crate::{DType, Device, Storage};
|
||||
use crate::{op::Op, storage::Storage, DType, Device};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct Tensor {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct Tensor_ {
|
||||
storage: Storage,
|
||||
shape: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
op: Option<Op>,
|
||||
}
|
||||
|
||||
pub struct Tensor(Arc<Tensor_>);
|
||||
|
||||
impl Tensor {
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let storage = device.zeros(shape, dtype);
|
||||
Self {
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.to_vec(),
|
||||
stride: vec![1; shape.len()],
|
||||
}
|
||||
op: None,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.0.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self.storage {
|
||||
Storage::Cpu { .. } => Device::Cpu,
|
||||
}
|
||||
self.0.storage.device()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
&self.0.shape
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.stride
|
||||
&self.0.stride
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.len()
|
||||
self.0.shape.len()
|
||||
}
|
||||
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.shape.iter().product()
|
||||
self.0.shape.iter().product()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user