Rename the .r functions to .dims so as to be a bit more explicit. (#220)

This commit is contained in:
Laurent Mazare
2023-07-22 11:39:27 +02:00
committed by GitHub
parent 52c5d8c087
commit 43c7223292
18 changed files with 56 additions and 50 deletions

View File

@ -4,7 +4,7 @@ use test_utils::to_vec3_round;
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
let (dim1, dim2) = tensor.shape().r2()?;
let (dim1, dim2) = tensor.dims2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
Ok(())
@ -12,7 +12,7 @@ fn zeros(device: &Device) -> Result<()> {
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.shape().r1()?;
let dim1 = tensor.dims1()?;
assert_eq!(dim1, 3);
let content: Vec<f32> = tensor.to_vec1()?;
assert_eq!(content, [3., 1., 4.]);
@ -28,7 +28,7 @@ fn add_mul(device: &Device) -> Result<()> {
fn tensor_2d(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let dims = tensor.shape().r2()?;
let dims = tensor.dims2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content, data);
@ -41,7 +41,7 @@ fn binary_op(device: &Device) -> Result<()> {
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
let tensor2 = Tensor::new(data2, device)?;
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
let dims = tensor.shape().r2()?;
let dims = tensor.dims2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
@ -56,7 +56,7 @@ fn binary_op(device: &Device) -> Result<()> {
fn transpose(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?.t()?;
let dims = tensor.shape().r2()?;
let dims = tensor.dims2()?;
assert_eq!(dims, (5, 2));
assert_eq!(
tensor.to_vec2::<f32>()?,