mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Apply rustfmt. (#2247)
This commit is contained in:
@ -1032,11 +1032,18 @@ pub fn simple_eval(
|
||||
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
|
||||
let rank_i64: i64 = input.rank().try_into().unwrap();
|
||||
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
|
||||
bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
|
||||
bail!(
|
||||
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
|
||||
axis_i64,
|
||||
-rank_i64,
|
||||
rank_i64 - 1
|
||||
)
|
||||
}
|
||||
let axis = input.normalize_axis(axis_i64)?;
|
||||
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
|
||||
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
|
||||
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
if select_last_index == 1 {
|
||||
bail!("select_last_index for ArgMin is currently not supported")
|
||||
}
|
||||
@ -1044,7 +1051,8 @@ pub fn simple_eval(
|
||||
input.argmin_keepdim(axis)?
|
||||
} else {
|
||||
input.argmin(axis)?
|
||||
}.to_dtype(DType::I64)?;
|
||||
}
|
||||
.to_dtype(DType::I64)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"ArgMax" => {
|
||||
@ -1052,11 +1060,18 @@ pub fn simple_eval(
|
||||
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
|
||||
let rank_i64: i64 = input.rank().try_into().unwrap();
|
||||
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
|
||||
bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
|
||||
bail!(
|
||||
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
|
||||
axis_i64,
|
||||
-rank_i64,
|
||||
rank_i64 - 1
|
||||
)
|
||||
}
|
||||
let axis = input.normalize_axis(axis_i64)?;
|
||||
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
|
||||
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
|
||||
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
if select_last_index == 1 {
|
||||
bail!("select_last_index for ArgMin is currently not supported")
|
||||
}
|
||||
@ -1064,7 +1079,8 @@ pub fn simple_eval(
|
||||
input.argmax_keepdim(axis)?
|
||||
} else {
|
||||
input.argmax(axis)?
|
||||
}.to_dtype(DType::I64)?;
|
||||
}
|
||||
.to_dtype(DType::I64)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LeakyRelu" => {
|
||||
@ -1072,7 +1088,10 @@ pub fn simple_eval(
|
||||
let dt = input.dtype();
|
||||
match dt {
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str())
|
||||
bail!(
|
||||
"unsupported dtype {}, only float types are allowed for LeakyRelu",
|
||||
dt.as_str()
|
||||
)
|
||||
}
|
||||
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
|
||||
}
|
||||
|
Reference in New Issue
Block a user