mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
This commit is contained in:
@ -4,7 +4,7 @@ use test_utils::to_vec3_round;
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
let (dim1, dim2) = tensor.shape().r2()?;
|
||||
let (dim1, dim2) = tensor.dims2()?;
|
||||
assert_eq!(dim1, 5);
|
||||
assert_eq!(dim2, 2);
|
||||
Ok(())
|
||||
@ -12,7 +12,7 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.shape().r1()?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
assert_eq!(dim1, 3);
|
||||
let content: Vec<f32> = tensor.to_vec1()?;
|
||||
assert_eq!(content, [3., 1., 4.]);
|
||||
@ -28,7 +28,7 @@ fn add_mul(device: &Device) -> Result<()> {
|
||||
fn tensor_2d(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content, data);
|
||||
@ -41,7 +41,7 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor2 = Tensor::new(data2, device)?;
|
||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
|
||||
@ -56,7 +56,7 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
fn transpose(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?.t()?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (5, 2));
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<f32>()?,
|
||||
|
Reference in New Issue
Block a user