mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -7,9 +7,9 @@ use candle::{Device, Tensor};
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
||||||
let sum = t.sum_keepdim(&[0])?;
|
let sum = t.sum_keepdim(0)?;
|
||||||
println!("{sum}");
|
println!("{sum}");
|
||||||
let sum = t.sum_keepdim(&[1])?;
|
let sum = t.sum_keepdim(1)?;
|
||||||
println!("{sum}");
|
println!("{sum}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -27,16 +27,16 @@ fn main() -> Result<()> {
|
|||||||
let xys_cpu = cos_sin(n, &Device::Cpu)?;
|
let xys_cpu = cos_sin(n, &Device::Cpu)?;
|
||||||
let xys = cos_sin(n, &device)?;
|
let xys = cos_sin(n, &device)?;
|
||||||
println!("{xys_cpu:?} {xys:?}");
|
println!("{xys_cpu:?} {xys:?}");
|
||||||
let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
|
let sum_keepdim_cpu = xys_cpu.sum_keepdim(1)?;
|
||||||
println!("{sum_keepdim_cpu}");
|
println!("{sum_keepdim_cpu}");
|
||||||
let sum_keepdim = xys.sum_keepdim(&[1])?;
|
let sum_keepdim = xys.sum_keepdim(1)?;
|
||||||
println!("{sum_keepdim}");
|
println!("{sum_keepdim}");
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let n_iters = 100;
|
let n_iters = 100;
|
||||||
let mut v = 0f32;
|
let mut v = 0f32;
|
||||||
for _i in 0..n_iters {
|
for _i in 0..n_iters {
|
||||||
let sum_keepdim = xys.sum_keepdim(&[1])?;
|
let sum_keepdim = xys.sum_keepdim(1)?;
|
||||||
let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
|
let sum_keepdim = sum_keepdim.sum_keepdim(0)?;
|
||||||
let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
|
let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
|
||||||
v += sum_keepdim;
|
v += sum_keepdim;
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,13 @@ pub enum Error {
|
|||||||
op: &'static str,
|
op: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
|
||||||
|
DuplicateDimIndex {
|
||||||
|
shape: Shape,
|
||||||
|
dims: Vec<usize>,
|
||||||
|
op: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
// === Shape Errors ===
|
// === Shape Errors ===
|
||||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||||
UnexpectedNumberOfDims {
|
UnexpectedNumberOfDims {
|
||||||
|
@ -256,6 +256,86 @@ impl Dim for D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait Dims: Sized {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
|
||||||
|
|
||||||
|
fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let dims = self.to_indexes_internal(shape, op)?;
|
||||||
|
for (i, &dim) in dims.iter().enumerate() {
|
||||||
|
if dims[..i].contains(&dim) {
|
||||||
|
Err(Error::DuplicateDimIndex {
|
||||||
|
shape: shape.clone(),
|
||||||
|
dims: dims.clone(),
|
||||||
|
op,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
if dim >= shape.rank() {
|
||||||
|
Err(Error::DimOutOfRange {
|
||||||
|
shape: shape.clone(),
|
||||||
|
dim: dim as i32,
|
||||||
|
op,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dims for Vec<usize> {
|
||||||
|
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> Dims for [usize; N] {
|
||||||
|
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||||
|
Ok(self.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dims for &[usize] {
|
||||||
|
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||||
|
Ok(self.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dims for () {
|
||||||
|
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||||
|
Ok(vec![])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: Dim + Sized> Dims for D {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let dim = self.to_index(shape, op)?;
|
||||||
|
Ok(vec![dim])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: Dim> Dims for (D,) {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let dim = self.0.to_index(shape, op)?;
|
||||||
|
Ok(vec![dim])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let d0 = self.0.to_index(shape, op)?;
|
||||||
|
let d1 = self.1.to_index(shape, op)?;
|
||||||
|
Ok(vec![d0, d1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let d0 = self.0.to_index(shape, op)?;
|
||||||
|
let d1 = self.1.to_index(shape, op)?;
|
||||||
|
let d2 = self.2.to_index(shape, op)?;
|
||||||
|
Ok(vec![d0, d1, d2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::shape::Dim;
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
@ -572,7 +572,7 @@ impl Tensor {
|
|||||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||||
// the operation.
|
// the operation.
|
||||||
let exp = self.exp()?;
|
let exp = self.exp()?;
|
||||||
let sum_exp = exp.sum_keepdim(&[dim])?;
|
let sum_exp = exp.sum_keepdim(dim)?;
|
||||||
exp.broadcast_div(&sum_exp)
|
exp.broadcast_div(&sum_exp)
|
||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
@ -588,28 +588,9 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
pub fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
|
||||||
/// input dimensions.
|
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
|
||||||
///
|
let storage = self.storage().sum(self.layout(), &sum_dims)?;
|
||||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
|
||||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use candle::{Tensor, Device};
|
|
||||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
|
||||||
/// let s = a.sum_keepdim(&[0])?;
|
|
||||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
|
||||||
/// let s = a.sum_keepdim(&[1])?;
|
|
||||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
|
||||||
/// let s = a.sum_keepdim(&[0, 1])?;
|
|
||||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
|
||||||
/// # Ok::<(), candle::Error>(())
|
|
||||||
/// ```
|
|
||||||
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
|
|
||||||
for &dim in sum_dims {
|
|
||||||
self.check_dim(dim, "sum")?;
|
|
||||||
}
|
|
||||||
let storage = self.storage().sum(self.layout(), sum_dims)?;
|
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||||
} else {
|
} else {
|
||||||
@ -619,15 +600,11 @@ impl Tensor {
|
|||||||
for &sum_dim in sum_dims.iter() {
|
for &sum_dim in sum_dims.iter() {
|
||||||
dims[sum_dim] = 1
|
dims[sum_dim] = 1
|
||||||
}
|
}
|
||||||
Ok(from_storage(storage, dims, op, false))
|
let sum = from_storage(storage, dims, op, false);
|
||||||
}
|
if keepdim {
|
||||||
|
Ok(sum)
|
||||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
} else {
|
||||||
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
|
match sum_dims.as_slice() {
|
||||||
/// kept.
|
|
||||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
|
||||||
let sum = self.sum_keepdim(sum_dims)?;
|
|
||||||
match sum_dims {
|
|
||||||
[] => Ok(sum),
|
[] => Ok(sum),
|
||||||
[i] => sum.squeeze(*i),
|
[i] => sum.squeeze(*i),
|
||||||
sum_dims => {
|
sum_dims => {
|
||||||
@ -647,6 +624,35 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
|
/// input dimensions.
|
||||||
|
///
|
||||||
|
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||||
|
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle::{Tensor, Device};
|
||||||
|
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||||
|
/// let s = a.sum_keepdim(0)?;
|
||||||
|
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||||
|
/// let s = a.sum_keepdim(1)?;
|
||||||
|
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||||
|
/// let s = a.sum_keepdim((0, 1))?;
|
||||||
|
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||||
|
/// # Ok::<(), candle::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
|
self.sum_impl(sum_dims, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
|
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
|
||||||
|
/// kept.
|
||||||
|
pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
|
self.sum_impl(sum_dims, false)
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies a 1D convolution over the input tensor.
|
/// Applies a 1D convolution over the input tensor.
|
||||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||||
@ -962,7 +968,7 @@ impl Tensor {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn sum_all(&self) -> Result<Tensor> {
|
pub fn sum_all(&self) -> Result<Tensor> {
|
||||||
let dims: Vec<_> = (0..self.rank()).collect();
|
let dims: Vec<_> = (0..self.rank()).collect();
|
||||||
self.sum_keepdim(&dims)?.reshape(())
|
self.sum(dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flatten_<D1: Dim, D2: Dim>(
|
fn flatten_<D1: Dim, D2: Dim>(
|
||||||
|
@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> {
|
|||||||
fn sum_grad(device: &Device) -> Result<()> {
|
fn sum_grad(device: &Device) -> Result<()> {
|
||||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||||
let x = x.as_tensor();
|
let x = x.as_tensor();
|
||||||
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?;
|
let y = (x.sqr()?.sum_keepdim(0)? * 2.)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [52.]);
|
assert_eq!(y.to_vec1::<f32>()?, [52.]);
|
||||||
@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
||||||
|
|
||||||
// Same test as before but squeezing on the last dimension.
|
// Same test as before but squeezing on the last dimension.
|
||||||
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?;
|
let y = (x.sqr()?.sum_keepdim(0)? * 2.)?.squeeze(0)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_scalar::<f32>()?, 52.);
|
assert_eq!(y.to_scalar::<f32>()?, 52.);
|
||||||
|
@ -108,65 +108,53 @@ fn sum(device: &Device) -> Result<()> {
|
|||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(2)?.to_vec3::<u32>()?,
|
||||||
&[[[8], [15]], [[10], [18]]]
|
&[[[8], [15]], [[10], [18]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(0)?.to_vec3::<u32>()?,
|
||||||
&[[[5, 2, 11], [9, 7, 17]]],
|
&[[[5, 2, 11], [9, 7, 17]]],
|
||||||
);
|
);
|
||||||
assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
assert_eq!(tensor.sum_keepdim((0, 2, 1))?.to_vec3::<u32>()?, &[[[51]]],);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
|
tensor.t()?.sum_keepdim(1)?.t()?.to_vec3::<u32>()?,
|
||||||
&[[[8], [15]], [[10], [18]]]
|
&[[[8], [15]], [[10], [18]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim((2, 1))?.to_vec3::<u32>()?,
|
||||||
&[[[8 + 15]], [[10 + 18]]]
|
&[[[8 + 15]], [[10 + 18]]]
|
||||||
);
|
);
|
||||||
let data: Vec<u32> = (0..4000u32).collect();
|
let data: Vec<u32> = (0..4000u32).collect();
|
||||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||||
assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
|
assert_eq!(tensor.sum_keepdim(0)?.to_vec1::<u32>()?, &[7998000]);
|
||||||
let tensor = tensor.reshape((2000, 2))?;
|
let tensor = tensor.reshape((2000, 2))?;
|
||||||
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::<u32>()?, &[[7998000]]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::<u32>()?,
|
||||||
.sum_keepdim(&[0])?
|
|
||||||
.sum_keepdim(&[1])?
|
|
||||||
.to_vec2::<u32>()?,
|
|
||||||
&[[7998000]]
|
&[[7998000]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||||
.sum_keepdim(&[1])?
|
|
||||||
.sum_keepdim(&[0])?
|
|
||||||
.to_vec2::<u32>()?,
|
|
||||||
&[[7998000]]
|
&[[7998000]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
tensor.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||||
&[[3998000, 4000000]]
|
&[[3998000, 4000000]]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Make the tensor non contiguous.
|
// Make the tensor non contiguous.
|
||||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||||
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::<u32>()?, &[[7998000]]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::<u32>()?,
|
||||||
.sum_keepdim(&[0])?
|
|
||||||
.sum_keepdim(&[1])?
|
|
||||||
.to_vec2::<u32>()?,
|
|
||||||
&[[7998000]]
|
&[[7998000]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||||
.sum_keepdim(&[1])?
|
|
||||||
.sum_keepdim(&[0])?
|
|
||||||
.to_vec2::<u32>()?,
|
|
||||||
&[[7998000]]
|
&[[7998000]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
tensor.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||||
&[[3998000, 4000000]]
|
&[[3998000, 4000000]]
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -174,33 +162,33 @@ fn sum(device: &Device) -> Result<()> {
|
|||||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||||
for tensor in [t1, t2] {
|
for tensor in [t1, t2] {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim((0, 1, 2))?.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor
|
||||||
.sum_keepdim(&[0])?
|
.sum_keepdim(0)?
|
||||||
.sum_keepdim(&[2])?
|
.sum_keepdim(2)?
|
||||||
.sum_keepdim(&[1])?
|
.sum_keepdim(1)?
|
||||||
.to_vec3::<u32>()?,
|
.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor
|
||||||
.sum_keepdim(&[0])?
|
.sum_keepdim(0)?
|
||||||
.sum_keepdim(&[1, 2])?
|
.sum_keepdim((1, 2))?
|
||||||
.to_vec3::<u32>()?,
|
.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor
|
||||||
.sum_keepdim(&[1])?
|
.sum_keepdim(1)?
|
||||||
.sum_keepdim(&[0, 2])?
|
.sum_keepdim((0, 2))?
|
||||||
.to_vec3::<u32>()?,
|
.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(0)?.to_vec3::<u32>()?,
|
||||||
&[[
|
&[[
|
||||||
[398000, 398200, 398400, 398600],
|
[398000, 398200, 398400, 398600],
|
||||||
[398800, 399000, 399200, 399400],
|
[398800, 399000, 399200, 399400],
|
||||||
|
@ -604,7 +604,7 @@ fn main() -> Result<()> {
|
|||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||||
let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
println!("pooled embeddings {:?}", embeddings.shape());
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
let mut similarities = vec![];
|
let mut similarities = vec![];
|
||||||
for i in 0..n_sentences {
|
for i in 0..n_sentences {
|
||||||
|
@ -95,7 +95,7 @@ impl RmsNorm {
|
|||||||
// This is a no-op if x's dtype is already f32.
|
// This is a no-op if x's dtype is already f32.
|
||||||
let x = x.to_dtype(DType::F32)?;
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().r1()?;
|
let size = self.scale.shape().r1()?;
|
||||||
|
@ -70,7 +70,7 @@ pub fn conv1d_weight_norm(
|
|||||||
) -> Result<Conv1d> {
|
) -> Result<Conv1d> {
|
||||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||||
let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?;
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
let bias = vb.get(out_c, "bias")?;
|
let bias = vb.get(out_c, "bias")?;
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
|
@ -98,7 +98,7 @@ impl T5LayerNorm {
|
|||||||
let dtype = xs.dtype();
|
let dtype = xs.dtype();
|
||||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
|
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
|
||||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||||
let xs = xs.to_dtype(dtype)?;
|
let xs = xs.to_dtype(dtype)?;
|
||||||
|
@ -51,9 +51,9 @@ impl LayerNorm {
|
|||||||
};
|
};
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?;
|
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
let x = x_normed
|
let x = x_normed
|
||||||
.to_dtype(x_dtype)?
|
.to_dtype(x_dtype)?
|
||||||
|
@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> {
|
|||||||
[4.1742344, 0.5, -3.1742344]
|
[4.1742344, 0.5, -3.1742344]
|
||||||
]]
|
]]
|
||||||
);
|
);
|
||||||
let mean = (res.sum_keepdim(&[2])? / 3.0)?;
|
let mean = (res.sum_keepdim(2)? / 3.0)?;
|
||||||
// The average value should be `b`.
|
// The average value should be `b`.
|
||||||
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
||||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?;
|
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
|
||||||
// The standard deviation should be sqrt(`w`).
|
// The standard deviation should be sqrt(`w`).
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
std.to_vec3::<f32>()?,
|
std.to_vec3::<f32>()?,
|
||||||
|
Reference in New Issue
Block a user