mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling. * Update for the latest tokenizers version.
This commit is contained in:
@ -48,7 +48,7 @@ safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.3", default-features = false }
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
@ -980,7 +980,7 @@ impl Map1 for Pool2D {
|
||||
dev: &CudaDevice,
|
||||
inp_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, w_k, h_k)
|
||||
// Input shape: (b_size, c, h, w)
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
@ -1018,6 +1018,39 @@ impl Map1 for Pool2D {
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
inp_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Input shape: (b_size, c, h, w)
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let (out_w, out_h) = (self.0, self.1);
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let scale_w = dims[2] as f64 / out_w as f64;
|
||||
let scale_h = dims[3] as f64 / out_h as f64;
|
||||
let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
|
||||
impl<'a> Map2 for WhereCond<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1513,8 +1546,10 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
|
@ -56,9 +56,8 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsample_nearest2d() -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
|
||||
fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
|
||||
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
|
||||
assert_eq!(
|
||||
t.i(0)?.i(0)?.to_vec2::<f32>()?,
|
||||
@ -83,3 +82,8 @@ test_device!(
|
||||
avg_pool2d_pytorch_gpu
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu
|
||||
);
|
||||
|
@ -111,7 +111,10 @@ fn main() -> Result<()> {
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
|
@ -65,10 +65,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?;
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
@ -72,16 +72,14 @@ impl TextGeneration {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
self.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?
|
||||
self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ fn main() -> Result<()> {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
@ -231,7 +231,7 @@ fn main() -> Result<()> {
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -169,10 +169,7 @@ impl Decoder {
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(tokens.clone(), true)
|
||||
.map_err(E::msg)?;
|
||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
||||
|
||||
Ok(DecodingResult {
|
||||
|
@ -220,6 +220,48 @@ __device__ void max_pool2d(
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void upsample_nearest2d(
|
||||
const size_t w_out,
|
||||
const size_t h_out,
|
||||
const double w_scale,
|
||||
const double h_scale,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
if (dst_i >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Improve this.
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c);
|
||||
const size_t c_idx = (dst_i / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (dst_i / h_out) % w_out;
|
||||
const size_t dst_h = dst_i % h_out;
|
||||
|
||||
size_t src_w = static_cast<size_t>(dst_w * w_scale);
|
||||
size_t src_h = static_cast<size_t>(dst_h * h_scale);
|
||||
if (src_w >= w_in) {
|
||||
src_w = w_in - 1;
|
||||
}
|
||||
if (src_h >= h_in) {
|
||||
src_h = h_in - 1;
|
||||
}
|
||||
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
|
||||
dst[dst_i] = src[src_i];
|
||||
}
|
||||
|
||||
|
||||
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -278,11 +320,25 @@ extern "C" __global__ void FN_NAME( \
|
||||
max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t w_out, \
|
||||
const size_t h_out, \
|
||||
const double w_scale, \
|
||||
const double h_scale, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, info, src, dst); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -290,6 +346,7 @@ CONV1D_OP(__half, float, conv1d_f16)
|
||||
CONV2D_OP(__half, float, conv2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -311,3 +368,8 @@ MAX_POOL2D_OP(float, max_pool2d_f32)
|
||||
MAX_POOL2D_OP(double, max_pool2d_f64)
|
||||
MAX_POOL2D_OP(uint8_t, max_pool2d_u8)
|
||||
MAX_POOL2D_OP(uint32_t, max_pool2d_u32)
|
||||
|
||||
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
|
@ -159,10 +159,7 @@ impl Decoder {
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(tokens.clone(), true)
|
||||
.map_err(E::msg)?;
|
||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
||||
|
||||
Ok(DecodingResult {
|
||||
|
Reference in New Issue
Block a user