mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Split the tensor file.
This commit is contained in:
18
src/device.rs
Normal file
18
src/device.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use crate::{DType, Storage};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let elem_count: usize = shape.iter().product();
|
||||
let buffer = vec![0; elem_count * dtype.size_in_bytes()];
|
||||
Storage::Cpu { dtype, buffer }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
14
src/dtype.rs
Normal file
14
src/dtype.rs
Normal file
@ -0,0 +1,14 @@
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
F64,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
}
|
@ -1,3 +1,9 @@
|
||||
mod device;
|
||||
mod dtype;
|
||||
mod storage;
|
||||
mod tensor;
|
||||
|
||||
pub use tensor::{DType, Tensor};
|
||||
pub use device::Device;
|
||||
pub use dtype::DType;
|
||||
use storage::Storage;
|
||||
pub use tensor::Tensor;
|
||||
|
7
src/storage.rs
Normal file
7
src/storage.rs
Normal file
@ -0,0 +1,7 @@
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum Storage {
|
||||
Cpu {
|
||||
dtype: crate::DType,
|
||||
buffer: Vec<u8>,
|
||||
},
|
||||
}
|
@ -1,17 +1,4 @@
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
F64,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::{DType, Device, Storage};
|
||||
|
||||
pub struct Tensor {
|
||||
storage: Storage,
|
||||
@ -19,25 +6,9 @@ pub struct Tensor {
|
||||
stride: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
enum Storage {
|
||||
Cpu { dtype: DType, buffer: Vec<u8> },
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let storage = match device {
|
||||
Device::Cpu => {
|
||||
let elem_count: usize = shape.iter().product();
|
||||
let buffer = vec![0; elem_count * dtype.size_in_bytes()];
|
||||
Storage::Cpu { dtype, buffer }
|
||||
}
|
||||
};
|
||||
let storage = device.zeros(shape, dtype);
|
||||
Self {
|
||||
storage,
|
||||
shape: shape.to_vec(),
|
||||
|
Reference in New Issue
Block a user