mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use shape with holes. (#771)
This commit is contained in:
@ -123,7 +123,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor>
|
||||
let relative_coords = relative_coords.to_dtype(DType::U32)?;
|
||||
rel_pos_resized
|
||||
.index_select(&relative_coords.reshape(d1 * d2)?, 0)?
|
||||
.reshape((d1, d2, rel_pos_resized.dim(1)?))
|
||||
.reshape((d1, d2, ()))
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
@ -243,7 +243,7 @@ fn window_unpartition(
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.reshape((b, h_p, w_p, windows.elem_count() / b / h_p / w_p))?;
|
||||
.reshape((b, h_p, w_p, ()))?;
|
||||
let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
|
||||
let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
|
||||
Ok(xs)
|
||||
|
Reference in New Issue
Block a user