mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Merge pull request #44 from LaurentMazare/check-dim
Improve how we check that the dims are in bounds.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,3 +15,6 @@ Cargo.lock
|
|||||||
|
|
||||||
*tokenizer.json
|
*tokenizer.json
|
||||||
*.npz
|
*.npz
|
||||||
|
|
||||||
|
perf.data
|
||||||
|
flamegraph.svg
|
||||||
|
@ -10,6 +10,13 @@ pub enum Error {
|
|||||||
got: DType,
|
got: DType,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||||
|
DimOutOfRange {
|
||||||
|
shape: Shape,
|
||||||
|
dim: usize,
|
||||||
|
op: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||||
NarrowInvalidArgs {
|
NarrowInvalidArgs {
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
|
@ -348,11 +348,24 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||||
|
if dim >= self.dims().len() {
|
||||||
|
Err(Error::DimOutOfRange {
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
dim,
|
||||||
|
op,
|
||||||
|
})?
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
/// ranges from `start` to `start + len`.
|
/// ranges from `start` to `start + len`.
|
||||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
if dim >= dims.len() || start + len > dims[dim] {
|
self.check_dim(dim, "narrow")?;
|
||||||
|
if start + len > dims[dim] {
|
||||||
Err(Error::NarrowInvalidArgs {
|
Err(Error::NarrowInvalidArgs {
|
||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
dim,
|
dim,
|
||||||
@ -380,6 +393,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||||
|
self.check_dim(dim, "softmax")?;
|
||||||
// TODO: unify the two branches.
|
// TODO: unify the two branches.
|
||||||
if self.device().is_cuda() {
|
if self.device().is_cuda() {
|
||||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||||
@ -402,6 +416,9 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
|
for &dim in sum_dims {
|
||||||
|
self.check_dim(dim, "sum")?;
|
||||||
|
}
|
||||||
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||||
@ -645,6 +662,8 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
let start_dim = start_dim.unwrap_or(0);
|
let start_dim = start_dim.unwrap_or(0);
|
||||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||||
|
self.check_dim(start_dim, "flatten")?;
|
||||||
|
self.check_dim(end_dim, "flatten")?;
|
||||||
if start_dim < end_dim {
|
if start_dim < end_dim {
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
let mut dst_dims = dims[..start_dim].to_vec();
|
let mut dst_dims = dims[..start_dim].to_vec();
|
||||||
@ -689,6 +708,8 @@ impl Tensor {
|
|||||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||||
/// swapped.
|
/// swapped.
|
||||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||||
|
self.check_dim(dim1, "transpose")?;
|
||||||
|
self.check_dim(dim2, "transpose")?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||||
} else {
|
} else {
|
||||||
@ -876,6 +897,7 @@ impl Tensor {
|
|||||||
// The PyTorch semantics are to return the same tensor if the target dimension
|
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||||
// does not have a size of 1.
|
// does not have a size of 1.
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
|
self.check_dim(index, "squeeze")?;
|
||||||
if dims[index] == 1 {
|
if dims[index] == 1 {
|
||||||
let mut dims = dims.to_vec();
|
let mut dims = dims.to_vec();
|
||||||
dims.remove(index);
|
dims.remove(index);
|
||||||
@ -910,13 +932,8 @@ impl Tensor {
|
|||||||
if args.len() == 1 {
|
if args.len() == 1 {
|
||||||
return Ok(arg0.clone());
|
return Ok(arg0.clone());
|
||||||
}
|
}
|
||||||
let rank = arg0.rank();
|
for arg in args {
|
||||||
if dim >= rank {
|
arg.as_ref().check_dim(dim, "cat")?;
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: (dim + 1),
|
|
||||||
got: rank,
|
|
||||||
shape: arg0.shape().clone(),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
if dim == 0 {
|
if dim == 0 {
|
||||||
Self::cat0(args)
|
Self::cat0(args)
|
||||||
|
Reference in New Issue
Block a user