fix: negative axis (#1296)

* fix: negative axis

* Use normalize_axis.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
YangNianYi
2023-11-09 06:28:21 +08:00
committed by GitHub
parent f772213e84
commit 73d02f4f57

View File

@ -298,14 +298,7 @@ pub fn simple_eval(
let output = match get_attr_opt::<i64>(node, "axis")? { let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?, None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => { Some(&axis) => {
let num_axis = input.rank() as i64; let axis = input.normalize_axis(axis)?;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
candle_nn::ops::log_softmax(input, axis)? candle_nn::ops::log_softmax(input, axis)?
} }
}; };
@ -316,14 +309,7 @@ pub fn simple_eval(
let output = match get_attr_opt::<i64>(node, "axis")? { let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?, None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => { Some(&axis) => {
let num_axis = input.rank() as i64; let axis = input.normalize_axis(axis)?;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
candle_nn::ops::softmax(input, axis)? candle_nn::ops::softmax(input, axis)?
} }
}; };
@ -666,21 +652,10 @@ pub fn simple_eval(
.map(|n| Ok(get(n.as_str())?.clone())) .map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?; .collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?; let axis: i64 = *get_attr(node, "axis")?;
let num_axis = if inputs.is_empty() { if inputs.is_empty() {
bail!("empty concat") bail!("empty concat")
} else {
inputs[0].rank() as i64
};
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!(
"wrong axis in concat {axis} for shape {:?}",
inputs[0].shape()
)
} else {
(num_axis - axis) as usize
}; };
let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?; let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }