mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Start refactoring the stride.
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
use crate::op::{BinaryOp, UnaryOp};
|
use crate::op::{BinaryOp, UnaryOp};
|
||||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
use crate::{DType, Error, Layout, Result, Shape, StridedIndex};
|
||||||
use gemm::{gemm, Parallelism};
|
use gemm::{gemm, Parallelism};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -18,12 +18,11 @@ pub enum CpuStorage {
|
|||||||
|
|
||||||
fn wcond<T: Copy>(
|
fn wcond<T: Copy>(
|
||||||
pred: &[u32],
|
pred: &[u32],
|
||||||
shape: &Shape,
|
layout: &Layout,
|
||||||
stride: &[usize],
|
|
||||||
t: &[T],
|
t: &[T],
|
||||||
stride_t: &[usize],
|
layout_t: &Layout,
|
||||||
f: &[T],
|
f: &[T],
|
||||||
stride_f: &[usize],
|
layout_f: &Layout,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||||
{
|
{
|
||||||
@ -73,12 +72,7 @@ fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
|||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||||
vs: &[T],
|
|
||||||
shape: &Shape,
|
|
||||||
stride: &[usize],
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
if shape.is_contiguous(stride) {
|
if shape.is_contiguous(stride) {
|
||||||
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
||||||
} else {
|
} else {
|
||||||
@ -461,65 +455,59 @@ impl CpuStorage {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
&self,
|
|
||||||
shape: &Shape,
|
|
||||||
stride: &[usize],
|
|
||||||
mul: f64,
|
|
||||||
add: f64,
|
|
||||||
) -> Result<Self> {
|
|
||||||
match self {
|
match self {
|
||||||
Self::U32(storage) => {
|
Self::U32(storage) => {
|
||||||
let mul = mul as u32;
|
let mul = mul as u32;
|
||||||
let add = add as u32;
|
let add = add as u32;
|
||||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
Self::BF16(storage) => {
|
Self::BF16(storage) => {
|
||||||
let mul = bf16::from_f64(mul);
|
let mul = bf16::from_f64(mul);
|
||||||
let add = bf16::from_f64(add);
|
let add = bf16::from_f64(add);
|
||||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||||
Ok(Self::BF16(data))
|
Ok(Self::BF16(data))
|
||||||
}
|
}
|
||||||
Self::F16(storage) => {
|
Self::F16(storage) => {
|
||||||
let mul = f16::from_f64(mul);
|
let mul = f16::from_f64(mul);
|
||||||
let add = f16::from_f64(add);
|
let add = f16::from_f64(add);
|
||||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||||
Ok(Self::F16(data))
|
Ok(Self::F16(data))
|
||||||
}
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
let mul = mul as f32;
|
let mul = mul as f32;
|
||||||
let add = add as f32;
|
let add = add as f32;
|
||||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
Self::F64(storage) => {
|
Self::F64(storage) => {
|
||||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Self::BF16(storage) => {
|
Self::BF16(storage) => {
|
||||||
let data = unary_map(storage, shape, stride, B::bf16);
|
let data = unary_map(storage, layout, B::bf16);
|
||||||
Ok(Self::BF16(data))
|
Ok(Self::BF16(data))
|
||||||
}
|
}
|
||||||
Self::F16(storage) => {
|
Self::F16(storage) => {
|
||||||
let data = unary_map(storage, shape, stride, B::f16);
|
let data = unary_map(storage, layout, B::f16);
|
||||||
Ok(Self::F16(data))
|
Ok(Self::F16(data))
|
||||||
}
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
let data = unary_map(storage, shape, stride, B::f32);
|
let data = unary_map(storage, layout, B::f32);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
Self::F64(storage) => {
|
Self::F64(storage) => {
|
||||||
let data = unary_map(storage, shape, stride, B::f64);
|
let data = unary_map(storage, layout, B::f64);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
Self::U32(storage) => {
|
Self::U32(storage) => {
|
||||||
let data = unary_map(storage, shape, stride, B::u32);
|
let data = unary_map(storage, layout, B::u32);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
47
candle-core/src/layout.rs
Normal file
47
candle-core/src/layout.rs
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
use crate::Shape;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||||
|
pub struct Layout {
|
||||||
|
shape: Shape,
|
||||||
|
// The strides are given in number of elements and not in bytes.
|
||||||
|
stride: Vec<usize>,
|
||||||
|
start_offset: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layout {
|
||||||
|
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
|
||||||
|
let shape = shape.into();
|
||||||
|
let stride = shape.stride_contiguous();
|
||||||
|
Self {
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
start_offset: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dims(&self) -> &[usize] {
|
||||||
|
self.shape.dims()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shape(&self) -> &Shape {
|
||||||
|
&self.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stride(&self) -> &[usize] {
|
||||||
|
&self.stride
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_offset(&self) -> usize {
|
||||||
|
self.start_offset
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||||
|
pub fn is_contiguous(&self) -> bool {
|
||||||
|
self.shape.is_contiguous(&self.stride)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
|
||||||
|
pub fn is_fortran_contiguous(&self) -> bool {
|
||||||
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
|
}
|
||||||
|
}
|
@ -7,6 +7,7 @@ pub mod display;
|
|||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
mod error;
|
mod error;
|
||||||
|
mod layout;
|
||||||
mod npy;
|
mod npy;
|
||||||
mod op;
|
mod op;
|
||||||
mod shape;
|
mod shape;
|
||||||
@ -19,6 +20,7 @@ pub use cpu_backend::CpuStorage;
|
|||||||
pub use device::{Device, DeviceLocation};
|
pub use device::{Device, DeviceLocation};
|
||||||
pub use dtype::{DType, WithDType};
|
pub use dtype::{DType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
|
pub use layout::Layout;
|
||||||
pub use shape::Shape;
|
pub use shape::Shape;
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
use strided_index::StridedIndex;
|
use strided_index::StridedIndex;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -53,33 +53,27 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
&self,
|
|
||||||
shape: &Shape,
|
|
||||||
stride: &[usize],
|
|
||||||
mul: f64,
|
|
||||||
add: f64,
|
|
||||||
) -> Result<Self> {
|
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
Self::Cuda(storage) => {
|
Self::Cuda(storage) => {
|
||||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
|
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
let storage = storage.sum(shape, stride, s)?;
|
let storage = storage.sum(layout, s)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
Self::Cuda(storage) => {
|
Self::Cuda(storage) => {
|
||||||
let storage = storage.sum(shape, stride, s)?;
|
let storage = storage.sum(layout, s)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,32 +87,28 @@ impl Storage {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
Self::Cuda(storage) => {
|
Self::Cuda(storage) => {
|
||||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||||
&self,
|
|
||||||
shape: &Shape,
|
|
||||||
stride: &[usize],
|
|
||||||
) -> Result<Self> {
|
|
||||||
// TODO: Different code path for the contiguous case?
|
// TODO: Different code path for the contiguous case?
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
Self::Cuda(storage) => {
|
Self::Cuda(storage) => {
|
||||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -127,19 +117,18 @@ impl Storage {
|
|||||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
shape: &Shape,
|
lhs_layout: &Layout,
|
||||||
lhs_stride: &[usize],
|
rhs_layout: &Layout,
|
||||||
rhs_stride: &[usize],
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, B::NAME)?;
|
self.same_device(rhs, B::NAME)?;
|
||||||
self.same_dtype(rhs, B::NAME)?;
|
self.same_dtype(rhs, B::NAME)?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
@ -156,23 +145,22 @@ impl Storage {
|
|||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
layout: &Shape,
|
||||||
stride: &[usize],
|
|
||||||
t: &Self,
|
t: &Self,
|
||||||
stride_t: &[usize],
|
layout_t: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
stride_f: &[usize],
|
layout_f: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(t, "where")?;
|
self.same_device(t, "where")?;
|
||||||
self.same_device(f, "where")?;
|
self.same_device(f, "where")?;
|
||||||
t.same_dtype(f, "where")?;
|
t.same_dtype(f, "where")?;
|
||||||
match (self, t, f) {
|
match (self, t, f) {
|
||||||
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
@ -185,8 +173,7 @@ impl Storage {
|
|||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
layout: &Layout,
|
||||||
stride: &[usize],
|
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
@ -194,11 +181,11 @@ impl Storage {
|
|||||||
self.same_device(rhs, "embedding")?;
|
self.same_device(rhs, "embedding")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Unique identifier for tensors.
|
/// Unique identifier for tensors.
|
||||||
@ -17,9 +17,7 @@ impl TensorId {
|
|||||||
pub struct Tensor_ {
|
pub struct Tensor_ {
|
||||||
id: TensorId,
|
id: TensorId,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
shape: Shape,
|
layout: Layout,
|
||||||
// The strides are given in number of elements and not in bytes.
|
|
||||||
stride: Vec<usize>,
|
|
||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
}
|
}
|
||||||
@ -50,7 +48,7 @@ macro_rules! unary_op {
|
|||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage
|
.storage
|
||||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::$op_name(self.clone()))
|
Some(Op::$op_name(self.clone()))
|
||||||
} else {
|
} else {
|
||||||
@ -67,9 +65,8 @@ macro_rules! binary_op {
|
|||||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||||
&rhs.storage,
|
&rhs.storage,
|
||||||
shape,
|
self.layout(),
|
||||||
self.stride(),
|
rhs.layout(),
|
||||||
rhs.stride(),
|
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || rhs.track_op() {
|
let op = if self.track_op() || rhs.track_op() {
|
||||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||||
@ -107,13 +104,10 @@ fn from_storage<S: Into<Shape>>(
|
|||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Tensor {
|
) -> Tensor {
|
||||||
let shape = shape.into();
|
|
||||||
let stride = shape.stride_contiguous();
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(storage),
|
storage: Arc::new(storage),
|
||||||
shape,
|
layout: Layout::contiguous(shape),
|
||||||
stride,
|
|
||||||
op,
|
op,
|
||||||
is_variable,
|
is_variable,
|
||||||
};
|
};
|
||||||
@ -342,8 +336,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let storage = self.storage.affine(self.layout(), mul, add)?;
|
||||||
let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?;
|
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Affine {
|
Some(Op::Affine {
|
||||||
arg: self.clone(),
|
arg: self.clone(),
|
||||||
@ -353,7 +346,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
@ -401,9 +394,7 @@ impl Tensor {
|
|||||||
exp.broadcast_div(&sum_exp)
|
exp.broadcast_div(&sum_exp)
|
||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = self
|
let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||||
.storage
|
|
||||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
|
||||||
// The resulting storage is contiguous.
|
// The resulting storage is contiguous.
|
||||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
@ -416,7 +407,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||||
} else {
|
} else {
|
||||||
@ -461,8 +452,8 @@ impl Tensor {
|
|||||||
let storage = self.storage.matmul_impl(
|
let storage = self.storage.matmul_impl(
|
||||||
&rhs.storage,
|
&rhs.storage,
|
||||||
(batching, m, n, k),
|
(batching, m, n, k),
|
||||||
self.stride(),
|
self.layout(),
|
||||||
rhs.stride(),
|
rhs.layout(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || rhs.track_op() {
|
let op = if self.track_op() || rhs.track_op() {
|
||||||
Some(Op::Matmul(self.clone(), rhs.clone()))
|
Some(Op::Matmul(self.clone(), rhs.clone()))
|
||||||
@ -476,12 +467,11 @@ impl Tensor {
|
|||||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||||
let storage = self.storage.where_cond(
|
let storage = self.storage.where_cond(
|
||||||
shape,
|
self.layout(),
|
||||||
self.stride(),
|
|
||||||
&on_true.storage,
|
&on_true.storage,
|
||||||
on_true.stride(),
|
on_true.layout(),
|
||||||
&on_false.storage,
|
&on_false.storage,
|
||||||
on_false.stride(),
|
on_false.layout(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||||
Some(Op::WhereCond(
|
Some(Op::WhereCond(
|
||||||
@ -498,10 +488,10 @@ impl Tensor {
|
|||||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||||
if !rhs.is_contiguous() {
|
if !rhs.is_contiguous() {
|
||||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||||
} else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 {
|
} else if rhs.rank() != 2 || ids.rank() != 1 {
|
||||||
return Err(Error::ShapeMismatchBinaryOp {
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
lhs: ids.shape.clone(),
|
lhs: ids.shape().clone(),
|
||||||
rhs: rhs.shape.clone(),
|
rhs: rhs.shape().clone(),
|
||||||
op: "embedding",
|
op: "embedding",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -509,7 +499,7 @@ impl Tensor {
|
|||||||
let seq_len = ids_shape.r1()?;
|
let seq_len = ids_shape.r1()?;
|
||||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||||
let storage = ids.storage.embedding_impl(
|
let storage = ids.storage.embedding_impl(
|
||||||
ids_shape,
|
ids.layout(),
|
||||||
&ids.stride,
|
&ids.stride,
|
||||||
&rhs.storage,
|
&rhs.storage,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -625,8 +615,13 @@ impl Tensor {
|
|||||||
self.shape().dims()
|
self.shape().dims()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stride(&self) -> &[usize] {
|
pub fn layout(&self) -> &Layout {
|
||||||
&self.stride
|
&self.layout
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
||||||
|
pub fn stride_tmp(&self) -> &[usize] {
|
||||||
|
&self.layout.stride()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
@ -734,12 +729,12 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||||
pub fn is_contiguous(&self) -> bool {
|
pub fn is_contiguous(&self) -> bool {
|
||||||
self.shape.is_contiguous(&self.stride)
|
self.layout.is_contiguous()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
|
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
|
||||||
pub fn is_fortran_contiguous(&self) -> bool {
|
pub fn is_fortran_contiguous(&self) -> bool {
|
||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.layout.is_fortran_contiguous()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
||||||
@ -748,8 +743,7 @@ impl Tensor {
|
|||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(self.storage.try_clone()?),
|
storage: Arc::new(self.storage.try_clone()?),
|
||||||
shape: self.shape.clone(),
|
layout: self.layout.clone(),
|
||||||
stride: self.stride.clone(),
|
|
||||||
op: None, // TODO
|
op: None, // TODO
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
@ -762,8 +756,7 @@ impl Tensor {
|
|||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
shape: self.shape.clone(),
|
layout: self.layout.clone(),
|
||||||
stride: self.stride.clone(),
|
|
||||||
op: None,
|
op: None,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
@ -796,8 +789,7 @@ impl Tensor {
|
|||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(storage),
|
storage: Arc::new(storage),
|
||||||
shape: self.shape.clone(),
|
layout: self.layout.clone(),
|
||||||
stride: self.stride.clone(),
|
|
||||||
op,
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
@ -810,7 +802,7 @@ impl Tensor {
|
|||||||
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||||
let left_shape = left_shape.into();
|
let left_shape = left_shape.into();
|
||||||
let mut dims = left_shape.into_dims();
|
let mut dims = left_shape.into_dims();
|
||||||
dims.extend(self.shape.dims());
|
dims.extend(self.dims());
|
||||||
self.broadcast_as(dims)
|
self.broadcast_as(dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -866,7 +858,7 @@ impl Tensor {
|
|||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
let storage = self.storage.to_dtype(self.layout(), dtype)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::ToDType(self.clone()))
|
Some(Op::ToDType(self.clone()))
|
||||||
} else {
|
} else {
|
||||||
|
Reference in New Issue
Block a user