From e7f8e72588b963843546fa8a18ca5db9707a8637 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 25 Mar 2024 09:11:20 +0100 Subject: [PATCH] Contiguous variant of the rope kernel. (#1929) * Contiguous variant of the rope kernel. * Add the cuda kernel. * Metal kernel. --- candle-kernels/src/reduce.cu | 40 +++- candle-metal-kernels/src/lib.rs | 43 +++++ candle-metal-kernels/src/reduce.metal | 35 +++- candle-nn/src/rotary_emb.rs | 252 ++++++++++++++++++++++++++ candle-nn/tests/ops.rs | 32 +++- 5 files changed, 389 insertions(+), 13 deletions(-) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index e6002a6f..48bbcd83 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -160,6 +160,24 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c; } +template +__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx > bh * td) return; + + uint32_t i_bh = idx / (td / 2); + uint32_t i_td = idx - (td / 2) * i_bh; + uint32_t i_t = i_td / (d / 2); + uint32_t i_d = i_td - (d / 2) * i_t; + uint32_t i1 = i_bh * td + i_t * d + i_d; + uint32_t i2 = i1 + d / 2; + uint32_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} template __device__ void @@ -416,8 +434,8 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, rmsnorm(src, dst, alpha, n_cols, eps); \ } \ -#define ROPEI_OP(TYPENAME, FN_NAME) \ - extern "C" __global__ void FN_NAME( \ +#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I) \ + extern "C" __global__ void FN_NAME_I( \ const TYPENAME *src, \ const TYPENAME *cos, \ const TYPENAME *sin, \ @@ -426,11 +444,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, const uint32_t td) { \ ropei(src, cos, sin, dst, bh, td); \ } \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t bh, \ + const uint32_t td, \ + const uint32_t d) { \ + rope(src, cos, sin, dst, bh, td, d); \ + } \ #if __CUDA_ARCH__ >= 800 SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) -ROPEI_OP(__nv_bfloat16, rope_i_bf16) +ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif @@ -438,7 +466,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) -ROPEI_OP(__half, rope_i_f16) +ROPE_OP(__half, rope_f16, rope_i_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif @@ -450,8 +478,8 @@ SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) RMSNORM_OP(float, rmsnorm_f32) RMSNORM_OP(double, rmsnorm_f64) -ROPEI_OP(float, rope_i_f32) -ROPEI_OP(double, rope_i_f64) +ROPE_OP(float, rope_f32, rope_i_f32) +ROPE_OP(double, rope_f64, rope_i_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e83814a8..449bef8f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -849,6 +849,49 @@ pub fn call_rope_i( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_rope( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + d: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + d, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index fa980dea..be5a0921 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -313,8 +313,8 @@ kernel void NAME( } \ } \ -#define ROPEI(FN_NAME, TYPENAME) \ -kernel void FN_NAME( \ +#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \ +kernel void FN_NAME_I( \ constant size_t &bh, \ constant size_t &td, \ device const TYPENAME *src, \ @@ -332,6 +332,31 @@ kernel void FN_NAME( \ dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \ dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \ }\ +kernel void FN_NAME( \ + constant size_t &bh, \ + constant size_t &td, \ + constant size_t &d, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint idx [[ thread_position_in_grid ]] \ +) { \ + if (2 * idx >= bh * td) { \ + return; \ + } \ + size_t i_bh = idx / (td / 2); \ + size_t i_td = idx - (td / 2) * i_bh; \ + size_t i_t = i_td / (d / 2); \ + size_t i_d = i_td - (d / 2) * i_t; \ + size_t i1 = i_bh * td + i_t * d + i_d; \ + size_t i2 = i1 + d / 2; \ + size_t i_cs = i_t * (d / 2) + i_d; \ + TYPENAME c = cos[i_cs]; \ + TYPENAME s = sin[i_cs]; \ + dst[i1] = src[i1] * c - src[i2] * s; \ + dst[i2] = src[i1] * s + src[i2] * c; \ +}\ REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) @@ -361,8 +386,8 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) -ROPEI(rope_i_f32, float) -ROPEI(rope_i_f16, half) +ROPEI(rope_f32, rope_i_f32, float) +ROPEI(rope_f16, rope_i_f16, half) #if __METAL_VERSION__ >= 220 REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) @@ -381,5 +406,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat) -ROPEI(rope_i_bf16, bfloat) +ROPEI(rope_bf16, rope_i_bf16, bfloat) #endif diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index 20545b8d..9c5543fb 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -245,3 +245,255 @@ pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { let rope = rope.flatten_from(D::Minus2)?; Ok(rope) } + +/// Contiguous variant of rope embeddings. +#[derive(Debug, Clone)] +struct RotaryEmb; + +impl candle::CustomOp3 for RotaryEmb { + fn name(&self) -> &'static str { + "rotary-emb" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * d) + .zip(dst.par_chunks_mut(t * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i1 = i_t * d + i_d; + let i2 = i1 + d / 2; + let i_cs = i_t * (d / 2) + i_d; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, t, d).into())) + } + + use candle::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, + &cos, + &sin, + &dst, + (b * h) as u32, + (t * d) as u32, + d as u32, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use candle::backend::BackendStorage; + use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &candle::MetalStorage, + l_src: &Layout, + cos: &candle::MetalStorage, + l_cos: &Layout, + sin: &candle::MetalStorage, + l_sin: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + candle::bail!( + "dtype mismatch in rope {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + candle::DType::F32 => "rope_f32", + candle::DType::F16 => "rope_f16", + candle::DType::BF16 => "rope_bf16", + dtype => candle::bail!("rope is not implemented for {dtype:?}"), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-i")?; + candle_metal_kernels::call_rope( + device.metal_device(), + &command_buffer, + kernels, + name, + b * h, + t * d, + d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = cos.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + candle::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + candle::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + candle::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + candle::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmb) +} + +fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?; + let cos = Tensor::cat(&[cos, cos], D::Minus1)?; + let sin = Tensor::cat(&[sin, sin], D::Minus1)?; + let cos = cos.narrow(0, 0, seq_len)?; + let sin = sin.narrow(0, 0, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; + x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)? +} diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index af883b85..20a66e75 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -86,7 +86,7 @@ fn softmax_numerical_stability() -> Result<()> { Ok(()) } -fn rope(device: &Device) -> Result<()> { +fn ropei(device: &Device) -> Result<()> { use rand::{rngs::StdRng, Rng, SeedableRng}; let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); @@ -107,12 +107,40 @@ fn rope(device: &Device) -> Result<()> { let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; if device.is_cpu() { assert_eq!(sum_diff, 0.); - } else if device.is_cuda() { + } else { assert!(sum_diff < 1e-4); } Ok(()) } +fn rope(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?; + let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);