mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Abstract the implementation of Shape.
This commit is contained in:
@ -283,9 +283,8 @@ impl Tensor {
|
||||
});
|
||||
}
|
||||
|
||||
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
|
||||
c_shape.extend(&[m, n]);
|
||||
let c_shape = Shape(c_shape);
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
let c_stride = c_shape.stride_contiguous();
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
|
||||
let storage = self.storage.matmul_impl(
|
||||
@ -297,8 +296,8 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: c_shape.clone(),
|
||||
stride: c_shape.stride_contiguous(),
|
||||
shape: c_shape,
|
||||
stride: c_stride,
|
||||
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
||||
is_variable: false,
|
||||
};
|
||||
@ -414,7 +413,6 @@ impl Tensor {
|
||||
|
||||
pub fn t(&self) -> Result<Tensor> {
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut shape = self.shape().clone();
|
||||
let n = stride.len();
|
||||
if n < 2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
@ -423,12 +421,13 @@ impl Tensor {
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
(shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
|
||||
let mut dims = self.shape().dims().to_vec();
|
||||
(dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
|
||||
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
// TODO The op should have a backward
|
||||
op: None,
|
||||
|
Reference in New Issue
Block a user