mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the reshape method and operation (without grad for now).
This commit is contained in:
@ -17,7 +17,6 @@ pub(crate) enum Op {
|
|||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
Neg(Tensor),
|
Neg(Tensor),
|
||||||
#[allow(dead_code)]
|
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
|
@ -595,6 +595,36 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||||
|
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||||
|
let shape = shape.into();
|
||||||
|
if shape.elem_count() != self.elem_count() {
|
||||||
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: shape,
|
||||||
|
op: "reshape",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
|
self.storage
|
||||||
|
.copy_strided_src(&mut storage, &shape, &self.stride, 0)?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Reshape(self.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let stride = shape.stride_contiguous();
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
op,
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
|
Reference in New Issue
Block a user