mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Cleanup the tensor creation code.
This commit is contained in:
@ -23,6 +23,8 @@ pub(crate) enum Op {
|
|||||||
},
|
},
|
||||||
Neg(Tensor),
|
Neg(Tensor),
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Softmax(Tensor, usize),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
|
151
src/tensor.rs
151
src/tensor.rs
@ -54,15 +54,7 @@ macro_rules! unary_op {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -82,15 +74,7 @@ macro_rules! binary_op {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -110,19 +94,25 @@ macro_rules! broadcast_binary_op {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||||
|
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
|
||||||
|
let stride = shape.stride_contiguous();
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
op,
|
||||||
|
is_variable,
|
||||||
|
};
|
||||||
|
Tensor(Arc::new(tensor_))
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
fn ones_impl<S: Into<Shape>>(
|
fn ones_impl<S: Into<Shape>>(
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -132,16 +122,7 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.ones(&shape, dtype)?;
|
let storage = device.ones(&shape, dtype)?;
|
||||||
let stride = shape.stride_contiguous();
|
Ok(from_storage(storage, shape, None, is_variable))
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape,
|
|
||||||
stride,
|
|
||||||
op: None,
|
|
||||||
is_variable,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
@ -164,16 +145,7 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape, dtype)?;
|
let storage = device.zeros(&shape, dtype)?;
|
||||||
let stride = shape.stride_contiguous();
|
Ok(from_storage(storage, shape, None, is_variable))
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape,
|
|
||||||
stride,
|
|
||||||
op: None,
|
|
||||||
is_variable,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
@ -200,16 +172,7 @@ impl Tensor {
|
|||||||
return Err(Error::ShapeMismatch { buffer_size, shape });
|
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||||
}
|
}
|
||||||
let storage = device.storage(array)?;
|
let storage = device.storage(array)?;
|
||||||
let stride = shape.stride_contiguous();
|
Ok(from_storage(storage, shape, None, is_variable))
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape,
|
|
||||||
stride,
|
|
||||||
op: None,
|
|
||||||
is_variable,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||||
@ -314,9 +277,7 @@ impl Tensor {
|
|||||||
|
|
||||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self
|
let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?;
|
||||||
.storage
|
|
||||||
.affine_impl(self.shape(), self.stride(), mul, add)?;
|
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Affine {
|
Some(Op::Affine {
|
||||||
arg: self.clone(),
|
arg: self.clone(),
|
||||||
@ -326,15 +287,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||||
@ -373,7 +326,6 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||||
let c_stride = c_shape.stride_contiguous();
|
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
|
|
||||||
let storage = self.storage.matmul_impl(
|
let storage = self.storage.matmul_impl(
|
||||||
@ -387,15 +339,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, c_shape, op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape: c_shape,
|
|
||||||
stride: c_stride,
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||||
@ -419,15 +363,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, shape, op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
@ -669,15 +605,12 @@ impl Tensor {
|
|||||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||||
self.storage
|
self.storage
|
||||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
self.op.clone(),
|
||||||
op: self.op.clone(),
|
self.is_variable,
|
||||||
is_variable: self.is_variable,
|
))
|
||||||
};
|
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -702,16 +635,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let stride = shape.stride_contiguous();
|
Ok(from_storage(storage, shape, op, false))
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape,
|
|
||||||
stride,
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
||||||
@ -778,7 +702,6 @@ impl Tensor {
|
|||||||
offsets.push(next_offset);
|
offsets.push(next_offset);
|
||||||
}
|
}
|
||||||
let shape = Shape::from(cat_dims);
|
let shape = Shape::from(cat_dims);
|
||||||
let stride = shape.stride_contiguous();
|
|
||||||
let op = if args.iter().any(|arg| arg.track_op()) {
|
let op = if args.iter().any(|arg| arg.track_op()) {
|
||||||
Some(Op::Cat(args.to_vec(), dim))
|
Some(Op::Cat(args.to_vec(), dim))
|
||||||
} else {
|
} else {
|
||||||
@ -789,15 +712,7 @@ impl Tensor {
|
|||||||
arg.storage
|
arg.storage
|
||||||
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
|
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
|
||||||
}
|
}
|
||||||
let tensor_ = Tensor_ {
|
Ok(from_storage(storage, shape, op, false))
|
||||||
id: TensorId::new(),
|
|
||||||
storage,
|
|
||||||
shape,
|
|
||||||
stride,
|
|
||||||
op,
|
|
||||||
is_variable: false,
|
|
||||||
};
|
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
@ -855,6 +770,7 @@ impl Tensor {
|
|||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
|
| Op::Softmax(node, _)
|
||||||
| Op::Sqr(node)
|
| Op::Sqr(node)
|
||||||
| Op::Sqrt(node)
|
| Op::Sqrt(node)
|
||||||
| Op::Gelu(node)
|
| Op::Gelu(node)
|
||||||
@ -975,6 +891,9 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&grad)?
|
*sum_grad = sum_grad.sub(&grad)?
|
||||||
}
|
}
|
||||||
|
Op::Softmax(_arg, _) => {
|
||||||
|
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||||
|
}
|
||||||
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
||||||
Op::Sqr(arg) => {
|
Op::Sqr(arg) => {
|
||||||
|
Reference in New Issue
Block a user