mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -7,9 +7,9 @@ use candle::{Device, Tensor};
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
||||||
let sum = t.sum(&[0])?;
|
let sum = t.sum_keepdim(&[0])?;
|
||||||
println!("{sum}");
|
println!("{sum}");
|
||||||
let sum = t.sum(&[1])?;
|
let sum = t.sum_keepdim(&[1])?;
|
||||||
println!("{sum}");
|
println!("{sum}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -27,18 +27,18 @@ fn main() -> Result<()> {
|
|||||||
let xys_cpu = cos_sin(n, &Device::Cpu)?;
|
let xys_cpu = cos_sin(n, &Device::Cpu)?;
|
||||||
let xys = cos_sin(n, &device)?;
|
let xys = cos_sin(n, &device)?;
|
||||||
println!("{xys_cpu:?} {xys:?}");
|
println!("{xys_cpu:?} {xys:?}");
|
||||||
let sum_cpu = xys_cpu.sum(&[1])?;
|
let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
|
||||||
println!("{sum_cpu}");
|
println!("{sum_keepdim_cpu}");
|
||||||
let sum = xys.sum(&[1])?;
|
let sum_keepdim = xys.sum_keepdim(&[1])?;
|
||||||
println!("{sum}");
|
println!("{sum_keepdim}");
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let n_iters = 100;
|
let n_iters = 100;
|
||||||
let mut v = 0f32;
|
let mut v = 0f32;
|
||||||
for _i in 0..n_iters {
|
for _i in 0..n_iters {
|
||||||
let sum = xys.sum(&[1])?;
|
let sum_keepdim = xys.sum_keepdim(&[1])?;
|
||||||
let sum = sum.sum(&[0])?;
|
let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
|
||||||
let sum: f32 = sum.reshape(&[])?.to_scalar()?;
|
let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
|
||||||
v += sum;
|
v += sum_keepdim;
|
||||||
}
|
}
|
||||||
let elapsed = start.elapsed();
|
let elapsed = start.elapsed();
|
||||||
if v > 0. {
|
if v > 0. {
|
||||||
|
@ -195,11 +195,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut arg_grad = grad.sum(sum_dims.as_slice())?;
|
let arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||||
// sum_dims has increasing values.
|
|
||||||
for &dim in sum_dims.iter().rev() {
|
|
||||||
arg_grad = arg_grad.squeeze(dim)?
|
|
||||||
}
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||||
}
|
}
|
||||||
|
@ -572,7 +572,7 @@ impl Tensor {
|
|||||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||||
// the operation.
|
// the operation.
|
||||||
let exp = self.exp()?;
|
let exp = self.exp()?;
|
||||||
let sum_exp = exp.sum(&[dim])?;
|
let sum_exp = exp.sum_keepdim(&[dim])?;
|
||||||
exp.broadcast_div(&sum_exp)
|
exp.broadcast_div(&sum_exp)
|
||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
@ -591,21 +591,21 @@ impl Tensor {
|
|||||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
/// input dimensions.
|
/// input dimensions.
|
||||||
///
|
///
|
||||||
/// The resulting tensor as a shape that is similar to the shape of the input tensor, except
|
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use candle::{Tensor, Device};
|
/// use candle::{Tensor, Device};
|
||||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||||
/// let s = a.sum(&[0])?;
|
/// let s = a.sum_keepdim(&[0])?;
|
||||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||||
/// let s = a.sum(&[1])?;
|
/// let s = a.sum_keepdim(&[1])?;
|
||||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||||
/// let s = a.sum(&[0, 1])?;
|
/// let s = a.sum_keepdim(&[0, 1])?;
|
||||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||||
/// # Ok::<(), candle::Error>(())
|
/// # Ok::<(), candle::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
for &dim in sum_dims {
|
for &dim in sum_dims {
|
||||||
self.check_dim(dim, "sum")?;
|
self.check_dim(dim, "sum")?;
|
||||||
}
|
}
|
||||||
@ -622,6 +622,32 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, dims, op, false))
|
Ok(from_storage(storage, dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
|
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
|
||||||
|
/// kept.
|
||||||
|
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
|
let sum = self.sum_keepdim(sum_dims)?;
|
||||||
|
match sum_dims {
|
||||||
|
[] => Ok(sum),
|
||||||
|
[i] => sum.squeeze(*i),
|
||||||
|
sum_dims => {
|
||||||
|
let dims = sum
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(dim_idx, &v)| {
|
||||||
|
if sum_dims.contains(&dim_idx) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(v)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
sum.reshape(dims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies a 1D convolution over the input tensor.
|
/// Applies a 1D convolution over the input tensor.
|
||||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||||
@ -936,7 +962,7 @@ impl Tensor {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn sum_all(&self) -> Result<Tensor> {
|
pub fn sum_all(&self) -> Result<Tensor> {
|
||||||
let dims: Vec<_> = (0..self.rank()).collect();
|
let dims: Vec<_> = (0..self.rank()).collect();
|
||||||
self.sum(&dims)?.reshape(())
|
self.sum_keepdim(&dims)?.reshape(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flatten_<D1: Dim, D2: Dim>(
|
fn flatten_<D1: Dim, D2: Dim>(
|
||||||
|
@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> {
|
|||||||
fn sum_grad(device: &Device) -> Result<()> {
|
fn sum_grad(device: &Device) -> Result<()> {
|
||||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||||
let x = x.as_tensor();
|
let x = x.as_tensor();
|
||||||
let y = (x.sqr()?.sum(&[0])? * 2.)?;
|
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [52.]);
|
assert_eq!(y.to_vec1::<f32>()?, [52.]);
|
||||||
@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
||||||
|
|
||||||
// Same test as before but squeezing on the last dimension.
|
// Same test as before but squeezing on the last dimension.
|
||||||
let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
|
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_scalar::<f32>()?, 52.);
|
assert_eq!(y.to_scalar::<f32>()?, 52.);
|
||||||
|
@ -108,56 +108,99 @@ fn sum(device: &Device) -> Result<()> {
|
|||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[2])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
|
||||||
&[[[8], [15]], [[10], [18]]]
|
&[[[8], [15]], [[10], [18]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[0])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
||||||
&[[[5, 2, 11], [9, 7, 17]]],
|
&[[[5, 2, 11], [9, 7, 17]]],
|
||||||
);
|
);
|
||||||
assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
|
tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
|
||||||
&[[[8], [15]], [[10], [18]]]
|
&[[[8], [15]], [[10], [18]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[2, 1])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
|
||||||
&[[[8 + 15]], [[10 + 18]]]
|
&[[[8 + 15]], [[10 + 18]]]
|
||||||
);
|
);
|
||||||
let data: Vec<u32> = (0..4000u32).collect();
|
let data: Vec<u32> = (0..4000u32).collect();
|
||||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||||
assert_eq!(tensor.sum(&[0])?.to_vec1::<u32>()?, &[7998000]);
|
assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
|
||||||
let tensor = tensor.reshape((2000, 2))?;
|
let tensor = tensor.reshape((2000, 2))?;
|
||||||
assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||||
assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
|
assert_eq!(
|
||||||
assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
|
tensor
|
||||||
assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
|
.sum_keepdim(&[0])?
|
||||||
|
.sum_keepdim(&[1])?
|
||||||
|
.to_vec2::<u32>()?,
|
||||||
|
&[[7998000]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor
|
||||||
|
.sum_keepdim(&[1])?
|
||||||
|
.sum_keepdim(&[0])?
|
||||||
|
.to_vec2::<u32>()?,
|
||||||
|
&[[7998000]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
||||||
|
&[[3998000, 4000000]]
|
||||||
|
);
|
||||||
|
|
||||||
// Make the tensor non contiguous.
|
// Make the tensor non contiguous.
|
||||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||||
assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||||
assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
|
assert_eq!(
|
||||||
assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
|
tensor
|
||||||
assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
|
.sum_keepdim(&[0])?
|
||||||
|
.sum_keepdim(&[1])?
|
||||||
|
.to_vec2::<u32>()?,
|
||||||
|
&[[7998000]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor
|
||||||
|
.sum_keepdim(&[1])?
|
||||||
|
.sum_keepdim(&[0])?
|
||||||
|
.to_vec2::<u32>()?,
|
||||||
|
&[[7998000]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
||||||
|
&[[3998000, 4000000]]
|
||||||
|
);
|
||||||
|
|
||||||
let t1 = tensor.reshape((200, 5, 4))?;
|
let t1 = tensor.reshape((200, 5, 4))?;
|
||||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||||
for tensor in [t1, t2] {
|
for tensor in [t1, t2] {
|
||||||
assert_eq!(tensor.sum(&[0, 1, 2])?.to_vec3::<u32>()?, &[[[7998000]]]);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[0])?.sum(&[2])?.sum(&[1])?.to_vec3::<u32>()?,
|
tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[0])?.sum(&[1, 2])?.to_vec3::<u32>()?,
|
tensor
|
||||||
|
.sum_keepdim(&[0])?
|
||||||
|
.sum_keepdim(&[2])?
|
||||||
|
.sum_keepdim(&[1])?
|
||||||
|
.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[1])?.sum(&[0, 2])?.to_vec3::<u32>()?,
|
tensor
|
||||||
|
.sum_keepdim(&[0])?
|
||||||
|
.sum_keepdim(&[1, 2])?
|
||||||
|
.to_vec3::<u32>()?,
|
||||||
&[[[7998000]]]
|
&[[[7998000]]]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.sum(&[0])?.to_vec3::<u32>()?,
|
tensor
|
||||||
|
.sum_keepdim(&[1])?
|
||||||
|
.sum_keepdim(&[0, 2])?
|
||||||
|
.to_vec3::<u32>()?,
|
||||||
|
&[[[7998000]]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
||||||
&[[
|
&[[
|
||||||
[398000, 398200, 398400, 398600],
|
[398000, 398200, 398400, 398600],
|
||||||
[398800, 399000, 399200, 399400],
|
[398800, 399000, 399200, 399400],
|
||||||
|
@ -604,16 +604,16 @@ fn main() -> Result<()> {
|
|||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||||
let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
||||||
println!("pooled embeddings {:?}", embeddings.shape());
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
let mut similarities = vec![];
|
let mut similarities = vec![];
|
||||||
for i in 0..n_sentences {
|
for i in 0..n_sentences {
|
||||||
let e_i = embeddings.get(i)?;
|
let e_i = embeddings.get(i)?;
|
||||||
for j in (i + 1)..n_sentences {
|
for j in (i + 1)..n_sentences {
|
||||||
let e_j = embeddings.get(j)?;
|
let e_j = embeddings.get(j)?;
|
||||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
similarities.push((cosine_similarity, i, j))
|
similarities.push((cosine_similarity, i, j))
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,7 @@ impl RmsNorm {
|
|||||||
// This is a no-op if x's dtype is already f32.
|
// This is a no-op if x's dtype is already f32.
|
||||||
let x = x.to_dtype(DType::F32)?;
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().r1()?;
|
let size = self.scale.shape().r1()?;
|
||||||
|
@ -70,7 +70,7 @@ pub fn conv1d_weight_norm(
|
|||||||
) -> Result<Conv1d> {
|
) -> Result<Conv1d> {
|
||||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||||
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?;
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
let bias = vb.get(out_c, "bias")?;
|
let bias = vb.get(out_c, "bias")?;
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
|
@ -98,7 +98,7 @@ impl T5LayerNorm {
|
|||||||
let dtype = xs.dtype();
|
let dtype = xs.dtype();
|
||||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||||
let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
|
let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
|
||||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||||
let xs = xs.to_dtype(dtype)?;
|
let xs = xs.to_dtype(dtype)?;
|
||||||
|
@ -51,9 +51,9 @@ impl LayerNorm {
|
|||||||
};
|
};
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
let x = x_normed
|
let x = x_normed
|
||||||
.to_dtype(x_dtype)?
|
.to_dtype(x_dtype)?
|
||||||
|
@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> {
|
|||||||
[4.1742344, 0.5, -3.1742344]
|
[4.1742344, 0.5, -3.1742344]
|
||||||
]]
|
]]
|
||||||
);
|
);
|
||||||
let mean = (res.sum(&[2])? / 3.0)?;
|
let mean = (res.sum_keepdim(&[2])? / 3.0)?;
|
||||||
// The average value should be `b`.
|
// The average value should be `b`.
|
||||||
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
||||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum(&[2])?.sqrt()? / 3.0)?;
|
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?;
|
||||||
// The standard deviation should be sqrt(`w`).
|
// The standard deviation should be sqrt(`w`).
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
std.to_vec3::<f32>()?,
|
std.to_vec3::<f32>()?,
|
||||||
|
@ -312,9 +312,11 @@ impl PyTensor {
|
|||||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
|
fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||||
// TODO: Support a single dim as input?
|
// TODO: Support a single dim as input?
|
||||||
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
|
Ok(PyTensor(
|
||||||
|
self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum_all(&self) -> PyResult<Self> {
|
fn sum_all(&self) -> PyResult<Self> {
|
||||||
|
Reference in New Issue
Block a user