mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the op.
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, Storage};
|
use crate::{storage::Storage, DType};
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
mod device;
|
mod device;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
|
mod op;
|
||||||
mod storage;
|
mod storage;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
pub use device::Device;
|
pub use device::Device;
|
||||||
pub use dtype::DType;
|
pub use dtype::DType;
|
||||||
use storage::Storage;
|
|
||||||
pub use tensor::Tensor;
|
pub use tensor::Tensor;
|
||||||
|
7
src/op.rs
Normal file
7
src/op.rs
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
use crate::Tensor;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) enum Op {
|
||||||
|
Add(Tensor, Tensor),
|
||||||
|
Mul(Tensor, Tensor),
|
||||||
|
}
|
@ -1,3 +1,5 @@
|
|||||||
|
use crate::{DType, Device};
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) enum Storage {
|
pub(crate) enum Storage {
|
||||||
Cpu {
|
Cpu {
|
||||||
@ -5,3 +7,17 @@ pub(crate) enum Storage {
|
|||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
pub(crate) fn device(&self) -> Device {
|
||||||
|
match self {
|
||||||
|
Self::Cpu { .. } => Device::Cpu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn dtype(&self) -> DType {
|
||||||
|
match self {
|
||||||
|
Self::Cpu { dtype, .. } => *dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,40 +1,49 @@
|
|||||||
use crate::{DType, Device, Storage};
|
use crate::{op::Op, storage::Storage, DType, Device};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub struct Tensor {
|
#[allow(dead_code)]
|
||||||
|
pub(crate) struct Tensor_ {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: Vec<usize>,
|
shape: Vec<usize>,
|
||||||
stride: Vec<usize>,
|
stride: Vec<usize>,
|
||||||
|
op: Option<Op>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Tensor(Arc<Tensor_>);
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||||
let storage = device.zeros(shape, dtype);
|
let storage = device.zeros(shape, dtype);
|
||||||
Self {
|
let tensor_ = Tensor_ {
|
||||||
storage,
|
storage,
|
||||||
shape: shape.to_vec(),
|
shape: shape.to_vec(),
|
||||||
stride: vec![1; shape.len()],
|
stride: vec![1; shape.len()],
|
||||||
}
|
op: None,
|
||||||
|
};
|
||||||
|
Tensor(Arc::new(tensor_))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
self.0.storage.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self.storage {
|
self.0.storage.device()
|
||||||
Storage::Cpu { .. } => Device::Cpu,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &[usize] {
|
pub fn shape(&self) -> &[usize] {
|
||||||
&self.shape
|
&self.0.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stride(&self) -> &[usize] {
|
pub fn stride(&self) -> &[usize] {
|
||||||
&self.stride
|
&self.0.stride
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
self.shape.len()
|
self.0.shape.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn elem_count(&self) -> usize {
|
pub fn elem_count(&self) -> usize {
|
||||||
self.shape.iter().product()
|
self.0.shape.iter().product()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user