diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml index ec792a25..02814ed7 100644 --- a/.github/workflows/ci_cuda.yaml +++ b/.github/workflows/ci_cuda.yaml @@ -8,6 +8,8 @@ jobs: start-runner: name: Start self-hosted EC2 runner runs-on: ubuntu-latest + # Don't run on forks, they won't have access to secrets anyway. + if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} env: AWS_REGION: us-east-1 EC2_AMI_ID: ami-03cfed9ea28f4b002 @@ -70,7 +72,7 @@ jobs: runs-on: ubuntu-latest env: AWS_REGION: us-east-1 - if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs + if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs steps: - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v1 diff --git a/README.md b/README.md index a03367a5..93cbccc4 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ And then head over to - [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. +- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem. If you have an addition to this list, please submit a pull request. diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs index a1ec1e94..f8ca510d 100644 --- a/candle-book/src/lib.rs +++ b/candle-book/src/lib.rs @@ -28,6 +28,7 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap() #[rustfmt::skip] #[test] fn book_hub_2() { + { // ANCHOR: book_hub_2 use candle::Device; use hf_hub::api::sync::Api; @@ -45,9 +46,10 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap() assert_eq!(weights.len(), 206); } - #[rustfmt::skip] - #[test] - fn book_hub_3() { + // #[rustfmt::skip] + // #[test] + // fn book_hub_3() { + { // ANCHOR: book_hub_3 use candle::{DType, Device, Tensor}; use hf_hub::api::sync::Api; @@ -102,6 +104,7 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un assert_eq!(view.shape(), &[768, 768]); assert_eq!(tp_tensor.dims(), &[192, 768]); } +} #[rustfmt::skip] #[test] diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 6bd12589..1b279999 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -46,7 +46,7 @@ accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] -name = "matmul" +name = "bench_main" harness = false [[bench]] diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs new file mode 100644 index 00000000..4425f2fb --- /dev/null +++ b/candle-core/benches/bench_main.rs @@ -0,0 +1,4 @@ +mod benchmarks; + +use criterion::criterion_main; +criterion_main!(benchmarks::matmul::benches); diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/benchmarks/matmul.rs similarity index 63% rename from candle-core/benches/matmul.rs rename to candle-core/benches/benchmarks/matmul.rs index 83679771..fb173f04 100644 --- a/candle-core/benches/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -1,5 +1,6 @@ -use candle_core::{DType, Device, Tensor}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use crate::benchmarks::{bench_name, device, BenchDevice}; +use candle_core::{DType, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; fn run(a: &Tensor, b: &Tensor) { @@ -12,14 +13,14 @@ fn criterion_benchmark(c: &mut Criterion) { let n = 2048; let k = 2048; - let device = Device::new_metal(0).unwrap(); + let device = device().unwrap(); let dtype = DType::F32; let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); let flops = b * m * n * k; - let mut group = c.benchmark_group("matmul_metal"); + let mut group = c.benchmark_group(bench_name("matmul")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { @@ -27,11 +28,7 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { run(black_box(&lhs), black_box(&rhs)); } - if let Device::Metal(device) = &device { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } + device.sync().unwrap(); start.elapsed() }) }); @@ -39,4 +36,3 @@ fn criterion_benchmark(c: &mut Criterion) { } criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs new file mode 100644 index 00000000..1344770d --- /dev/null +++ b/candle-core/benches/benchmarks/mod.rs @@ -0,0 +1,55 @@ +pub(crate) mod matmul; + +use candle_core::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + return Ok(device.synchronize()?); + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } +} + +pub(crate) fn device() -> Result { + if cfg!(feature = "metal") { + Device::new_metal(0) + } else if cfg!(feature = "cuda") { + Device::new_cuda(0) + } else { + Ok(Device::Cpu) + } +} + +pub(crate) fn bench_name>(name: S) -> String { + format!("{}_{}", device_variant(), name.into()) +} + +const fn device_variant() -> &'static str { + if cfg!(feature = "metal") { + "metal" + } else if cfg!(feature = "cuda") { + "cuda" + } else if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + } +} diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 3f6060ce..a4c36366 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -354,7 +354,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", - dtype => crate::bail!("Affine {dtype:?}"), + dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( &device.device, @@ -372,7 +372,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", - dtype => crate::bail!("Affine {dtype:?}"), + dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( &device.device, @@ -405,7 +405,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "powf_f32", DType::F16 => "powf_f16", - dtype => crate::bail!("Powf {dtype:?}"), + dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"), }; candle_metal_kernels::call_powf( &device.device, @@ -422,7 +422,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "powf_f32_strided", DType::F16 => "powf_f16_strided", - dtype => crate::bail!("Powf {dtype:?}"), + dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"), }; candle_metal_kernels::call_powf_strided( &device.device, @@ -454,7 +454,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "elu_f32", DType::F16 => "elu_f16", - dtype => crate::bail!("Powf {dtype:?}"), + dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"), }; candle_metal_kernels::call_elu( &device.device, @@ -471,7 +471,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "elu_f32_strided", DType::F16 => "elu_f16_strided", - dtype => crate::bail!("Powf {dtype:?}"), + dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"), }; candle_metal_kernels::call_elu_strided( &device.device, @@ -533,7 +533,17 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), - (k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"), + (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false), + (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false), + (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), + (ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true), + (ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true), + (ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false), + (ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false), + (ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false), + (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true), + (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true), + (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? @@ -580,11 +590,18 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U32, DType::I64) => "cast_u32_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I64) => "cast_u8_i64", (DType::F32, DType::F16) => "cast_f32_f16", (DType::F16, DType::F32) => "cast_f16_f32", - (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + (DType::I64, DType::F32) => "cast_i64_f32", + (DType::F32, DType::BF16) => "cast_f32_bf16", + (DType::BF16, DType::F32) => "cast_bf16_f32", + (left, right) => { + crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") + } }; candle_metal_kernels::call_cast_contiguous( &device.device, @@ -601,11 +618,18 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", - (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + (DType::I64, DType::F32) => "cast_i64_f32_strided", + (DType::F32, DType::BF16) => "cast_f32_bf16_strided", + (DType::BF16, DType::F32) => "cast_bf16_f32_strided", + (left, right) => { + crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented") + } }; candle_metal_kernels::call_cast_strided( &device.device, @@ -646,9 +670,11 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F32) => contiguous::gelu::FLOAT, ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uabs", DType::F32) => contiguous::abs::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, + ("urecip", DType::F32) => contiguous::recip::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, @@ -660,11 +686,15 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F16) => contiguous::gelu::HALF, ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, ("uerf", DType::F16) => contiguous::erf::HALF, + ("uabs", DType::F16) => contiguous::abs::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, + ("urecip", DType::F16) => contiguous::recip::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + (name, dtype) => { + crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") + } }; candle_metal_kernels::call_unary_contiguous( &device.device, @@ -689,6 +719,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F32) => strided::gelu::FLOAT, ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, ("uerf", DType::F32) => strided::erf::FLOAT, + ("uabs", DType::F32) => strided::abs::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, @@ -702,10 +733,13 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F16) => strided::gelu::HALF, ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, ("uerf", DType::F16) => strided::erf::HALF, + ("uabs", DType::F16) => strided::abs::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, ("uround", DType::F16) => strided::round::HALF, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + (name, dtype) => { + crate::bail!("Metal strided unary {name} {dtype:?} not implemented") + } }; candle_metal_kernels::call_unary_strided( &device.device, @@ -758,7 +792,10 @@ impl BackendStorage for MetalStorage { let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F16) => "where_u8_f16", - (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"), + (DType::U8, DType::I64) => "where_u8_i64", + (DType::U8, DType::U32) => "where_u8_u32", + (DType::U8, DType::U8) => "where_u8_u8", + (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), }; candle_metal_kernels::call_where_cond_strided( &device.device, @@ -805,7 +842,7 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer()?; let name = match self.dtype { DType::F32 => "im2col1d_f32", - dtype => crate::bail!("conv1d metal {dtype:?} not implemented"), + dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"), }; candle_metal_kernels::call_im2col1d_strided( &self.device.device, @@ -858,7 +895,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose1D, ) -> Result { - crate::bail!("conv_transpose1d metal") + crate::bail!("Metal conv_transpose1d not implemented") } fn conv2d( @@ -889,7 +926,7 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer()?; let name = match self.dtype { DType::F32 => "im2col_f32", - dtype => crate::bail!("conv1d metal {dtype:?} not implemented"), + dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"), }; candle_metal_kernels::call_im2col_strided( &self.device.device, @@ -945,19 +982,19 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose2D, ) -> Result { - crate::bail!("conv_tranpose2d metal") + crate::bail!("Metal conv_tranpose2d not implemented") } fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - crate::bail!("avg_pool2d metal") + crate::bail!("Metal avg_pool2d not implemented") } fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - crate::bail!("max_pool2d metal") + crate::bail!("Metal max_pool2d not implemented") } fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { - crate::bail!("upsample_nearest1d metal") + crate::bail!("Metal upsample_nearest1d not implemented") } fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result { @@ -970,7 +1007,7 @@ impl BackendStorage for MetalStorage { } let name = match self.dtype { DType::F32 => "upsample_nearest2d_f32", - dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"), + dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"), }; let dst_el = out_w * out_h * dims[0] * dims[1]; @@ -1008,7 +1045,7 @@ impl BackendStorage for MetalStorage { let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F16) => "gather_u32_f16", - (left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"), + (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_gather( @@ -1081,7 +1118,7 @@ impl BackendStorage for MetalStorage { && ids_l.is_contiguous() && ids_l.start_offset() == 0) { - crate::bail!("Non contiguous index select not implemented"); + crate::bail!("Metal strided index_select not implemented"); } let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); @@ -1093,7 +1130,9 @@ impl BackendStorage for MetalStorage { let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", - (left, right) => crate::bail!("index select metal {left:?} {right:?}"), + (left, right) => { + crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") + } }; let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_index_select( @@ -1134,7 +1173,7 @@ impl BackendStorage for MetalStorage { let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "ia_u32_f32", _ => Err(MetalError::UnexpectedDType { - msg: "index-add ids should be u8/u32/i64", + msg: "index-add ids should be u32", expected: DType::U32, got: ids.dtype(), })?, @@ -1215,9 +1254,10 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, DType::U8 => candle_metal_kernels::unary::strided::copy::U8, - dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"), }; candle_metal_kernels::call_unary_strided( &self.device.device, @@ -1289,7 +1329,39 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), - (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), + ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), + ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), + ("div", DType::I64) => (contiguous::div::I64, self.dtype), + ("eq", DType::I64) => (contiguous::eq::I64, DType::U8), + ("ne", DType::I64) => (contiguous::ne::I64, DType::U8), + ("le", DType::I64) => (contiguous::le::I64, DType::U8), + ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), + ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), + ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), + ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), + ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), + ("div", DType::U32) => (contiguous::div::U32, self.dtype), + ("eq", DType::U32) => (contiguous::eq::U32, DType::U8), + ("ne", DType::U32) => (contiguous::ne::U32, DType::U8), + ("le", DType::U32) => (contiguous::le::U32, DType::U8), + ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), + ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), + ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), + ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), + ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), + ("div", DType::U8) => (contiguous::div::U8, self.dtype), + ("eq", DType::U8) => (contiguous::eq::U8, DType::U8), + ("ne", DType::U8) => (contiguous::ne::U8, DType::U8), + ("le", DType::U8) => (contiguous::le::U8, DType::U8), + ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), + ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), + ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + (name, dtype) => { + crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") + } }; let buffer = device.new_buffer(el_count, dtype, op)?; candle_metal_kernels::call_binary_contiguous( @@ -1332,7 +1404,45 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), - (name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"), + ("badd", DType::I64) => (strided::add::I64, self.dtype), + ("bsub", DType::I64) => (strided::sub::I64, self.dtype), + ("bmul", DType::I64) => (strided::mul::I64, self.dtype), + ("bdiv", DType::I64) => (strided::div::I64, self.dtype), + ("bminimum", DType::I64) => (strided::min::I64, self.dtype), + ("bmaximum", DType::I64) => (strided::max::I64, self.dtype), + ("eq", DType::I64) => (strided::eq::I64, DType::U8), + ("ne", DType::I64) => (strided::ne::I64, DType::U8), + ("le", DType::I64) => (strided::le::I64, DType::U8), + ("lt", DType::I64) => (strided::lt::I64, DType::U8), + ("ge", DType::I64) => (strided::ge::I64, DType::U8), + ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), + ("bsub", DType::U32) => (strided::sub::U32, self.dtype), + ("bmul", DType::U32) => (strided::mul::U32, self.dtype), + ("bdiv", DType::U32) => (strided::div::U32, self.dtype), + ("bminimum", DType::U32) => (strided::min::U32, self.dtype), + ("bmaximum", DType::U32) => (strided::max::U32, self.dtype), + ("eq", DType::U32) => (strided::eq::U32, DType::U8), + ("ne", DType::U32) => (strided::ne::U32, DType::U8), + ("le", DType::U32) => (strided::le::U32, DType::U8), + ("lt", DType::U32) => (strided::lt::U32, DType::U8), + ("ge", DType::U32) => (strided::ge::U32, DType::U8), + ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), + ("bsub", DType::U8) => (strided::sub::U8, self.dtype), + ("bmul", DType::U8) => (strided::mul::U8, self.dtype), + ("bdiv", DType::U8) => (strided::div::U8, self.dtype), + ("bminimum", DType::U8) => (strided::min::U8, self.dtype), + ("bmaximum", DType::U8) => (strided::max::U8, self.dtype), + ("eq", DType::U8) => (strided::eq::U8, DType::U8), + ("ne", DType::U8) => (strided::ne::U8, DType::U8), + ("le", DType::U8) => (strided::le::U8, DType::U8), + ("lt", DType::U8) => (strided::lt::U8, DType::U8), + ("ge", DType::U8) => (strided::ge::U8, DType::U8), + ("gt", DType::U8) => (strided::gt::U8, DType::U8), + (name, dtype) => { + crate::bail!("Metal strided binary {name} {dtype:?} not implemented") + } }; let buffer = device.new_buffer(el_count, dtype, op)?; candle_metal_kernels::call_binary_strided( @@ -1387,7 +1497,7 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, _seed: u64) -> Result<()> { - crate::bail!("set_seed") + crate::bail!("Metal set_seed not implemented") } fn location(&self) -> crate::DeviceLocation { diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index c2ed0e25..251c184b 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -33,6 +33,8 @@ enum Which { V2, #[value(name = "solar-10.7b")] Solar10_7B, + #[value(name = "tiny-llama-1.1b-chat")] + TinyLlama1_1BChat, } #[derive(Parser, Debug)] @@ -124,6 +126,7 @@ fn main() -> Result<()> { Which::V1 => "Narsil/amall-7b".to_string(), Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); println!("loading the model weights from {model_id}"); let revision = args.revision.unwrap_or("main".to_string()); @@ -134,8 +137,12 @@ fn main() -> Result<()> { let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config = config.into_config(args.use_flash_attn); - let filenames = - candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + let filenames = match args.which { + Which::V1 | Which::V2 | Which::Solar10_7B => { + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? + } + Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], + }; println!("building the model"); let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; @@ -158,14 +165,14 @@ fn main() -> Result<()> { let mut index_pos = 0; let mut token_generated = 0; for index in 0..args.sample_len { - let context_size = if cache.use_kv_cache && index > 0 { - 1 + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) } else { - tokens.len() + (tokens.len(), 0) }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = llama.forward(&input, index_pos)?; + let logits = llama.forward(&input, context_index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { logits diff --git a/candle-examples/examples/reinforcement-learning/README.md b/candle-examples/examples/reinforcement-learning/README.md index 2d3d14b0..28819067 100644 --- a/candle-examples/examples/reinforcement-learning/README.md +++ b/candle-examples/examples/reinforcement-learning/README.md @@ -8,9 +8,16 @@ Python package with: pip install "gymnasium[accept-rom-license]" ``` -In order to run the example, use the following command. Note the additional +In order to run the examples, use the following commands. Note the additional `--package` flag to ensure that there is no conflict with the `candle-pyo3` crate. + +For the Policy Gradient example: ```bash -cargo run --example reinforcement-learning --features=pyo3 --package candle-examples +cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg +``` + +For the Deep Deterministic Policy Gradient example: +```bash +cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg ``` diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs index c6d72fed..1ce4889e 100644 --- a/candle-examples/examples/reinforcement-learning/ddpg.rs +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -8,6 +8,8 @@ use candle_nn::{ }; use rand::{distributions::Uniform, thread_rng, Rng}; +use super::gym_env::GymEnv; + pub struct OuNoise { mu: f64, theta: f64, @@ -449,3 +451,106 @@ impl DDPG<'_> { Ok(()) } } + +// The impact of the q value of the next state on the current state's q value. +const GAMMA: f64 = 0.99; +// The weight for updating the target networks. +const TAU: f64 = 0.005; +// The capacity of the replay buffer used for sampling training data. +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +// The training batch size for each training iteration. +const TRAINING_BATCH_SIZE: usize = 100; +// The total number of episodes. +const MAX_EPISODES: usize = 100; +// The maximum length of an episode. +const EPISODE_LENGTH: usize = 200; +// The number of training iterations after one episode finishes. +const TRAINING_ITERATIONS: usize = 200; + +// Ornstein-Uhlenbeck process parameters. +const MU: f64 = 0.0; +const THETA: f64 = 0.15; +const SIGMA: f64 = 0.1; + +const ACTOR_LEARNING_RATE: f64 = 1e-4; +const CRITIC_LEARNING_RATE: f64 = 1e-3; + +pub fn run() -> Result<()> { + let env = GymEnv::new("Pendulum-v1")?; + println!("action space: {}", env.action_space()); + println!("observation space: {:?}", env.observation_space()); + + let size_state = env.observation_space().iter().product::(); + let size_action = env.action_space(); + + let mut agent = DDPG::new( + &Device::Cpu, + size_state, + size_action, + true, + ACTOR_LEARNING_RATE, + CRITIC_LEARNING_RATE, + GAMMA, + TAU, + REPLAY_BUFFER_CAPACITY, + OuNoise::new(MU, THETA, SIGMA, size_action)?, + )?; + + let mut rng = rand::thread_rng(); + + for episode in 0..MAX_EPISODES { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::())?; + + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + agent.remember( + &state, + &Tensor::new(vec![action], &Device::Cpu)?, + &Tensor::new(vec![step.reward as f32], &Device::Cpu)?, + &step.state, + step.terminated, + step.truncated, + ); + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + + println!("episode {episode} with total reward of {total_reward}"); + + for _ in 0..TRAINING_ITERATIONS { + agent.train(TRAINING_BATCH_SIZE)?; + } + } + + println!("Testing..."); + agent.train = false; + for episode in 0..10 { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::())?; + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + println!("episode {episode} with total reward of {total_reward}"); + } + Ok(()) +} diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs index 96d7102d..e87afae2 100644 --- a/candle-examples/examples/reinforcement-learning/main.rs +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -6,139 +6,32 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle::Result; +use clap::{Parser, Subcommand}; + mod gym_env; mod vec_gym_env; mod ddpg; +mod policy_gradient; -use candle::{Device, Result, Tensor}; -use clap::Parser; -use rand::Rng; - -// The impact of the q value of the next state on the current state's q value. -const GAMMA: f64 = 0.99; -// The weight for updating the target networks. -const TAU: f64 = 0.005; -// The capacity of the replay buffer used for sampling training data. -const REPLAY_BUFFER_CAPACITY: usize = 100_000; -// The training batch size for each training iteration. -const TRAINING_BATCH_SIZE: usize = 100; -// The total number of episodes. -const MAX_EPISODES: usize = 100; -// The maximum length of an episode. -const EPISODE_LENGTH: usize = 200; -// The number of training iterations after one episode finishes. -const TRAINING_ITERATIONS: usize = 200; - -// Ornstein-Uhlenbeck process parameters. -const MU: f64 = 0.0; -const THETA: f64 = 0.15; -const SIGMA: f64 = 0.1; - -const ACTOR_LEARNING_RATE: f64 = 1e-4; -const CRITIC_LEARNING_RATE: f64 = 1e-3; - -#[derive(Parser, Debug, Clone)] -#[command(author, version, about, long_about = None)] +#[derive(Parser)] struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, + #[command(subcommand)] + command: Command, +} - /// Enable tracing (generates a trace-timestamp.json file). - #[arg(long)] - tracing: bool, +#[derive(Subcommand)] +enum Command { + Pg, + Ddpg, } fn main() -> Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; - let args = Args::parse(); - - let _guard = if args.tracing { - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; - - let env = gym_env::GymEnv::new("Pendulum-v1")?; - println!("action space: {}", env.action_space()); - println!("observation space: {:?}", env.observation_space()); - - let size_state = env.observation_space().iter().product::(); - let size_action = env.action_space(); - - let mut agent = ddpg::DDPG::new( - &Device::Cpu, - size_state, - size_action, - true, - ACTOR_LEARNING_RATE, - CRITIC_LEARNING_RATE, - GAMMA, - TAU, - REPLAY_BUFFER_CAPACITY, - ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?, - )?; - - let mut rng = rand::thread_rng(); - - for episode in 0..MAX_EPISODES { - // let mut state = env.reset(episode as u64)?; - let mut state = env.reset(rng.gen::())?; - - let mut total_reward = 0.0; - for _ in 0..EPISODE_LENGTH { - let mut action = 2.0 * agent.actions(&state)?; - action = action.clamp(-2.0, 2.0); - - let step = env.step(vec![action])?; - total_reward += step.reward; - - agent.remember( - &state, - &Tensor::new(vec![action], &Device::Cpu)?, - &Tensor::new(vec![step.reward as f32], &Device::Cpu)?, - &step.state, - step.terminated, - step.truncated, - ); - - if step.terminated || step.truncated { - break; - } - state = step.state; - } - - println!("episode {episode} with total reward of {total_reward}"); - - for _ in 0..TRAINING_ITERATIONS { - agent.train(TRAINING_BATCH_SIZE)?; - } - } - - println!("Testing..."); - agent.train = false; - for episode in 0..10 { - // let mut state = env.reset(episode as u64)?; - let mut state = env.reset(rng.gen::())?; - let mut total_reward = 0.0; - for _ in 0..EPISODE_LENGTH { - let mut action = 2.0 * agent.actions(&state)?; - action = action.clamp(-2.0, 2.0); - - let step = env.step(vec![action])?; - total_reward += step.reward; - - if step.terminated || step.truncated { - break; - } - state = step.state; - } - println!("episode {episode} with total reward of {total_reward}"); + match args.command { + Command::Pg => policy_gradient::run()?, + Command::Ddpg => ddpg::run()?, } Ok(()) } diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs new file mode 100644 index 00000000..044cbfcd --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs @@ -0,0 +1,146 @@ +use super::gym_env::{GymEnv, Step}; +use candle::{DType, Device, Error, Module, Result, Tensor}; +use candle_nn::{ + linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer, + ParamsAdamW, VarBuilder, VarMap, +}; +use rand::{distributions::Distribution, rngs::ThreadRng, Rng}; + +fn new_model( + input_shape: &[usize], + num_actions: usize, + dtype: DType, + device: &Device, +) -> Result<(impl Module, VarMap)> { + let input_size = input_shape.iter().product(); + + let mut varmap = VarMap::new(); + let var_builder = VarBuilder::from_varmap(&varmap, dtype, device); + + let model = seq() + .add(linear(input_size, 32, var_builder.pp("lin1"))?) + .add(Activation::Relu) + .add(linear(32, num_actions, var_builder.pp("lin2"))?); + + Ok((model, varmap)) +} + +fn accumulate_rewards(steps: &[Step]) -> Vec { + let mut rewards: Vec = steps.iter().map(|s| s.reward).collect(); + let mut acc_reward = 0f64; + for (i, reward) in rewards.iter_mut().enumerate().rev() { + if steps[i].terminated { + acc_reward = 0.0; + } + acc_reward += *reward; + *reward = acc_reward; + } + rewards +} + +fn weighted_sample(probs: Vec, rng: &mut ThreadRng) -> Result { + let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?; + let mut rng = rng; + Ok(distribution.sample(&mut rng)) +} + +pub fn run() -> Result<()> { + let env = GymEnv::new("CartPole-v1")?; + + println!("action space: {:?}", env.action_space()); + println!("observation space: {:?}", env.observation_space()); + + let (model, varmap) = new_model( + env.observation_space(), + env.action_space(), + DType::F32, + &Device::Cpu, + )?; + + let optimizer_params = ParamsAdamW { + lr: 0.01, + weight_decay: 0.01, + ..Default::default() + }; + + let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?; + + let mut rng = rand::thread_rng(); + + for epoch_idx in 0..100 { + let mut state = env.reset(rng.gen::())?; + let mut steps: Vec> = vec![]; + + loop { + let action = { + let action_probs: Vec = + softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)? + .squeeze(0)? + .to_vec1()?; + weighted_sample(action_probs, &mut rng)? as i64 + }; + + let step = env.step(action)?; + steps.push(step.copy_with_obs(&state)); + + if step.terminated || step.truncated { + state = env.reset(rng.gen::())?; + if steps.len() > 5000 { + break; + } + } else { + state = step.state; + } + } + + let total_reward: f64 = steps.iter().map(|s| s.reward).sum(); + let episodes: i64 = steps + .iter() + .map(|s| (s.terminated || s.truncated) as i64) + .sum(); + println!( + "epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}", + epoch_idx, + episodes, + total_reward / episodes as f64 + ); + + let batch_size = steps.len(); + + let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? + .to_dtype(DType::F32)? + .detach()?; + + let actions_mask = { + let actions: Vec = steps.iter().map(|s| s.action).collect(); + let actions_mask: Vec = actions + .iter() + .map(|&action| { + // One-hot encoding + let mut action_mask = vec![0.0; env.action_space()]; + action_mask[action as usize] = 1.0; + + Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + }) + .collect(); + Tensor::stack(&actions_mask, 0)?.detach()? + }; + + let states = { + let states: Vec = steps.into_iter().map(|s| s.state).collect(); + Tensor::stack(&states, 0)?.detach()? + }; + + let log_probs = actions_mask + .mul(&log_softmax(&model.forward(&states)?, 1)?)? + .sum(1)?; + + let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?; + optimizer.backward_step(&loss)?; + } + + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion/README.md b/candle-examples/examples/stable-diffusion/README.md index feb7ca56..1d98f580 100644 --- a/candle-examples/examples/stable-diffusion/README.md +++ b/candle-examples/examples/stable-diffusion/README.md @@ -29,7 +29,7 @@ e.g.: ```bash cargo run --example stable-diffusion --release --features=cuda,cudnn \ - -- --prompt "a cosmonaut on a horse (hd, realistic, high-def) --sd-version turbo" + -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" --sd-version turbo ``` The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index 6702618e..b1dd3127 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) let func = candle_nn::func(move |xs| { let xs = conv.forward(xs)?; let xs = match &bn { - Some(bn) => bn.forward(&xs)?, + Some(bn) => xs.apply_t(bn, false)?, None => xs, }; let xs = if leaky { diff --git a/candle-flash-attn/kernels/alibi.h b/candle-flash-attn/kernels/alibi.h new file mode 100644 index 00000000..1afb3687 --- /dev/null +++ b/candle-flash-attn/kernels/alibi.h @@ -0,0 +1,62 @@ +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride, + const float alibi_slope) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } +} + +} // namespace flash diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index 94251a41..65435e51 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -14,9 +14,12 @@ struct BlockInfo { template __device__ BlockInfo(const Params ¶ms, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -32,8 +35,10 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; - const uint32_t actual_seqlen_q; - const uint32_t actual_seqlen_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index be4ae0ca..80b517e9 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,15 +7,6 @@ #include #include -// #ifdef OLD_GENERATOR_PATH -// #include -// #else -// #include -// #endif -// -// #include - - constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; @@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params { // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; // The scaling factors for the kernel. float scale_softmax; @@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + int *__restrict__ blockmask; + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int *__restrict__ cache_batch_idx; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; @@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params { float rp_dropout; float scale_softmax_rp_dropout; - // Random state. - // at::PhiloxCudaState philox_args; + // Local window size + int window_size_left, window_size_right; bool is_bf16; bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 72991257..8113dbc7 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,17 +1,15 @@ #include "flash_fwd_launch_template.h" -// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { -// FWD_HEADDIM_SWITCH(params.d, [&] { -// run_mha_fwd_(params, stream); -// }); -// } - -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); - }); - }); +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { +// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); +// } else { +// run_mha_fwd_splitkv_dispatch(params, stream); +// } + }); + }); } extern "C" void run_mha( @@ -20,6 +18,7 @@ extern "C" void run_mha( void *v_ptr, void *o_ptr, void *softmax_lse_ptr, + void *alibi_slopes_ptr, int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_k_ptr, @@ -28,6 +27,7 @@ extern "C" void run_mha( uint32_t k_batch_stride, uint32_t v_batch_stride, uint32_t o_batch_stride, + uint32_t alibi_slopes_batch_stride, uint32_t q_row_stride, uint32_t k_row_stride, @@ -51,8 +51,11 @@ extern "C" void run_mha( uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, + int is_bf16, int is_causal, - int is_bf16 + + int window_size_left, + int window_size_right ) { Flash_fwd_params params; // Reset the parameters @@ -65,12 +68,14 @@ extern "C" void run_mha( params.o_ptr = o_ptr; params.softmax_lse_ptr = softmax_lse_ptr; + params.alibi_slopes_ptr = alibi_slopes_ptr; // All stride are in elements, not bytes. params.q_batch_stride = q_batch_stride; params.k_batch_stride = k_batch_stride; params.v_batch_stride = v_batch_stride; params.o_batch_stride = o_batch_stride; + params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; params.q_row_stride = q_row_stride; params.k_row_stride = k_row_stride; @@ -92,7 +97,6 @@ extern "C" void run_mha( params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; - params.is_causal = is_causal; // Set the different scale values. params.scale_softmax = softmax_scale; @@ -106,6 +110,14 @@ extern "C" void run_mha( params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. + params.seqused_k = nullptr; + + params.is_causal = is_causal; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; + params.num_splits = 1; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu index 654400a7..6ffa4126 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -1,19 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu index 5b7254a9..19b005ad 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -1,32 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k -// run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// // 1st ones are good for H100, A100 -// // 2nd one is good for A6000 bc we get slightly better occupancy -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// // 1st one is good for H100, A100, A6000 -// } -// } - template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu index 6a9d60c3..f674f481 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -1,17 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu index 6c40a164..afd0a8a3 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -1,27 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. -// // For A100, H100, 1st is fastest. -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu index d2f4cba7..aa91bdd6 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -1,16 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu index 2875c926..37a96526 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -1,27 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This one is slightly faster for causal? -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout -// // For A6000, 1st is faster when causal, 3rd is faster when not causal -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu index 982fe7ea..167a0df2 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim224(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu index 4c083f7b..58ffe75c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim224(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu index cb074a95..1b370141 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu index ddf5e132..9f35129c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu index 81e359e1..770de6fc 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -1,10 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index 91e6331e..8dbf8b94 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -1,23 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// // For dropout there might be a lot of register spilling? -// // These two are very slow due to register spilling -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // This one is slightly slower -// // run_flash_fwd>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu index fffcbebb..22eac878 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -1,19 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu index 01bd1716..e6da5dd2 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -1,26 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower -// // Using block size (64 x 256) is 27% slower for seqlen=2k -// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu index b0b27db5..9c003540 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -1,17 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu index 820b63cb..8108696a 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -1,23 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This 3rd one is good for H100, and A100, A6000 -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // These two are always slower -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// } -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 232dea0d..05f5f701 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,20 +4,18 @@ #pragma once -#include #include -#include #include #include #include -#include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" -#include "philox.cuh" + +#include "alibi.h" namespace flash { @@ -25,49 +23,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - // TODO: Shouldn't this be size<1>? - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, Tensor2 &acc_o, float softmax_scale_log2) { @@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T flash::reduce_sum(scores, scores_sum); } else { Tensor scores_max_prev = make_fragment_like(scores_max); - copy(scores_max, scores_max_prev); + cute::copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); @@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T template inline __device__ void write_softmax_to_gmem( - Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P ) { // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) Layout l = tOrP.layout(); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); - // TODO(laurent): reactivate the following - // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); #pragma unroll for (int mi = 0; mi < size<1>(tPrP); ++mi) { - copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. +// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { +// auto seeds = at::cuda::philox::unpack(params.philox_args); +// params.rng_state[0] = std::get<0>(seeds); +// params.rng_state[1] = std::get<1>(seeds); +// params.rng_state[0] = 0; +// params.rng_state[1] = 0; +// } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse @@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); - auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Copy Atom retiling // - auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} - auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); - auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // TODO: this might need to change if we change the mma instruction in SM70 @@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } // // Copy rmem to smem @@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); @@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } // auto seeds = at::cuda::philox::unpack(params.philox_args); @@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); + float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. - constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } + // if (cute::thread0()) { print_tensor(scores); } // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { - if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) @@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 + ); + // if (cute::thread0()) { print_tensor(scores); } } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); + cute::copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { @@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (cute::thread0()) { print(tOrP); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K ); flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); + cute::copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { @@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi block_row_idx, block_col_idx, kNWarps); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor rO = flash::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning - auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - copy(smem_thr_copy_O, taccOrO, taccOsO); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; @@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); Tensor tOrO = make_tensor(shape(tOgO)); - copy(gmem_thr_copy_O, tOsO, tOrO); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) @@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } + //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 398ce077..66ab6206 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -4,15 +4,14 @@ #pragma once -// #include - #include "static_switch.h" #include "flash.h" #include "flash_fwd_kernel.h" -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_attn(params); } template @@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. - const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // if (smem_size >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - // C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); }); }); }); } + template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 96; + constexpr static int Headdim = 96; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 128; + constexpr static int Headdim = 128; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 160; + constexpr static int Headdim = 160; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 192; + constexpr static int Headdim = 192; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 224; + constexpr static int Headdim = 224; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 256; + constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_sm, max_smem_per_block; diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 3468e4bf..f000ff24 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape, Int>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; @@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy >; using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; @@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>; using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. @@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); using SmemLayoutKtransposed = decltype(tile_to_shape( SmemLayoutAtomKtransposed{}, make_shape(Int{}, Int{}))); // Maybe the KtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); using SmemLayoutPdStransposed = decltype(tile_to_shape( SmemLayoutAtomPdStransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom; + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); using SmemLayoutQdOtransposed = decltype(tile_to_shape( SmemLayoutAtomQdOtransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, @@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) diff --git a/candle-flash-attn/kernels/kernel_traits_sm90.h b/candle-flash-attn/kernels/kernel_traits_sm90.h new file mode 100644 index 00000000..e07f3839 --- /dev/null +++ b/candle-flash-attn/kernels/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h index 3e9a7b45..09a93f14 100644 --- a/candle-flash-attn/kernels/softmax.h +++ b/candle-flash-attn/kernels/softmax.h @@ -8,8 +8,7 @@ #include -#include -#include +#include #include "philox.cuh" #include "utils.h" @@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tens } template -inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll @@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor &tensor, const uint32_t } } -template -inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, - const uint32_t warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; - const uint32_t row_idx_offset = row_idx_offset_; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const u } } +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, - const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); @@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx( CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { @@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx( template inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, unsigned long long seed, unsigned long long offset, - uint32_t block_row_start, uint32_t block_col_start, - uint32_t block_row_stride) { + int block_row_start, int block_col_start, + int block_row_stride) { // tensor has shape (8, MMA_M, MMA_N / 2) using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 2221a2fa..6fb39dc4 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2(const float2 x) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2 (&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = flash::half2_unpack(a); - float2 bf = flash::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = flash::hfma2_to_float(a.x, b.x); - sum += flash::hfma2_to_float(a.y, b.y); - sum += flash::hfma2_to_float(a.z, b.z); - sum += flash::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct MaxOp { __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } @@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) { template + typename TiledMma, typename TiledCopyA, typename TiledCopyB, + typename ThrCopyA, typename ThrCopyB> inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } @@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 //////////////////////////////////////////////////////////////////////////////////////////////////// template + typename TiledMma, typename TiledCopy, typename ThrCopy> inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } @@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -319,9 +289,9 @@ void cp_async_wait() { template -inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, +inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, int max_MN=0) { + Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - copy(thr_copy, S(_, m, k), D(_, m, k)); + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { - clear(D(_, m, k)); + cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { - clear(D(_, m, _)); + cute::clear(D(_, m, _)); } } // TD [2023-04-13]: Strange that the code below can cause race condition. @@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, _), D(_, m, _)); + // copy(tiled_copy, S(_, m, _), D(_, m, _)); // } else if (Clear_OOB_MN) { // clear(D(_, m, _)); // } @@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, k), D(_, m, k)); + // copy(tiled_copy, S(_, m, k), D(_, m, k)); // } else if (Clear_OOB_MN) { // clear(D(_, m, k)); // } diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index 90f34e43..ca65520b 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -7,6 +7,8 @@ extern "C" { v_ptr: *const c_void, o_ptr: *const c_void, softmax_lse_ptr: *const c_void, + alibi_slopes_ptr: *const c_void, + cu_seqlens_q_ptr: *const i32, cu_seqlens_k_ptr: *const i32, @@ -14,6 +16,7 @@ extern "C" { k_batch_stride: u32, v_batch_stride: u32, o_batch_stride: u32, + alibi_slopes_batch_stride: u32, q_row_stride: u32, k_row_stride: u32, @@ -37,8 +40,11 @@ extern "C" { seqlen_q_rounded: u32, seqlen_k_rounded: u32, - is_causal: c_int, is_bf16: c_int, + is_causal: c_int, + + window_size_left: c_int, + window_size_right: c_int, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 3395bd0d..21a06b5e 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,12 +3,14 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, - pub causal: bool, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, } fn round_multiple(x: usize, m: usize) -> usize { @@ -85,6 +87,51 @@ impl FlashAttn { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -94,9 +141,22 @@ impl FlashAttn { let dst = unsafe { dev.alloc::(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -109,12 +169,14 @@ impl FlashAttn { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), /* q_batch_stride */ q_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -133,8 +195,10 @@ impl FlashAttn { /* seqlen_k */ seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -197,20 +261,137 @@ pub fn flash_attn( softmax_scale: f32, causal: bool, ) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + let op = FlashAttn { softmax_scale, - causal, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } struct FlashAttnVarLen { - softmax_scale: f32, - causal: bool, - max_seqlen_q: usize, - max_seqlen_k: usize, - seqlens_q: Tensor, - seqlens_k: Tensor, + pub softmax_scale: f32, + pub max_seqlen_q: usize, + pub max_seqlen_k: usize, + pub seqlens_q: Tensor, + pub seqlens_k: Tensor, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, } impl FlashAttnVarLen { @@ -311,7 +492,54 @@ impl FlashAttnVarLen { if nseqlens_k != nseqlens_q { candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") } + let batch_size = nseqlens_q - 1; + + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); @@ -323,9 +551,22 @@ impl FlashAttnVarLen { .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) .w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = self.max_seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = self.max_seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -340,12 +581,14 @@ impl FlashAttnVarLen { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0, /* o_batch_stride */ 0, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -364,8 +607,10 @@ impl FlashAttnVarLen { /* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -440,13 +685,176 @@ pub fn flash_attn_varlen( softmax_scale: f32, causal: bool, ) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + let op = FlashAttnVarLen { softmax_scale, - causal, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index 8c3b4a8c..cdc8fef8 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -56,15 +56,24 @@ kernel void FN_NAME_STRIDED( \ #define BINARY_OP(FN, NAME) \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ -BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ +BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); + +#define INT64_BINARY_OP(NAME, FN) \ +BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ -BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); +BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ +BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +#define INT64_BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided); BINARY_OP(x + y, add) BINARY_OP(x - y, sub) @@ -80,6 +89,22 @@ BINARY_OP_OUT(lt, x < y) BINARY_OP_OUT(ge, x >= y) BINARY_OP_OUT(gt, x > y) +#if __METAL_VERSION__ >= 220 +INT64_BINARY_OP(add, x + y) +INT64_BINARY_OP(sub, x - y) +INT64_BINARY_OP(mul, x * y) +INT64_BINARY_OP(div, x / y) +INT64_BINARY_OP(min, MIN(x, y)) +INT64_BINARY_OP(max, MAX(x, y)) + +INT64_BINARY_OP_OUT(eq, x == y) +INT64_BINARY_OP_OUT(ne, x != y) +INT64_BINARY_OP_OUT(le, x <= y) +INT64_BINARY_OP_OUT(lt, x < y) +INT64_BINARY_OP_OUT(ge, x >= y) +INT64_BINARY_OP_OUT(gt, x > y) +#endif + #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 8481389d..e9ab17b1 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -52,5 +52,13 @@ CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f32_f16, cast_f32_f16_strided, float, half) -#if __METAL_VERSION__ >= 310 +#if __METAL_VERSION__ >= 220 +CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) +CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) +CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +#endif + +#if __METAL_VERSION__ >= 310 +CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) +CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5db4c2cb..ba5ef3de 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -137,6 +137,9 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } )+ pub mod copy { @@ -144,6 +147,7 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel("copy_f32"); pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); + pub const I64: Kernel = Kernel("copy_i64"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -157,6 +161,9 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } )+ pub mod copy { @@ -164,6 +171,7 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel("copy_f32_strided"); pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); + pub const I64: Kernel = Kernel("copy_i64_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } @@ -172,7 +180,10 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); + ops!( + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh, + recip + ); } pub mod binary { ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 2d584917..83a56f0a 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -263,24 +263,38 @@ kernel void NAME( REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0) REDUCE(x * y, fast_mul_f32_strided, float, 1) REDUCE(x * y, fast_mul_u32_strided, uint, 1) REDUCE(x * y, fast_mul_f16_strided, half, 1) REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0) REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF) ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF) ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) ARGMAX(fast_argmax_u32_strided, uint, 0) +ARGMAX(fast_argmax_u8_strided, uint8_t, 0) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) + +#if __METAL_VERSION__ >= 220 +REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) +REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) +ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) +ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) +#endif + #if __METAL_VERSION__ >= 310 REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 1f9cb38a..40b4bcf4 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -55,6 +55,9 @@ kernel void FN_NAME( \ WHERE_OP(float, uint8_t, where_u8_f32) // WHERE_OP(double, uint8_t, where_u8_f64) -// WHERE_OP(uint8_t, uint8_t, where_u8_u8) -// WHERE_OP(uint32_t, uint8_t, where_u8_u32) -// WHERE_OP(int64_t, uint8_t, where_u8_i64) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) +WHERE_OP(uint32_t, uint8_t, where_u8_u32) + +#if __METAL_VERSION__ >= 220 +WHERE_OP(int64_t, uint8_t, where_u8_i64) +#endif diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 2a88598c..19efabd3 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -18,7 +18,9 @@ METAL_FUNC uint get_strided_index( } template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T recip(T in){ return T(1.0 / in); } template METAL_FUNC T neg(T in){ return -in; } + template METAL_FUNC T erf(T in){ float x = (float) in; // constants @@ -56,8 +58,6 @@ template METAL_FUNC T gelu(T x) { return static_cast(0.5) * x * (static_cast(1.0) + T(tanh(beta))); } - - #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -101,17 +101,24 @@ UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) UNARY_OP(gelu) +UNARY_OP(abs) UNARY_OP(ceil) UNARY_OP(floor) UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) +UNARY_OP(recip) + UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) +#if __METAL_VERSION__ >= 220 +UNARY(id, int64_t, copy_i64, copy_i64_strided) +#endif + #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) @@ -127,6 +134,7 @@ BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) +BFLOAT_UNARY_OP(recip) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 8cfc6740..856c2c7a 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -7,15 +7,21 @@ //! running stats. //! //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 -use candle::{DType, Result, Tensor}; +use candle::{DType, Result, Tensor, Var}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct BatchNormConfig { pub eps: f64, pub remove_mean: bool, + /// The meaning of affine here is different from LayerNorm: when false there is no learnable /// parameter at all, 1 used for gamma and 0 for beta. pub affine: bool, + + /// Controls exponential moving average of running stats. Defaults to 0.1 + /// + /// `running_stat * (1.0 - momentum) + stat * momentum`. + pub momentum: f64, } impl Default for BatchNormConfig { @@ -24,6 +30,7 @@ impl Default for BatchNormConfig { eps: 1e-5, remove_mean: true, affine: true, + momentum: 0.1, } } } @@ -32,23 +39,61 @@ impl From for BatchNormConfig { fn from(eps: f64) -> Self { Self { eps, - remove_mean: true, - affine: true, + ..Default::default() } } } #[derive(Clone, Debug)] pub struct BatchNorm { - running_mean: Tensor, - running_var: Tensor, + running_mean: Var, + running_var: Var, weight_and_bias: Option<(Tensor, Tensor)>, remove_mean: bool, eps: f64, - num_features: usize, + momentum: f64, } impl BatchNorm { + fn check_validity(&self, num_features: usize) -> Result<()> { + if self.eps < 0. { + candle::bail!("batch-norm eps cannot be negative {}", self.eps) + } + if !(0.0..=1.0).contains(&self.momentum) { + candle::bail!( + "batch-norm momentum must be between 0 and 1, is {}", + self.momentum + ) + } + if self.running_mean.dims() != [num_features] { + candle::bail!( + "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]", + self.running_mean.shape(), + ) + } + if self.running_var.dims() != [num_features] { + candle::bail!( + "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]", + self.running_var.shape(), + ) + } + if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() { + if weight.dims() != [num_features] { + candle::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + weight.shape(), + ) + } + if bias.dims() != [num_features] { + candle::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + bias.shape(), + ) + } + } + Ok(()) + } + pub fn new( num_features: usize, running_mean: Tensor, @@ -57,29 +102,16 @@ impl BatchNorm { bias: Tensor, eps: f64, ) -> Result { - if eps < 0. { - candle::bail!("batch-norm eps cannot be negative {eps}") - } - if weight.dims() != [num_features] { - candle::bail!( - "batch-norm unexpected weight shape {:?} {num_features}", - weight.shape() - ) - } - if bias.dims() != [num_features] { - candle::bail!( - "batch-norm unexpected bias shape {:?} {num_features}", - bias.shape() - ) - } - Ok(Self { - running_mean, - running_var, + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias: Some((weight, bias)), remove_mean: true, eps, - num_features, - }) + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) } pub fn new_no_bias( @@ -88,25 +120,64 @@ impl BatchNorm { running_var: Tensor, eps: f64, ) -> Result { - if eps < 0. { - candle::bail!("batch-norm eps cannot be negative {eps}") - } - Ok(Self { - running_mean, - running_var, + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias: None, remove_mean: true, eps, - num_features, - }) + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + weight: Tensor, + bias: Tensor, + eps: f64, + momentum: f64, + ) -> Result { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: Some((weight, bias)), + remove_mean: true, + eps, + momentum, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_no_bias_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + eps: f64, + momentum: f64, + ) -> Result { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: None, + remove_mean: true, + eps, + momentum, + }; + out.check_validity(num_features)?; + Ok(out) } pub fn running_mean(&self) -> &Tensor { - &self.running_mean + self.running_mean.as_tensor() } pub fn running_var(&self) -> &Tensor { - &self.running_var + self.running_var.as_tensor() } pub fn eps(&self) -> f64 { @@ -117,7 +188,12 @@ impl BatchNorm { self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) } - pub fn forward_learning(&self, x: &Tensor) -> Result { + pub fn momentum(&self) -> f64 { + self.momentum + } + + pub fn forward_train(&self, x: &Tensor) -> Result { + let num_features = self.running_mean.as_tensor().dim(0)?; let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, @@ -129,40 +205,54 @@ impl BatchNorm { x.shape() ) } - if x.dim(1)? != self.num_features { + if x.dim(1)? != num_features { candle::bail!( "batch-norm input doesn't have the expected number of features ({:?} <> {})", x.shape(), - self.num_features + num_features ) } let x = x.to_dtype(internal_dtype)?; let x = x.transpose(0, 1)?; let x_dims_post_transpose = x.dims(); + // Flatten all the dimensions exception the channel one as this performs a Spatial Batch + // Normalization. let x = x.flatten_from(1)?.contiguous()?; let x = if self.remove_mean { + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. let mean_x = x.mean_keepdim(1)?; + let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? + + (mean_x.flatten_all()? * self.momentum)?)?; + self.running_mean.set(&updated_running_mean)?; x.broadcast_sub(&mean_x)? } else { x }; + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. let norm_x = x.sqr()?.mean_keepdim(1)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed.to_dtype(x_dtype)?; + let updated_running_var = { + let batch_size = x.dim(1)? as f64; + let running_var_weight = 1.0 - self.momentum; + let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0); + ((self.running_var.as_tensor() * running_var_weight)? + + (&norm_x.flatten_all()? * norm_x_weight)?)? + }; + self.running_var.set(&updated_running_var)?; + let x = x + .broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)?; let x = match &self.weight_and_bias { None => x, Some((weight, bias)) => { - let weight = weight.reshape((self.num_features, 1))?; - let bias = bias.reshape((self.num_features, 1))?; + let weight = weight.reshape(((), 1))?; + let bias = bias.reshape(((), 1))?; x.broadcast_mul(&weight)?.broadcast_add(&bias)? } }; x.reshape(x_dims_post_transpose)?.transpose(0, 1) } -} -impl crate::Module for BatchNorm { - fn forward(&self, x: &Tensor) -> Result { + fn forward_eval(&self, x: &Tensor) -> Result { let target_shape: Vec = x .dims() .iter() @@ -170,9 +260,13 @@ impl crate::Module for BatchNorm { .map(|(idx, v)| if idx == 1 { *v } else { 1 }) .collect(); let target_shape = target_shape.as_slice(); + let x = x - .broadcast_sub(&self.running_mean.reshape(target_shape)?)? - .broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?; + .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)? + .broadcast_div( + &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?, + )?; + match &self.weight_and_bias { None => Ok(x), Some((weight, bias)) => { @@ -184,30 +278,41 @@ impl crate::Module for BatchNorm { } } +impl crate::ModuleT for BatchNorm { + fn forward_t(&self, x: &Tensor, train: bool) -> Result { + if train { + self.forward_train(x) + } else { + self.forward_eval(x) + } + } +} + pub fn batch_norm>( num_features: usize, config: C, vb: crate::VarBuilder, ) -> Result { + use crate::Init; let config = config.into(); if config.eps < 0. { candle::bail!("batch-norm eps cannot be negative {}", config.eps) } - let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?; - let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?; + let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?; + let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?; let weight_and_bias = if config.affine { - let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?; - let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?; + let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?; + let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?; Some((weight, bias)) } else { None }; Ok(BatchNorm { - running_mean, - running_var, + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias, remove_mean: config.remove_mean, eps: config.eps, - num_features, + momentum: config.momentum, }) } diff --git a/candle-nn/src/encoding.rs b/candle-nn/src/encoding.rs new file mode 100644 index 00000000..38e2cc3b --- /dev/null +++ b/candle-nn/src/encoding.rs @@ -0,0 +1,150 @@ +//! Encoding Utilities. (e.g., one-hot/cold encoding) + +use candle::{bail, DType, Result, Tensor, WithDType}; + +/// One-hot/cold encoding. +/// +/// Given an input tensor of indices, this function returns a tensor of the same shape as the input +/// tensor with an additional dimension of the given depth size. The values in the returned tensor are +/// all set to the `off_value` except for the positions represented by the indices, which are set to the `on_value`. +/// +/// This method returns a tensor with a rank that is one rank larger than the input tensor. +/// +/// As an example, the following tensor will be encoded to a one-hot matrix: +/// +/// `[[0i64, 2], [1, -1]]` +/// +/// with a depth of 4 will be encoded to: +/// +/// `[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]` +/// +/// When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored, +/// resulting in a vector of values set to the `off_value`. +/// +/// +/// This method supports one-cold encoding by setting `on_value` to `0` and `off_value` to `1`. +/// By default `on_value` is `1` and `off_value` is `0`. +/// +/// Other encoding values can be used by setting `on_value` and `off_value` to the desired values. +/// +/// # Examples +/// +/// ## One-hot encoding +/// +/// ```rust +/// use candle::{Shape, Tensor, Device}; +/// use candle_nn::encoding::one_hot; +/// +/// let device = candle::Device::Cpu; +/// +/// let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap(); +/// let depth = 4; +/// let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap(); +/// +/// let expected_matrix = [ +/// [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], +/// [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], +/// ]; +/// +/// assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth))); +/// +/// let matrix = one_hot.to_vec3::().unwrap(); +/// +/// assert_eq!(matrix, expected_matrix); +///``` +/// ## One-cold Encoding +/// +/// ```rust +/// use candle::{Shape, Tensor, Device}; +/// use candle_nn::encoding::one_hot; +/// +/// +/// let device = candle::Device::Cpu; +/// let depth = 4; +/// let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap(); +/// let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap(); +/// +/// let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]]; +/// +/// assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth))); +/// +/// let matrix = one_cold.to_vec3::().unwrap(); +/// +/// assert_eq!(matrix, expected_matrix); +/// ``` +/// +/// +/// # Bails +/// +/// This method bails if: +/// - One of the index value is less than -1. +/// - One of the index value is greater than or equal to the depth value. +/// - The input data type is not `U8`, `U32`, or `I64`. +/// +/// # API Design +/// +/// The api design for this method is loosely based on the [TensorFlow One-Hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) method. +pub fn one_hot( + indices: Tensor, + depth: usize, + on_value: D, + off_value: D, +) -> Result { + let mut target_shape = indices.dims().to_vec(); + target_shape.push(depth); + let indices = indices.flatten_all()?; + let mut out = vec![off_value; depth * indices.elem_count()]; + match indices.dtype() { + DType::U8 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + DType::U32 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + DType::I64 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + dtype => { + bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64") + } + }; + Tensor::from_vec(out, target_shape, indices.device()) +} + +fn set_at_index>( + value: I, + offset: usize, + depth: usize, + v: &mut Vec, + on_value: D, +) -> Result<()> { + let value = value.into(); + // Skip for an entire row of off_values + if value == -1 { + return Ok(()); + } + if value < -1 { + bail!( + "one_hot: invalid negative index value {value}, expected a positive index value or -1" + ); + } + let value = value as usize; + if value >= depth { + bail!("one_hot: index value {value} exceeds depth {depth}") + } + let idx = offset + value; + if idx >= v.len() { + bail!("one_hot: index out of bounds {idx}, len {}", v.len()); + } + v[idx] = on_value; + Ok(()) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 8f00e54c..6306c55a 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -2,6 +2,7 @@ pub mod activation; pub mod batch_norm; pub mod conv; pub mod embedding; +pub mod encoding; pub mod func; pub mod group_norm; pub mod init; diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 5bbaf238..6fd7361a 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4) output = m(input) print(input.flatten()) print(output.flatten()) +print(m.running_mean) +print(m.running_var) */ #[test] fn batch_norm() -> Result<()> { @@ -37,7 +39,7 @@ fn batch_norm() -> Result<()> { 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, ]; let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; - let output = bn.forward_learning(&input)?; + let output = bn.forward_train(&input)?; assert_eq!(output.dims(), &[2, 5, 3, 4]); let output = output.flatten_all()?; assert_eq!( @@ -65,11 +67,20 @@ fn batch_norm() -> Result<()> { Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, 1e-8, )?; - let output2 = bn2.forward_learning(&input)?; + let output2 = bn2.forward_train(&input)?; assert_eq!(output2.dims(), &[2, 5, 3, 4]); let output2 = output2.flatten_all()?; let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; let sum_diff2 = diff2.sum_keepdim(0)?; assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]); + + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + &[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020] + ); + assert_eq!( + test_utils::to_vec1_round(bn.running_var(), 4)?, + &[0.9972, 0.9842, 0.9956, 0.9866, 0.9898] + ); Ok(()) } diff --git a/candle-nn/tests/one_hot.rs b/candle-nn/tests/one_hot.rs new file mode 100644 index 00000000..36afdf68 --- /dev/null +++ b/candle-nn/tests/one_hot.rs @@ -0,0 +1,120 @@ +use candle::{Result, Shape, Tensor}; +use candle_nn::encoding::one_hot; + +#[test] +fn test_i64_one_hot() -> Result<()> { + let device = candle::Device::Cpu; + + let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?; + let depth = 4; + + let on_value = 1.0; + let off_value = 0.0; + + let one_hot = one_hot::(indices, depth, on_value, off_value)?; + + let expected_matrix = [ + [[1., 0., 0., 0.], [0., 0., 1., 0.]], + [[0., 1., 0., 0.], [0., 0., 0., 0.]], + ]; + + assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth))); + + let matrix = one_hot.to_vec3::()?; + + assert_eq!(matrix, expected_matrix); + + Ok(()) +} + +#[test] +fn test_rank_3_one_hot() -> Result<()> { + let device = candle::Device::Cpu; + + let indices = Tensor::new( + vec![ + vec![vec![0i64, 1], vec![2, 3]], + vec![vec![3, 1], vec![1, -1]], + ], + &device, + )?; + let depth = 4; + + let on_value = 1.0; + let off_value = 0.0; + + let one_hot = one_hot::(indices, depth, on_value, off_value)?; + + let expected_matrix = Tensor::new( + vec![ + vec![ + vec![vec![1f32, 0., 0., 0.], vec![0., 1., 0., 0.]], + vec![vec![0., 0., 1., 0.], vec![0., 0., 0., 1.]], + ], + vec![ + vec![vec![0., 0., 0., 1.], vec![0., 1., 0., 0.]], + vec![vec![0., 1., 0., 0.], vec![0., 0., 0., 0.]], + ], + ], + &device, + )?; + + assert_eq!(one_hot.shape(), expected_matrix.shape()); + assert_eq!(one_hot.dims(), expected_matrix.dims()); + + let matrix = one_hot.get(1)?.to_vec3::()?; + let expected_matrix = expected_matrix.get(1)?.to_vec3::()?; + + assert_eq!(matrix, expected_matrix); + + Ok(()) +} + +#[test] +fn test_u8_one_cold() -> Result<()> { + let device = candle::Device::Cpu; + let depth = 4; + let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?; + + let on_value = 0u8; + let off_value = 1; + + // Note that the method does not require the turbofish operator, as the type is inferred from the on_value. + let one_cold = one_hot(indices, depth, on_value, off_value)?; + + let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 1]]]; + + assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth))); + + let matrix = one_cold.to_vec3::()?; + + assert_eq!(matrix, expected_matrix); + + Ok(()) +} + +#[test] +fn test_iter() -> Result<()> { + let device = candle::Device::Cpu; + let depth = 4; + let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?; + let matrix = indices.to_vec2::()?; + let (dim1, dim2) = indices.dims2()?; + + let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j))); + + let mut v = vec![0; depth * dim1 * dim2]; + + for (i, j) in iter { + let idx = i * depth * dim2 + j * depth; + v[idx] = matrix[i][j]; + } + + for (i, row) in matrix.iter().enumerate() { + for (j, &value) in row.iter().enumerate() { + let idx = i * depth * dim2 + j * depth; + assert_eq!(v[idx], value); + } + } + Ok(()) +} diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index 76245f37..f5abfa5d 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?; let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?; Ok(candle_nn::func(move |xs| { - let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; - (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2) + let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?; + (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false) })) } @@ -64,7 +64,7 @@ fn convmixer( .collect::>>()?; let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?; Ok(candle_nn::func(move |xs| { - let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; + let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?; for block in blocks.iter() { xs = xs.apply(block)? } diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index ab51c76d..f15c9c79 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -169,8 +169,7 @@ impl ConvNormActivation { impl Module for ConvNormActivation { fn forward(&self, xs: &Tensor) -> Result { - let xs = self.conv2d.forward(xs)?; - let xs = self.bn2d.forward(&xs)?; + let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?; if self.activation { swish(&xs) } else { diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs index f2588e01..30029a0b 100644 --- a/candle-transformers/src/models/resnet.rs +++ b/candle-transformers/src/models/resnet.rs @@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul if stride != 1 || c_in != c_out { let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?; let bn = batch_norm(c_out, 1e-5, vb.pp(1))?; - Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn))) + Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false))) } else { Ok(Func::new(|xs| Ok(xs.clone()))) } @@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu Ok(Func::new(move |xs| { let ys = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .apply(&conv2)? - .apply(&bn2)?; + .apply_t(&bn2, false)?; (xs.apply(&downsample)? + ys)?.relu() })) } @@ -94,7 +94,7 @@ fn resnet( Ok(Func::new(move |xs| { let xs = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus2, 1, 1)? @@ -149,13 +149,13 @@ fn bottleneck_block( Ok(Func::new(move |xs| { let ys = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .apply(&conv2)? - .apply(&bn2)? + .apply_t(&bn2, false)? .relu()? .apply(&conv3)? - .apply(&bn3)?; + .apply_t(&bn3, false)?; (xs.apply(&downsample)? + ys)?.relu() })) } @@ -206,7 +206,7 @@ fn bottleneck_resnet( Ok(Func::new(move |xs| { let xs = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus2, 1, 1)? diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs index cd2936ab..d1700cc5 100644 --- a/candle-transformers/src/models/segment_anything/tiny_vit.rs +++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs @@ -28,7 +28,7 @@ impl Conv2dBN { impl Module for Conv2dBN { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); - xs.apply(&self.c)?.apply(&self.bn) + xs.apply(&self.c)?.apply_t(&self.bn, false) } } diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 4a69cca0..58f795bb 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -185,7 +185,7 @@ impl PaellaVQ { xs = xs.apply(&down_block.1)? } xs.apply(&self.down_blocks_conv)? - .apply(&self.down_blocks_bn) + .apply_t(&self.down_blocks_bn, false) } pub fn decode(&self, xs: &Tensor) -> Result { diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index d49cf55f..f1d7ea20 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -107,8 +107,7 @@ impl ConvBlock { impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result { - let xs = self.conv.forward(xs)?; - let xs = self.bn.forward(&xs)?; + let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?; candle_nn::ops::silu(&xs) } }