Propagate the layout refactoring.

This commit is contained in:
laurent
2023-06-28 13:42:23 +01:00
parent 30b355ccd2
commit 303b853098
5 changed files with 130 additions and 129 deletions

View File

@ -481,13 +481,9 @@ impl Tensor {
let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?;
let (vocab_size, hidden_size) = rhs.shape().r2()?;
let storage = ids.storage.embedding_impl(
ids.layout(),
&ids.stride,
&rhs.storage,
hidden_size,
vocab_size,
)?;
let storage = ids
.storage
.embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))
@ -498,7 +494,7 @@ impl Tensor {
}
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self.dims(), self.stride())
self.layout.strided_index()
}
/// Returns data from the underlying storage, this does not take the strides
@ -591,7 +587,7 @@ impl Tensor {
}
pub fn shape(&self) -> &Shape {
&self.shape
&self.layout().shape()
}
pub fn dims(&self) -> &[usize] {
@ -682,18 +678,6 @@ impl Tensor {
/// Returns a tensor that is a transposed version of the input, the given dimensions are
/// swapped.
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
let rank = self.rank();
if rank <= dim1 || rank <= dim2 {
return Err(Error::UnexpectedNumberOfDims {
expected: usize::max(dim1, dim2),
got: rank,
shape: self.shape().clone(),
});
}
let mut stride = self.stride().to_vec();
let mut dims = self.shape().dims().to_vec();
dims.swap(dim1, dim2);
stride.swap(dim1, dim2);
let op = if self.track_op() {
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
@ -702,8 +686,7 @@ impl Tensor {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
shape: Shape::from(dims),
stride,
layout: self.layout.transpose(dim1, dim2)?,
op,
is_variable: false,
};
@ -795,36 +778,10 @@ impl Tensor {
} else {
None
};
let shape = shape.into();
if shape.rank() < self.rank() {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
});
}
let added_dims = shape.rank() - self.rank();
let mut stride = vec![0; added_dims];
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
.iter()
.zip(self.dims().iter().zip(self.stride()))
{
let s = if dst_dim == src_dim {
src_stride
} else if src_dim != 1 {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
});
} else {
0
};
stride.push(s)
}
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
shape,
stride,
layout: self.layout.broadcast_as(shape)?,
op,
is_variable: false,
};
@ -888,12 +845,10 @@ impl Tensor {
None
};
if self.is_contiguous() {
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
shape,
stride,
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
op,
is_variable: false,
};