mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Propagate the layout refactoring.
This commit is contained in:
@ -481,13 +481,9 @@ impl Tensor {
|
||||
let ids_shape = ids.shape();
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids.storage.embedding_impl(
|
||||
ids.layout(),
|
||||
&ids.stride,
|
||||
&rhs.storage,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
)?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
@ -498,7 +494,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(self.dims(), self.stride())
|
||||
self.layout.strided_index()
|
||||
}
|
||||
|
||||
/// Returns data from the underlying storage, this does not take the strides
|
||||
@ -591,7 +587,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
&self.layout().shape()
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
@ -682,18 +678,6 @@ impl Tensor {
|
||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||
/// swapped.
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||
let rank = self.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: usize::max(dim1, dim2),
|
||||
got: rank,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut dims = self.shape().dims().to_vec();
|
||||
dims.swap(dim1, dim2);
|
||||
stride.swap(dim1, dim2);
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||
} else {
|
||||
@ -702,8 +686,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
layout: self.layout.transpose(dim1, dim2)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -795,36 +778,10 @@ impl Tensor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.rank() {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
}
|
||||
let added_dims = shape.rank() - self.rank();
|
||||
let mut stride = vec![0; added_dims];
|
||||
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||
.iter()
|
||||
.zip(self.dims().iter().zip(self.stride()))
|
||||
{
|
||||
let s = if dst_dim == src_dim {
|
||||
src_stride
|
||||
} else if src_dim != 1 {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
} else {
|
||||
0
|
||||
};
|
||||
stride.push(s)
|
||||
}
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
stride,
|
||||
layout: self.layout.broadcast_as(shape)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -888,12 +845,10 @@ impl Tensor {
|
||||
None
|
||||
};
|
||||
if self.is_contiguous() {
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
stride,
|
||||
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
|
Reference in New Issue
Block a user