mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Support strides in affine.
This commit is contained in:
@ -1,27 +1,42 @@
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
extern "C" __global__ void affine_f32(
|
||||
const size_t numel,
|
||||
const size_t num_dims,
|
||||
const size_t *info,
|
||||
const float *x,
|
||||
float *y,
|
||||
const float mul,
|
||||
const float add
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= numel) {
|
||||
return;
|
||||
}
|
||||
y[i] = x[i] * mul + add;
|
||||
// This is likely to be very very slow, we should either optimize the contiguous case
|
||||
// as a separate kernel, proceed by block, improve the stride computations (and probably
|
||||
// do all of these).
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
y[strided_i] = x[i] * mul + add;
|
||||
}
|
||||
|
||||
extern "C" __global__ void affine_f64(
|
||||
const size_t numel,
|
||||
const size_t num_dims,
|
||||
const size_t *info,
|
||||
const double *x,
|
||||
double *y,
|
||||
const double mul,
|
||||
const double add
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= numel) {
|
||||
return;
|
||||
}
|
||||
y[i] = x[i] * mul + add;
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
y[strided_i] = x[i] * mul + add;
|
||||
}
|
||||
|
@ -135,19 +135,17 @@ impl CudaStorage {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
if !shape.is_contiguous(stride) {
|
||||
return Err(CudaError::RequiresContiguous { op: "affine" });
|
||||
}
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.0.htod_copy([dims, stride].concat())?;
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out, mul as f32, add as f32);
|
||||
let out = unsafe { dev.0.alloc::<f32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
@ -155,8 +153,8 @@ impl CudaStorage {
|
||||
Self::F64(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out, mul, add);
|
||||
let out = unsafe { dev.0.alloc::<f64>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F64(out))
|
||||
@ -209,13 +207,12 @@ impl CudaStorage {
|
||||
let dims = shape.dims();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
let dims_and_strides = [dims, lhs_stride, rhs_stride].concat();
|
||||
let dims_and_strides = dev.0.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||
match (self, rhs) {
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let dims_and_strides = dev.0.htod_copy(dims_and_strides)?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
@ -225,7 +222,6 @@ impl CudaStorage {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let dims_and_strides = dev.0.htod_copy(dims_and_strides)?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
|
Reference in New Issue
Block a user