mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Reshape can now return a view.
This commit is contained in:
@ -648,8 +648,9 @@ impl Tensor {
|
||||
|
||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||
/// original tensor is the same. This uses a new storage and copies the data over, the returned
|
||||
/// tensor is always contiguous.
|
||||
/// original tensor is the same.
|
||||
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
||||
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
@ -659,15 +660,28 @@ impl Tensor {
|
||||
op: "reshape",
|
||||
});
|
||||
}
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reshape(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
if self.is_contiguous() {
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
stride,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user