Cuda support for the mnist training. (#277)

* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
This commit is contained in:
Laurent Mazare
2023-07-29 19:48:04 +01:00
committed by GitHub
parent 16c33383eb
commit c950a5c6b1
6 changed files with 453 additions and 28 deletions

View File

@ -164,6 +164,278 @@ fn sum(device: &Device) -> Result<()> {
Ok(())
}
fn min(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.min_keepdim(2)?.to_vec3::<u32>()?,
&[[[1], [1]], [[1], [2]]]
);
assert_eq!(
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
&[[[2, 1, 4], [1, 2, 8]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.min_keepdim(0)?
.min_keepdim(2)?
.min_keepdim(1)?
.to_vec3::<u32>()?,
&[[[200]]]
);
assert_eq!(
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
&[[
[200, 201, 202, 203],
[204, 205, 206, 207],
[208, 209, 210, 211],
[212, 213, 214, 215],
[216, 217, 218, 219]
]]
);
}
Ok(())
}
fn max(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.max_keepdim(2)?.to_vec3::<u32>()?,
&[[[4], [9]], [[7], [8]]]
);
assert_eq!(
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
&[[[3, 1, 7], [8, 5, 9]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.max_keepdim(0)?
.max_keepdim(2)?
.max_keepdim(1)?
.to_vec3::<u32>()?,
&[[[3999]]]
);
assert_eq!(
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
&[[
[3980, 3981, 3982, 3983],
[3984, 3985, 3986, 3987],
[3988, 3989, 3990, 3991],
[3992, 3993, 3994, 3995],
[3996, 3997, 3998, 3999]
]]
);
}
Ok(())
}
fn argmin(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.argmin_keepdim(2)?.to_vec3::<u32>()?,
&[[[1], [0]], [[1], [1]]]
);
assert_eq!(
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
&[[[1, 0, 0], [0, 1, 1]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmin_keepdim(1)?
.argmin_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmin_keepdim(1)?
.argmin_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(2)?
.argmin_keepdim(1)?
.to_vec3::<u32>()?,
&[[[0]]]
);
assert_eq!(
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
&[[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
]]
);
}
Ok(())
}
fn argmax(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.argmax_keepdim(2)?.to_vec3::<u32>()?,
&[[[2], [2]], [[2], [0]]]
);
assert_eq!(
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
&[[[0, 0, 1], [1, 0, 0]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmax_keepdim(1)?
.argmax_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmax_keepdim(1)?
.argmax_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(2)?
.argmax_keepdim(1)?
.to_vec3::<u32>()?,
&[[[0]]]
);
assert_eq!(
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
&[[
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
]]
);
}
Ok(())
}
fn narrow(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
@ -581,6 +853,10 @@ test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);