mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some helper functions.
This commit is contained in:
@ -306,12 +306,7 @@ impl CausalSelfAttention {
|
|||||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
||||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
||||||
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
|
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
|
||||||
// TODO: Add the flatten op.
|
let rope = rope.flatten(Some(rope.rank() - 2), None)?;
|
||||||
let mut dims = rope.dims().to_vec();
|
|
||||||
let v1 = dims.pop().unwrap();
|
|
||||||
let v2 = dims.pop().unwrap();
|
|
||||||
dims.push(v1 * v2);
|
|
||||||
let rope = rope.reshape(dims)?;
|
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -652,6 +652,31 @@ impl Tensor {
|
|||||||
&self.op
|
&self.op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn sum_all(&self) -> Result<Tensor> {
|
||||||
|
let dims: Vec<_> = (0..self.rank()).collect();
|
||||||
|
self.sum(&dims)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
|
||||||
|
let start_dim = start_dim.unwrap_or(0);
|
||||||
|
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||||
|
if start_dim < end_dim {
|
||||||
|
let dims = self.dims();
|
||||||
|
let mut dst_dims = dims[..start_dim].to_vec();
|
||||||
|
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
|
||||||
|
if end_dim + 1 < dims.len() {
|
||||||
|
dst_dims.extend(&dims[end_dim + 1..]);
|
||||||
|
}
|
||||||
|
self.reshape(dst_dims)
|
||||||
|
} else {
|
||||||
|
Ok(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flatten_all(&self) -> Result<Tensor> {
|
||||||
|
self.flatten(None, None)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||||
/// input are swapped.
|
/// input are swapped.
|
||||||
pub fn t(&self) -> Result<Tensor> {
|
pub fn t(&self) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user