mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training. * min/max fix + testing. * Add the argmin/argmax tests. * More cuda support for argmin/argmax. * Cuda kernels for argmin and argmax.
This commit is contained in:
@ -164,6 +164,278 @@ fn sum(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn min(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[1], [1]], [[1], [2]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[2, 1, 4], [1, 2, 8]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.min_keepdim(0)?
|
||||
.min_keepdim(2)?
|
||||
.min_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[200]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[200, 201, 202, 203],
|
||||
[204, 205, 206, 207],
|
||||
[208, 209, 210, 211],
|
||||
[212, 213, 214, 215],
|
||||
[216, 217, 218, 219]
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn max(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[4], [9]], [[7], [8]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[3, 1, 7], [8, 5, 9]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.max_keepdim(0)?
|
||||
.max_keepdim(2)?
|
||||
.max_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[3999]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[3980, 3981, 3982, 3983],
|
||||
[3984, 3985, 3986, 3987],
|
||||
[3988, 3989, 3990, 3991],
|
||||
[3992, 3993, 3994, 3995],
|
||||
[3996, 3997, 3998, 3999]
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn argmin(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[1], [0]], [[1], [1]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[1, 0, 0], [0, 1, 1]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(1)?
|
||||
.argmin_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(1)?
|
||||
.argmin_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(2)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn argmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[2], [2]], [[2], [0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[0, 0, 1], [1, 0, 0]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(1)?
|
||||
.argmax_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(1)?
|
||||
.argmax_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(2)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn narrow(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -581,6 +853,10 @@ test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
|
Reference in New Issue
Block a user