mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Fix for cudnn bf16 conv2d. (#2535)
This commit is contained in:
@ -26,6 +26,7 @@ impl From<cudarc::driver::DriverError> for crate::Error {
|
|||||||
|
|
||||||
pub(crate) fn launch_conv2d<
|
pub(crate) fn launch_conv2d<
|
||||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||||
|
Y: cudarc::cudnn::CudnnDataType,
|
||||||
>(
|
>(
|
||||||
src: &CudaView<T>,
|
src: &CudaView<T>,
|
||||||
src_l: &crate::Layout,
|
src_l: &crate::Layout,
|
||||||
@ -48,7 +49,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
}
|
}
|
||||||
c
|
c
|
||||||
})?;
|
})?;
|
||||||
let conv = cudnn.create_conv2d::<T>(
|
let conv = cudnn.create_conv2d::<Y>(
|
||||||
/* pad */ [params.padding as i32, params.padding as i32],
|
/* pad */ [params.padding as i32, params.padding as i32],
|
||||||
/* stride */ [params.stride as i32, params.stride as i32],
|
/* stride */ [params.stride as i32, params.stride as i32],
|
||||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||||
@ -62,18 +63,18 @@ pub(crate) fn launch_conv2d<
|
|||||||
];
|
];
|
||||||
// Note that `src` already starts at the proper offset.
|
// Note that `src` already starts at the proper offset.
|
||||||
let x = if src_l.is_contiguous() {
|
let x = if src_l.is_contiguous() {
|
||||||
cudnn.create_4d_tensor(
|
cudnn.create_4d_tensor::<T>(
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
x_shape,
|
x_shape,
|
||||||
)?
|
)?
|
||||||
} else {
|
} else {
|
||||||
let s = src_l.stride();
|
let s = src_l.stride();
|
||||||
cudnn.create_4d_tensor_ex(
|
cudnn.create_4d_tensor_ex::<T>(
|
||||||
x_shape,
|
x_shape,
|
||||||
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
let w = cudnn.create_4d_filter(
|
let w = cudnn.create_4d_filter::<T>(
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
[
|
[
|
||||||
params.c_out as i32,
|
params.c_out as i32,
|
||||||
@ -83,7 +84,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||||
let y = cudnn.create_4d_tensor(
|
let y = cudnn.create_4d_tensor::<T>(
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||||
)?;
|
)?;
|
||||||
|
@ -1522,7 +1522,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::U8(out)
|
S::U8(out)
|
||||||
}
|
}
|
||||||
@ -1530,7 +1530,10 @@ impl BackendStorage for CudaStorage {
|
|||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
|
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
||||||
|
// version.
|
||||||
|
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
||||||
|
crate::cudnn::launch_conv2d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::BF16(out)
|
S::BF16(out)
|
||||||
}
|
}
|
||||||
@ -1538,7 +1541,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::F16(out)
|
S::F16(out)
|
||||||
}
|
}
|
||||||
@ -1546,7 +1549,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::F32(out)
|
S::F32(out)
|
||||||
}
|
}
|
||||||
@ -1554,7 +1557,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::F64(out)
|
S::F64(out)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user