mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Start refactoring the stride.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -17,9 +17,7 @@ impl TensorId {
|
||||
pub struct Tensor_ {
|
||||
id: TensorId,
|
||||
storage: Arc<Storage>,
|
||||
shape: Shape,
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
stride: Vec<usize>,
|
||||
layout: Layout,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
}
|
||||
@ -50,7 +48,7 @@ macro_rules! unary_op {
|
||||
let shape = self.shape();
|
||||
let storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
} else {
|
||||
@ -67,9 +65,8 @@ macro_rules! binary_op {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||
&rhs.storage,
|
||||
shape,
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||
@ -107,13 +104,10 @@ fn from_storage<S: Into<Shape>>(
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
shape,
|
||||
stride,
|
||||
layout: Layout::contiguous(shape),
|
||||
op,
|
||||
is_variable,
|
||||
};
|
||||
@ -342,8 +336,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?;
|
||||
let storage = self.storage.affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
@ -353,7 +346,7 @@ impl Tensor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
@ -401,9 +394,7 @@ impl Tensor {
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
||||
let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
@ -416,7 +407,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
||||
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -461,8 +452,8 @@ impl Tensor {
|
||||
let storage = self.storage.matmul_impl(
|
||||
&rhs.storage,
|
||||
(batching, m, n, k),
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::Matmul(self.clone(), rhs.clone()))
|
||||
@ -476,12 +467,11 @@ impl Tensor {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
let storage = self.storage.where_cond(
|
||||
shape,
|
||||
self.stride(),
|
||||
self.layout(),
|
||||
&on_true.storage,
|
||||
on_true.stride(),
|
||||
on_true.layout(),
|
||||
&on_false.storage,
|
||||
on_false.stride(),
|
||||
on_false.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
Some(Op::WhereCond(
|
||||
@ -498,10 +488,10 @@ impl Tensor {
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||
} else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 {
|
||||
} else if rhs.rank() != 2 || ids.rank() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: ids.shape.clone(),
|
||||
rhs: rhs.shape.clone(),
|
||||
lhs: ids.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "embedding",
|
||||
});
|
||||
}
|
||||
@ -509,7 +499,7 @@ impl Tensor {
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids.storage.embedding_impl(
|
||||
ids_shape,
|
||||
ids.layout(),
|
||||
&ids.stride,
|
||||
&rhs.storage,
|
||||
hidden_size,
|
||||
@ -625,8 +615,13 @@ impl Tensor {
|
||||
self.shape().dims()
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.stride
|
||||
pub fn layout(&self) -> &Layout {
|
||||
&self.layout
|
||||
}
|
||||
|
||||
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
||||
pub fn stride_tmp(&self) -> &[usize] {
|
||||
&self.layout.stride()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
@ -734,12 +729,12 @@ impl Tensor {
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.shape.is_contiguous(&self.stride)
|
||||
self.layout.is_contiguous()
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
|
||||
pub fn is_fortran_contiguous(&self) -> bool {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
self.layout.is_fortran_contiguous()
|
||||
}
|
||||
|
||||
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
||||
@ -748,8 +743,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(self.storage.try_clone()?),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: None, // TODO
|
||||
is_variable: false,
|
||||
};
|
||||
@ -762,8 +756,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: None,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -796,8 +789,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -810,7 +802,7 @@ impl Tensor {
|
||||
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||
let left_shape = left_shape.into();
|
||||
let mut dims = left_shape.into_dims();
|
||||
dims.extend(self.shape.dims());
|
||||
dims.extend(self.dims());
|
||||
self.broadcast_as(dims)
|
||||
}
|
||||
|
||||
@ -866,7 +858,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||
let storage = self.storage.to_dtype(self.layout(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user