Add some helper functions.

This commit is contained in:
laurent
2023-06-27 17:37:09 +01:00
parent efc39b71c5
commit c44e5346f4
2 changed files with 26 additions and 6 deletions

View File

@ -306,12 +306,7 @@ impl CausalSelfAttention {
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
// TODO: Add the flatten op.
let mut dims = rope.dims().to_vec();
let v1 = dims.pop().unwrap();
let v2 = dims.pop().unwrap();
dims.push(v1 * v2);
let rope = rope.reshape(dims)?;
let rope = rope.flatten(Some(rope.rank() - 2), None)?;
Ok(rope)
}