mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
fix: negative axis (#1296)
* fix: negative axis * Use normalize_axis. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -298,14 +298,7 @@ pub fn simple_eval(
|
|||||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||||
Some(&axis) => {
|
Some(&axis) => {
|
||||||
let num_axis = input.rank() as i64;
|
let axis = input.normalize_axis(axis)?;
|
||||||
let axis = if axis >= 0 {
|
|
||||||
axis as usize
|
|
||||||
} else if axis < -num_axis {
|
|
||||||
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
|
||||||
} else {
|
|
||||||
(num_axis - axis) as usize
|
|
||||||
};
|
|
||||||
candle_nn::ops::log_softmax(input, axis)?
|
candle_nn::ops::log_softmax(input, axis)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -316,14 +309,7 @@ pub fn simple_eval(
|
|||||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||||
Some(&axis) => {
|
Some(&axis) => {
|
||||||
let num_axis = input.rank() as i64;
|
let axis = input.normalize_axis(axis)?;
|
||||||
let axis = if axis >= 0 {
|
|
||||||
axis as usize
|
|
||||||
} else if axis < -num_axis {
|
|
||||||
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
|
||||||
} else {
|
|
||||||
(num_axis - axis) as usize
|
|
||||||
};
|
|
||||||
candle_nn::ops::softmax(input, axis)?
|
candle_nn::ops::softmax(input, axis)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -666,21 +652,10 @@ pub fn simple_eval(
|
|||||||
.map(|n| Ok(get(n.as_str())?.clone()))
|
.map(|n| Ok(get(n.as_str())?.clone()))
|
||||||
.collect::<Result<Vec<Value>>>()?;
|
.collect::<Result<Vec<Value>>>()?;
|
||||||
let axis: i64 = *get_attr(node, "axis")?;
|
let axis: i64 = *get_attr(node, "axis")?;
|
||||||
let num_axis = if inputs.is_empty() {
|
if inputs.is_empty() {
|
||||||
bail!("empty concat")
|
bail!("empty concat")
|
||||||
} else {
|
|
||||||
inputs[0].rank() as i64
|
|
||||||
};
|
|
||||||
let axis = if axis >= 0 {
|
|
||||||
axis as usize
|
|
||||||
} else if axis < -num_axis {
|
|
||||||
bail!(
|
|
||||||
"wrong axis in concat {axis} for shape {:?}",
|
|
||||||
inputs[0].shape()
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
(num_axis - axis) as usize
|
|
||||||
};
|
};
|
||||||
|
let axis = inputs[0].normalize_axis(axis)?;
|
||||||
let output = Tensor::cat(&inputs, axis)?;
|
let output = Tensor::cat(&inputs, axis)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user