Add to the cuda example a reproduction of the issue. (#579)

* Add to the cuda example a reproduction of the issue.

* Tweak.

* Add a test using non-square matrixes.

* Fix the conv2d kernel.

* Display the error.

* And tweak the comment.
This commit is contained in:
Laurent Mazare
2023-08-24 12:07:31 +01:00
committed by GitHub
parent dd64465899
commit ca318a6ec7
3 changed files with 58 additions and 12 deletions

View File

@ -9,8 +9,17 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::rand(-1f32, 1f32, 96, &device)?;
println!("{t}");
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let out_t = in_t.conv2d(&k_t, 0, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;

View File

@ -183,8 +183,44 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
Ok(())
}
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 2, 4, 2))
w = torch.randn((1, 2, 1, 1))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
fn conv2d_non_square(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
],
dev,
)?;
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
let t = t.reshape((1, 2, 4, 2))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
);
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);

View File

@ -64,18 +64,18 @@ __device__ void conv2d(
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, c_in, w_in, h_in)
// k: (c_out, c_in, w_k, h_k)
// src: (b_size, c_in, h_in, w_in)
// k: (c_out, c_in, h_k, w_k)
const size_t *src_dims = info;
const size_t *src_s = info + 4;
const size_t *k_dims = info + 8;
const size_t *k_s = info + 12;
const size_t w_k = k_dims[2];
const size_t h_k = k_dims[3];
const size_t h_k = k_dims[2];
const size_t w_k = k_dims[3];
const size_t c_out = k_dims[0];
const size_t c_in = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
return;
}
@ -83,8 +83,9 @@ __device__ void conv2d(
// TODO
const size_t b_idx = dst_i / (w_out * h_out * c_out);
const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
const size_t dst_w = (dst_i / h_out) % w_out;
const size_t dst_h = dst_i % h_out;
// NCHW layout.
const size_t dst_h = (dst_i / w_out) % h_out;
const size_t dst_w = dst_i % w_out;
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
@ -101,8 +102,8 @@ __device__ void conv2d(
}
src_h -= padding;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + w_offset * k_s[2] + h_offset * k_s[3];
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_h * src_s[2] + src_w * src_s[3];
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + h_offset * k_s[2] + w_offset * k_s[3];
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
}
}