mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Abstract the implementation of Shape.
This commit is contained in:
13
src/shape.rs
13
src/shape.rs
@ -1,7 +1,7 @@
|
|||||||
use crate::{Error, Result};
|
use crate::{Error, Result};
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Eq)]
|
#[derive(Clone, PartialEq, Eq)]
|
||||||
pub struct Shape(pub(crate) Vec<usize>);
|
pub struct Shape(Vec<usize>);
|
||||||
|
|
||||||
impl std::fmt::Debug for Shape {
|
impl std::fmt::Debug for Shape {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
@ -63,6 +63,12 @@ impl From<(usize, usize, usize)> for Shape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Vec<usize>> for Shape {
|
||||||
|
fn from(dims: Vec<usize>) -> Self {
|
||||||
|
Self(dims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! extract_dims {
|
macro_rules! extract_dims {
|
||||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||||
@ -142,6 +148,11 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||||
|
self.0.extend(additional_dims);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -283,9 +283,8 @@ impl Tensor {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
|
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||||
c_shape.extend(&[m, n]);
|
let c_stride = c_shape.stride_contiguous();
|
||||||
let c_shape = Shape(c_shape);
|
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
|
|
||||||
let storage = self.storage.matmul_impl(
|
let storage = self.storage.matmul_impl(
|
||||||
@ -297,8 +296,8 @@ impl Tensor {
|
|||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: c_shape.clone(),
|
shape: c_shape,
|
||||||
stride: c_shape.stride_contiguous(),
|
stride: c_stride,
|
||||||
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
@ -414,7 +413,6 @@ impl Tensor {
|
|||||||
|
|
||||||
pub fn t(&self) -> Result<Tensor> {
|
pub fn t(&self) -> Result<Tensor> {
|
||||||
let mut stride = self.stride().to_vec();
|
let mut stride = self.stride().to_vec();
|
||||||
let mut shape = self.shape().clone();
|
|
||||||
let n = stride.len();
|
let n = stride.len();
|
||||||
if n < 2 {
|
if n < 2 {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -423,12 +421,13 @@ impl Tensor {
|
|||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
(shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
|
let mut dims = self.shape().dims().to_vec();
|
||||||
|
(dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
|
||||||
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
shape,
|
shape: Shape::from(dims),
|
||||||
stride,
|
stride,
|
||||||
// TODO The op should have a backward
|
// TODO The op should have a backward
|
||||||
op: None,
|
op: None,
|
||||||
|
Reference in New Issue
Block a user