mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
onnx: fix pad, unsqueeze (#2317)
* onnx: fix pad, unsqueeze both implementations have off-by-one errors: - Pad 'reflect' cycle for eg `dim==3` is `[0,1,2,1]` which has length of 4 (or `dim*2 - 2`) not 5 (current code `dim*2 - 1`) - Unsqueeze(-1) for tensor with `dim==3` should be 3 (ie `dim+index+1`) not 2 (ie currently `dim+index`) in addition, Pad is incorrectly calculating the starting padding. If we want to pad out 2 elements to the start, and we have this cycle of indices of length 6, then we should skip 4 elements, but currently we skip 2. A more visual representation of what's going on is below: ``` pad_start: 2 data: [a,b,c,d] indices: [0, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1, 0, ..] // zigzag between 0..4 actual: skip [ c d| c b a b] expected: ~ skip ~ [ c b| a b c d] ``` The values between `[` and `|` are padding and the values between `|` and `]` in the example should match the original data being padded. * Fix clippy lints. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -570,6 +570,11 @@ fn simple_eval_(
|
||||
.map(|&i| {
|
||||
if i == xs.rank() as i64 {
|
||||
Ok(xs.rank())
|
||||
} else if i < 0 {
|
||||
// normalize_axis doesn't work correctly here
|
||||
// because we actually want normalized with respect
|
||||
// to the final size, not the current (off by one)
|
||||
Ok(xs.rank() - (-i as usize) + 1)
|
||||
} else {
|
||||
xs.normalize_axis(i)
|
||||
}
|
||||
@ -1040,8 +1045,8 @@ fn simple_eval_(
|
||||
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
|
||||
}
|
||||
let idx = if dim > 1 {
|
||||
let cycle_len = dim * 2 - 1;
|
||||
let skip = (pads_pre[i] as usize) % cycle_len;
|
||||
let cycle_len = dim * 2 - 2;
|
||||
let skip = cycle_len - ((pads_pre[i] as usize) % cycle_len);
|
||||
let idx = zigzag(0, (dim - 1) as i64)
|
||||
.skip(skip)
|
||||
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
|
||||
|
Reference in New Issue
Block a user